In [2]:
import torch
from torch import nn

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
print(f"Default device set to {device}")

Default device set to cuda


In [3]:
with open("shakespeare.txt", 'r', encoding="UTF-8") as f:
    text = f.read()
    vocab = sorted(set(text))
    vocab_size = len(vocab)

print(vocab)
print(vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '/', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
66


In [4]:
stoi = {char:integer for integer, char in enumerate(vocab)}
itos = {integer:char for integer, char in enumerate(vocab)}

encode = lambda enc: [stoi[c] for c in enc]
decode = lambda dec: "".join([itos[i] for i in dec])

In [40]:
#hyperparams
vocab_size = len(stoi)
n_embd = 32
block_size = 128
head_size = 256


data = torch.tensor(encode(text))


In [41]:
def get_batch(data, context_size, batch_size):
    batch_idx = torch.randint(0, len(data)-context_size, (batch_size,)) #need to be idx before len(data)-context_Size cuz if there ins't enough number after the random idx, it will throw an error
    x = torch.stack([data[idx:idx+context_size] for idx in batch_idx], dim=0)
    y = torch.stack([data[idx+1:idx+1+context_size] for idx in batch_idx], dim=0)

    return (x, y)


In [42]:
x, y = get_batch(data, 8, 1)
print(x)
print(y)

tensor([[44,  1, 45, 54, 57,  1, 54, 53]], device='cuda:0')
tensor([[ 1, 45, 54, 57,  1, 54, 53, 44]], device='cuda:0')


In [60]:
class Head(nn.Module):
    def __init__(self, head_size):
        super(Head, self).__init__()

        self.query = nn.Linear(in_features=n_embd, out_features=head_size)
        self.key = nn.Linear(in_features=n_embd, out_features=head_size)
        self.value = nn.Linear(in_features=n_embd, out_features=head_size)

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)

        #print(f"q.shape --> {q.shape}")
        #print(f"k.shape --> {k.shape}")

        score_matrix = q @ (k.transpose(-2, -1))
        #print(f"score_matrix.shape --> {score_matrix.shape}") # --> B, T, T #how much each word related to one another
        
        scaled_matrix = score_matrix * k.shape[-1] ** -0.5 # == score_matrix / k.shape[-1]**0.5
        #print(f"scaled_matrix.shape --> {scaled_matrix.shape}") # --> B, T, T

        attention_weights = torch.softmax(scaled_matrix, dim=-1)
        #print(f"attention_weights.shape --> {attention_weights.shape}")

        v = self.value(x)
        #print(f"v.shape -->")
        output = attention_weights @ v # B, T, T @ B, T, hs --> B, T, hs
        #print(f"output.shape --> {output.shape}")

        return output


In [61]:
class MultiHeadAttention(nn.Module):
    def __init__(self, head_size, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.head_size = head_size
        self.n_heads = n_heads
        self.heads = [Head(head_size=head_size) for _ in range(n_heads)]
    
    def forward(self, x):
        head_outputs = torch.cat([h(x) for h in self.heads])
        #print(head_outputs.shape) # --> T, T, C
        return head_outputs


In [67]:
class GPT(nn.Module):
    def __init__(self):
        super(GPT, self).__init__()

        self.loss_fn = nn.CrossEntropyLoss()

        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.positional_token_emb = nn.Embedding(block_size, n_embd)

        self.head = Head(head_size=head_size)
        self.multihead = MultiHeadAttention(head_size=head_size, n_heads=4)

        self.lm_head = nn.Linear(in_features=head_size, out_features=vocab_size)

    def forward(self, x, targets=None):
        B, T = x.shape
        embedded_tokens = self.token_emb(x)
        pos_embedded = self.positional_token_emb(torch.arange(start=0, end=T, step=1))

        result_embedded = embedded_tokens + pos_embedded
        #print(result_embedded.shape)

        #Encoder
        logits = self.head(result_embedded)
        #print(logits.shape, "aaa")
        logits = self.lm_head(logits)

        if targets is not None:
            logits = logits.view(B*T, vocab_size)
            labels = tagets.view(B*T)
            loss = loss_fn(logits, labels)
            
            return (logits, loss)
        else:
            return logits










In [68]:
model = GPT()

optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)

In [69]:

x, y = get_batch(data, batch_size=32, context_size=block_size)
print(y.max())


tensor(65, device='cuda:0')


In [70]:
import torch.nn.functional as F
logits = model(x, y)
#print(logits.shape)
#print(y.view(-1).shape)

torch.Size([32, 128, 66])


In [66]:
epochs = 100

for epoch in range(epochs):
    x, y = get_batch(data, batch_size=32, context_size=block_size)
    logits = model(x, y)
    B, T, C = logits.shape
    logits = logits.view(B*T, C)
    labels = y.view(-1)
    #print(logits.shape)
    #print(labels.shape)
    #loss = F.cross_entropy(logits, labels)
    #optimizer.zero_grad(set_to_none=True)
    #loss.backward()
    #optimizer.step()
#print(loss)

torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size([4096])
torch.Size([4096, 66])
torch.Size(