<a href="https://colab.research.google.com/github/YDayoub/U-transformer/blob/main/U_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
'''
Import required libraries

'''
import torch
from matplotlib import pyplot as plt
from torch import nn
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from functools import partial
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cuda:0


In [11]:
batch_size = 128          
test_batch_size = 128   
epochs = 15             
lr = 5e-4               
seed = 42               
h_dims = 16
n_heads = 2
n_blocks = 2
dropout = 0.1
clip = 5

In [None]:
'''
This code is adapted from 
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html
'''
class Embedding_with_PosEncoding(nn.Module):
  def __init__(self,input_dim,d_model, max_len=5000,dropout=0):
    '''
    Args:
      d_model: hidden space dimentionality for Embedding
      input_dim: input space dimentionality
      max_len: maximum length of an input sequence
      drop: probability of an element to be zeroed
    '''
    super(Embedding_with_PosEncoding,self).__init__()
    self.emb = nn.Linear(in_features=input_dim,out_features = d_model)
    self.dropout = nn.Dropout(p=dropout)
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    # register_buffer is used to save and retrain parameters which don't need to train
    self.register_buffer('pe', pe, persistent=False) 
  def forward(self,x):
    seq_len = x.size(1)
    x = self.emb(x)
    x = self.dropout(x)
    x = x + self.pe[:, :seq_len]
    return x
  def get_pe(self):
    return self.pe

In [None]:
def test_positional_encoding():
  batch_dim,seq_len,input_dim= (15,10,200)
  d_model = 100
  max_len =100
  x = torch.randn(batch_dim,seq_len, input_dim)
  pos_encoder = Embedding_with_PosEncoding(input_dim,d_model,max_len)
  pe = pos_encoder.get_pe()
  res = pos_encoder(x)
  assert res.shape ==  torch.Size([batch_dim,seq_len,d_model])
  assert pe.shape == torch.Size([1, max_len, d_model])
test_positional_encoding()

In [None]:
def scaled_dot_product(query,key,values,mask=None,scale=True):
  '''
      Args:
        query: tensor of queries
        key : tensor of keys
        value: tensor of value
        mask (numpy.ndarray): attention-mask, used to perform self attention when required
        scale (bool): whether to scale the dot product of the query and transposed key
  '''
  if scale:
    depth = query.shape[-1] ** 0.5
  else:
    depth = 1
  dots = torch.matmul(query,torch.swapaxes(key,-1,-2))/depth
  if mask is not None:
    dots = torch.where(mask,dots,torch.full_like(dots, -9e15))
  logsumexp = torch.logsumexp(dots, axis=-1, keepdims=True)
  dots = torch.exp(dots - logsumexp)
  attention = torch.matmul(dots, values)
  return attention
def dot_product_self_attention(q, k, v,device=device):
  '''
    Args:
        q: queries.
        k: keys.
        v: values.
    Returns:
        masked dot product self attention tensor.  
  '''
  mask_size = q.shape[-2]
  mask = torch.tril(torch.ones((1, mask_size, mask_size), dtype=torch.bool), diagonal=0).to(device)        
  return scaled_dot_product(q, k, v, mask)



In [None]:
class QKV(nn.Module):
  '''
  takes as input a tensor of shape (batch_size,seq_len,d_model)
  returns:
  three tensors q,k,v of shape (batch_size,n_heads,seq_len,d_model//n_heads)
  '''

  def __init__(self,n_heads,d_model):
    '''
      Args:
        n_heads: number of heads used in multihead attention
        d_model: hidden space dimensions
    '''
    assert d_model%n_heads==0,'d_models should be divisible by n_heads'
    super(QKV,self).__init__()
    self.qvk = nn.Linear(in_features=d_model,out_features=3*d_model)
    self.d_model = d_model
    self.n_heads = n_heads
    self.d_heads = d_model//n_heads
  def forward(self,x):
    batch_size,seq_len,d_model = x.shape
    x = self.qvk(x)
    x = x.reshape(batch_size,seq_len,self.n_heads,3*self.d_heads)
    x = x.permute(0,2,1,3)
    q,k,v = x.chunk(3,dim=-1)
    return q,k,v



In [None]:
def test_QKV():
  batch_dim,seq_len,d_model= (15,10,200)
  n_heads = 2
  x = torch.randn(batch_dim,seq_len, d_model).to(device)
  qkv = QKV(n_heads=n_heads,d_model=d_model).to(device)
  q,k,v = qkv(x)
  assert q.shape ==  torch.Size([batch_dim, n_heads, seq_len,d_model//n_heads])
  assert k.shape ==  torch.Size([batch_dim, n_heads, seq_len,d_model//n_heads])
  assert v.shape ==  torch.Size([batch_dim, n_heads, seq_len,d_model//n_heads])
test_QKV()

In [None]:
class MultiheadAttention(nn.Module):
  '''
  This class implements mulithead attention
  '''
  def __init__(self,d_model,causal_attention=False):
    '''
      Args:
        d_model: hidden space dimensions
        causal_attention: boolean whether to use attention or causal attention 
    '''
    super(MultiheadAttention,self).__init__()
    self.d_model = d_model
    self.o = nn.Linear(in_features=d_model,out_features=d_model)
    self.causal_attention = causal_attention 

  def forward(self,q,k,v):
    batch_size,n_heads,seq_len,d_heads = q.shape
    if self.causal_attention:
      atten =  dot_product_self_attention(q, k, v)
    else:
      atten = scaled_dot_product(q,k,v)
    atten = atten.permute(0,2,1,3)
    atten = atten.reshape(batch_size,seq_len,self.d_model)
    res = self.o(atten)
    return res


In [None]:
def test_MultiheadAttention():
  batch_dim,seq_len,d_model= (15,10,200)
  n_heads = 2
  att = MultiheadAttention(d_model,causal_attention=False).to(device)
  causal_att = MultiheadAttention(d_model,causal_attention=True).to(device)
  x = torch.randn(batch_dim, n_heads, seq_len,3,d_model//n_heads).to(device)
  q,k,v = x[:,:,:,0,:],x[:,:,:,1,:],x[:,:,:,2,:]
  o1 = att(q,k,v)
  o2 = causal_att(q,k,v)
  assert o1.shape ==  torch.Size([batch_dim, seq_len,d_model])
  assert o2.shape ==  torch.Size([batch_dim,  seq_len,d_model])
test_MultiheadAttention()

In [None]:
class EncoderBlock(nn.Module):
  '''
  This class implements encoder block
  '''
  def __init__(self,d_model, n_heads, dim_feedforward, dropout=0.0):
    '''
      Args:
        d_model: hidden space dimensions
        n_heads: number of heads
        dim_feedforward: Dimensionality of the hidden layer in the MLP  
        drop: probability of an element to be zeroed
    '''
    super(EncoderBlock,self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.qkv =  QKV(n_heads=n_heads,d_model=d_model)
    self.attention = MultiheadAttention(d_model=d_model,causal_attention=False)
    self.feedforward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, d_model)
        )
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)

  def forward(self,x0):
    q,k,v = self.qkv(x0)
    x1 = self.attention(q,k,v)
    x2 = self.norm1(x0+self.dropout(x1))
    x3 = self.feedforward(x2)
    x4 = self.norm2(self.dropout(x3)+x2)
    return x4



In [None]:
def reshape_tensor(x,n_heads):
  '''
    Args:
      x: tensor of shape (batch_size,seq_len,d_model)
      n_heads: number of heads in mutlihead attention
    Returns:
      reshaped tensor of shape (batch_size,n_heads,seq_len,d_model//n_heads)    
  '''
  batch_size,seq_len,d_model = x.shape
  x = x.reshape(batch_size,seq_len,n_heads,d_model//n_heads)
  x = x.permute(0,2,1,3)
  return x

class DecoderBlock(nn.Module):
  '''
    This class implements decoder block
  '''

  def __init__(self,d_model, n_heads, dim_feedforward, dropout=0.0):
    '''
      Args:
        d_model: hidden space dimensions
        n_heads: number of heads
        dim_feedforward: Dimensionality of the hidden layer in the MLP  
        drop: probability of an element to be zeroed
    '''
    super(DecoderBlock,self).__init__()
    self.n_heads = n_heads
    self.d_model = d_model
    self.qkv = QKV(n_heads,d_model)
    self.dropout = nn.Dropout(p=dropout)
    self.attention = MultiheadAttention(d_model,causal_attention=False)
    self.causal_attention = MultiheadAttention(d_model,causal_attention=True)
    self.feedforward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, d_model)
        )
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)

  def forward(self,x0,skip_con):
    q,k,v = self.qkv(x0)
    x1 = self.attention(q,k,v)
    x2 = self.norm1(x0+self.dropout(x1))
    x3 = reshape_tensor(x2,self.n_heads)
    skip_con = reshape_tensor(skip_con,self.n_heads)
    x4 = self.causal_attention(x3,skip_con,skip_con)
    x5 = self.norm2(x2+self.dropout(x4))
    x6 = self.feedforward(x5)
    x7 = self.norm3(self.dropout(x6)+x5)
    return x7

In [None]:
class UnetTransformer(nn.Module):
  '''
    This class implements unet transformer
  '''
  def __init__(self,n_blocks,input_dim,n_heads,d_model,num_classes,dim_feedforward,dropout=0.0):

    '''
      Args:
        n_blocks: number of encoder/decoder blocks
        input_dim: Dimensionality of the input space
        n_heads: number of heads in MultiHeadAttention
        d_model: Dimensionality of the embedding space
        num_classes: Dimensionality of the output space
        dim_feedforward:  Dimensionality of the hidden layer in the MLP 


    '''
    super(UnetTransformer,self).__init__()
    self.n_blocks = n_blocks
    self.pos_enc = Embedding_with_PosEncoding(input_dim,d_model,dropout=dropout)
    for i in range(n_blocks):
      vars(self)['_modules']['enc_'+str(i)] = EncoderBlock(d_model, n_heads, dim_feedforward, dropout)
    for i in range(n_blocks):
      vars(self)['_modules']['dec_'+str(i)] = DecoderBlock(d_model, n_heads, dim_feedforward, dropout)
    self.output_layer = nn.Sequential( nn.Linear(d_model, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )
  def forward(self,x):
    x_encoded = self.pos_enc(x)
    layers = vars(self)['_modules']
    stack = [x_encoded]
    x = layers['enc_0'](x_encoded)
    for i in range(1,self.n_blocks):
      stack.append(x)
      x = layers['enc_'+str(i)](x)
    for i in range(self.n_blocks):
      x = layers['dec_'+str(i)](x,stack.pop())
    return self.output_layer(x)

    

In [None]:
def train(model, device, train_loader, optimizer, criterion):
  model.train()
  num_classes = train_loader.dataset.num_categories
  l = 0
  acc = 0
  pbar = tqdm(total = len(train_loader),position=0,leave=True)
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    data = F.one_hot(data, num_classes=num_classes).float()
    optimizer.zero_grad()        
    preds = model(data)        
    loss = criterion(preds.view(-1,preds.size(-1)), target.view(-1),reduction="mean")
    loss.backward() 
    #nn.utils.clip_grad_norm(model.parameters(), clip)       
    optimizer.step()
    with torch.no_grad():
      current_loss = loss.item()
      l+= loss.item()
      acc+=(preds.argmax(dim=-1) == target).float().mean().item()
    pbar.set_description('training_step {} loss:{:3f}'.format(batch_idx,current_loss))
    pbar.update()
  acc = 100. * acc / (len(train_loader))
  l = l/len(train_loader)
  print('{0}: loss: {1:.3f} acc {2:.1f}'.format('train',l,acc))


In [None]:
def test( model, device, test_loader,criterion,mode='eval'):
  model.eval()
  test_loss = 0
  correct = 0
  num_classes=test_loader.dataset.num_categories
  with torch.no_grad():
      for data, target in test_loader:
          data, target = data.to(device), target.to(device)
          data = F.one_hot(data, num_classes=num_classes).float()  
          output = model(data)
          test_loss += criterion(output.view(-1,output.size(-1)),\
                                        target.view(-1)).item()          
          correct += (output.argmax(dim=-1) == target).float().mean().item()

  loss = test_loss/len(test_loader)
  acc = 100. * correct / len(test_loader)
  print('{0}: loss: {1:.3f} acc {2:.1f}'.format(mode,loss,acc))


      

In [None]:
class ReverseDataset(data.Dataset):

    def __init__(self, num_categories, seq_len, size):
        super().__init__()
        self.num_categories = num_categories
        self.seq_len = seq_len
        self.size = size

        self.data = torch.randint(self.num_categories, size=(self.size, self.seq_len))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        inp_data = self.data[idx]
        labels = torch.flip(inp_data, dims=(0,))
        return inp_data, labels

In [None]:
def main():
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.determinstic = True
    torch.backends.cudnn.benchmark = False
    dataset = partial(ReverseDataset, 10, 16)
    train_loader = data.DataLoader(dataset(50000), batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)

    val_loader   = data.DataLoader(dataset(1000), batch_size=test_batch_size)
    test_loader  = data.DataLoader(dataset(10000), batch_size=test_batch_size)
    model = UnetTransformer(n_blocks=n_blocks,input_dim=train_loader.dataset.num_categories,\
                            n_heads=n_heads,d_model = h_dims,num_classes=\
                            train_loader.dataset.num_categories,dim_feedforward = h_dims,\
                            dropout=dropout).to(device)
    optimizer = optim.Adam(model.parameters(), lr= lr)
    criterion =  F.cross_entropy
    


    for epoch in range(1, epochs + 1):
       train(model, device, train_loader, optimizer,criterion)
       test(model, device, val_loader,criterion)
        
    torch.save(model.state_dict(), "model.h5")
    print('------------testing--------------')
    test(model, device, test_loader,criterion,mode='test')


if __name__ == '__main__':
    main()

training_step 389 loss:1.737720: 100%|██████████| 390/390 [00:07<00:00, 54.70it/s]


train: loss: 2.237 acc 15.1
eval: loss: 1.562 acc 48.6


training_step 389 loss:0.198213: 100%|██████████| 390/390 [00:07<00:00, 54.91it/s]


train: loss: 0.660 acc 83.9
eval: loss: 0.083 acc 100.0


training_step 389 loss:0.078665: 100%|██████████| 390/390 [00:07<00:00, 54.66it/s]


train: loss: 0.122 acc 98.6
eval: loss: 0.014 acc 100.0


training_step 389 loss:0.039562: 100%|██████████| 390/390 [00:07<00:00, 53.25it/s]


train: loss: 0.052 acc 99.3
eval: loss: 0.004 acc 100.0


training_step 389 loss:0.033151: 100%|██████████| 390/390 [00:07<00:00, 53.82it/s]


train: loss: 0.033 acc 99.5
eval: loss: 0.002 acc 100.0


training_step 389 loss:0.021332: 100%|██████████| 390/390 [00:07<00:00, 54.47it/s]


train: loss: 0.025 acc 99.6
eval: loss: 0.001 acc 100.0


training_step 389 loss:0.017049: 100%|██████████| 390/390 [00:07<00:00, 54.72it/s]


train: loss: 0.019 acc 99.7
eval: loss: 0.000 acc 100.0


training_step 389 loss:0.015821: 100%|██████████| 390/390 [00:07<00:00, 54.32it/s]


train: loss: 0.015 acc 99.7
eval: loss: 0.000 acc 100.0


training_step 389 loss:0.011333: 100%|██████████| 390/390 [00:07<00:00, 54.02it/s]


train: loss: 0.013 acc 99.7
eval: loss: 0.000 acc 100.0


training_step 389 loss:0.009495: 100%|██████████| 390/390 [00:07<00:00, 54.11it/s]


train: loss: 0.011 acc 99.8
eval: loss: 0.000 acc 100.0


training_step 389 loss:0.007387: 100%|██████████| 390/390 [00:07<00:00, 54.39it/s]


train: loss: 0.010 acc 99.8
eval: loss: 0.000 acc 100.0


training_step 389 loss:0.005797: 100%|██████████| 390/390 [00:07<00:00, 54.78it/s]


train: loss: 0.008 acc 99.8
eval: loss: 0.000 acc 100.0


training_step 389 loss:0.008655: 100%|██████████| 390/390 [00:07<00:00, 53.47it/s]


train: loss: 0.008 acc 99.8
eval: loss: 0.000 acc 100.0


training_step 389 loss:0.008578: 100%|██████████| 390/390 [00:07<00:00, 54.46it/s]


train: loss: 0.007 acc 99.8
eval: loss: 0.000 acc 100.0


training_step 389 loss:0.003324: 100%|██████████| 390/390 [00:07<00:00, 54.42it/s]


train: loss: 0.006 acc 99.8
eval: loss: 0.000 acc 100.0
------------testing--------------
test: loss: 0.000 acc 100.0
