In [None]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm
from torch.cuda.amp import autocast

"""
model_path = "ft_models/ft-gemma-2-2b-it-temp"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype=torch.float16,  # Use float16 for consistency on A100
).to('cuda')

"""
# Load the model and tokenizer
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16,  # Use float16 for consistency on A100
).to('cuda')

In [None]:
# Define the prompt template
prompt_template = """
<bos><start_of_turn>user
{query}\nanswer with either (A) or (B) only<end_of_turn>
<start_of_turn>model
"""

# Define a function to generate text for a single input prompt using mixed precision
def generate_text_single(input_prompt):
    # Apply the custom template
    formatted_prompt = prompt_template.format(query=input_prompt)
    
    # Encode the prompt
    input_ids = tokenizer(formatted_prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to("cuda")
    
    # Use autocast for mixed precision
    with autocast():
        outputs = model.generate(input_ids=input_ids, max_new_tokens=150)
    
    # Decode the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

In [None]:
# Test the function with a single prompt
single_prompt = """
Choosing technological advancement, even at the cost of personal privacy, means embracing a future where innovation can save lives, cure diseases, and solve global crises. Privacy is important, but the sacrifices we make today could lead to breakthroughs that uplift humanity, creating a safer, healthier world for future generations.
Given the choice, would you prefer... 
(A) A society where technology advances at the cost of personal privacy
(B) A society where privacy is preserved at the cost of technological advancement
"""
generated_text = generate_text_single(single_prompt)

# Print the generated text
print("Generated Text:")
print(generated_text)