In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

In [None]:
# Tokenization

token_to_id = {
  "What": 0,
  "is": 1,
  "KreakxX":2,
  "cool":3,
  "<EOS":4,
}

id_to_token = dict(map(reversed,token_to_id.items()))

In [None]:
# Dataset

inputs = torch.tensor([[token_to_id["What"], token_to_id["is"], token_to_id["KreakxX"], token_to_id["<EOS>"], token_to_id["cool"]]])

labels = torch.tensor([[token_to_id["is"], token_to_id["KreakxX"], token_to_id["<EOS>"], token_to_id["cool"], token_to_id["<EOS>"]]])


dataset = TensorDataset(inputs,labels)
dataloader = DataLoader(dataset)

In [None]:
class PostionalEncoding(nn.Module):
  # d_model = is embeding_dim max len is as it says the max_seq_len
  def __init__(self,d_model=2, max_len=6):
    super().__init__()


    # creates a Matrix of zeros with max_len rows and d_model columns
    pe = torch.zeros(max_len,d_model)

    position = torch.arange(start=0, end=max_len,step=1).float().unsqueeze(1)
    # creates a sequence of numbers between 0 and max_len

    embedding_index = torch.arange(start=0,end=d_model,step=2).float()
    # creates a sequence of numbers between 0 and d_model

    div_term = 1/torch.tensor(10000.0) ** (embedding_index / d_model)

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.sin(position * div_term)

    self.register_buffer('pe',pe)
  
  def forward(self, word_embeddings):
    return word_embeddings + self.pe[:word_embeddings.size(0),:]

In [None]:
class Attention(nn.Module):
  def __init__(self, d_model=2):
    super().__init__()
    
    self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    # untrained weights

    self.row_dim = 0
    self.col_dim = 1
  
  def forward(self, encodings_q, encodings_k, encodings_v, mask=None):
    q = self.W_q(encodings_q)
    k = self.W_k(encodings_k)
    v = self.W_v(encodings_v)

    sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

    scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)

    if mask is not None:
      scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
    
    attention_percents = F.softmax(scaled_sims,dim=self.col_dim)
    attention_scores = torch.matmul(attention_percents,v)
    
    return attention_scores

In [None]:
class DecoderOnlyTransformer():
  def __init__(self, num_tokens=4, d_model=2,max_len=6):
    super().__init__()

    self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)
    self.pe = PostionalEncoding(d_model, max_len)
    self.self_attention = Attention(d_model)
    self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)

    self.loss = nn.CrossEntropyLoss()
  
  def forward(self,token_ids):
    word_embeddings = self.we(token_ids)
    positional_encoding = self.pe(word_embeddings)
    mask = torch.tril(torch.ones((token_ids.size(dim=0),token_ids.size(dim=0))))
    mask = mask == 0

    self_attention = self.self_attention(positional_encoding,positional_encoding,positional_encoding,mask=mask)

    residual_connection = positional_encoding + self_attention

    fc_layer_output = self.fc_layer(residual_connection)

    return fc_layer_output
  def configure_optimizers(self):
    return Adam(self.parameters(), lr=0.1)
  
  def training_step(self, batch, batch_idx):
    input_tokens,labels = batch
    output = self.forward(input_tokens[0])
    loss = self.loss(output,labels[0])
    return loss

In [None]:
model = DecoderOnlyTransformer(num_tokens=len(token_to_id),d_model=2, max_len=6)

model_input = torch.tensor([[token_to_id["What"], token_to_id["is"], token_to_id["KreakxX"], token_to_id["<EOS>"]]])

input_length = model_input.size(dim=0)

predictions = model(model_input)
predicted_id = torch.tensor([torch.argmax(predictions[-1:])])
predicted_ids = predicted_id

max_len = 6

for i in range(max_len):
  if(predicted_id == token_to_id["<EOS>"]):
    break

  model_input = torch.cat((model_input, predicted_id))
  predictions = model(model_input)
  predicted_id = torch.tensor([torch.argmax(predictions[-1:])])
  predicted_ids = torch.cat((predicted_ids,predicted_id))

for id in predicted_ids:
  print("\t", id_to_token[id.item()])