In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

In [None]:
! pip install lightning



In [None]:
import lightning as L

In [None]:
token_to_id = {
    'monkey': 0,
    'stole': 1,
    'banana': 2,
    'from': 3,
    'penguin': 4,
    '<EOS>': 5
}
id_to_token= dict(map(reversed, token_to_id.items()))

In [None]:
inputs = torch.tensor([
    [token_to_id["monkey"],
     token_to_id["stole"],
     token_to_id["banana"],
     token_to_id["<EOS>"],
     token_to_id["from"]],

    [token_to_id["penguin"],
     token_to_id["stole"],
     token_to_id["banana"],
     token_to_id["<EOS>"],
     token_to_id["from"]]
])

labels = torch.tensor([
    [token_to_id["stole"],
     token_to_id["banana"],
     token_to_id["<EOS>"],
     token_to_id["from"],
     token_to_id["<EOS>"]],

    [token_to_id["stole"],
     token_to_id["banana"],
     token_to_id["<EOS>"],
     token_to_id["from"],
     token_to_id["<EOS>"]]
])
dataset= TensorDataset(inputs, labels)
dataloader=DataLoader(dataset)

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k
        self.softmax = nn.Softmax(dim=1)
    def forward(self, Query, Key, Value, masked= False):
        scaled_dot_product = torch.matmul(Query, Key.mT) / torch.sqrt(torch.tensor(self.d_k, dtype= torch.float32, device=Query.device))
        if masked:
            self.mask = torch.triu(torch.multiply(torch.ones_like(scaled_dot_product), float("-inf")), diagonal=1).to(Query.device)
            scaled_dot_product = scaled_dot_product+self.mask
        attention_pattern = self.softmax(scaled_dot_product)
        delta_E = torch.matmul(attention_pattern, Value)

        return delta_E


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.query = nn.Linear(d_model, self.d_k*num_heads, bias= False)
        self.key = nn.Linear(d_model, self.d_k*num_heads, bias= False)
        self.value = nn.Linear(d_model, self.d_k*num_heads, bias= False)
        self.output_matrix = nn.Linear(num_heads * self.d_k, d_model, bias= False)

        self.attention = ScaledDotProductAttention(self.d_k)

    def forward(self, Query, Key, Value, masked= False):
        batch_size = Query.shape[0]

        self.Q = self.query(Query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2).to(Query.device)
        self.K = self.key(Key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2).to(Query.device)
        self.V = self.value(Value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2).to(Query.device)

        self.delta_E = self.attention(self.Q, self.K, self.V, masked).transpose(1,2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)

        self.output = self.output_matrix(self.delta_E)

        return self.output

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model=2, max_len=6):
        super(PositionalEncoding, self).__init__()
        pe= torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        embedding_index= torch.arange(0, d_model, step=2, dtype=torch.float).unsqueeze(0)
        angle = pos/(torch.tensor(10000)**(embedding_index/d_model))
        pe[:, 0::2] = torch.sin(angle)
        pe[:, 1::2] = torch.cos(angle)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :].to(x.device)

In [None]:
class DecoderOnlyTransformer(L.LightningModule):
    def __init__(self, num_tokens=4, d_model=2, max_len=6):
      super().__init__()
      self.word_embedding= nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)
      self.pe= PositionalEncoding(d_model=d_model, max_len=max_len)
      self.attention= MultiHeadAttention(d_model=d_model, num_heads=2)
      self.norm= nn.LayerNorm(d_model)
      self.fc= nn.Linear(in_features=d_model, out_features=num_tokens)

      self.loss= nn.CrossEntropyLoss()
    def forward(self, tokens):
      embeddings= self.word_embedding(tokens)
      embeddings = self.pe(embeddings)
      self_attention_vals= self.attention(embeddings, embeddings, embeddings, masked= True)
      residual_connection= embeddings+self_attention_vals
      residual_connection= self.norm(residual_connection)
      output= self.fc(residual_connection)
      return output

    def configure_optimizers(self):
      return Adam(self.parameters(), lr=0.1)


    def training_step(self, batch, batch_idx):
      input_tokens, labels = batch
      output = self.forward(input_tokens)
      output = output.view(-1, output.size(-1))
      labels = labels.view(-1)
      loss = self.loss(output, labels)
      return loss

In [None]:
model=DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2,max_len=6)

In [None]:
model_input = torch.tensor([token_to_id["monkey"],
                            token_to_id["stole"],
                            token_to_id["<EOS>"]])

input_length = model_input.size(dim=0)
predictions = model(model_input.unsqueeze(0))

predicted_id=torch.argmax(predictions[-1,:])
predicted_ids=predicted_id.unsqueeze(0)

max_length=6
for i in range (input_length,max_length):
  if predicted_id == token_to_id["<EOS>"]:
    break

  model_input = torch.cat((model_input, predicted_id.unsqueeze(0)))
  predictions = model(model_input.unsqueeze(0))
  predicted_id = torch.argmax(predictions[0, -1, :])
  predicted_ids = torch.cat((predicted_ids, predicted_id.unsqueeze(0)))

print("Predicted Tokens:\n")
for id in predicted_ids:
  print("\t", id_to_token[id.item()])

Predicted Tokens:

	 monkey
	 <EOS>


In [None]:
train=L.Trainer(max_epochs=30)
train.fit(model, train_dataloaders=dataloader)

INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: 
  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | word_embedding | Embedding          | 12     | train
1 | pe             | PositionalEncoding | 0      | train
2 | attention      | MultiHeadAttention | 16     | train
3 | norm          

Training: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=30` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


In [None]:
model_input = torch.tensor([token_to_id["monkey"],
                            token_to_id["stole"],
                            token_to_id["banana"],
                            token_to_id["from"]])

input_length = model_input.size(dim=0)
predictions = model(model_input.unsqueeze(0))

predicted_id=torch.argmax(predictions[-1,:])
predicted_ids=predicted_id.unsqueeze(0)

max_length=6
for i in range (input_length,max_length):
  if predicted_id == token_to_id["<EOS>"]:
    break

  model_input = torch.cat((model_input, predicted_id.unsqueeze(0)))
  predictions = model(model_input.unsqueeze(0))
  predicted_id = torch.argmax(predictions[0, -1, :])
  predicted_ids = torch.cat((predicted_ids, predicted_id.unsqueeze(0)))

print("Predicted Tokens:\n")
for id in predicted_ids:
  print("\t", id_to_token[id.item()])


Predicted Tokens:

	 monkey
	 monkey
	 monkey
