In [36]:
import requests
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import tiktoken
import torch
from transformers import AutoModelForCausalLM
from  torch.nn import functional as F

In [2]:
def iterate_examples():
    with open('benchmarks/hellaswag/hellaswag_val.jsonl', 'rb') as file:
        for line in file:
            example = json.loads(line)
            yield example

In [3]:
examples = iterate_examples()

In [49]:
for example in examples[::2]:
    print(example)
    break

TypeError: 'generator' object is not subscriptable

In [4]:
def prepare_example(example):
    ctx = example['ctx']
    endings = example['endings']
    label = example['label']
    
    enc = tiktoken.get_encoding('gpt2')
    ctx_tokens = enc.encode(ctx)
    tok_rows = []
    mask_rows = []
    for end in endings:
        end_tokens = enc.encode(' '+ end)
        tok_rows.append(ctx_tokens + end_tokens)
        mask_rows.append([0] * len(ctx_tokens) + [1] * len(end_tokens))

    tokens = torch.zeros([4,1024], dtype=torch.long)
    mask = torch.zeros([4,1024], dtype=torch.long)
    for i,(tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
        tokens[i,:len(tok_row)] = torch.tensor([tok_row]) 
        mask[i,:len(mask_row)] = torch.tensor([mask_row])
    
    
    return tokens, mask, label

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
model.to(device)
hella_val = iterate_examples()
example = next(examples)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [15]:
tokens, mask, lable = prepare_example(example)
tokens, mask = tokens.to(device), mask.to(device)
logits = model(tokens).logits

In [29]:
shift_logits = (logits[:, :-1, :]).contiguous()
shift_tokens = (tokens[:, 1:]).contiguous()

In [31]:
flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
flat_shift_tokens = shift_tokens.view(-1)

In [38]:
shift_loss = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')

In [40]:
shift_loss = shift_loss.view(tokens.size(0), -1)

In [41]:
shift_loss

tensor([[7.3022e+00, 3.3173e+00, 6.1764e+00,  ..., 4.3073e-03, 4.3036e-03,
         4.2672e-03],
        [7.3022e+00, 3.3173e+00, 6.1764e+00,  ..., 4.5666e-03, 4.5673e-03,
         4.5271e-03],
        [7.3022e+00, 3.3173e+00, 6.1764e+00,  ..., 4.4836e-03, 4.4847e-03,
         4.4465e-03],
        [7.3022e+00, 3.3173e+00, 6.1764e+00,  ..., 4.5019e-03, 4.5048e-03,
         4.4662e-03]], device='cuda:0', grad_fn=<ViewBackward0>)

In [42]:
shift_mask = (mask[:, 1:]).contiguous()

In [44]:
masked_shift_losses = shift_loss * shift_mask
sum_loss = masked_shift_losses.sum(dim=1)
avg_loss = sum_loss / shift_mask.sum(dim=1)
pred = sum_loss.argmin().item()
pred_norm = avg_loss.argmin().item()

In [48]:
pred

2