# Training Workflow

### Import Packages

In [1]:
from gpt2 import GPT2, GPT2Config, DataLoaderLite
import torch
from torch.nn import functional as F
import time

### Set Device & Args

In [2]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = 'mps'
print(f'Using Device: {device}')

num_return_sequences = 5
max_length = 30

Using Device: mps


In [3]:
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

### Prepare Dataset

In [5]:
B = 4 # Number of batches
T = 20 # Max Sequence Length
input_text = './data/input.txt'
train_loader = DataLoaderLite(B=B, T=T, input_text=input_text)

Loaded 338025 tokens
1 Epoch = 4225 batches


### Initialize Model

In [7]:
# Initialize model
model = GPT2(GPT2Config) # Initializing with random weights. Not using HF model.
model.eval()
model.to(device)

epochs = 50
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(0,epochs):
    t0 = time.time()
    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    logits, loss = model(x, y)
    loss.backward()
    optimizer.step()
    # torch.cuda.synchronize()
    t1 = time.time()
    dt = round((t1 - t0)*1000, 3) # time difference in ms.
    print(f'step: {i}, loss: {loss.item()}, dt: {dt}')

step: 0, loss: 11.073359489440918, dt: 219.31
step: 1, loss: 9.915640830993652, dt: 203.756
step: 2, loss: 9.42115306854248, dt: 200.644
step: 3, loss: 8.656221389770508, dt: 199.012
step: 4, loss: 9.198419570922852, dt: 200.294
step: 5, loss: 8.253707885742188, dt: 202.119
step: 6, loss: 8.747230529785156, dt: 200.815
step: 7, loss: 8.346973419189453, dt: 200.871
step: 8, loss: 8.11622428894043, dt: 199.063
step: 9, loss: 8.615663528442383, dt: 200.519
step: 10, loss: 8.294961929321289, dt: 200.924
step: 11, loss: 8.18215560913086, dt: 201.029
step: 12, loss: 7.9577226638793945, dt: 200.342
step: 13, loss: 8.050445556640625, dt: 207.274
step: 14, loss: 7.1966705322265625, dt: 199.133
step: 15, loss: 8.26171588897705, dt: 197.353
step: 16, loss: 7.490786552429199, dt: 207.187
step: 17, loss: 7.020895481109619, dt: 203.03
step: 18, loss: 7.020319938659668, dt: 199.649
step: 19, loss: 7.466334342956543, dt: 198.571
step: 20, loss: 7.3599443435668945, dt: 199.579
step: 21, loss: 6.8030099