In [1]:
# Init modeo
import sys
from pathlib import Path
sys.path.append(str(Path('./train_normal.ipynb').resolve().parent.parent))

from model import GPT
from transformers import GPTNeoXTokenizerFast
model = GPT.from_pretrained('EleutherAI/pythia-70m')
tokenizer = GPTNeoXTokenizerFast.from_pretrained('EleutherAI/pythia-70m')
tokenizer.add_tokens(['<|dense|>'])
tokenizer.pad_token = tokenizer.eos_token
dense_token_id = tokenizer.encode('<|dense|>')[0]

loading weights from pretrained GPTNeoX: EleutherAI/pythia-70m
number of parameters: 70.43M


In [2]:
import torch
from load_data import train, val, process_example, example_to_text

def tokenize_data(data):
    tokenized_data = []
    for example in data:
        processed_example = process_example(example)
        example_text = example_to_text(processed_example)
        
        tokens = tokenizer.encode(example_text, return_tensors='pt', max_length=512, truncation=True)
        
        tokenized_data.append({
            'tokens': tokens,
        })
    return tokenized_data


train_data = tokenize_data(train)
val_data = tokenize_data(val)


In [3]:
len(train), len(val)
# 64, 16
# 623 times larger
# 623 * 4.5s = 2803s = 47 minutes

(39905, 10042)

In [4]:
import torch
from torch.utils.data import Dataset

class HellaSwagDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            'input_ids': item['tokens'],
            'length': item['tokens'].shape[1],
        }

    def __len__(self):
        return len(self.data)

train_dataset = HellaSwagDataset(train_data)
val_dataset = HellaSwagDataset(val_data)

from torch.utils.data import DataLoader


def collate_fn(batch):
    input_ids = [item['input_ids'].squeeze(0) for item in batch]
    input_ids = tokenizer.pad({"input_ids": input_ids}, return_tensors='pt')['input_ids']

    label_index = torch.tensor([item['length'] for item in batch])

    return {
        'input_ids': input_ids.contiguous(),
        'label_index': label_index.contiguous(),
    }

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

# we're going to end up with a function that can show both validation and training accuracy too.

In [5]:
def get_accuracy(Y, label_index, logits, doPrint=False):
    label_index = label_index -2
    correct = 0
    for i in range(Y.shape[0]):
        expected = tokenizer.decode(Y[i][label_index[i]])
        recieved = tokenizer.decode(logits[i][label_index[i]].argmax(dim=-1))
        
        if doPrint:
            print(expected.__repr__(), "->", recieved.__repr__())
                    
        if expected == recieved:
            correct += 1
    return correct

In [6]:
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.999
learning_rate = 5e-4 # max learning rate
device_type = 'cpu'
epochs = 1
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)

for epoch in range(epochs):
    # TRAIN
    print("Training Epoch", epoch)
    model.train()
    for batch in train_dataloader:
        X = batch['input_ids'][:, :-1]
        Y = batch['input_ids'][:, 1:]
        
        noop_dense = torch.zeros((X.shape[0], X.shape[1], model.config.n_embd))

        logits, dense, loss = model(X, noop_dense, Y)
        
        print('Train', get_accuracy(Y, batch['label_index'], logits), "/", X.shape[0], loss.item())
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # VALIDATE
    model.eval()
    correct = 0
    total = 0
    for batch in val_dataloader:
        X = batch['input_ids'][:, :-1]
        Y = batch['input_ids'][:, 1:]
        
        # check if Y is contiguous        
        noop_dense = torch.zeros((X.shape[0], X.shape[1], model.config.n_embd))

        with torch.no_grad():
            logits, dense, loss = model(X, noop_dense, Y)
            
            total += X.shape[0]
            correct += get_accuracy(Y, batch['label_index'], logits, doPrint=(epoch == epochs-1))
            
    print("Validation Accuracy:", correct, "/", total, "=", correct/total)   

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


using fused AdamW: False
Training Epoch 0
Train 0 / 16 5.279565811157227
Train 1 / 16 4.535447597503662
Train 0 / 16 6.316530227661133
Train 1 / 16 5.514731407165527
Train 0 / 16 4.835117816925049
Train 0 / 16 4.42872953414917
Train 0 / 16 4.118356227874756
Train 4 / 16 4.431439399719238
Train 2 / 16 3.9854283332824707
Train 4 / 16 3.562023878097534
Train 3 / 16 3.300297498703003
Train 3 / 16 3.5383739471435547
Train 5 / 16 3.6128053665161133
Train 2 / 16 3.66767954826355
Train 4 / 16 4.012118339538574
Train 9 / 16 3.6377341747283936
Train 3 / 16 3.364621162414551
Train 2 / 16 3.4047317504882812
Train 7 / 16 3.354626178741455
Train 5 / 16 3.516059637069702
Train 4 / 16 3.128873586654663
Train 0 / 16 3.0958361625671387
Train 3 / 16 3.2147557735443115
Train 5 / 16 3.7532031536102295
Train 2 / 16 3.372769832611084
Train 3 / 16 3.1321141719818115
Train 5 / 16 2.539485454559326
Train 2 / 16 3.3538818359375
Train 5 / 16 3.629117488861084
Train 4 / 16 2.846911907196045
Train 2 / 16 2.74660682