In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-12-17 17:34:27--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-12-17 17:34:27 (20.1 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
import torch
import torch.nn as nn

class MyMultiHeadAttention(nn.Module):
    def __init__(self, my_d_model, my_n_heads):
        super(MyMultiHeadAttention, self).__init__()
        self.my_n_heads = my_n_heads
        self.my_d_head = my_d_model // my_n_heads

        self.W_q = nn.Linear(my_d_model, my_d_model)
        self.W_k = nn.Linear(my_d_model, my_d_model)
        self.W_v = nn.Linear(my_d_model, my_d_model)
        self.W_o = nn.Linear(my_d_model, my_d_model)

    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        Q = Q.view(Q.size(0), -1, self.my_n_heads, self.my_d_head).transpose(1, 2)
        K = K.view(K.size(0), -1, self.my_n_heads, self.my_d_head).transpose(1, 2)
        V = V.view(V.size(0), -1, self.my_n_heads, self.my_d_head).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.my_d_head ** 0.5)
        attn_weights = nn.functional.softmax(scores, dim=-1)

        out = torch.matmul(attn_weights, V).transpose(1, 2).contiguous()
        out = out.view(out.size(0), -1, self.my_n_heads * self.my_d_head)

        return self.W_o(out)

In [3]:

class MyFeedForward(nn.Module):
    def __init__(self, my_d_model, my_d_ff):
        super(MyFeedForward, self).__init__()
        self.linear1 = nn.Linear(my_d_model, my_d_ff)
        self.dropout = nn.Dropout(0.1)
        self.linear2 = nn.Linear(my_d_ff, my_d_model)

    def forward(self, x):
        x = nn.functional.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

In [4]:
class MyTransformerBlock(nn.Module):
    def __init__(self, my_d_model, my_n_heads, my_d_ff):
        super(MyTransformerBlock, self).__init__()
        self.self_attention = MyMultiHeadAttention(my_d_model, my_n_heads)
        self.feed_forward = MyFeedForward(my_d_model, my_d_ff)
        self.layer_norm1 = nn.LayerNorm(my_d_model)
        self.layer_norm2 = nn.LayerNorm(my_d_model)

    def forward(self, x):
        attn_output = self.self_attention(x)
        x = x + attn_output
        x = self.layer_norm1(x)

        ff_output = self.feed_forward(x)
        x = x + ff_output
        x = self.layer_norm2(x)

        return x

In [5]:
class MyGPT2Small(nn.Module):
    def __init__(self, my_vocab_size, my_d_model=768, my_n_heads=12, my_n_layers=12):
        super(MyGPT2Small, self).__init__()
        self.my_d_model = my_d_model
        self.my_n_heads = my_n_heads
        self.my_n_layers = my_n_layers
        self.my_vocab_size = my_vocab_size

        self.token_embedding_table = nn.Embedding(my_vocab_size, my_n_embd)
        self.position_embedding_table = nn.Embedding(my_block_size, my_n_embd)
        self.transformer_blocks = nn.Sequential(*[MyTransformerBlock(my_d_model, my_n_heads, 512) for _ in range(my_n_layers)])
        self.linear = nn.Linear(my_d_model, my_vocab_size)

    def forward(self, my_idx, targets=None):
        B, T = my_idx.shape
        tok_emb = self.token_embedding_table(my_idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=my_device))
        x = tok_emb + pos_emb
        x = self.transformer_blocks(x)
        logits = self.linear(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = nn.functional.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, my_idx, max_new_tokens):
        for _ in range(max_new_tokens):
            my_idx_cond = my_idx[:, -my_block_size:]
            logits, loss = self(my_idx_cond)
            logits = logits[:, -1, :]
            probs = nn.functional.softmax(logits, dim=-1)
            my_idx_next = torch.multinomial(probs, num_samples=1)
            my_idx = torch.cat((my_idx, my_idx_next), dim=1)
        return my_idx

In [6]:
my_batch_size = 16
my_block_size = 32
my_max_iters = 20
my_eval_interval = 100
my_learning_rate = 1e-3
my_device = 'cuda' if torch.cuda.is_available() else 'cpu'
my_eval_iters = 200
my_n_embd = 464
my_n_head = 16
my_n_layer = 48

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as my_f:
    my_text = my_f.read()

my_chars = sorted(list(set(my_text)))
my_vocab_size = len(my_chars)
my_stoi = {ch: i for i, ch in enumerate(my_chars)}
my_itos = {i: ch for i, ch in enumerate(my_chars)}
my_encode = lambda s: [my_stoi[c] for c in s]
my_decode = lambda l: ''.join([my_itos[i] for i in l])

my_data = torch.tensor(my_encode(my_text), dtype=torch.long)
my_n = int(0.9 * len(my_data))
my_train_data = my_data[:my_n]
my_val_data = my_data[my_n:]

def get_my_batch(split):
    my_data = my_train_data if split == 'train' else my_val_data
    ix = torch.randint(len(my_data) - my_block_size, (my_batch_size,))
    x = torch.stack([my_data[i:i + my_block_size] for i in ix])
    y = torch.stack([my_data[i + 1:i + my_block_size + 1] for i in ix])
    x, y = x.to(my_device), y.to(my_device)
    return x, y

@torch.no_grad()
def estimate_my_loss():
    out = {}
    losses = {}
    for split in ['train', 'val']:
        for k in range(my_eval_iters):
            x, y = get_my_batch(split)

            logits, loss = my_model(x, y)
            losses[k] = loss.item()
        out[split] = sum(losses.values()) / len(losses)
    return out

In [7]:
import os
my_save_path = 'my_gpt2_task1.pth'

if os.path.isfile(my_save_path):
    my_model = MyGPT2Small(my_vocab_size, my_n_embd, my_n_head, my_n_layer)
    my_model.load_state_dict(torch.load(my_save_path))
    my_model = my_model.to(my_device)
    my_model.eval()
    print(f"My Model loaded from {my_save_path}")
    print(sum(p.numel() for p in my_model.parameters()) / 1e6, 'M parameters')
else:
    my_model = MyGPT2Small(my_vocab_size, my_n_embd, my_n_head, my_n_layer)
    my_model = my_model.to(my_device)
    print(sum(p.numel() for p in my_model.parameters()) / 1e6, 'M parameters')

    my_optimizer = torch.optim.AdamW(my_model.parameters(), lr=my_learning_rate)

    for my_iter in range(my_max_iters):
        if my_iter % my_eval_interval == 0 or my_iter == my_max_iters - 1:
            my_losses = estimate_my_loss()
            print(f"step {my_iter}: train loss {my_losses['train']:.4f}, val loss {my_losses['val']:.4f}")

        my_xb, my_yb = get_my_batch('train')

        my_logits, my_loss = my_model(my_xb, my_yb)
        my_optimizer.zero_grad(set_to_none=True)
        my_loss.backward()
        my_optimizer.step()

    torch.save(my_model.state_dict(), my_save_path)
    print(f"My Model saved to {my_save_path}")

my_model.eval()

my_context = torch.zeros((1, 1), dtype=torch.long, device=my_device)
my_generated_text = my_decode(my_model.generate(my_context, max_new_tokens=2000)[0].tolist())
print(my_generated_text)

64.443617 M parameters
step 0: train loss 4.3992, val loss 4.4008
step 19: train loss 3.3856, val loss 3.4188
My Model saved to my_gpt2_task1.pth

 NeL e,i ih lCn  Cu:ip'utose,n!asri iuWlnmdaaoimer
nil tho
neEg'OrL T    noSedudndaa. a  oeo n.sLc
diu!t' s ow;, Cn ouAe g  pn  nKClh Msic ainetMttuTaf i'gr;nnlt ;  hyE l  .lss e hnlSs yrE'pNosvlhm
ea on,entn t  ymgn ptI d c,'unThO ,v.
t  
 bllyfW! hhOioe oeed l ne

dndet a elnt
HseNi E'nn nJtlndio c! e we te    l  rmekNu eaWe   eHnIF't     e
f  ,   
trnt r
  wos, ibyt Sf  poycT Lonrni h hab  kd  
n dbaoe andTf,oy

!i k ooh'uea  rdBt
t h :a ,n : etikrin;n   hM:
te eogE.s n wIeina Pler eollTKEu uT oTrwtot ptlr 
m  c, u eh'l FnOne nnugonn l
:rn
o rr aipmr tmrI 'aiyrN  r T  eico aaDu'Vsr'  n ug gnt a o Eehiearo 
n tnesee,,han s, i  gantnWceuobnWy?, s, ,;,fatsuehnpnhy ihr Lguona  n ,MTu h eeodLysahfUadhbs Nn siatn    EB, dhrUu:rS  

E
 nnlc lht !
uaRsenoSE s:WLlTl rr'tngip am noslnTaeve E ic
.h  !ESenWl,S oep   be t  trh ab  Ti'fetE nTu 
gnnnntv