In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

In [36]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [61]:
EPOCHS = 5
LR = 0.1

In [8]:
token_to_id = {
    'what': 0,
    'is': 1,
    'statquest': 2,
    'awesome': 3,
    '<EOS>': 4
}

id_to_token = dict(map(reversed, token_to_id.items()))

In [18]:
inputs = torch.tensor([[
    token_to_id['what'],
    token_to_id['is'],
    token_to_id['statquest'],
    token_to_id['<EOS>'],
    token_to_id['awesome']],

    [token_to_id['statquest'],
    token_to_id['is'],
    token_to_id['what'],
    token_to_id['<EOS>'],
    token_to_id['awesome']]
                      ])

In [19]:
labels = torch.tensor([[
    token_to_id['is'],
    token_to_id['statquest'],
    token_to_id['<EOS>'],
    token_to_id['awesome'],
    token_to_id['<EOS>']],

    [token_to_id['is'],
    token_to_id['what'],
    token_to_id['<EOS>'],
    token_to_id['awesome'],
    token_to_id['<EOS>']]
])

In [25]:
print(f'INPUTS:\n{inputs}')
print(f'TARGETS:\n{labels}')

INPUTS:
tensor([[0, 1, 2, 4, 3],
        [2, 1, 0, 4, 3]])
TARGETS:
tensor([[1, 2, 4, 3, 4],
        [1, 0, 4, 3, 4]])


In [26]:
dataset = TensorDataset(inputs, labels) 
dataloader = DataLoader(dataset)

In [27]:
class PositionEncoding(nn.Module):
    
    def __init__(self, d_model=2, max_len=6):
        super().__init__()
        pe = torch.zeros(max_len, d_model)   
        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)
        embedding_index = torch.arange(start=0, end=d_model, step=2).float()
        div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, word_embeddings):
    
        return word_embeddings + self.pe[:word_embeddings.size(0), :]

In [28]:
class Attention(nn.Module): 
    
    def __init__(self, d_model=2):
        super().__init__()
        
        self.d_model=d_model
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

        
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)
        
        return attention_scores

In [44]:
class DecoderOnlyTransformer(nn.Module):
    
    def __init__(self, num_tokens=4, d_model=2, max_len=6):
        
        super().__init__()
        self.we = nn.Embedding(num_embeddings=num_tokens, 
                               embedding_dim=d_model)     
        self.pe = PositionEncoding(d_model=d_model, 
                                   max_len=max_len)
        self.self_attention = Attention(d_model=d_model)
        self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)

        
    def forward(self, token_ids):
                
        word_embeddings = self.we(token_ids)        
        position_encoded = self.pe(word_embeddings)
        mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0)), device='cpu'))
        mask = mask == 0
        
        self_attention_values = self.self_attention(position_encoded, 
                                                    position_encoded, 
                                                    position_encoded, 
                                                    mask=mask)
                
        residual_connection_values = position_encoded + self_attention_values
        
        fc_layer_output = self.fc_layer(residual_connection_values)
        
        return fc_layer_output

In [45]:
## First, create a model from DecoderOnlyTransformer()
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6)

## Now create the input for the transformer...
model_input = torch.tensor([token_to_id["what"], 
                            token_to_id["is"], 
                            token_to_id["statquest"], 
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

In [64]:
model_input

tensor([0, 1, 2, 4, 3, 3])

In [46]:
predictions = model(model_input) 

In [47]:
predictions

tensor([[ 0.2978,  1.4903,  1.9730, -0.5394,  0.4174],
        [ 0.7757,  1.6035,  1.9735,  0.0046,  0.5067],
        [ 2.0159,  1.2240,  1.1363,  1.7033,  0.8499],
        [ 2.2752,  0.2502, -0.1529,  2.4392,  1.0696]],
       grad_fn=<AddmmBackward0>)

In [48]:
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])

In [50]:
predicted_ids = predicted_id

In [52]:
max_length = 6
for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
## Now printout the predicted output phrase.
print("Predicted Tokens:") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:
	 awesome
	 awesome
	 <EOS>


In [69]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

In [71]:
optimizer.zero_grad(set_to_none=True)
for epoch in range(EPOCHS):
    for i, batch in enumerate(dataloader):
        x, y = batch
        y_hat = model(x[0])
        loss = F.cross_entropy(y_hat, y[0])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

In [73]:
model_input = torch.tensor([token_to_id["what"], 
                            token_to_id["is"], 
                            token_to_id["statquest"], 
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input) 
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
print("Predicted Tokens:") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:
	 awesome
	 <EOS>


In [74]:
## Now let's ask the other question...
model_input = torch.tensor([token_to_id["statquest"], 
                            token_to_id["is"], 
                            token_to_id["what"], 
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input) 
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
print("Predicted Tokens:") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:
	 awesome
	 <EOS>
