In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

In [None]:
batch_size = 64
block_size = 256
max_iter = 5000
eval_interval = 500
learning_rate = 3e-4
eval_iters = 200
embed_size = 384
num_heads = 6
n_layers = 6
dropout = 0.2

In [None]:
class FeedForwardNetwork(nn.Module):
  def __init__(self, inpt_size):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(inpt_size, inpt_size*4),
        nn.ReLU(),
        nn.Linear(inpt_size*4, inpt_size),
        nn.Dropout(dropout)
    )

  def forward(self,x):
    return self.net(x)

In [None]:
class Head(nn.Module):
  def __init__(self, head_size, masked = False, cross_attention = False):
    super().__init__()
    self.key = nn.Linear(embed_size, head_size, bias = False)
    self.query = nn.Linear(embed_size, head_size, bias = False)
    self.value = nn.Linear(embed_size, head_size, bias = False)
    self.masked = masked
    self.cross_attention = cross_attention
    if (self.masked):
      self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, enc = None):
    B,T,C = x.shape
    q = self.query(x)
    if(self.cross_attention):
      assert enc is not None
      k = self.key(enc)
      v = self.value(enc)
    else:
      k = self.key(x)
      v = self.value(x)
    wei = q @ k.transpose(-2,-1) * C**-0.5
    if (self.masked):
      wei = wei.masked_fill(self.tril, float('-inf'))
    wei = F.softmax(wei, dim = -1)
    wei = self.dropout(wei)
    out = wei @ v
    return out

In [None]:
class MultiheadAttention(nn.Module):
  def __init__(self, num_heads, head_size, masked = False, cross_attention = False):
    super().__init__()
    self.head_list = nn.ModuleList([Head(head_size, masked, cross_attention) for _ in range(num_heads)])
    self.proj = nn.Linear(embed_size, embed_size)
    self.dr = nn.Dropout(dropout)

  def forward(self, idx):
    head_out = torch.cat([h(idx) for h in self.head_list], dim = -1)
    out = self.proj(head_out)
    out = self.dr(out)
    return out

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self, embed_size, num_heads):
    super().__init__()
    head_size = embed_size//num_heads
    self.self_attn = MultiheadAttention(num_heads, head_size)
    self.ffn = FeedForwardNetwork(embed_size)
    self.ln1 = nn.LayerNorm(embed_size)
    self.ln2 = nn.LayerNorm(embed_size)

  def forward(self, x):
    x = x + self.self_attn(self.ln1(x))
    x = x + self.ffn(self.ln2(x))
    return x


In [None]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)


class LanguageTranslatorEncoder(nn.Module):
  def __init__(self, num_heads):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
    self.position_embedding_table = nn.Embedding(block_size, embed_size)
    self.blocks = nn.Sequential(*[EncoderBlock(embed_size, 4) for _ in range(n_layers)])
    self.lm = nn.LayerNorm(embed_size)
    self.ll = nn.Linear(embed_size, vocab_size)

  def forward(self, idx, target = None):
    B,T = idx.shape
    char_embds = self.token_embedding_table(idx)
    pos_embds = self.position_embedding_table(torch.arange(T, device = device))
    logits = char_embds + pos_embds
    logits = self.blocks(logits)
    logits = self.lm(logits)
    logits = self.ll(logits)
    if target is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      target = target.view(B*T)
      loss = F.cross_entropy(logits, target)

    return logits, loss


In [None]:
class DecoderBlock(nn.Module):
  def __init__(self, embed_size, num_heads):
    super().__init__()
    head_size = embed_size//num_heads
    self.masked_self_attn = MultiheadAttention(num_heads, head_size, masked = True)
    self.cross_attn = MultiheadAttention(num_heads, head_size, cross_attention = True)
    self.ffn = FeedForwardNetwork(embed_size)
    self.ln1 = nn.LayerNorm(embed_size)
    self.ln2 = nn.LayerNorm(embed_size)
    self.ln3 = nn.LayerNorm(embed_size)
    self.ln4 = nn.LayerNorm(embed_size)

  def forward(self, x, enc):
    x = x + self.masked_self_attn(self.ln1(x))
    x = x + self.cross_attn(self.ln2(x), self.ln3(enc))
    x = x + self.ffn(self.ln4(x))
    return x


In [None]:
class LanguageTranslatorDecoder(nn.Module):
  def __init__(self, num_heads):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
    self.position_embedding_table = nn.Embedding(block_size, embed_size)
    self.blocks = nn.Sequential(*[DecoderBlock(embed_size, 4) for _ in range(n_layers)])
    self.lm = nn.LayerNorm(embed_size)
    self.ll = nn.Linear(embed_size, vocab_size)

  def forward(self, idx, target = None):
    B,T = idx.shape
    char_embds = self.token_embedding_table(idx)
    pos_embds = self.position_embedding_table(torch.arange(T, device = device))
    logits = char_embds + pos_embds
    logits = self.blocks(logits)
    logits = self.lm(logits)
    logits = self.ll(logits)
    if target is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      target = target.view(B*T)
      loss = F.cross_entropy(logits, target)

    return logits, loss

  def generate(self, idx, max_new_tokens, num_heads):
    for _ in range(max_new_tokens):
      idx_cond = idx[:, -block_size:]
      logits,_ = self(idx_cond)
      logits = logits[:,-1,:]
      soft_max = nn.Softmax(dim = 1)
      probab  = soft_max(logits)
      predictions = torch.multinomial(probab, 1)
      idx = torch.cat((idx, predictions), -1)
    return idx