In [2]:
import torch.nn as nn
import torch

import numpy as np
import matplotlib.pyplot as plt
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Linear check [X]
N = 2
length = 4
seq = torch.rand((2, 4))
seq = torch.unsqueeze(seq, dim=2)

d_model = 2
embedding = nn.Linear(1, 4)

print(seq)
print(embedding(seq))


# Out check
N = 5
target_len = 3
d_model = 2

decoder_out = torch.rand((N, target_len, d_model))
print(decoder_out)

out = nn.Linear(d_model, 1)

model_out = out(decoder_out)

print(model_out)

tensor([[[0.2190],
         [0.3532],
         [0.0474],
         [0.3066]],

        [[0.2560],
         [0.8752],
         [0.1437],
         [0.8498]]])
tensor([[[ 1.5899e-02, -5.1423e-02, -3.0294e-01, -4.0232e-01],
         [-1.1244e-01,  2.8466e-02, -4.0180e-01, -2.9895e-01],
         [ 1.8001e-01, -1.5358e-01, -1.7652e-01, -5.3452e-01],
         [-6.7916e-02,  7.5148e-04, -3.6751e-01, -3.3481e-01]],

        [[-1.9509e-02, -2.9382e-02, -3.3022e-01, -3.7380e-01],
         [-6.1161e-01,  3.3920e-01, -7.8634e-01,  1.0315e-01],
         [ 8.7929e-02, -9.6261e-02, -2.4745e-01, -4.6034e-01],
         [-5.8730e-01,  3.2407e-01, -7.6761e-01,  8.3563e-02]]],
       grad_fn=<AddBackward0>)
tensor([[[0.9862, 0.6988],
         [0.9182, 0.1094],
         [0.4937, 0.5580]],

        [[0.5667, 0.1993],
         [0.7072, 0.6829],
         [0.6924, 0.2154]],

        [[0.2458, 0.5928],
         [0.8755, 0.0358],
         [0.9921, 0.5045]],

        [[0.1643, 0.4938],
         [0.0528, 0.7670],
  

In [21]:
# Generate data
N = 2000
seq_len = 4
SOS_token = -1e2
EOS_token = 1e2

# Generate data
data = torch.rand((N, seq_len))

# Generate target
target = torch.zeros((N, seq_len-1))
for i in range(seq_len-1):
    target[:, i] = data[:, i] + data[:, i+1]
    
# Append SOS and EOS token
SOS_ = SOS_token*torch.ones((N, 1))
EOS_ = EOS_token*torch.ones((N, 1))

data = torch.cat((SOS_, data), axis=1)
data = torch.cat((data, EOS_), axis=1)

target = torch.cat((SOS_, target), axis=1)
target = torch.cat((target, EOS_), axis=1)

# Unsqueeze so each entry becomes its own dimension, needed for the nn.Linear embedding
data = torch.unsqueeze(data, dim=2)
target = torch.unsqueeze(target, dim=2)

In [22]:
import math

def batchify(data, target):
    numMiniBatch = int(math.floor(data.shape[0]/100.0))
    inputMiniBatches = data.chunk(numMiniBatch)
    outputMiniBatches = target.chunk(numMiniBatch)
    
    return numMiniBatch, inputMiniBatches, outputMiniBatches

num_batch, data_batches, target_batches = batchify(data, target)

In [23]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout_p, max_len):
        super().__init__()

        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, d_model)
        positions_list = torch.arange(0, max_len, dtype=torch.long).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, d_model, 2).long() * (-math.log(10000.0)) / d_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/d_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/d_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

    
class TransformerModel(nn.Module):
    def __init__(self, d_model, n_head, n_layers):
        super().__init__()
        
        self.d_model = d_model
        self.positional_encoder = PositionalEncoding(
            d_model = d_model, 
            dropout_p = 0.1,
            max_len = 100)
        self.embedding = nn.Linear(1, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=n_head,
            num_encoder_layers=n_layers,
            num_decoder_layers=n_layers)
        
        # For floating point with one output, then self.out = nn.Linear(d_model, 1)
        self.out = nn.Linear(d_model, 1) # Learned linear at the end where output of decoder is run through
        
    def get_tgt_mask(self, size):
    
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        
        return mask
        
    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
        
        src = (self.embedding(src) * math.sqrt(self.d_model)).to(device)
        tgt = (self.embedding(tgt) * math.sqrt(self.d_model)).to(device)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)
        
        src = src.permute(1, 0, 2)
        tgt = tgt.permute(1, 0, 2)
        
        transformer_output = self.transformer(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask).to(device)
        output = self.out(transformer_output).to(device)
        
        output = output.permute(1, 2, 0).to(device)
        
        return output

In [24]:
# Define the model
d_model = 16
n_head = 2
n_layers = 2

model = TransformerModel(d_model, n_head, n_layers)
opt = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.MSELoss()

In [25]:
# Training
def train(model, data_batches, target_batches, num_batch):
    model.train()
    total_loss = 0

    for i in range(num_batch):
        data = data_batches[i].to(device)
        target = target_batches[i].to(device)
        
        target_in = target[:, :-1]
        target_expected = target[:, 1:]
        
        sequence_length = target_in.size(1)

        tgt_mask = model.get_tgt_mask(sequence_length).to(device)
        src_mask = model.get_tgt_mask(data.size(1)).to(device)
        
        pred = model(data, target_in, src_mask, tgt_mask)
        pred = torch.permute(pred, (0, 2, 1))
        
        loss = (criterion(pred, target_expected).type(torch.float))

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.detach().item()
        
    return total_loss
    
def train_loop(model, n_epochs, data, target, num_batch):
    best_loss = 10000.0
    best_model = copy.deepcopy(model).to(device)
    loss_ = np.array([])
    epochs = np.array([])
    
    for i in range(n_epochs):
        loss = train(model, data, target, num_batch)
        loss_ = np.append(loss_, loss)
        epochs = np.append(epochs, i)
        
        if loss < best_loss:
            best_loss = loss
            best_model = copy.deepcopy(model).to(device)
                     
        if i % 10 == 0:
            print(f'Epoch: {i}\nTotal Loss: {loss}')
            print(f'----------------------------------')
    
    return best_model, loss_, epochs

In [26]:
n_epochs = 110
best_model, loss_, epochs = train_loop(model, n_epochs, data_batches, target_batches, num_batch)

Epoch: 0
Total Loss: 49057.34521484375
----------------------------------
Epoch: 10
Total Loss: 47502.649658203125
----------------------------------
Epoch: 20
Total Loss: 47058.110595703125
----------------------------------
Epoch: 30
Total Loss: 46585.136962890625
----------------------------------
Epoch: 40
Total Loss: 46048.794189453125
----------------------------------
Epoch: 50
Total Loss: 45511.4765625
----------------------------------
Epoch: 60
Total Loss: 44871.1435546875
----------------------------------
Epoch: 70
Total Loss: 44248.941650390625
----------------------------------
Epoch: 80
Total Loss: 43585.76904296875
----------------------------------
Epoch: 90
Total Loss: 42881.96337890625
----------------------------------
Epoch: 100
Total Loss: 42167.56982421875
----------------------------------


In [None]:
example_in = torch.tensor([[[np.pi], [1.0], [2.0], [3.0]]])
example_out = 