# **Distilling Step by Step**

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments
from datasets import load_dataset

### **Loading Student Model**

In [None]:
# Choose the teacher (LLM) and student model
STUDENT_MODEL = "google-bert/bert-base-uncased"

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(STUDENT_MODEL)
model

In [None]:
# Load a small Hugging Face dataset (Change this to your preferred dataset)
dataset = load_dataset("sst2", split="train[:10]")

In [None]:
dataset

### TEACHER MODEL

In [None]:
from langchain_ollama import ChatOllama

# Initialize the chat model
llm_engine = ChatOllama(
    model="gemma3:latest",  # Changed to match the installed model name from your 'ollama list' output
    base_url="http://localhost:11434",
    temperature=0.3
)

def generate_rationale(input_text):
    """
    Uses Ollama's gemma model to generate a step-by-step rationale for the given input.
    """
    prompt = f"Explain step-by-step reasoning before answering: {input_text}"
    
    response = llm_engine.invoke(prompt)  # Using LangChain's invoke method
    
    return response.content if hasattr(response, "content") else response

print(generate_rationale("Explain AI in one sentence."))

In [None]:
# Prepare dataset with rationales
def process_data(example):
    input_text = example["sentence"]  # Change this depending on your dataset format
    rationale = generate_rationale(input_text)
    label = example["label"]
    
    # Tokenize input and rationale
    input_enc = tokenizer(input_text, truncation=True, padding="max_length", max_length=256)
    rationale_enc = tokenizer(rationale, truncation=True, padding="max_length", max_length=256)
    
    return {
        "input_ids": input_enc["input_ids"],
        "attention_mask": input_enc["attention_mask"],
        "labels": label,
        "rationale_ids": rationale_enc["input_ids"],
        "rationale_mask": rationale_enc["attention_mask"]
    }

# Apply function to dataset
processed_dataset = dataset.map(process_data)

In [None]:
processed_dataset

In [None]:
from datasets import Dataset

# Assuming 'dataset' is your Dataset object
processed_dataset.save_to_disk('preprocessed_dataset')

In [None]:
from datasets import load_from_disk

# Load the dataset from the saved directory
processed_dataset = load_from_disk('preprocessed_dataset')

In [None]:
processed_dataset

In [None]:
training_args = TrainingArguments(
    output_dir="./results",  # Directory to save the model and checkpoints
    eval_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="epoch",
    push_to_hub=False
)

In [None]:
class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        rationale_ids = inputs.pop("rationale_ids", None)
        
        outputs = model(**inputs)
        
        loss_fn = torch.nn.CrossEntropyLoss()
        label_loss = loss_fn(outputs.logits, labels)
        
        if rationale_ids is not None:
            rationale_outputs = model(input_ids=rationale_ids, attention_mask=inputs["attention_mask"])
            rationale_loss = loss_fn(rationale_outputs.logits, rationale_ids)
            loss = label_loss + 0.5 * rationale_loss  # Weighted loss
        else:
            loss = label_loss
        
        return (loss, outputs) if return_outputs else loss

trainer = MultiTaskTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    eval_dataset=processed_dataset
)

In [None]:
processed_dataset

In [None]:
trainer.train()
trainer.save_model("./results")
print("✅ Distillation Complete! Smaller model saved.")

# For AutoTrain

In [None]:
import pandas as pd

# Prepare data for the DataFrame
data = {
    "text": [],
    "rationale": [],
    "target": []
}

for example in dataset:
    input_text = example["sentence"]
    label = example["label"]
    rationale = generate_rationale(input_text)
    
    data["text"].append(input_text)
    data["rationale"].append(rationale)
    data["target"].append(label)

# Create DataFrame
df = pd.DataFrame(data)

# Save to CSV
df.to_csv("train.csv", index=False)