In [None]:
import pip

try:
    __import__('lightning')
except:
    pip.main(['install', 'lightning'])

In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

In [61]:
# "What you cannot create you do not understand."
# — Richard Feynman

In [62]:
import torch
from torch.utils.data import TensorDataset, DataLoader

token_to_id = {
    'what': 0,
    'you': 1,
    'cannot': 2,
    'create': 3,
    'do': 4,
    'not': 5,
    'understand': 6,
    '<EOS>': 7,
}

id_to_token = dict(map(reversed, token_to_id.items()))

# Tokenized sentence with <EOS>
full_sequence = [token_to_id[tok] for tok in ['what', 'you', 'cannot', 'create', 'you', 'do', 'not', 'understand', '<EOS>']]

# Split into input-label pairs using shifting
inputs = torch.tensor([full_sequence[:-1]])   # all tokens except the last
labels = torch.tensor([full_sequence[1:]])    # all tokens except the first

# Create dataset and dataloader
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset, batch_size=1)

# For demonstration, let's print
for batch in dataloader:
    x, y = batch
    print("Inputs: ", x)
    print("Labels: ", y)


Inputs:  tensor([[0, 1, 2, 3, 1, 4, 5, 6]])
Labels:  tensor([[1, 2, 3, 1, 4, 5, 6, 7]])


In [63]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model=2, max_len=6):
        super().__init__()

        pe = torch.zeros(max_len, d_model) # positional enconding

        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)

        embedding_index = torch.arange(start=0, end=d_model, step=2).float()
        div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)

        pe[:, 0::2] = torch.sin(position * div_term) ## every other column, starting with the 1st, has sin() values
        pe[:, 1::2] = torch.cos(position * div_term) ## every other column, starting with the 2nd, has cos() values

        self.register_buffer('pe', pe)

    def forward(self, word_embeddings):
        return word_embeddings + self.pe[:word_embeddings.size(0)]

In [64]:
class Attention(nn.Module):
    def __init__(self, d_model=2):
        super().__init__()

        self.d_model = d_model

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask):
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        # if mask is not None:
        scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)
        
        return attention_scores

In [65]:
class DecoderOnlyTransformer(L.LightningModule):
    
    def __init__(self, num_tokens=4, d_model=2, max_len=7):
        
        super().__init__()
        
        # L.seed_everything(seed=42)
        
        self.we = nn.Embedding(num_embeddings=num_tokens, 
                               embedding_dim=d_model)     
        
        self.pe = PositionalEncoding(d_model=d_model, 
                                   max_len=max_len)

        self.self_attention = Attention(d_model=d_model)
        self.self_attention_2 = Attention(d_model=d_model)
        self.self_attention_3 = Attention(d_model=d_model)
        self.self_attention_4 = Attention(d_model=d_model)

        self.reduce_attention_dim = nn.Linear(in_features=(4*d_model), out_features=d_model)

        self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)
        
        self.loss = nn.CrossEntropyLoss()
        # self.loss = F.softmax()
        
        
    def forward(self, token_ids):
                
        word_embeddings = self.we(token_ids)        
        position_encoded = self.pe(word_embeddings)
        
        mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0)), device=self.device))
        mask = mask == 0

        self_attention_values = self.self_attention(position_encoded, 
                                                    position_encoded, 
                                                    position_encoded, 
                                                    mask=mask)
                
        residual_connection_values = position_encoded + self_attention_values
        
        fc_layer_output = self.fc_layer(residual_connection_values)
        
        return fc_layer_output
    
    
    def configure_optimizers(self): 
        return Adam(self.parameters(), lr=0.1)
    
    
    def training_step(self, batch, batch_idx):
        input_tokens, labels = batch # collect input
        output = self.forward(input_tokens[0])
        loss = self.loss(output, labels[0])
        print(f'Loss: {loss}')
        return loss

In [66]:
## First, create a model from DecoderOnlyTransformer()
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=12)

## Now create the input for the transformer...
model_input = torch.tensor([token_to_id["what"], 
                           #  token_to_id["is"], 
                           #  token_to_id["statquest"], 
                           #  token_to_id["<EOS>"]
                           ])
input_length = model_input.size(dim=0)

## Now get get predictions from the model
predictions = model(model_input) 

predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])

predicted_ids = predicted_id


max_length = 7
for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
## Now printout the predicted output phrase.
print("Predicted Tokens:\n") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 create
	 create
	 create
	 create
	 create
	 create
	 create


In [67]:
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id
print(predicted_ids)

tensor([3])


In [68]:
trainer = L.Trainer(max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)

INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name                 | Type               | Params | Mode 
--------------------------------------------------------------------
0 | we                   | Embedding          | 16     | train
1 | pe                   | PositionalEncoding | 0      | train
2 | self_attention       | Attention          | 12     | train
3 | self_attention_2     | Attention          | 12     | train
4 | self_attention_3     | Attention          | 12     | train
5 | self_attention_4     | Attention          | 12     | train
6 | reduce_attention_dim | Linear             | 18     | train
7 | fc_layer             | 

Training: |          | 0/? [00:00<?, ?it/s]

Loss: 2.132384777069092
Loss: 1.8171823024749756
Loss: 1.5985180139541626
Loss: 1.4378249645233154
Loss: 1.3075172901153564
Loss: 1.1973940134048462
Loss: 1.1041547060012817
Loss: 1.0170232057571411
Loss: 0.9212038516998291
Loss: 0.8178309202194214
Loss: 0.7340945601463318
Loss: 0.6448743939399719
Loss: 0.5580506324768066
Loss: 0.4920317232608795
Loss: 0.4466918110847473
Loss: 0.4077514410018921
Loss: 0.36753010749816895


INFO: `Trainer.fit` stopped: `max_epochs=30` reached.


Loss: 0.32929039001464844
Loss: 0.28941813111305237
Loss: 0.2590929865837097
Loss: 0.2363891750574112
Loss: 0.21618640422821045
Loss: 0.20074868202209473
Loss: 0.18827491998672485
Loss: 0.17244865000247955
Loss: 0.15360501408576965
Loss: 0.13558350503444672
Loss: 0.11649902164936066
Loss: 0.09833530336618423
Loss: 0.0826854482293129


In [69]:

model_input = torch.tensor([token_to_id["what"], 
                            token_to_id["you"], 
                            # token_to_id["statquest"], 
                            # token_to_id["<EOS>"]
                           ])
input_length = model_input.size(dim=0)

predictions = model(model_input) 
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

for i in range(input_length, max_length):
    # if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
    #     break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
print("Predicted Tokens:\n") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 cannot
	 create
	 you
	 do
	 not
	 understand
