In [None]:
import torch
import math
from torch.nn import functional as F

f = open("/content/input.txt", "r")
raw_text = f.read()

device = "cuda" if torch.cuda.is_available() else 'cpu'

In [None]:
chars = sorted(list(set(raw_text)))

c_to_i = {j:i for i,j in zip(range(len(chars)), chars)}
i_to_c = {i:j for i,j in zip(range(len(chars)), chars)}
encode = lambda s : [c_to_i[c] for c in s]
decode = lambda l : ''.join([i_to_c[i] for i in l])

data = torch.tensor(encode(raw_text)).to(device)

n = int(0.9*len(data))
train = data[:n]
val = data[n:]

In [None]:
batch_size = 64
block_size = 32
n_embeddings = 384
n_heads = 6
dropout = 0.2
vocab_size = len(chars)

def get_batch(split="train"):
  if(split == "train"):
    batch_data = train
  else:
    batch_data = val
  rand_vals = torch.randint(len(batch_data) - block_size, (batch_size,))
  x = torch.stack([batch_data[i:i+block_size] for i in rand_vals])
  y = torch.stack([batch_data[i+1:i+block_size+1] for i in rand_vals])
  return x,y


In [None]:
class Attention_Head(torch.nn.Module):
  def __init__(self, n_embeddings, head_size):
    super().__init__()
    self.Q_weights = torch.nn.Linear(n_embeddings, head_size, bias=False)
    self.K_weights = torch.nn.Linear(n_embeddings, head_size, bias=False)
    self.V_weights = torch.nn.Linear(n_embeddings, head_size, bias=False)
    self.sqrt_head_size = math.sqrt(head_size)
    self.reg = torch.nn.Dropout(dropout)
    #self.attention_layer = torch.nn.Linear(head_size, vocab_size)

  def forward(self, idx):
    #attention head
    Q = self.Q_weights(idx) # (B, T, head_size)
    K = self.K_weights(idx) # (B, T, head_size)
    V = self.V_weights(idx) # (B, T, head_size)

    W = Q @ K.transpose(-1,-2) # (B, T, T)

    W = torch.tril(W) / self.sqrt_head_size + torch.triu(torch.full_like(W, float("-inf")), 1)
    W = F.softmax(W, dim = -2) #(B, T, T)
    W = self.reg(W)

    Att = W @ V #(B, T, head_size)

    return Att

In [None]:
class multi_attention_head(torch.nn.Module):
  def __init__(self, n_heads, head_size):
    super().__init__()
    self.heads = torch.nn.ModuleList([Attention_Head(n_embeddings, head_size) for i in range(n_heads)])
    self.proj = torch.nn.Linear(n_embeddings,n_embeddings)
    self.reg = torch.nn.Dropout(dropout)

  def forward(self, idx):
    mha_out = torch.concat([h(idx) for h in self.heads], dim =-1)
    projections = self.proj(mha_out)
    out = self.reg(projections)
    return out

In [None]:
class ffwd(torch.nn.Module):
  def __init__(self, n_embeddings):
    super().__init__();
    self.linear = torch.nn.Linear(n_embeddings, 4 * n_embeddings)
    self.activation = torch.nn.ReLU()
    self.proj = torch.nn.Linear(4 * n_embeddings, n_embeddings)
    self.reg = torch.nn.Dropout(dropout)

  def forward(self, idx):
    raw_activations = self.linear(idx)
    relu_act = self.activation(raw_activations)
    projections = self.proj(relu_act)
    out = self.reg(projections)
    return out

In [None]:
class Block(torch.nn.Module):
  def __init__(self, n_embeddings, n_headds):
    super().__init__();
    self.ma_head = multi_attention_head(n_heads, n_embeddings//n_heads)
    self.feed_forward = ffwd(n_embeddings)
    self.ln1 = torch.nn.LayerNorm(n_embeddings)
    self.ln2 = torch.nn.LayerNorm(n_embeddings)

  def forward(self, idx):
    add_norm_mha = idx + self.ma_head(self.ln1(idx)) #(B, T, C) + #(B, T, C)
    add_norm_ffwd = add_norm_mha + self.feed_forward(self.ln2(add_norm_mha))
    return add_norm_ffwd

In [None]:
class BigramModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.embedding_table = torch.nn.Embedding(vocab_size,n_embeddings)
    self.pos_embeddings = torch.nn.Embedding(block_size, n_embeddings)
    self.blocks = torch.nn.Sequential(
        Block(n_embeddings, n_heads),
        Block(n_embeddings, n_heads),
        Block(n_embeddings, n_heads)
    )
    self.lm_head = torch.nn.Linear(n_embeddings, vocab_size)

  def forward(self, idx, targets=None):
    B, T = idx.shape
    #C = num_embeddings
    tok_emb = self.embedding_table(idx) #(B, T, C)
    pos_emb = self.pos_embeddings(torch.arange(T, device=device)) #(T, C)
    x = tok_emb + pos_emb # (B, T, C)

    raw_output = self.blocks(x)

    logits = self.lm_head(raw_output) #(B, T, vocab_size)
    #print(tok_emb.shape, pos_emb.shape)
    #logits = self.lm_head(W) #(B, T, vocab_size)
    if targets == None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
    return logits, loss

  def generate(self, idx, max_new_tokens):
    for i in range(max_new_tokens):
      logits, loss = self(idx[:,-20:])
      pred = F.softmax(logits[:,-1,:], dim=-1)
      nv = torch.multinomial(pred, 1)
      idx = torch.cat((idx, nv), dim = 1)
    return idx

In [None]:
model = BigramModel()
model.to(device)

BigramModel(
  (embedding_table): Embedding(65, 384)
  (pos_embeddings): Embedding(20, 384)
  (blocks): Sequential(
    (0): Block(
      (ma_head): multi_attention_head(
        (heads): ModuleList(
          (0-5): 6 x Attention_Head(
            (Q_weights): Linear(in_features=384, out_features=64, bias=False)
            (K_weights): Linear(in_features=384, out_features=64, bias=False)
            (V_weights): Linear(in_features=384, out_features=64, bias=False)
            (reg): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (reg): Dropout(p=0.2, inplace=False)
      )
      (feed_forward): ffwd(
        (linear): Linear(in_features=384, out_features=1536, bias=True)
        (activation): ReLU()
        (proj): Linear(in_features=1536, out_features=384, bias=True)
        (reg): Dropout(p=0.2, inplace=False)
      )
      (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (ln2): Laye

In [None]:
init_tensor = torch.tensor([[0]]).to(device)
zed = model.generate(init_tensor, 1000)

In [None]:
print([decode(i) for i in zed.tolist()][0])


OFo oe o
O a n Aeat: whe toos of thom has rear senctered, coshat king igherare he 's ede hife the is lot save timendw the
wich'd Lo down ad, prakin iar's po sortour, nd gotlied yould achmomay swand to string
Apbacid ound sdtithes sus gord, angleble baight thonbe andand,
Whath sinve of stengwailed, he'r nowranve
BaWo his of his eeqet Low thaly'd trive is de moke my and my as, bre dis weard awite bo sepe lovetow
Good le soughou deag hois ford'ceits y to pasterem gank,
With pekes sigh be ince my lovee.

CORIAR got by swawhion the ave he.

TENVERMLANG! Fowthe
Wo tman is thatand the is kill Eves, whou toon be tho beys
righad harimait, and your egarwancenp
Thoor kigh ofind hathis 'th driee
BOKNGHEO:
Whmaued.

Fith miill ifsabd faimny sat tiy him thal
They dis

inier forme how havn wold hat once!

LOKNG Lidtof;
And Eo dhent no detoone: Low! Wist your sair ang-peit-lal;
Rome, gut-on a I' As waced
lo:
Ad is a prosterj chou our is,
socts; qukes, thond what 'terenldeviir kinder.

Hot frore, ny l

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
val_interval = 300

for steps in range(5000):
  optimizer.zero_grad()
  batch_x, batch_y = get_batch()
  lgts, loss = model(batch_x, batch_y)
  loss.backward()
  optimizer.step()
  if(steps % val_interval == 0):
    val_x, val_y = get_batch("val")
    val_loss = model(val_x, val_y)[1]
    print(loss, val_loss)