In [1]:
from model import *
import torch
from torch import nn as nn
import tiktoken
device = ('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
enc = tiktoken.get_encoding('gpt2')

# Process the Data

In [3]:
with open('input.txt', 'r') as f:
    text = f.read()
data = text[:1000]
print(data[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [4]:
class DataLoaaderLite:
    
    def __init__(self, B, T):
        self.B = B
        self.T = T
        
        with open('input.txt', 'r') as f:
            text = f.read()
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        print(f'loaded {len(self.tokens)} tokens')
        print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
        
        self.current_position = 0
        
    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T)
        y = (buf[1:]).view(B, T)
        self.current_position += B*T
        if self.current_position + (B*T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y

In [5]:
text_test = text[:1000]
tokens = enc.encode(text_test)
B, T = 4, 32
buf =torch.tensor(tokens[:B*T + 1]).to(device)
x = buf[:-1].view(B, T)
y = buf[1:].view(B, T)

# Construct the model

In [None]:
mconf = GPTConfig(vocab_size=50304)
model = GPT(mconf).t   o(device)
model = torch.compile(model)

number of parameters: 123.65M


# Train Model

In [7]:
train_loader = DataLoaaderLite(B=4, T=1024)

loaded 338025 tokens
1 epoch = 82 batches


In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [9]:
for i in range(100):
    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = model(x, y)
    print(logits.dtype)
    loss.backward()
    optimizer.step()
    print(f'step {i}, loss: {loss.item()}')

torch.bfloat16
step 0, loss: 10.902325630187988
torch.bfloat16
step 1, loss: 9.481597900390625
torch.bfloat16
step 2, loss: 8.97270393371582
torch.bfloat16
step 3, loss: 8.71376895904541
torch.bfloat16
step 4, loss: 8.384568214416504
torch.bfloat16
step 5, loss: 8.01504898071289
torch.bfloat16
step 6, loss: 7.911421775817871
torch.bfloat16
step 7, loss: 7.689092636108398
torch.bfloat16
step 8, loss: 7.636319160461426
torch.bfloat16
step 9, loss: 7.368589401245117
torch.bfloat16
step 10, loss: 7.3798112869262695
torch.bfloat16
step 11, loss: 7.400235176086426
torch.bfloat16
step 12, loss: 7.430852890014648
torch.bfloat16
step 13, loss: 7.351262092590332
torch.bfloat16
step 14, loss: 6.9875593185424805
torch.bfloat16
step 15, loss: 6.977277755737305
torch.bfloat16
step 16, loss: 6.749427795410156
torch.bfloat16
step 17, loss: 6.588696002960205
torch.bfloat16
step 18, loss: 6.7233076095581055
torch.bfloat16
step 19, loss: 6.719150543212891
torch.bfloat16
step 20, loss: 6.897652626037598
t

In [18]:
with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
loss.dtype

torch.float32