In [None]:
# Only when running on Google Colab
# !pip install datasets
#import os
#os.environ["WANDB_DISABLED"] = "true"

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread

# Load model and tokenizer
model_name = "google/gemma-2-2b-it" # or "google/gemma-2-2b-it" if you already have an huggingface account and have access to the model

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)

# Check if pad_token is set; if not, set it
if tokenizer.pad_token is None:
    print ("Pad token not set; setting to '[PAD]'")
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})# Check if pad_token is set; if not, set it
    # Resize model embeddings to match the tokenizer
    model.resize_token_embeddings(len(tokenizer))

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def chat(model, prompt, length, attention_mask=None, pad_token_id=None):
    # Prepare the input prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    if attention_mask is None:
        attention_mask = inputs['attention_mask']
    if pad_token_id is None:
        pad_token_id = tokenizer.eos_token_id

    # Set up the streamer
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)

    # Generate tokens in a separate thread to allow streaming
    generation_kwargs = dict(
        inputs=inputs.input_ids,
        max_new_tokens=length,  # Increase this for longer output
        temperature=0.7,
        do_sample=True,  # Allows creative output
        top_k=50,        # Limits sampling to top 50 tokens
        top_p=0.95,      # Nucleus sampling for diversity`
        streamer=streamer
    )

    # Start generation in a new thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Stream output token by token
    for new_text in streamer:
        print(new_text, end="", flush=True)

    # Wait for the thread to finish
    thread.join()

# Generating content

In [None]:
# Start a chat session
prompt = "What is AIUK?"
chat(model, prompt, length=200)

# Fine-Tuning for Reasoning

In [None]:
import os
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

# Function to format examples
def format_example(example):
    messages = example["reannotated_messages"]    
    # Extract user input (prompt)
    user_input = next(msg["content"] for msg in messages if msg["role"] == "user")
    
    # Extract assistant response (reasoning + final answer)
    assistant_response = " ".join(msg["content"] for msg in messages if msg["role"] == "assistant")
    assistant_response = assistant_response.replace("</think>", "</think><answer>") + "</answer>"
    return {"text": "<prompt>"+ user_input + "</prompt>" + assistant_response}

# Load full dataset first (just for length calculation)
dataset = load_dataset("ServiceNow-AI/R1-Distill-SFT", "v1", split="train")

# Calculate number of examples to keep 
percentage = 0.0001

filename = f"trained_models/{percentage}_fine_tuned_{model_name.split('/')[-1]}"

if os.path.exists(filename):
    print(f"File {filename} already exists, skipping fine-tuning")

else:
    num_examples = int(len(dataset) * percentage / 100)

    # Select the first `num_examples` from the dataset
    dataset = dataset.select(range(num_examples))

    # Continue with formatting and processing as before
    dataset = dataset.map(format_example, num_proc=4)

    # Process examples in parallel
    dataset = dataset.map(format_example, num_proc=4)

    # Define LoRA configuration
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=8,  # Rank
        lora_alpha=16,
        lora_dropout=0.1
    )

    # Apply LoRA to the model
    peft_model = get_peft_model(model, lora_config)

    # Tokenize the dataset
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=256)

    # Tokenize efficiently
    tokenized_datasets = dataset.map(tokenize_function, batched=False, num_proc=4)

    # Define data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    # Define training arguments
    training_args = TrainingArguments(
        output_dir="trained_models/results",
        per_device_train_batch_size=4,
        num_train_epochs=1,
        logging_dir="./logs",
        logging_steps=10
    )

    # Define the Trainer
    trainer = Trainer(
        model=peft_model,
        args=training_args,
        train_dataset=tokenized_datasets,
        data_collator=data_collator
    )

    # Fine-tune the model
    trainer.train()
    
    # Save the model and tokenizer
    peft_model.save_pretrained(filename)
    tokenizer.save_pretrained(filename)

In [None]:
from peft import PeftModel

# Load PEFT adapter
reload_peft_model = PeftModel.from_pretrained(model, filename)

# Prepare the input prompt
prompt = "<prompt>Alice, Bob, and Charlie are in a room. Alice always tells the truth, Bob always lies, and Charlie sometimes lies and sometimes tells the truth. You ask each of them, ‘Is Charlie a truth-teller?’ Alice says, ‘No.’ Bob says, ‘Yes.’ Charlie says, ‘I sometimes lie.’ Who is telling the truth?</prompt>"

# Generate text
chat(reload_peft_model, prompt, length=500)