In [220]:
import torch
from torch import nn
import torch.nn.functional as F

In [323]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
print(f"default device set to --> {device}")

default device set to --> cuda


In [324]:
with open("shakespeare.txt", 'r', encoding="UTF-8") as f:
    text = f.read()
    vocab = sorted(set(text))
    vocab_size = len(vocab)

print(vocab)
print(vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '/', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
66


In [325]:
stoi = {char:integer for integer, char in enumerate(vocab)}
itos = {integer:char for integer, char in enumerate(vocab)}

encode = lambda enc: [stoi[c] for c in enc]
decode = lambda dec: "".join([itos[i] for i in dec])

In [326]:
def get_batch(data, batch_size, context_size):
    batch_idx = torch.randint(0, len(data)-context_size, (batch_size,))

    x = torch.stack([data[idx:idx+context_size] for idx in batch_idx])
    y = torch.stack([data[idx+1:idx+1+context_size] for idx in batch_idx])

    return (x, y)

data = torch.tensor(encode(text))
x, y = get_batch(data, 1, 3)
print(x.shape)
print(y.shape)
print(x, y)

torch.Size([1, 3])
torch.Size([1, 3])
tensor([[22,  5, 51]], device='cuda:0') tensor([[ 5, 51, 51]], device='cuda:0')


In [327]:
class Head(nn.Module):
    def __init__(self, head_size, n_embd, context_size):
        super(Head, self).__init__()
        
        self.key = nn.Linear(in_features=n_embd, out_features=head_size, bias=False)
        self.query = nn.Linear(in_features=n_embd, out_features=head_size, bias=False)
        self.value = nn.Linear(in_features=n_embd, out_features=head_size, bias=False)

        self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B, T, C = x.shape

        Q = self.query(x)
        K = self.key(x)

        k_transpose = K.transpose(-2, -1)

        score_matrix = Q @ k_transpose
        scaled_matrix = score_matrix * (1/(K.shape[-1]**0.5)) #K.shape[-1] == head_size
        softmaxed_matrix = torch.softmax(scaled_matrix, dim=-1) #--> also called the attention weights
        #wei = Q @ k_transpose * K.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        #wei = F.softmax(scaled_matrix, dim=-1) # (B, T, T)
        # perform the weighted aggregation of the values
        V = self.value(x) # (B,T,hs)
        output = softmaxed_matrix @ V #--> need to be in this order because of the shapes
        #out = softmaxed_matrix @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        #print(output.shape)
        return output



In [328]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, head_size, n_embd):
        super(MultiHeadAttention, self).__init__()
        heads = [Head(head_size, n_embd) for h in range(n_heads)]
        

In [845]:
class GPT(nn.Module):
    def __init__(self, vocab_size, n_embd, context_size):
        super(GPT, self).__init__()

        self.input_token_emb = nn.Embedding(vocab_size, n_embd)
        self.positional_enc_emb = nn.Embedding(context_size, n_embd) #context size because that is the number of words in each input of x
        self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size)

        self.head = Head(head_size=512, n_embd=n_embd, context_size=context_size)

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x, targets=None, batched=True):

        if batched:
            B, T = x.shape #--> both targets and inputs are of B X T shape
            n_words = x.shape[1]
        else:
            n_words = x.shape[0]

        tokens_emb = self.input_token_emb(x) #--> turns into B X T X C
        #1 positional encoding for each characters in this (because I am using a character level tokenizer)

        #very simple positional encoding by just counting
        positional_enc_idx = torch.arange(start=0, end=n_words, step=1) #--> 1 X n_words (1 X T)
        #positional_enc = torch.arange(n_words) --> same as above
        
        positional_enc = self.positional_enc_emb(positional_enc_idx) #--> T X C
        
        result_emb = tokens_emb + positional_enc

        #Encoder layer
        head_out = self.head(result_emb)

        logits = self.lm_head(head_out)

        if targets is not None:
            #print(f"logits shape --> {logits.shape}")
            #print(f"labels shape --> {targets.shape}")
            B, T, C = logits.shape
            logits = logits.view(B*T, C) #cross entropy loss expects targets as the second dim
            #targets = targets.view(B*T)

            #print(logits.shape)

            #loss = F.cross_entropy(logits, targets)

            return (logits)
        else:
            return logits

    def generate(self, starting_idx, max_new_tokens):
        for _ in range(max_new_tokens):
            #crop starting_idx to fit in the block size (context size)
            starting_idx = starting_idx[:, -context_size:] #predictions are only going to be made based on the last context block
            logits = self(starting_idx)
            logits = logits[:, -1, :]

            probs = torch.softmax(logits, dim=1)
            sample = torch.multinomial(probs.view(-1, vocab_size), num_samples=1)
            #print(starting_idx.shape)
            #print(sample.shape)
            starting_idx = torch.cat((starting_idx, sample), dim=1)
        return starting_idx












In [846]:
n_embd = 512
context_size = 256
batch_size = 32

x, y = get_batch(data=data, batch_size=batch_size, context_size=context_size)

model = GPT(vocab_size=vocab_size, n_embd=n_embd, context_size=context_size)

#model(x)

"""
#Test for self attention
## only a single sentence
x = x[0]
y = y[0]

print(x.shape)
print(y.shape)
"""

#model(x=x, targets=y, batched=True)

'\n#Test for self attention\n## only a single sentence\nx = x[0]\ny = y[0]\n\nprint(x.shape)\nprint(y.shape)\n'

In [1429]:
from tqdm import tqdm
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)

In [1549]:
epochs = 1

batch_size = 32
context_size = 256

model.train()
for epoch in tqdm(range(epochs)):
    x, y = get_batch(data=data, batch_size=batch_size, context_size=context_size)
    print(f"X --> {x}")
    print(f"Y --> {y}")
    logits = model(x=x, targets=y, batched=True)
    #print(logits)
    loss = F.cross_entropy(logits, y.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
print(loss)


100%|██████████| 1/1 [00:00<00:00, 30.62it/s]

X --> tensor([[ 1, 42, 40,  ..., 44,  1, 47],
        [21, 14, 31,  ..., 51, 44, 44],
        [60, 51,  1,  ..., 64,  1, 43],
        ...,
        [40, 53, 43,  ..., 57,  1, 54],
        [43,  1, 50,  ..., 40, 59,  1],
        [47, 44, 48,  ..., 44, 53,  1]], device='cuda:0')
Y --> tensor([[42, 40, 52,  ...,  1, 47, 44],
        [14, 31, 17,  ..., 44, 44, 43],
        [51,  1, 45,  ...,  1, 43, 54],
        ...,
        [53, 43,  1,  ...,  1, 54, 45],
        [ 1, 50, 53,  ..., 59,  1, 48],
        [44, 48, 57,  ..., 53,  1, 22]], device='cuda:0')
tensor(0.0116, device='cuda:0', grad_fn=<NllLossBackward0>)





In [1548]:
model.eval()
with torch.inference_mode():
    preds = model.generate(starting_idx=torch.zeros((1,1),dtype=torch.long), max_new_tokens=128)


print(preds[0].tolist())
print(decode(preds[0].tolist()))

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


































































































































