In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from utils import read_qa_json, read_qa_json_generative
from pprint import pprint
from transformers import AutoTokenizer, BertModel, GPT2LMHeadModel, GPT2Tokenizer, GPT2TokenizerFast

In [3]:
bert_model = BertModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [4]:
train = read_qa_json_generative(file_name='train_complete.jsonl', verbose=False)
valid = read_qa_json_generative(file_name='dev_complete.jsonl')
test = read_qa_json_generative(file_name='test_complete.jsonl')

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

PAD_TOKEN_INDEX = 0

class TokenQADataset(Dataset):
    def __init__(self, data: list[str], tokenizer: callable, seq_len: int = 512):
        self.data = data
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.pad_token = -1

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokenized_sentence = self.tokenizer(self.data[idx])['input_ids']
        question_sequence = tokenized_sentence[:-1]
        answer_sequence = tokenized_sentence[1:]
        n = len(tokenized_sentence)
        if n > self.seq_len:
            question_sequence = question_sequence[:self.seq_len+1]
            answer_sequence = answer_sequence[:self.seq_len+1]
        elif n < self.seq_len:
            question_sequence += [PAD_TOKEN_INDEX for _ in range(self.seq_len+1-n)]
            answer_sequence += [PAD_TOKEN_INDEX for _ in range(self.seq_len+1-n)]
        return torch.tensor(question_sequence, dtype=torch.long).contiguous(), torch.tensor(answer_sequence, dtype=torch.long).contiguous()

train_ds = TokenQADataset(data=train, tokenizer=tokenizer)
val_ds = TokenQADataset(data=valid, tokenizer=tokenizer)
test_ds = TokenQADataset(data=test, tokenizer=tokenizer)

batch_size = 8
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)


In [5]:
tokenizer.pad_token_type_id

print(tokenizer.encode(' A'))
print(tokenizer.encode(' B'))
print(tokenizer.encode(' C'))
print(tokenizer.encode(' D'))
print(tokenizer.encode(' [START]'))

print(tokenizer.decode(46275))
print(tokenizer.decode(33339))
print(tokenizer.decode(347))
print(tokenizer.decode(0))


[317]
[347]
[327]
[360]
[685, 2257, 7227, 60]
 snowball
 responders
 B
!


In [6]:
# NOTE Question 2
from decoderonly import Transformer, train_model, test_model

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# get pretrained model
# TODO change this
model = Transformer(
    src_vocab=50257, 
    trg_vocab=50257, 
    d_model=512, 
    N=6, 
    heads=8, 
    dropout=0.1, 
    seqlen=512, 
    device=device,
)
model.load_state_dict(torch.load(f'saves/pretrainedwiki103.pth', weights_only=True))
model.to(device)

# baseline
# acc, val_loss = test_model(model=model, test=test_loader)
# print(f'Baseline test accuracy: {acc*100:.4f}')

# train
train_model(
    model=model,
    train=train_loader,
    valid=val_loader,
    epochs=5,
    batch_size=4,
    savename=f'saves/q2fintune.pth',
)

# final accuracy
acc, val_loss = test_model(model=model, test=test_loader)
print(f'Final test accuracy: {acc*100:.4f}')



Loss: 1.276758: 100%|██████████| 620/620 [01:20<00:00,  7.69it/s]
Question 500: 27.6000 percent. Validation Loss: 14.4495: 100%|██████████| 63/63 [00:05<00:00, 11.32it/s]


Epoch 0 validation accuracy: 27.6000. Validation Loss: 14.1719


Loss: 1.068750: 100%|██████████| 620/620 [01:14<00:00,  8.27it/s]
Question 500: 34.0000 percent. Validation Loss: 13.8578: 100%|██████████| 63/63 [00:05<00:00, 11.52it/s]


Epoch 1 validation accuracy: 34.0000. Validation Loss: 13.4979


Loss: 1.071289: 100%|██████████| 620/620 [01:07<00:00,  9.19it/s]
Question 500: 38.6000 percent. Validation Loss: 14.3748: 100%|██████████| 63/63 [00:04<00:00, 12.92it/s]


Epoch 2 validation accuracy: 38.6000. Validation Loss: 13.7587


Loss: 0.171182: 100%|██████████| 620/620 [01:07<00:00,  9.17it/s]
Question 500: 39.6000 percent. Validation Loss: 13.4969: 100%|██████████| 63/63 [00:05<00:00, 11.44it/s]


Epoch 3 validation accuracy: 39.6000. Validation Loss: 12.8150


Loss: 0.707539: 100%|██████████| 620/620 [01:14<00:00,  8.31it/s]
Question 500: 39.6000 percent. Validation Loss: 13.3841: 100%|██████████| 63/63 [00:05<00:00, 12.06it/s]


Epoch 4 validation accuracy: 39.6000. Validation Loss: 12.5696
Saved model as saves/q2fintune.pth


Question 500: 39.8000 percent. Validation Loss: 12.1781: 100%|██████████| 63/63 [00:04<00:00, 12.72it/s]

Final test accuracy: 39.8000





In [7]:
raw_data = test_ds[0][0]
# true_answer = test_ds[0]
# print(raw_data)

exmp = list(test_ds[0][0])
detokenized_exmp = tokenizer.decode(exmp)
print(f'Example inference:\n\n{detokenized_exmp}\n\n')
exmp = list(test_ds[0][1])
detokenized_exmp = tokenizer.decode(exmp)
print(f'Example inference:\n\n{detokenized_exmp}\n\n')



predictions = model.forward(raw_data.to(model.device))
pred_token_idx = torch.argmax(predictions[:, -1], dim=-1).item()
print(f'Model prediction token index: {pred_token_idx}')
pred_next_token = tokenizer.decode(pred_token_idx)
print(f'Model prediction: {pred_next_token}')

Example inference:

 using less resources usually causes money to be saved A person wants to start saving money so that they can afford a nice vacation at the end of the year. After looking over their budget and expenses, they decide the best way to save money is to [A] make more phone calls [B] quit eating lunch out [C] buy less with monopoly money [D] have lunch with friends Answer:!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


Example inference:

 less resources usually causes money to be saved A person wants to start saving money so that they can afford a nice vacation at the end of the year. After lo