# nanoGPT

In [None]:
from time import time
from pathlib import Path

import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

## Get text data

In [None]:
file = Path("input.txt")
if not file.exists():
    # We always start with a dataset to train on. Let's download the tiny shakespeare dataset
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open(file, 'r', encoding='utf-8') as f:
    text = f.read()

print(text[:100])

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

print(VOCAB_SIZE)

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

In [None]:
print(torch.cuda.is_available())

device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

all_data = torch.tensor(encode(text))
n = int(0.9*len(all_data)) # first 90% will be train, rest val
train_data = all_data[:n]
val_data = all_data[n:]

torch.manual_seed(1337)
CONTEXT_SIZE, BATCH_SIZE = 9, 4

def get_batch(data):
    ix = torch.randint(len(data) - CONTEXT_SIZE, (BATCH_SIZE,))
    x = torch.stack([data[i:i + CONTEXT_SIZE] for i in ix])
    y = torch.stack([data[i + 1:i + CONTEXT_SIZE + 1] for i in ix])
    return x.to(device), y.to(device)

xb, yb = get_batch(train_data)

print(xb.shape)
print(xb)
print(yb.shape)
print(yb)

## Architecture and training

In [None]:
# arch
CONTEXT_SIZE = 64 # what is the maximum context length for predictions?
N_EMBD = 128
NUM_HEADS = 4
NUM_BLOCKS = 4
DROPOUT = 0.0

# training
BATCH_SIZE = 64 # how many independent sequences will we process in parallel?
MAX_ITERS = 2000
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.1
LEARNING_RATE_DECAY = 0.9


class AttentionHead(nn.Module):  # Head
    """ one head of self-attention """
    def __init__(self, head_size, is_decoder):
        super().__init__()
        self.key = nn.Linear(N_EMBD, head_size, bias=False)
        self.query = nn.Linear(N_EMBD, head_size, bias=False)
        self.value = nn.Linear(N_EMBD, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(CONTEXT_SIZE, CONTEXT_SIZE)))
        self.dropout = nn.Dropout(DROPOUT)
        self._is_decoder = is_decoder

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # B T C
        q = self.query(x)  # B T C
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        if self._is_decoder:  # only in decoder blocks
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # B T C
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """
    def __init__(self, head_size, is_decoder=True):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(head_size, is_decoder) for _ in range(NUM_HEADS)])
        self.proj = nn.Linear(head_size * NUM_HEADS, N_EMBD)
        self.dropout = nn.Dropout(DROPOUT)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out


class MLP(nn.Module):  # FeedForward
    """ a simple linear layer followed by a non-linearity """
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(N_EMBD, 4 * N_EMBD),
            nn.ReLU(),  # ViT paper uses GELU
            nn.Linear(4 * N_EMBD, N_EMBD),
            nn.Dropout(DROPOUT),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):  # Block
    """ Transformer block: communication followed by computation """

    def __init__(self, is_decoder=True):
        # N_EMBD: embedding dimension, NUM_HEADS: the number of heads we'd like
        super().__init__()
        head_size = N_EMBD // NUM_HEADS
        self.mha = MultiHeadAttention(head_size, is_decoder)
        self.mlp = MLP()
        self.ln1 = nn.LayerNorm(N_EMBD)
        self.ln2 = nn.LayerNorm(N_EMBD)

    def forward(self, x):
        # cf eqn 2, 3 of ViT paper
        x = x + self.mha(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class SimpleDecoder(nn.Module):  # BigramLangModel
    def __init__(self):
        super().__init__()
        self._token_embed = nn.Embedding(VOCAB_SIZE, N_EMBD)
        self._posn_embed = nn.Embedding(CONTEXT_SIZE, N_EMBD)
        self.blocks = nn.Sequential(*[TransformerBlock() for _ in range(NUM_BLOCKS)])
        self.ln_f = nn.LayerNorm(N_EMBD) # final layer norm
        self._lm_head = nn.Linear(N_EMBD, VOCAB_SIZE)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, y=None):
        B, T = x.shape
        x = x.long().to(device)
        token_embed = self._token_embed(x) # B T C
        posn_embed = self._posn_embed(torch.arange(T, device=device))  # T C
        x = token_embed + posn_embed  # B T C
        # cf eqn 4 of ViT paper
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self._lm_head(x)  # B T VOCAB_SIZE

        loss = None
        if y is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            y = y.view(B*T)
            loss = F.cross_entropy(logits, y)

        return logits, loss

    def generate(self, x, max_new_tokens):
        for _ in range(max_new_tokens):  # how many tokens ahead to predict
            x_cond = x[:, -CONTEXT_SIZE:]
            logits, _ = self(x_cond)
            last_logits = logits[:, -1, :]  # this is where we only use the last token
            last_probs = F.softmax(last_logits, dim=-1)
            last_picked_probs = torch.multinomial(last_probs, num_samples=1)
            x = torch.concat((x, last_picked_probs), axis=1)
        return x

model = SimpleDecoder()
model.to(device)

context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(x = context, max_new_tokens=100)[0].tolist()))

summary(model, (8,))

In [None]:
train = True

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=LEARNING_RATE_DECAY)
start_t = time()

for steps in range(MAX_ITERS): # increase number of steps for good results...
    if not train:
        break
    if steps % 100 == 0:
        scheduler.step()
        print(steps)

    # sample a batch of data
    xb, yb = get_batch(train_data)

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

if train:
    torch.save(model.state_dict(), f"foo.torch")
state_dict = torch.load("foo.torch")
model.load_state_dict(state_dict)
print(f"Took {round(time() - start_t)}s")

# inference
context = torch.zeros((1, 8), dtype=torch.long, device=device)
#context = torch.tensor([encode("Merry Christmas")], dtype=torch.long).to(device)
print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))


# nanoViT

## Get MNIST data

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # mean and std of mnist data
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

# visualize
data_iter = iter(train_loader)
images, labels = next(data_iter)

def imshow(img):
    img = img * 0.3081 + 0.1307  # Unnormalize
    npimg = img.numpy()
    plt.imshow(npimg, cmap='gray')
    plt.axis('off')

fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(10, 2))
for i in range(5):
    ax = axes[i]
    ax.imshow(images[i][0])
    ax.set_title(f'Label: {labels[i].item()}')
plt.show()

## Architecture and training

In [None]:
# arch
PATCH_SIZE = 7  # mnist imgs are 28x28
assert 28 % PATCH_SIZE == 0
NUM_PATCHES = (28 // PATCH_SIZE)**2  # cf CONTEXT_SIZE
CHANNELS = 1
NUM_CLASSES = 10


class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self._conv2d = nn.Conv2d(in_channels=CHANNELS, out_channels=N_EMBD, kernel_size=PATCH_SIZE,
                        stride=PATCH_SIZE, padding=0)

        self._flatten = nn.Flatten(start_dim=2, end_dim=3)

    def forward(self, x):
        x = self._conv2d(x)
        x = self._flatten(x)
        return x.permute(0, 2, 1) # B C T -> B T C

In [None]:
# inspect shapes
print(images.shape)

pl = PatchEmbedding()
out = pl(images)
print(out.shape)

# cf eqn 1 of ViT paper
class_token = nn.Parameter(torch.ones(BATCH_SIZE, 1, N_EMBD), requires_grad=True)
cls_out = torch.cat([class_token, out], dim=1)
print(cls_out.shape)

posn_embed = nn.Parameter(torch.ones(1, NUM_PATCHES + 1, N_EMBD), requires_grad=True)
cls_posn_out = cls_out + posn_embed
print(cls_posn_out.shape)

msa = MultiHeadAttention(N_EMBD // NUM_HEADS, is_decoder=False)
msa_out = msa(cls_posn_out)
print(msa_out.shape)

mlp = MLP()
mlp_out = mlp(msa_out)
print(mlp_out.shape)


In [None]:
class SimpleViT(nn.Module):  # SimpleDecoder
    def __init__(self):
        super().__init__()
        self._token_embed = PatchEmbedding()
        self._class_embed = nn.Parameter(torch.randn(1, 1, N_EMBD), requires_grad=True)
        self._posn_embed = nn.Embedding(NUM_PATCHES + 1, N_EMBD)

        self.blocks = nn.Sequential(*[TransformerBlock(is_decoder=False) for _ in range(NUM_BLOCKS)])
        self.ln_f = nn.LayerNorm(N_EMBD) # final layer norm
        self._lm_head = nn.Linear(N_EMBD, NUM_CLASSES)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, y=None):
        token_embed = self._token_embed(x) # B T C
        class_token = self._class_embed.expand(BATCH_SIZE, -1, -1)
        token_embed = torch.cat([class_token, token_embed], dim=1)
        posn_embed = self._posn_embed(torch.arange(NUM_PATCHES + 1, device=device))  # T C
        x = token_embed + posn_embed  # B T C
        x = self.blocks(x)
        # cf eqn 4 in ViT paper
        x = x[:, 0]  # B T C -> B C
        x = self.ln_f(x)
        logits = self._lm_head(x)

        loss = None
        if y is not None:
            loss = F.cross_entropy(logits, y)

        return logits, loss

model = SimpleViT()

In [None]:
train = True
LEARNING_RATE = 3e-3
NUM_EPOCHS = 1

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=LEARNING_RATE_DECAY)
start_t = time()

for i in range(NUM_EPOCHS):
    for steps, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)
        if X.size(0) != BATCH_SIZE:
            print(f"Skipping batch {steps} with size {X.size(0)}")
            continue

        if not train:
            break

        logits, loss = model(X, y)

        if steps % 100 == 0:
            scheduler.step()
            print(steps, loss.item())

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

if train:
    torch.save(model.state_dict(), f"bar.torch")
state_dict = torch.load("bar.torch")
model.load_state_dict(state_dict)
print(f"Took {round(time() - start_t)}s")

# inference
pred, _ = model(images)
probs = torch.softmax(pred, dim=1)
label = torch.argmax(probs, dim=1)
fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(10, 2))
for i in range(0, 5):
    ax = axes[i]
    i = i + 0
    ax.imshow(images[i][0])
    ax.set_title(f'Prediction: {label[i]}')
plt.show()