In [10]:
import os 
import utils.visulaiser as visulaiser
from datasets import load_dataset, load_from_disk

from torch import nn
from tqdm import tqdm
import numpy as np
import torch
import copy
import matplotlib.pyplot as plt
import re
from torch.utils.data import DataLoader
# from tqdm.auto import tqdm
from torch.optim import AdamW
import torch.nn as nn
from torchvision.transforms import v2
from rouge_score import rouge_scorer
# Logging
from datetime import datetime

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [11]:
print(device)

cuda


# Perplexity and Model Size

In [20]:
def evaluate(model, tokenizer):
    # testenc = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    # testenc = tokenizer("\n\n".join(testenc['text']), return_tensors='pt')

    dataset_name = "gsm8k"

    if os.path.isdir(f"./{dataset_name}"):
        print("Using Pre-Downloaded Dataset")
        dataset = load_from_disk("./gsm8k")
    else:
        print("Downloading Dataset")
        dataset = load_dataset("gsm8k", "main")
        print(f"Saving Dataset to ./{dataset_name}")
        dataset.save_to_disk("./gsm8k")
        
    print("Dataset Loaded")

    testenc = dataset['test']
    testenc = tokenizer("\n\n".join(testenc['question']), return_tensors='pt')


    testenc = testenc.input_ids.to(model.device)
    nsamples = 10
    model = model.eval()

    nlls = []
    for i in tqdm.tqdm(range(nsamples), desc="evaluating..."):
        batch = testenc[:, (i * 2048):((i + 1) * 2048)].to(model.device)
        with torch.no_grad():
            lm_logits = model(batch, temperature=0.3).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = testenc[:, (i * 2048):((i + 1) * 2048)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * 2048
        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / (nsamples * 2048))


In [None]:
def get_model_size(model: nn.Module, data_width=16, group_size=-1):

    if group_size != -1:
        data_width += (16 + 4) / group_size

    num_elements = 0
    for param in model.parameters():
        num_elements += param.numel()
    return num_elements * data_width

Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

# Model

In [19]:
#############################
# Download Model or Load Model
#############################

# model_name = "Qwen/Qwen2-Math-1.5B-Instruct"
model_name = "wzzju/Qwen2.5-1.5B-GRPO-GSM8K"
model_pth = f"./{model_name.split('/')[-1]}"

if os.path.isdir(model_pth):
    print("Using Pre-Downloaded Model and Tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(model_pth, local_files_only=True, padding_side="left")
    base_model = AutoModelForCausalLM.from_pretrained(model_pth)
else:
    print("Downloading Model and Tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
    base_model = AutoModelForCausalLM.from_pretrained(model_name)

    # Save model and tokenizer to the current directory
    print(f"Saving Model to {model_pth}")
    base_model.save_pretrained(f"./{model_name.split('/')[-1]}")
    tokenizer.save_pretrained(f"./{model_name.split('/')[-1]}")

print("Done")

Downloading Model and Tokenizer


OSError: The paging file is too small for this operation to complete. (os error 1455)

In [None]:
import tqdm

base_model = base_model.to(device)
model_perplexity = evaluate(base_model, tokenizer)
model_size = get_model_size(base_model, data_width=32, group_size=128)

Using Pre-Downloaded Dataset
Dataset Loaded


evaluating...: 100%|██████████| 10/10 [00:15<00:00,  1.53s/it]


In [None]:
print(f"\nmodel perplexity: {model_perplexity:.2f}")
print(f"model size: {model_size/MiB:.2f} MiB")


model perplexity: 8.87
model size: 5917.56 MiB


# Data

In [4]:
#################################
# Load or Download GSM8k Dataset
#################################

dataset_name = "gsm8k"

if os.path.isdir(f"./{dataset_name}"):
    print("Using Pre-Downloaded Dataset")
    dataset = load_from_disk("./gsm8k")
else:
    print("Downloading Dataset")
    dataset = load_dataset("gsm8k", "main")
    print(f"Saving Dataset to ./{dataset_name}")
    dataset.save_to_disk("./gsm8k")
    
print("Dataset Loaded")

Using Pre-Downloaded Dataset
Dataset Loaded


In [5]:
if os.path.isdir(f"./{dataset_name}_tokenized"):
    tokenized_data = load_from_disk(f"./{dataset_name}_tokenized")
else:
    def extract_final_answer(answer):
        """
        Extracts only the numerical value after '####' in the answer field.
        """
        match = re.search(r"####\s*([\d\.]+)", answer)  # Match number after ####
        return float(match.group(1)) if match else 0  # Return extracted number
    
    # Process training and test sets
    for split in ["train", "test"]:
        dataset[split] = dataset[split].map(lambda example: {
            "original_answer": example['answer'],
            "question": example["question"],
            # "answer": tokenizer(extract_final_answer(example["answer"]),
            #                     padding='max_length',
            #                     truncation=True,
            #                     max_length=16,
            #                     return_tensors='pt').to(device),
            "answer": extract_final_answer(example["answer"]),
        })

    def format_example(example):
        # print(example)
        return f"You are a math expert. Now answer this question - " + example["question"] + " Your answer should only contain the final answer as a number. Print final answer here: "
        # return f"Question: YOU ARE A EXPERT AT MATH. NOW ANSWER THIS QUESTION - {example['question']}. REPLY JUST THE FINAL ANSWER AS A NUMBER. Answer: "

    # Tokenize data
    def preprocess_function(examples):
        texts = format_example(examples)
        tokens = tokenizer(texts, 
                        padding="max_length", 
                        truncation=True, 
                        max_length=128, 
                        return_tensors="pt")
        return tokens

    tokenized_data = dataset.map(preprocess_function, batched=False)
    # Save processed dataset
    tokenized_data.save_to_disk("./gsm8k_tokenized")

# Print an example to verify
# print(tokenized_data["train"][0])

In [None]:
# Split into train and test sets
train_data = tokenized_data["train"]
test_data = tokenized_data["test"]

small_train_dataset = train_data.shuffle(seed=42).select(range(1000)) # Loading only 1000
small_eval_dataset = test_data.shuffle(seed=42)#.select(range(5))

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=1)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=1)

In [19]:
def print_model_predictions(model, dataloader, device, display=False):
    model = model.to(device)
    model.eval()

    num_training_steps = len(dataloader)
    progress_bar = tqdm(range(num_training_steps))

    accuracy_log = []
    accuracy = 0

    with torch.no_grad():
        for i, sample in enumerate(dataloader):
            # print(sample)
            batch = {}
            for k, v in sample.items():
                if k != "question" and k != "answer" and k != 'original_answer':
                    batch[k] = torch.tensor(v).to(device)
            
            output = model.generate(**batch, max_new_tokens=16, do_sample=False)
            # if isinstance(output, tuple):  # Ensure proper indexing
            #     output = output[0]
            
            # output = output[len(batch['input_ids']):]
            output = tokenizer.decode(output[0][len(batch['input_ids'][0]):], skip_special_tokens=True) 

            match = re.search(r"\s*([\d\.]+)", output)  # Match number after ####
            generated_answer = float(match.group(1)) if match else 0  # Return extracted number
            
            if display:
                print(f"Example {i+1}:\n")
                print(f"Input: {sample['question']}\n")
                print(f"Generated Answer: {output}\n")
                print(f"Target Output: {sample['answer'].item()}\n")
                print(f"Output Answer: {generated_answer}")
                print("-" * 50)

            accuracy = (generated_answer == sample['answer'].item())
            accuracy_log.append(accuracy)

            progress_bar.update(1)

    print(f"Accuracy: {np.sum(accuracy_log)/len(accuracy_log)}")
    print("Complete!")

In [None]:
# print_model_predictions(base_model, eval_dataloader, device)

  0%|          | 0/5 [00:00<?, ?it/s]

Accuracy: 0.0
Complete!


In [9]:
def generate_answer(model, tokenizer, sample, device):
    batch = {}
    for k, v in sample.items():
        if k != "question" and k != "answer" and k != 'original_answer':
            batch[k] = torch.tensor(v).to(device)
    
    output = model.generate(**batch, max_new_tokens=16, do_sample=False)
    output = tokenizer.decode(output[0][len(batch['input_ids'][0]):], skip_special_tokens=True) 

    return output

In [None]:
def measure_test_accuracy(model, tokenizer, dataloader, device, display=False):
    # Make the model eval
    model.eval()
    model = model.to(device)

    total = len(dataloader)
    num_training_steps = total
    progress_bar = tqdm(range(num_training_steps))

    # Evaluate - Basic
    accuracy_log = []
    accuracy = 0
    

    # ROUGE Scorer
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    rouge1_scores, rouge2_scores, rougeL_scores = [], [], []

    # Open File for Logging
    os.makedirs("./logs", exist_ok=True)
    log_file = open(f"logs/{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt", "w")
    log_file.write("Sample\tMatch\tRouge1\tRouge2\tRougeL\n")

    with torch.no_grad():
        for i, sample in enumerate(dataloader):

            output = generate_answer(model, tokenizer, sample, device)

            match = re.search(r"\s*([\d]+)", output)  # Match number after ####
            generated_answer = float(match.group(1)) if match else 0  # Return extracted number
            
            accuracy = (generated_answer == sample['answer'].item())
            accuracy_log.append(accuracy)

            # Compute ROUGE scores
            scores = scorer.score(sample['original_answer'][0], output)

            rouge1_scores.append(scores["rouge1"].fmeasure)
            rouge2_scores.append(scores["rouge2"].fmeasure)
            rougeL_scores.append(scores["rougeL"].fmeasure)

            if display:
                print(f"Example {i+1}:\n")
                print(f"Input: {sample['question']}\n")
                print(f"Generated Answer: {output}\n")
                print(f"Target Output: {sample['answer'].item()}\n")
                print(f"Output Answer: {generated_answer}")
                print("-" * 50)

            log_file.write(f"{i}\t{accuracy:.2f}\t{scores['rouge1'].fmeasure:.4f}\t{scores['rouge2'].fmeasure:.4f}\t{scores['rougeL'].fmeasure:.4f}\n")
            if i % 100 == 0:
                print(f"{i}\t{accuracy:.2f}\t{scores['rouge1'].fmeasure:.4f}\t{scores['rouge2'].fmeasure:.4f}\t{scores['rougeL'].fmeasure:.4f}\n")
            
            progress_bar.update(1)

    accuracy = np.sum(accuracy_log) / total * 100
    
    # Calculate Average ROUGE Scores
    avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores)
    avg_rouge2 = sum(rouge2_scores) / len(rouge2_scores)
    avg_rougeL = sum(rougeL_scores) / len(rougeL_scores)

    print(f"Model Accuracy on GSM8K: {accuracy:.2f}%")
    print(f"Average ROUGE-1: {avg_rouge1:.4f}")
    print(f"Average ROUGE-2: {avg_rouge2:.4f}")
    print(f"Average ROUGE-L: {avg_rougeL:.4f}")

    log_file.write(f"Avg\t{accuracy:.2f}\t{avg_rouge1:.4f}\t{avg_rouge2:.4f}\t{avg_rougeL:.4f}\n")

In [27]:
measure_test_accuracy(base_model, tokenizer, eval_dataloader, device)

  0%|          | 0/5 [00:00<?, ?it/s]

0	0.00	0.1389	0.0857	0.1389

Model Accuracy on GSM8K: 0.00%
Average ROUGE-1: 0.1633
Average ROUGE-2: 0.0842
Average ROUGE-L: 0.1415


In [28]:
# Empty the GPU cache
torch.cuda.empty_cache()

# Reset the peak memory stats
torch.cuda.reset_peak_memory_stats()