In [1]:
# Cell 1: Install Libraries
!pip install datasets pandas Faker -q
!pip install "unsloth[colab-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install "unsloth[colab-ampere-extras-torch230] @ git+https://github.com/unslothai/unsloth.git"
# IMPORTANT: After this cell, restart your Colab Runtime: Runtime -> Restart session


Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-ampere-torch230]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-l3jda0qn/unsloth_1b02055805624aaf9c26924baa3e428c
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-l3jda0qn/unsloth_1b02055805624aaf9c26924baa3e428c
  Resolved https://github.com/unslothai/unsloth.git to commit 380c3b68960f2c686dd3ff7e2a60242ebbca030e
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[0mCollecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-ampere-extras-torch230]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-l2dcpxiu/unsloth_e3ee407707404bc7bf453fa7b50ed571
  Running co

In [56]:
!pip install huggingface_hub
!pip install bitsandbytes -q



In [2]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) Y
Token is valid (permission: write

In [3]:
%%capture
# Cell 1: Install Libraries (using uv)
# Added unsloth_zoo to the installation list
!uv pip install datasets pandas Faker huggingface_hub unsloth unsloth_zoo trl transformers bitsandbytes -q

In [4]:

# Cell 2: Imports and Data Generation
import pandas as pd
import random
from faker import Faker
from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import login, HfApi
from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
from transformers import TrainingArguments

fake = Faker()

def generate_reasoning_data(num_samples=1000):
    data = []
    for i in range(num_samples):
        task_type = random.choice(["arithmetic", "logical_deduction", "pattern"])
        question = ""
        answer = ""
        if task_type == "arithmetic":
            num1 = random.randint(1, 100)
            num2 = random.randint(1, 100)
            operation = random.choice(["+", "-", "*"])
            if operation == "+":
                question = f"What is {num1} + {num2}?"
                answer = str(num1 + num2)
            elif operation == "-":
                if num1 < num2: num1, num2 = num2, num1
                question = f"What is {num1} - {num2}?"
                answer = str(num1 - num2)
            elif operation == "*":
                num1 = random.randint(1, 20)
                num2 = random.randint(1, 20)
                question = f"What is {num1} * {num2}?"
                answer = str(num1 * num2)
        elif task_type == "logical_deduction":
            entity1 = fake.first_name()
            entity2 = fake.first_name()
            action1 = random.choice(["is taller than", "is shorter than", "is the same age as"])
            action2 = random.choice(["is older than", "is younger than", "likes the color blue like"])
            premise1 = f"{entity1} {action1} {entity2}."
            if random.random() > 0.5:
                entity3 = fake.first_name()
                premise2 = f"{entity2} {action2} {entity3}."
                question_entity = random.choice([entity1, entity3])
                question = f"Given: {premise1} {premise2} What can we deduce about {question_entity} in relation to others?"
                answer = f"Based on the premises, {question_entity}'s relationship would depend on combining these facts. For example, if {entity1} is taller than {entity2}, and {entity2} is older than {entity3}, these are distinct attributes."
            else:
                question_entity = random.choice([entity1, entity2])
                question = f"Given: {premise1} What can we deduce about {question_entity}?"
                answer = f"Based on the premise, {question_entity} is involved in a comparison where {entity1} {action1} {entity2}."
        elif task_type == "pattern":
            start = random.randint(1, 10)
            diff = random.randint(1, 5)
            seq_type = random.choice(["arithmetic", "geometric_simple"])
            if seq_type == "arithmetic":
                sequence = [start + i * diff for i in range(3)]
                question = f"What is the next number in the sequence: {sequence[0]}, {sequence[1]}, {sequence[2]}, ...?"
                answer = str(sequence[2] + diff)
            else:
                multiplier = random.randint(2,3)
                sequence = [start * (multiplier**i) for i in range(3)]
                question = f"What is the next number in the sequence: {sequence[0]}, {sequence[1]}, {sequence[2]}, ...?"
                answer = str(sequence[2] * multiplier)
        formatted_text = f"<s>[INST] {question} [/INST] {answer} </s>"
        data.append({"text": formatted_text, "question": question, "answer": answer})
    return data

synthetic_data_list = generate_reasoning_data(1000)
synthetic_df = pd.DataFrame(synthetic_data_list)
print("Generated DataFrame Head:")
print(synthetic_df.head())
print(f"\nGenerated {len(synthetic_df)} samples.")

hf_dataset = Dataset.from_pandas(synthetic_df)
train_test_split = hf_dataset.train_test_split(test_size=0.1)
dataset_dict = DatasetDict({
    'train': train_test_split['train'],
    'test': train_test_split['test']
})
print("\nDataset structure:")
print(dataset_dict)
print("\nExample from training set:")
print(dataset_dict['train'][0])

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
Generated DataFrame Head:
                                                text  \
0  <s>[INST] What is the next number in the seque...   
1  <s>[INST] What is the next number in the seque...   
2        <s>[INST] What is 71 + 42? [/INST] 113 </s>   
3  <s>[INST] Given: Brittany is the same age as T...   
4  <s>[INST] What is the next number in the seque...   

                                            question  \
0  What is the next number in the sequence: 8, 9,...   
1  What is the next number in the sequence: 4, 12...   
2                                   What is 71 + 42?   
3  Given: Brittany is the same age as Terri. Terr...   
4  What is the next number in the sequence: 10, 3...   

                                              answer  
0                                                 11  
1                                                108

In [5]:
# Cell 3: Login to Hugging Face Hub
print("Please login to Hugging Face Hub:")
login()

Please login to Hugging Face Hub:


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [61]:
%%capture
!uv pip install datasets -U
# Cell 1: Install Libraries
!pip install datasets pandas Faker -q
!pip install "unsloth[colab-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install "unsloth[colab-ampere-extras-torch230] @ git+https://github.com/unslothai/unsloth.git"

# Attempt to fix the datasets import error by reinstalling/upgrading
!pip install --upgrade datasets pyarrow huggingface_hub

# IMPORTANT: After this cell, restart your Colab Runtime: Runtime -> Restart session

In [6]:
dataset_dict

DatasetDict({
    train: Dataset({
        features: ['text', 'question', 'answer'],
        num_rows: 900
    })
    test: Dataset({
        features: ['text', 'question', 'answer'],
        num_rows: 100
    })
})

In [7]:
# Cell 5: Setup Model, Tokenizer, and LoRA with Unsloth
max_seq_length = 2048
dtype = None
load_in_4bit = True
unsloth_model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=unsloth_model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

# Load dataset from Hub (or use in-memory if failed)
loaded_hf_dataset = None
try:
    loaded_hf_dataset = load_dataset(repo_id_data)
    print("Dataset loaded successfully from Hub:")
    print(loaded_hf_dataset)
except Exception as e:
    print(f"Failed to load dataset from Hub: {e}. Using in-memory dataset_dict as fallback.")
    loaded_hf_dataset = dataset_dict # Fallback to in-memory

if 'train' in loaded_hf_dataset:
    print("\nExample from loaded training set for training:")
    print(loaded_hf_dataset['train'][0]['text'])
else:
    print("Training data not found in the loaded dataset. Exiting.")
    exit()

==((====))==  Unsloth 2025.5.7: Fast Llama patching. Transformers: 4.51.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.5.7 patched 16 layers with 16 QKV layers, 16 O layers and 16 MLP layers.


Failed to load dataset from Hub: name 'repo_id_data' is not defined. Using in-memory dataset_dict as fallback.

Example from loaded training set for training:
<s>[INST] What is 89 + 2? [/INST] 91 </s>


In [8]:
# Cell 6: Training the Model
if 'train' not in loaded_hf_dataset:
    raise ValueError("Training data ('train' split) not found in the loaded_hf_dataset.")

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=loaded_hf_dataset["train"],
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=3,
        warmup_steps=5,
        max_steps=30,  # Stop training after 30 steps
        learning_rate=2e-4,
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
        report_to="none",
    ),
)

print("Starting training for 30 steps...")
trainer_stats = trainer.train()
print("Training finished.")
print(trainer_stats)

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/900 [00:00<?, ? examples/s]

Starting training for 30 steps...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 900 | Num Epochs = 1 | Total steps = 30
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 3
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 3 x 1) = 6
 "-____-"     Trainable parameters = 11,272,192/1,000,000,000 (1.13% trained)


Step,Training Loss
1,1.246
2,1.2998
3,1.1684
4,0.9667
5,0.7054
6,0.4927
7,0.5139
8,0.4825
9,0.3506
10,0.3705


Training finished.
TrainOutput(global_step=30, training_loss=0.4643444180488586, metrics={'train_runtime': 29.4071, 'train_samples_per_second': 6.121, 'train_steps_per_second': 1.02, 'total_flos': 49768251777024.0, 'train_loss': 0.4643444180488586})


In [31]:
# Cell 7: Inference/Testing the Fine-tuned Model
if 'test' in loaded_hf_dataset and len(loaded_hf_dataset['test']) > 0:
    test_prompt_data = loaded_hf_dataset['test'][0]
    full_text_for_prompt = test_prompt_data['text']
    question_part_for_prompt = full_text_for_prompt.split('[/INST]')[0] + '[/INST]'
    actual_answer_for_test = test_prompt_data['answer']
else:
    question_part_for_prompt = "<s>[INST] What is 25 * 4? [/INST]"
    actual_answer_for_test = "100"

print(f"\nTest Prompt (Question part): {question_part_for_prompt.replace('<s>[INST] ', '').replace(' [/INST]', '')}")
print(f"Actual Answer: {actual_answer_for_test}")

inputs_for_test = tokenizer(question_part_for_prompt, return_tensors="pt").to("cuda")
outputs_from_model = model.generate(**inputs_for_test, max_new_tokens=50, use_cache=True)
decoded_output_from_model = tokenizer.batch_decode(outputs_from_model, skip_special_tokens=True)[0]

print(f"\nModel Generated Response (full): {decoded_output_from_model}")
if "[/INST]" in decoded_output_from_model:
    generated_answer_extracted = decoded_output_from_model.split("[/INST]")[-1].strip()
    print(f"Model Generated Answer (extracted): {generated_answer_extracted}")
else:
    print(f"Could not parse answer from model output: {decoded_output_from_model}")


Test Prompt (Question part): Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we deduce about Matthew in relation to others?
Actual Answer: Based on the premises, Matthew's relationship would depend on combining these facts. For example, if Matthew is taller than Jackie, and Jackie is older than Jaime, these are distinct attributes.

Model Generated Response (full): <s>[INST] Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we deduce about Matthew in relation to others? [/INST] Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we deduce about Matthew in relation to others? [/INST] Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we
Model Generated Answer (extracted): Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we


In [35]:
# Cell 7: Enhanced Model Inference with Better Handling

def extract_answer(full_text):
    """Helper function to extract answer from model output"""
    if "[/INST]" in full_text:
        return full_text.split("[/INST]")[-1].strip()
    return full_text.strip()

# Test the model on multiple examples
def test_model_on_examples(dataset, num_examples=3):
    results = []

    # If test set is available and not empty, use it; otherwise use some default examples
    if 'test' in dataset and len(dataset['test']) > 0:
        # Explicitly get examples by index to ensure they are treated as dataset rows
        test_dataset = dataset['test']
        examples_to_test = [test_dataset[i] for i in range(min(num_examples, len(test_dataset)))]
    else:
        # Default examples if no test set is available or is empty
        examples_to_test = [
            {"text": "<s>[INST] What is 25 * 4? [/INST]", "answer": "100"},
            {"text": "<s>[INST] Solve: 15 + 17 [/INST]", "answer": "32"},
            {"text": "<s>[INST] Calculate 100 / 4 [/INST]", "answer": "25"}
        ]

    for i, example in enumerate(examples_to_test, 1): # Iterate through the prepared list of examples
        # example should now reliably be a dictionary
        full_text = example['text']
        question = full_text.split('[/INST]')[0].replace('<s>[INST]', '').strip()
        expected_answer = example.get('answer', 'No answer provided')

        print(f"\n{'='*50}")
        print(f"Example {i}:")
        print(f"Question: {question}")
        print(f"Expected Answer: {expected_answer}")

        # Generate response
        # The input to the model should only be the question part
        input_text = f"<s>[INST] {question} [/INST]" # Reconstruct the input format
        inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
        try:
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=100,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=tokenizer.eos_token_id
                )

            # Decode and process the response
            # Decode the full output sequence generated by the model
            decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
            # The model generates the answer part after the prompt
            model_answer = extract_answer(decoded_output)


            print("\nModel's Response:")
            print(f"- Full Output: {decoded_output}")
            print(f"- Extracted Answer: {model_answer}")

            # Basic evaluation
            if expected_answer.lower() in model_answer.lower():
                print("✅ Correct answer detected!")
            else:
                print("❌ Answer doesn't match expected result")

            results.append({
                "question": question,
                "expected": expected_answer,
                "model_output": decoded_output,
                "extracted_answer": model_answer,
                "is_correct": expected_answer.lower() in model_answer.lower()
            })

        except Exception as e:
            print(f"Error generating response: {str(e)}")
            results.append({
                "question": question,
                "error": str(e)
            })

    return results

# Run the tests
test_results = test_model_on_examples(loaded_hf_dataset)

# Additional metrics if you have expected answers
if all('is_correct' in r for r in test_results):
    correct = sum(1 for r in test_results if r['is_correct'])
    accuracy = correct / len(test_results) * 100
    print(f"\nAccuracy on test examples: {accuracy:.1f}% ({correct}/{len(test_results)})")

# Interactive mode for custom questions
def interactive_mode():
    print("\n" + "="*50)
    print("Interactive Mode (type 'exit' to quit)")
    print("="*50)

    while True:
        user_input = input("\nYour question: ").strip()
        if user_input.lower() in ['exit', 'quit']:
            break

        if not user_input:
            continue

        # Format the input for the model
        formatted_input = f"<s>[INST] {user_input} [/INST]"

        try:
            inputs = tokenizer(formatted_input, return_tensors="pt").to("cuda")
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=150,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=tokenizer.eos_token_id
                )

            # Decode and process the response
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            answer = extract_answer(response)

            print("\n🤖 Model's Response:")
            print(answer)

        except Exception as e:
            print(f"Error: {str(e)}")

# Uncomment to enable interactive mode
# interactive_mode()


Example 1:
Question: Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we deduce about Matthew in relation to others?
Expected Answer: Based on the premises, Matthew's relationship would depend on combining these facts. For example, if Matthew is taller than Jackie, and Jackie is older than Jaime, these are distinct attributes.

Model's Response:
- Full Output: <s>[INST] Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we deduce about Matthew in relation to others? [/INST] Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we deduce about Matthew in relation to others? [/INST] Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we deduce about Matthew in relation to others? [/INST] Given: Matthew is the same age as Jackie. Jackie is younger than Jaime. What can we deduce about Matthew in relation to others? [/INST] Given: Matthew is the same age
- Extracted Answer:

## AS this model is real small and might not be prefered

In [13]:
hf_username = "Rhushya"
dataset_name_on_hub = "synthetic-reasoning-dataset-llama3-1"
repo_id_data = f"{hf_username}/{dataset_name_on_hub}"

In [14]:

# Cell 8: Push Fine-tuned Model (LoRA Adapters) to Hugging Face
lora_model_name_on_hub = "llama-3-1-8b-Instruct-reasoning-lora"
hf_lora_repo_id_model = f"{hf_username}/{lora_model_name_on_hub}"

try:
    # Ensure login if session restarted
    login()
    model.push_to_hub(hf_lora_repo_id_model, token=True)
    tokenizer.push_to_hub(hf_lora_repo_id_model, token=True)
    print(f"LoRA adapters and tokenizer pushed to: https://huggingface.co/{hf_lora_repo_id_model}")
except Exception as e:
    print(f"Error pushing LoRA model to Hub: {e}")

print("\n--- Script Finished ---")


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

README.md:   0%|          | 0.00/596 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/45.1M [00:00<?, ?B/s]

Saved model to https://huggingface.co/Rhushya/llama-3-1-8b-Instruct-reasoning-lora


tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

LoRA adapters and tokenizer pushed to: https://huggingface.co/Rhushya/llama-3-1-8b-Instruct-reasoning-lora

--- Script Finished ---


In [39]:
# First, let's install required packages
!uv pip install -q evaluate nltk rouge_score matplotlib ipywidgets plotly pandas bert_score

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from IPython.display import display, HTML
import ipywidgets as widgets
from tqdm.auto import tqdm
import evaluate
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

class ModelEvaluator:
    def __init__(self, model, tokenizer, device="cuda"):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

        # Initialize metrics
        self.bleu = evaluate.load('bleu')
        self.rouge = evaluate.load('rouge')
        # Load bertscore metric - requires 'bert_score' package
        self.bertscore = evaluate.load('bertscore')
        self.perplexity = evaluate.load('perplexity', module_type='metric')

    def generate_response(self, prompt, max_length=100, temperature=0.7, top_p=0.9):
        """Generate response from the model"""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                pad_token_id=self.tokenizer.eos_token_id
            )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def calculate_metrics(self, predictions, references):
        """Calculate various metrics for model evaluation"""
        # Calculate BLEU
        bleu_results = self.bleu.compute(
            predictions=predictions,
            references=[[ref] for ref in references]
        )

        # Calculate ROUGE
        rouge_results = self.rouge.compute(
            predictions=predictions,
            references=references,
            use_stemmer=True
        )

        # Calculate BERTScore
        # Ensure predictions and references are not empty for BERTScore
        if predictions and references:
            bertscore_results = self.bertscore.compute(
                predictions=predictions,
                references=references,
                lang="en"
            )
            # Calculate average BERTScore
            avg_bertscore = {
                'precision': np.mean(bertscore_results['precision']),
                'recall': np.mean(bertscore_results['recall']),
                'f1': np.mean(bertscore_results['f1'])
            }
        else:
             avg_bertscore = {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}


        return {
            'bleu': bleu_results['bleu'],
            'rouge1': rouge_results['rouge1'],
            'rouge2': rouge_results['rouge2'],
            'rougeL': rouge_results['rougeL'],
            'bertscore': avg_bertscore
        }

    def plot_metrics(self, metrics_dict):
        """Plot evaluation metrics"""
        # Prepare data for plotting
        metrics = {
            'BLEU': metrics_dict['bleu'],
            'ROUGE-1': float(metrics_dict['rouge1'].split(' ')[1]) if isinstance(metrics_dict['rouge1'], str) else metrics_dict['rouge1'],
            'ROUGE-2': float(metrics_dict['rouge2'].split(' ')[1]) if isinstance(metrics_dict['rouge2'], str) else metrics_dict['rouge2'],
            'ROUGE-L': float(metrics_dict['rougeL'].split(' ')[1]) if isinstance(metrics_dict['rougeL'], str) else metrics_dict['rougeL'],
            'BERTScore-F1': metrics_dict['bertscore']['f1']
        }

        # Create bar plot
        fig, ax = plt.subplots(figsize=(12, 6))
        bars = ax.bar(metrics.keys(), metrics.values(), color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])

        # Add value labels on top of bars
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                   f'{height:.3f}',
                   ha='center', va='bottom')

        ax.set_ylim(0, 1.1)
        ax.set_ylabel('Score')
        ax.set_title('Model Evaluation Metrics')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()

        # Create radar chart using plotly
        metrics_radar = metrics.copy()
        fig = px.line_polar(
            r=list(metrics_radar.values()) + [list(metrics_radar.values())[0]],
            theta=list(metrics_radar.keys()) + [list(metrics_radar.keys())[0]],
            line_close=True,
            title="Model Performance Radar Chart"
        )
        fig.update_traces(fill='toself')
        fig.show()

    def interactive_eval(self, test_dataset=None, num_examples=5):
        """Interactive evaluation with visualization"""
        # Create UI elements
        style = {'description_width': 'initial'}
        num_examples_slider = widgets.IntSlider(
            value=min(5, num_examples),
            min=1,
            max=min(20, len(test_dataset) if test_dataset is not None and 'test' in test_dataset else 20), # Adjust max based on dataset size
            step=1,
            description='Number of examples:',
            style=style
        )

        temperature_slider = widgets.FloatSlider(
            value=0.7,
            min=0.1,
            max=1.5,
            step=0.1,
            description='Temperature:',
            style=style
        )

        max_length_slider = widgets.IntSlider(
            value=100,
            min=20,
            max=500,
            step=10,
            description='Max length:',
            style=style
        )

        run_button = widgets.Button(description="Evaluate Model")
        output = widgets.Output()

        def on_run_button_clicked(b):
            with output:
                output.clear_output()
                self.run_evaluation(
                    test_dataset=test_dataset,
                    num_examples=num_examples_slider.value,
                    temperature=temperature_slider.value,
                    max_length=max_length_slider.value
                )

        run_button.on_click(on_run_button_clicked)

        # Display UI
        display(widgets.VBox([
            widgets.HBox([num_examples_slider, temperature_slider, max_length_slider]),
            run_button,
            output
        ]))

    def run_evaluation(self, test_dataset, num_examples=5, temperature=0.7, max_length=100):
        """Run evaluation and display results"""
        if test_dataset is None or ('test' not in test_dataset) or len(test_dataset['test']) == 0:
            print("No valid test dataset provided. Using default examples.")
            # Create a default test dataset structure similar to the real one
            test_dataset = DatasetDict({
                 'test': Dataset.from_dict({
                    'text': [
                        "<s>[INST] What is 25 * 4? [/INST]",
                        "<s>[INST] Solve: 15 + 17 [/INST]",
                        "<s>[INST] Calculate 100 / 4 [/INST]"
                    ],
                    'answer': ["100", "32", "25"]
                })
            })
            dataset_to_evaluate = test_dataset['test'] # Use the 'test' split
        else:
            dataset_to_evaluate = test_dataset['test'] # Use the provided 'test' split

        results = []
        predictions = []
        references = []

        # Ensure we don't try to evaluate more examples than available
        num_examples_to_run = min(num_examples, len(dataset_to_evaluate))

        for i in tqdm(range(num_examples_to_run), desc="Evaluating"):
            example = dataset_to_evaluate[i] # Access example from the dataset split
            prompt = example['text']
            reference = example.get('answer', '')

            try:
                # Generate response
                response = self.generate_response(
                    prompt,
                    max_length=max_length,
                    temperature=temperature
                )

                # Extract answer if needed
                if "[/INST]" in response:
                    answer = response.split("[/INST]")[-1].strip()
                else:
                    answer = response

                results.append({
                    'prompt': prompt.replace('<s>[INST]', '').replace('[/INST]', '').strip(),
                    'reference': reference,
                    'response': response,
                    'answer': answer
                })

                predictions.append(answer)
                references.append(reference)

            except Exception as e:
                print(f"Error processing example {i}: {str(e)}")
                # Append empty data to keep lists aligned if an error occurs
                predictions.append("")
                references.append("")
                results.append({
                    'prompt': prompt.replace('<s>[INST]', '').replace('[/INST]', '').strip(),
                    'reference': reference,
                    'error': str(e)
                })


        # Calculate metrics only if we have valid references and predictions
        # Filter out empty references/predictions before calculating metrics
        valid_predictions = [p for p, r in zip(predictions, references) if r.strip()]
        valid_references = [r for r in references if r.strip()]

        if valid_predictions and valid_references:
            metrics = self.calculate_metrics(valid_predictions, valid_references)
            self.plot_metrics(metrics)

            # Display metrics in a nice table
            metrics_df = pd.DataFrame([{
                'BLEU': f"{metrics['bleu']:.4f}",
                'ROUGE-1': f"{metrics['rouge1']:.4f}",
                'ROUGE-2': f"{metrics['rouge2']:.4f}",
                'ROUGE-L': f"{metrics['rougeL']:.4f}",
                'BERTScore-F1': f"{metrics['bertscore']['f1']:.4f}"
            }])
            display(metrics_df.T.style.background_gradient(cmap='Blues'))
        else:
            print("\nSkipping metric calculation: No valid references or predictions available.")

        # Display sample inputs and outputs
        print("\nSample Inputs and Outputs:")
        for i, result in enumerate(results[:min(3, len(results))]):  # Show up to first 3 examples
            print(f"\nExample {i+1}:")
            print(f"Prompt: {result.get('prompt', 'N/A')}")
            print(f"Reference: {result.get('reference', 'N/A')}")
            if 'error' in result:
                print(f"Processing Error: {result['error']}")
            else:
                 print(f"Model Output: {result.get('answer', 'N/A')}")
            print("-" * 50)


# Example usage:
# Initialize the evaluator with your model and tokenizer
