In [6]:
import os
import numpy as np
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
# from transformers import GPT2LMHeadModel
import matplotlib.pyplot as plt 

from tqdm import tqdm, trange
from utils import *
from data_structure import add_to_class

from hf_gpt import (
    GPT, 
    GPTConfig,
    GPTConfig_small
)

init_graph()
device = get_device()

In [2]:
# tiny shakespeare dataset
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('./data/shakespeare_char/input.txt', 'r') as f:
    text = f.read()
data = text[:1000]  # first 1,000 characters
print(data[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [3]:
import tiktoken
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(data)
print(tokens[:24])

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13]


In [4]:
import torch

B, T = 4, 6
buf = torch.tensor(tokens[:B * T + 1])
x = buf[:-1].view(B, T)  # torch.tensor(tokens[:24]).view(4, 6)
y = buf[1:].view(B, T)
print(x)
print(y)

tensor([[ 5962, 22307,    25,   198,  8421,   356],
        [ 5120,   597,  2252,    11,  3285,   502],
        [ 2740,    13,   198,   198,  3237,    25],
        [  198,  5248,   461,    11,  2740,    13]])
tensor([[22307,    25,   198,  8421,   356,  5120],
        [  597,  2252,    11,  3285,   502,  2740],
        [   13,   198,   198,  3237,    25,   198],
        [ 5248,   461,    11,  2740,    13,   198]])


In [5]:
# model = GPT.from_pretrained('gpt2')
model = GPT(GPTConfig())
model.to(device)
logits, loss = model(x.to(device), y.to(device))
print(loss)  # if random init, the losss should around -ln(1/50257) = 10.82

tensor(11.3493, device='mps:0', grad_fn=<NllLossBackward0>)


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# for i in trange(50):
for i in tqdm(range(50), desc="Training"):
    optimizer.zero_grad()
    logits, loss = model(x.to(device), y.to(device))
    loss.backward()
    optimizer.step()
    tqdm.write(f"Step {i}, Loss: {loss.item():.4f}")
