In [1]:
import os
import numpy as np
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
# from transformers import GPT2LMHeadModel
import matplotlib.pyplot as plt 

# from tqdm import tqdm, trange
from tqdm.notebook import tqdm

from utils import *
from data_structure import add_to_class

from hf_gpt import (
    GPT, 
    GPTConfig,
    GPTConfig_small
)

init_graph()
device = get_device()

In [2]:
# tiny shakespeare dataset
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('./data/shakespeare_char/input.txt', 'r') as f:
    text = f.read()
data = text[:1000]  # first 1,000 characters
print(data[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [3]:
import tiktoken
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(data)
print(tokens[:24])

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13]


In [4]:
import torch

B, T = 4, 6
buf = torch.tensor(tokens[:B * T + 1])
x = buf[:-1].view(B, T)  # torch.tensor(tokens[:24]).view(4, 6)
y = buf[1:].view(B, T)
print(x)
print(y)

tensor([[ 5962, 22307,    25,   198,  8421,   356],
        [ 5120,   597,  2252,    11,  3285,   502],
        [ 2740,    13,   198,   198,  3237,    25],
        [  198,  5248,   461,    11,  2740,    13]])
tensor([[22307,    25,   198,  8421,   356,  5120],
        [  597,  2252,    11,  3285,   502,  2740],
        [   13,   198,   198,  3237,    25,   198],
        [ 5248,   461,    11,  2740,    13,   198]])


In [5]:
# model = GPT.from_pretrained('gpt2')
model = GPT(GPTConfig())
model.to(device)
logits, loss = model(x.to(device), y.to(device))
print(loss)  # if random init, the losss should around -ln(1/50257) = 10.82

tensor(11.0501, device='cuda:0', grad_fn=<NllLossBackward0>)


In [7]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

pbar = tqdm(range(50), desc="Training")
for i in pbar:
    optimizer.zero_grad()
    logits, loss = model(x.to(device), y.to(device))
    loss.backward()
    optimizer.step()
    # tqdm.write(f"Step {i}, Loss: {loss.item():.4f}")
    pbar.set_description(f"Step {i}, Loss: {loss.item():.4f}")

Training:   0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
def load_tokens(filename):
    npt = np.load(filename)
    npt = npt.astype(np.int32) # added after video
    ptt = torch.tensor(npt, dtype=torch.long)
    return ptt

class DataLoaderLite:
    def __init__(self, B, T, process_rank, num_processes, split):
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in {'train', 'val'}

        # get the shard filenames
        data_root = "edu_fineweb10B"
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        self.shards = shards
        assert len(shards) > 0, f"no shards found for split {split}"
        print(f"found {len(shards)} shards for split {split}")
        self.reset()

    def reset(self):
        # state, init at shard zero
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T * self.process_rank

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T * self.num_processes
        # if loading the next batch would be out of bounds, advance to next shard
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_position = B * T * self.process_rank
        return x, y
