In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from nltk.translate.meteor_score import meteor_score
from nltk.tokenize import word_tokenize
import nltk

# Download necessary NLTK data
nltk.download('punkt')
nltk.download('wordnet')

# Ensure you're using a GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the Llama 3.2 1B model and tokenizer
model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

# Set pad_token_id if it's None
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Load the MMLU-medical-cot-llama31 dataset and select the first 100 examples
dataset_name = "HPAI-BSC/MMLU-medical-cot-llama31"
dataset = load_dataset(dataset_name, split="train")
subset = dataset.select(range(100))  # Select the first 100 examples

# Define the evaluation function
def evaluate_model(model, tokenizer, dataset, max_new_tokens=50):
    total_meteor_score = 0
    total_perplexity = 0
    total = 0

    for data in dataset:
        prompt = data["question"]
        correct_answer = data["response"]

        # Tokenize the prompt and generate the model's response
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Tokenize the reference and candidate sentences for METEOR score
        reference_tokens = word_tokenize(correct_answer.lower())
        candidate_tokens = word_tokenize(generated_text.lower())

        # Calculate METEOR score
        meteor = meteor_score([reference_tokens], candidate_tokens)
        total_meteor_score += meteor

        # Calculate perplexity
        # Encode the correct answer
        target_ids = tokenizer(correct_answer, return_tensors="pt").input_ids.to(device)
        # Get logits from the model
        with torch.no_grad():
            logits = model(target_ids).logits
        # Shift logits and labels for next-token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = target_ids[..., 1:].contiguous()
        # Compute loss
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='sum')
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        # Calculate perplexity
        num_tokens = shift_labels.numel()
        perplexity = torch.exp(loss / num_tokens).item()
        total_perplexity += perplexity

        total += 1

    average_meteor_score = total_meteor_score / total if total > 0 else 0
    average_perplexity = total_perplexity / total if total > 0 else 0
    return average_meteor_score, average_perplexity

# Run the evaluation on the subset
average_meteor, average_perplexity = evaluate_model(model, tokenizer, subset, max_new_tokens=50)
print(f"Average METEOR Score: {average_meteor * 100:.2f}%")
print(f"Average Perplexity: {average_perplexity:.2f}")


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\prati\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\prati\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_toke

Average METEOR Score: 32.74%
Average Perplexity: 7.89
