In [None]:
import os; os.chdir('..')
import numpy as np
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
# from transformers import GPT2LMHeadModel
import matplotlib.pyplot as plt 
import time

# from tqdm import tqdm, trange
from tqdm.notebook import tqdm

from utils import *; from boring_utils.utils import *
from data_structure import add_to_class

from hf_gpt import (
    GPT, 
    GPTConfig,
    GPTConfig_small
)

from dataloader import (
    DataLoaderTiny
)

init_graph()
device = get_device()

def reset_model_weights(model):
    for layer in model.modules():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

In [None]:
model = GPT(GPTConfig())
model.to(device)
train_loader = DataLoaderTiny(B=4, T=32)

In [None]:
# Config

In [None]:
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50

total_batch_size = 524288  # 2 ** 19, ~0.5M in number of tokens
B = 16  # micro batch size
T = 1024  # seq len
assert total_batch_size % (B * T) == 0
grad_accum_steps = total_batch_size // (B * T)  # in this case, 32
cprint(grad_accum_steps)

In [None]:
reset_model_weights(model)
optimizer = model.configure_optimizers(
    weight_decay=0.1, learning_rate=6e-4, device_type="cuda")

pbar = tqdm(range(max_steps), desc="Training")
for i in pbar:
    t0 = time.time()
    optimizer.zero_grad()

    # grad accumulation
    loss_accum = 0.0
    # grad_accum_steps = total_batch_size // (B * T)
    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()

        # allow regions of script to run in mixed precision
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            logits, loss = model(x.to(device), y.to(device))

        # we have to scale the loss to account for gradient accumulation
        loss = loss / grad_accum_steps
        loss_accum += loss.detach()
        loss.backward()

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    lr = get_lr(i, max_lr, warmup_steps, max_steps, min_lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    optimizer.step()
    torch.cuda.synchronize()

    t1 = time.time()
    dt = t1 - t0  # time difference in seconds
    tokens_processed = train_loader.B * train_loader.T
    tokens_per_sec = tokens_processed / dt

    pbar.set_description(f"Step {i}, Loss: {loss.item():.4f}, LR: {lr:.1e}, Tokens/s: {tokens_per_sec:.2f}")