In [1]:
import torch
import warnings

import torch.nn as nn
import numpy as np

from tqdm import tqdm
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import random_split
from torch import Tensor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


### What does this notebook do?

A simple transformer based encoder-decoder architecture is trained to flip the input of binary values e.g. [0, 1, 1, 0, 0, 1] -> [1, 0, 0, 1, 1, 0]

### Helper functions

In [7]:
# Positional encoding
def sinusoids(length, channels, max_timescale=5000):
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


def create_mask(tgt):
    mask = (torch.triu(torch.ones((tgt.shape[1], tgt.shape[1]), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, True).masked_fill(mask == 1, False)
    return mask.type(torch.bool)


# Helper class to track losses
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val   = 0
        self.avg   = 0
        self.sum   = 0
        self.count = 0

    def update(self, val, n=1):
        self.val    = val
        self.sum   += val * n
        self.count += n
        self.avg    = self.sum / self.count

### Data preprocessing

In [8]:
# N          = number of training examples
# seq_length = maximum sequence length
# inputs     = N randomly generated training examples of binaries
# outputs    = flipped version of inputs, what the model should learn to predict
N          = 5000
seq_length = 10
inputs     = torch.randint(0,2,(N,seq_length))
outputs    = torch.fliplr(inputs)

In [9]:
class ToDataset(Dataset):

    def __init__(self, x, y):
        
        # Build items in class instance
        self.x = x.float()
        self.y = y.float()
        
    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self,index):
        return self.x[index], self.y[index]

In [10]:
train_loader = DataLoader(dataset = ToDataset(inputs, outputs), batch_size = 32)

### Neural network model

In [11]:
class Model(nn.Module):
    
    def __init__(self, emb_dim = 256):
        super(Model, self).__init__()
        
        self.emb_dim = emb_dim
                
        # Encoder embedding
        self.encoder_embedding = nn.Linear(1,self.emb_dim)
        
        # Transformer
        self.transformer = nn.Transformer(d_model=self.emb_dim,
                                       dim_feedforward=256,
                                       nhead=4,
                                       num_encoder_layers=2,
                                       num_decoder_layers=2,
                                       batch_first=True)
        
        
        # Decoder embedding
        self.decoder_embedding = nn.Linear(1,self.emb_dim)
        
        # Decoder projection
        self.decoder_projection = nn.Linear(self.emb_dim, 1)
        
        # Normalization layers
        self.src_norm = nn.LayerNorm(self.emb_dim)
        self.tgt_norm = nn.LayerNorm(self.emb_dim)
                        
    def forward(self, src: Tensor, tgt: Tensor, tgt_mask: Tensor):
        
        # Source projection
        src = torch.unsqueeze(src,-1)
        src_emb = self.encoder_embedding(src)
        src_emb = self.src_norm(src_emb)
        src_emb = src_emb + sinusoids(src.shape[1], self.emb_dim, max_timescale=200)
        
        # Target projection
        tgt = torch.unsqueeze(tgt, -1)
        tgt_emb = self.decoder_embedding(tgt)
        tgt_emb = self.tgt_norm(tgt_emb)
        tgt_emb = tgt_emb + sinusoids(tgt.shape[1], self.emb_dim, max_timescale=200)
            
        # Transformer encoder-decoder
        outs = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
        
        # Decoder projection
        return self.decoder_projection(outs)
    
    def encode(self, src: Tensor):
        
        # Source projection
        src = torch.unsqueeze(src,-1)
        src_emb = self.encoder_embedding(src)
        src_emb = self.src_norm(src_emb)
        src_emb = src_emb + sinusoids(src.shape[1], self.emb_dim, max_timescale=200)
        
        # Transformer encoder
        return self.transformer.encoder(src_emb)
    
    
    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        
        # Target projection
        tgt = torch.unsqueeze(tgt, -1)
        tgt_emb = self.decoder_embedding(tgt)
        tgt_emb = self.tgt_norm(tgt_emb)
        tgt_emb = tgt_emb + sinusoids(tgt.shape[1], self.emb_dim, max_timescale=200)
        
        # Transformer decoder
        outs =  self.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        
        return self.decoder_projection(outs)
    
        
# Initialize model
model = Model(emb_dim = 256)

# Number of parameters in model
trainables = [p for p in model.parameters() if p.requires_grad]
print('Total parameter number is : {:.3f} million'.format(sum(p.numel() for p in model.parameters()) / 1e6))

Total parameter number is : 2.114 million


### Training

In [12]:
def trainModel(train_loader, model):
    
    # Switch to train mode
    model.train()
    
    # Binary cross entropy with logits loss & loss tracker
    criterion, losses, optimizer = nn.BCEWithLogitsLoss(), AverageMeter(), Adam(model.parameters(),lr=0.0001)
    
    for epoch in range(10):
    
        # Train in mini-batches
        for batch_idx, data in enumerate(train_loader):

            # Get the inputs
            src, tgt = data
            tgt_input = torch.cat((torch.ones((tgt.shape[0],1)).fill_(0.0), tgt), -1)

            tgt_input = tgt_input[:,:-1]
            tgt_mask = create_mask(tgt_input)

            # Forward + Backward
            optimizer.zero_grad()      
            outputs = model(src,tgt_input,tgt_mask)
            loss = criterion(outputs, torch.unsqueeze(tgt,-1))
            loss.backward()
            optimizer.step()
            
            # Update metrics
            losses.update(loss.data.cpu().numpy(), tgt.size(0))

            # Print info
            if batch_idx % 40 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\t'
                 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader), loss=losses))
        print('*****')

In [13]:
# Train model
trainModel(train_loader, model)

*****
*****
*****
*****
*****
*****
*****
*****
*****
*****


### Testing

In [14]:
# Create random input x, true is the flipped array
x    = torch.randint(0,2,(1,10)).float()
true = torch.fliplr(x).flatten().numpy()

In [15]:
# This is x
x

tensor([[1., 0., 1., 0., 0., 0., 0., 0., 0., 1.]])

In [16]:
# This is the flipped array
true

array([1., 0., 0., 0., 0., 0., 0., 1., 0., 1.], dtype=float32)

In [17]:
# Now we will flip x using the transformer and check if we get the same result as "true"
model.eval()
sigmoid = nn.Sigmoid()

# Get the memory
memory = model.encode(x)

# SOS (Start Of Sentence token) is 0
ys = torch.ones(1, 1).fill_(0.0)

# Roll out one sample at a time
for i in tqdm(range(seq_length)):
    tgt_mask = create_mask(ys)
    out = sigmoid(model.decode(ys, memory, tgt_mask))
    out = torch.squeeze(out, -1).data
    out = out[:,-1].view(1,1)
    ys = torch.cat((ys, out), -1)
    
# A token with pr > 0.5 is 1, else 0
predicted = np.where(ys.numpy() > 0.5, 1.0, 0.0).flatten()
predicted = predicted[1:]

100%|██████████████████████████████████████████| 10/10 [00:00<00:00, 225.65it/s]


In [18]:
# Print the predicted array - inspect visually
predicted

array([1., 0., 0., 0., 0., 0., 0., 1., 0., 1.])

In [19]:
# Confirm the arrays are equal
np.array_equal(predicted, true)

True