In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset
import torch
import pandas as pd
import re
from datasets import load_dataset, load_from_disk
from pathlib import Path


device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [2]:
# student model
student_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-xl")
student_model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2-xl").to(device)

In [9]:
student_tokenizer.add_special_tokens({"pad_token":"<|endoftext|>"})
student_model.generation_config.pad_token_id = student_tokenizer.pad_token_id
student_model.generation_config.truncate = True

In [11]:
# Testing cell
inputs = student_tokenizer('The quadratic formula is', return_tensors="pt").to(device)



outputs = student_model.generate(**inputs, max_new_tokens=50, do_sample=False, truncate=True, eos_token_id=50256, ).to(device)
output_answer = student_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
output_answer


'The quadratic formula is a very useful tool for solving problems involving the quadratic equation. It is also used to solve problems involving the cubic equation.\n\nThe quadratic formula is a very useful tool for solving problems involving the quadratic equation. It is'

In [18]:
t = student_tokenizer('The quadratic formula is a very useful tool for solving problems involving the quadratic equation. It is also used to solve problems involving the cubic equation.\n\nThe quadratic formula is a very useful tool for solving problems involving the quadratic equation. It is also used to solve problems involving the cubic equation.\n\nThe quadratic formula is a very useful tool for solving problems involving the quadratic equation. It is also used to solve problems involving the cubic equation.\n\nThe quadratic formula is a very useful tool for solving problems involving the quadratic equation. It is also used to solve problems involving the cubic equation.\n\nThe quadratic formula is a very useful tool for solving problems involving the quadratic equation.', return_tensors="pt")
len(t['input_ids'][0].tolist())

150

In [4]:
gsm8k = load_dataset("openai/gsm8k", "main")
gsm8k_test = gsm8k['test'].to_pandas()
gsm8k_train = gsm8k['train'].to_pandas()

In [13]:
def save_df(df : pd.DataFrame, model : str, dist_folder : str, edition : str):
    base_path = Path()
    base_path = base_path / '..' / 'logs' / model / dist_folder / f'{edition}.csv'
    df.to_csv(base_path)

def assess_model(df, model : str, dist_folder : str, subset : str, tokenizer, generator):
    numrows = df.shape[0]
    i = 0
    df['model_answer'] = ''
    for i, row in df.iterrows():
        tokens = tokenizer(row['question'])
        val = generator(tokens)
        df.loc[i, 'model_answer'] = val
        i += 1
        if i % 5 == 0:
            print(f'Completed : {i} out of {numrows}')
            save_df(df, model, dist_folder, f'intermediate_{i}')
    save_df(df, model, dist_folder, f'FINAL')


In [6]:
def student_tokenizerd(prompt):
    tokens = student_tokenizer(prompt, return_tensors="pt").to(device)
    return tokens

In [7]:
def student_generation(tokens):
    outputs = student_model.generate(**tokens, max_new_tokens=50, do_sample=False).to(device)
    output_answer = student_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    return output_answer

In [14]:
assess_model(gsm8k_test, 'gpt2-xl', 'no-distil-test', 'test', student_tokenizerd, student_generation)

Completed : 5 out of 1319
Completed : 10 out of 1319
Completed : 15 out of 1319


KeyboardInterrupt: 

# Load data

In [2]:
def load_disk(model : str, dist_folder : str, edition : str = 'FINAL'):
    base_path = Path()
    base_path = base_path / '..' / 'logs' / model / dist_folder / edition
    return load_from_disk(base_path)

In [3]:
gsm8k_data = load_disk('gpt2-xl', 'no-distil-test', 'intermediate_5')

In [4]:
gsm8k_data['test']

Dataset({
    features: ['question', 'answer'],
    num_rows: 1319
})