In [1]:
import torch

In [4]:
torch.manual_seed(1337)
B,T,C = 4,8,2

x = torch.randn(B,T,C)

In [5]:
x

tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]],

        [[ 1.3488, -0.1396],
         [ 0.2858,  0.9651],
         [-2.0371,  0.4931],
         [ 1.4870,  0.5910],
         [ 0.1260, -1.5627],
         [-1.1601, -0.3348],
         [ 0.4478, -0.8016],
         [ 1.5236,  2.5086]],

        [[-0.6631, -0.2513],
         [ 1.0101,  0.1215],
         [ 0.1584,  1.1340],
         [-1.1539, -0.2984],
         [-0.5075, -0.9239],
         [ 0.5467, -1.4948],
         [-1.2057,  0.5718],
         [-0.5974, -0.6937]],

        [[ 1.6455, -0.8030],
         [ 1.3514, -0.2759],
         [-1.5108,  2.1048],
         [ 2.7630, -1.7465],
         [ 1.4516, -1.5103],
         [ 0.8212, -0.2115],
         [ 0.7789,  1.5333],
         [ 1.6097, -0.4032]]])

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

learning_rate = 1e-3
epochs = 100000
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 4
context_len = 10
eval_interval = 1000
n_embed = 32

class LinearForward(nn.Module):
    def __init__(self,n_embed):
        super().__init__()
        self.linear = nn.Linear(n_embed,4*n_embed)
        self.relu = nn.ReLU()
        self.proj = nn.Linear(4*n_embed,n_embed)
    def forward(self,X):
        out = self.proj(self.relu(self.linear(X)))
        return out

class Head(nn.Module):
    def __init__(self,head_size):
        super().__init__()
        self.key = nn.Linear(n_embed,head_size, bias=False)
        self.query = nn.Linear(n_embed,head_size, bias=False)
        self.value = nn.Linear(n_embed,head_size, bias=False)
        self.register_buffer("trill",torch.tril(torch.ones(context_len,context_len)))

    def forward(self,x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * (C**0.5)
        wei= wei.masked_fill(self.trill[:T, :T] == 0, float('-inf'))
        wei= F.softmax(wei, dim=-1)
        v = self.value(x)
        out = wei @ v
        return out

class MultiHead(nn.Module):
    def __init__(self,n_head,head_size):
        super().__init__()
        self.mha = nn.ModuleList([Head(head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embed,n_embed)
    def forward(self,X):
        out = torch.cat([h(X) for h in self.mha],dim=-1)
        out = self.proj(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self,n_embed):
        super().__init__()
        self.mha = MultiHead(4, n_embed//4)
        self.ff = LinearForward(n_embed)
        self.prenorm = nn.LayerNorm(n_embed)
        self.postnorm = nn.LayerNorm(n_embed)

    def forward(self,X):
        x = self.mha(self.prenorm(X))
        out = self.ff(self.postnorm(x))
        return out


class BigramLanguageModel(nn.Module):
    def __init__(self,vocab_size):
        super().__init__()
        self.lookup_table = nn.Embedding(vocab_size,n_embed)
        self.linear_head = nn.Linear(n_embed,vocab_size)
        self.positional_embeddings = nn.Embedding(context_len,n_embed)
        self.transformer = TransformerBlock(n_embed)

    def forward(self,idx,targets=None):
        B,T = idx.shape
        token_embed = self.lookup_table(idx)
        pos_embed = self.positional_embeddings(torch.arange(T,device=device))
        x = token_embed + pos_embed
        x = self.transformer(x)
        logits = self.linear_head(x)

        B, T, C = logits.shape
        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(B*T,C),targets.view(B*T))
        return logits,loss

    def generate(self,idx,max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-context_len:]
            logits,loss = self(idx_cond)
            logits = logits[:,-1,:] # the whole batch, the next token and the whole vocab size
            probs = F.softmax(logits,dim=-1) # this gives us the probabilities
            idx_next = torch.multinomial(probs,num_samples=1)
            idx = torch.cat([idx, idx_next],dim=1)
        return idx

In [6]:
model = BigramLanguageModel(vocab_size=65)
model.load_state_dict(torch.load("../models/gpt.pth"))
with open("../data/input.txt","r",encoding='utf-8') as f:
    text = f.read()
chars = sorted(list(set("".join(text))))
vocab_size = len(chars)
context = torch.zeros((1,1),dtype=torch.long,device=device)
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for s,i in stoi.items()}
encode = lambda s:[stoi[c] for c in s]
decode = lambda s:[itos[c] for c in s]
print("".join(decode(model.generate(context,1000)[0].tolist())))


Sway day lonecan
I sarse with I fits our laid that a peayclave theepong; lik me,
In I hat his an, well ton he to theit
LANTIO:
Clikerere sighter layoustime pred
Sir I stry
Thal was chat of you sas pall crulen with to lanion for my them
The are uss to but devenak!
That letion:
I becturespier me them and up suree ane of hink fage! fast wheso wer Cut think mete as mouty ry thal: and out,
Of tharome?
Stion the begenukakes wing oby ling deet thoul us thou datin'd go math:
For olds.

JORKET:

Good,
Andien.

BOLIO:
You buterationg ther
EDend, mary.

Unlin theremost he bear ther, if you anfing the hat younumpark.

The shis he falove
Rom PARLANA:
O pre dot my your be.

MOLALAM:
Is twarwas berrow anten perces
a puray, wing this brinswy, so me hen mans
that to jed thaver
Of the wity
Ands sereve inturgeng cet
LAET:
But to Ped sur gard thou slavent why youstem, tiolf is I of you. welf, freelaterd?
Bep loollonjus one peak,
I feare faralm you to is good, I them
archet!

Rome ort: me on a sa's Rolds,