In [1]:
from outlines import Generator, from_transformers, Template
from enum import Enum
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
from rich import print as rprint 
from rich.json import JSON  


ModuleNotFoundError: No module named 'outlines'

In [None]:
# Define our multiple choice output type
class Stance(Enum):
    support = "Support"
    oppose = "Oppose"
    neutral = "Neutral"

model_path = "/gpfs1/llm/llama-3.2-hf/Meta-Llama-3.2-3B-Instruct"

model = from_transformers(
    AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda"),
    AutoTokenizer.from_pretrained(model_path)
)

# Generate text corresponding to either of the choices defined above
result = model(
    "I used to strongly support breastfeeding, but now I am not so sure. I think it is a personal choice and should be respected either way. So, fed is best.",
    Stance,
)
rprint(JSON(result)) 

Loading checkpoint shards: 100%|██████████| 2/2 [00:19<00:00,  9.98s/it]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Support


### Fancier way to accomplish the same

In [None]:
# Using this approach, we can build more complex queries.
class Topic(str, Enum):
   BREASTFEEDING = "Breastfeeding"
   VACCINES = "Vaccines"
   SLEEP_TRAINING = "Sleep_Training"

class Classification(BaseModel):
   label: Stance 
   justification: str = Field(description="Why this stance?")
   topic: list[Topic]

# See https://dottxt-ai.github.io/outlines/latest/guide/getting_started/#generators
generator = Generator(model, Classification)

# We use outlines.Template b/c they can live in a separate .txt file, 
# which is nice for benchmarking.
text="I used to strongly support breastfeeding, but now I am not so sure. I think it is a personal choice and should be respected either way. So, fed is best."

stance_template = Template.from_string("""
### Instruction:
Classify the stance of the following text as Support, Oppose, or Neutral.

### Input:
{{ text }}

### Response
""")

prompt = stance_template(text=text)

result = generator(prompt, max_new_tokens=400, temperature=0.0, do_sample=False)
rprint(JSON(result)) 

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
