<a href="https://colab.research.google.com/github/YDayoub/U-transformer/blob/main/U_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
'''
Import required libraries
'''
import torch
from matplotlib import pyplot as plt
from torch import nn
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from functools import partial
import math
from typing import Tuple
from torch.utils.data import dataset
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cuda:0


In [2]:
# batch_size = 128          
# test_batch_size = 128   
epochs = 100             
lr = 5e-4               
seed = 42               
h_dims = 1024
n_heads = 16
n_blocks = 2
dropout = 0.2
clip = 0.5
batch_size = 20
eval_batch_size = 10
bptt = 256
d_model = 400

In [3]:
'''
This code is adapted from 
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html
'''
class Embedding_with_PosEncoding(nn.Module):
  def __init__(self,input_dim,d_model, max_len=5000,dropout=0):
    '''
    Args:
      d_model: hidden space dimentionality for Embedding
      input_dim: input space dimentionality
      max_len: maximum length of an input sequence
      drop: probability of an element to be zeroed
    '''
    super(Embedding_with_PosEncoding,self).__init__()
    self.emb = nn.Embedding(input_dim,d_model)
    self.dropout = nn.Dropout(p=dropout)
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    # register_buffer is used to save and retrain parameters which don't need to train
    self.register_buffer('pe', pe, persistent=False) 
  def forward(self,x):
    seq_len = x.size(1)
    x = self.emb(x)
    x = self.dropout(x)
    x = x + self.pe[:, :seq_len]
    return x
  def get_pe(self):
    return self.pe

In [4]:
def test_positional_encoding():
  batch_dim,seq_len,input_dim= (15,10,10)
  d_model = 100
  max_len =100
  x = torch.randint(low=0, high=10,size=(batch_dim,seq_len))
  pos_encoder = Embedding_with_PosEncoding(input_dim,d_model,max_len)
  pe = pos_encoder.get_pe()
  res = pos_encoder(x)
  assert res.shape ==  torch.Size([batch_dim,seq_len,d_model])
  assert pe.shape == torch.Size([1, max_len, d_model])
test_positional_encoding()

In [5]:
def scaled_dot_product(query,key,values,mask=None,scale=True):
  '''
      Args:
        query: tensor of queries
        key : tensor of keys
        value: tensor of value
        mask (numpy.ndarray): attention-mask, used to perform self attention when required
        scale (bool): whether to scale the dot product of the query and transposed key
  '''
  if scale:
    depth = query.shape[-1] ** 0.5
  else:
    depth = 1
  dots = torch.matmul(query,torch.swapaxes(key,-1,-2))/depth
  if mask is not None:
    dots = torch.where(mask,dots,torch.full_like(dots, -9e15))
  logsumexp = torch.logsumexp(dots, axis=-1, keepdims=True)
  dots = torch.exp(dots - logsumexp)
  attention = torch.matmul(dots, values)
  return attention
def dot_product_self_attention(q, k, v,device=device):
  '''
    Args:
        q: queries.
        k: keys.
        v: values.
    Returns:
        masked dot product self attention tensor.  
  '''
  mask_size = q.shape[-2]
  mask = torch.tril(torch.ones((1, mask_size, mask_size), dtype=torch.bool), diagonal=0).to(device)        
  return scaled_dot_product(q, k, v, mask)



In [6]:
class QKV(nn.Module):
  '''
  takes as input a tensor of shape (batch_size,seq_len,d_model)
  returns:
  three tensors q,k,v of shape (batch_size,n_heads,seq_len,d_model//n_heads)
  '''

  def __init__(self,n_heads,d_model):
    '''
      Args:
        n_heads: number of heads used in multihead attention
        d_model: hidden space dimensions
    '''
    assert d_model%n_heads==0,'d_models should be divisible by n_heads'
    super(QKV,self).__init__()
    self.qvk = nn.Linear(in_features=d_model,out_features=3*d_model)
    self.d_model = d_model
    self.n_heads = n_heads
    self.d_heads = d_model//n_heads
  def forward(self,x):
    batch_size,seq_len,d_model = x.shape
    x = self.qvk(x)
    x = x.reshape(batch_size,seq_len,self.n_heads,3*self.d_heads)
    x = x.permute(0,2,1,3)
    q,k,v = x.chunk(3,dim=-1)
    return q,k,v



In [7]:
def test_QKV():
  batch_dim,seq_len,d_model= (15,10,200)
  n_heads = 2
  x = torch.randn(batch_dim,seq_len, d_model).to(device)
  qkv = QKV(n_heads=n_heads,d_model=d_model).to(device)
  q,k,v = qkv(x)
  assert q.shape ==  torch.Size([batch_dim, n_heads, seq_len,d_model//n_heads])
  assert k.shape ==  torch.Size([batch_dim, n_heads, seq_len,d_model//n_heads])
  assert v.shape ==  torch.Size([batch_dim, n_heads, seq_len,d_model//n_heads])
test_QKV()

In [8]:
class MultiheadAttention(nn.Module):
  '''
  This class implements mulithead attention
  '''
  def __init__(self,d_model,causal_attention=False):
    '''
      Args:
        d_model: hidden space dimensions
        causal_attention: boolean whether to use attention or causal attention 
    '''
    super(MultiheadAttention,self).__init__()
    self.d_model = d_model
    self.o = nn.Linear(in_features=d_model,out_features=d_model)
    self.causal_attention = causal_attention 

  def forward(self,q,k,v):
    batch_size,n_heads,seq_len,d_heads = q.shape
    if self.causal_attention:
      atten =  dot_product_self_attention(q, k, v)
    else:
      atten = scaled_dot_product(q,k,v)
    atten = atten.permute(0,2,1,3)
    atten = atten.reshape(batch_size,seq_len,self.d_model)
    res = self.o(atten)
    return res


In [9]:
def test_MultiheadAttention():
  batch_dim,seq_len,d_model= (15,10,200)
  n_heads = 2
  att = MultiheadAttention(d_model,causal_attention=False).to(device)
  causal_att = MultiheadAttention(d_model,causal_attention=True).to(device)
  x = torch.randn(batch_dim, n_heads, seq_len,3,d_model//n_heads).to(device)
  q,k,v = x[:,:,:,0,:],x[:,:,:,1,:],x[:,:,:,2,:]
  o1 = att(q,k,v)
  o2 = causal_att(q,k,v)
  assert o1.shape ==  torch.Size([batch_dim, seq_len,d_model])
  assert o2.shape ==  torch.Size([batch_dim,  seq_len,d_model])
test_MultiheadAttention()

In [10]:
class EncoderBlock(nn.Module):
  '''
  This class implements encoder block
  '''
  def __init__(self,d_model, n_heads, dim_feedforward, dropout=0.0):
    '''
      Args:
        d_model: hidden space dimensions
        n_heads: number of heads
        dim_feedforward: Dimensionality of the hidden layer in the MLP  
        drop: probability of an element to be zeroed
    '''
    super(EncoderBlock,self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.qkv =  QKV(n_heads=n_heads,d_model=d_model)
    self.attention = MultiheadAttention(d_model=d_model,causal_attention=True)
    self.feedforward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, d_model)
        )
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)

  def forward(self,x0):
    q,k,v = self.qkv(x0)
    x1 = self.attention(q,k,v)
    x2 = self.norm1(x0+self.dropout(x1))
    x3 = self.feedforward(x2)
    x4 = self.norm2(self.dropout(x3)+x2)
    return x4



In [11]:
def reshape_tensor(x,n_heads):
  '''
    Args:
      x: tensor of shape (batch_size,seq_len,d_model)
      n_heads: number of heads in mutlihead attention
    Returns:
      reshaped tensor of shape (batch_size,n_heads,seq_len,d_model//n_heads)    
  '''
  batch_size,seq_len,d_model = x.shape
  x = x.reshape(batch_size,seq_len,n_heads,d_model//n_heads)
  x = x.permute(0,2,1,3)
  return x

class DecoderBlock(nn.Module):
  '''
    This class implements decoder block
  '''

  def __init__(self,d_model, n_heads, dim_feedforward, dropout=0.0):
    '''
      Args:
        d_model: hidden space dimensions
        n_heads: number of heads
        dim_feedforward: Dimensionality of the hidden layer in the MLP  
        drop: probability of an element to be zeroed
    '''
    super(DecoderBlock,self).__init__()
    self.n_heads = n_heads
    self.d_model = d_model
    self.qkv = QKV(n_heads,d_model)
    self.dropout = nn.Dropout(p=dropout)
    self.attention = MultiheadAttention(d_model,causal_attention=True)
    self.causal_attention = MultiheadAttention(d_model,causal_attention=True)
    self.feedforward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, d_model)
        )
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)

  def forward(self,x0,skip_con):
    q,k,v = self.qkv(x0)
    x1 = self.causal_attention(q,k,v)
    x2 = self.norm1(x0+self.dropout(x1))
    x3 = reshape_tensor(x2,self.n_heads)
    skip_con = reshape_tensor(skip_con,self.n_heads)
    x4 = self.attention(x3,skip_con,skip_con)
    x5 = self.norm2(x2+self.dropout(x4))
    x6 = self.feedforward(x5)
    x7 = self.norm3(self.dropout(x6)+x5)
    return x7

In [12]:
class UnetTransformer(nn.Module):
  '''
    This class implements unet transformer
  '''
  def __init__(self,n_blocks,n_tokens,n_heads,d_model,dim_feedforward,dropout=0.0):

    '''
      Args:
        n_blocks: number of encoder/decoder blocks
        n_tokens: Dimensionality of the input space
        n_heads: number of heads in MultiHeadAttention
        d_model: Dimensionality of the embedding space
        num_classes: Dimensionality of the output space
        dim_feedforward:  Dimensionality of the hidden layer in the MLP 


    '''
    super(UnetTransformer,self).__init__()
    self.n_blocks = n_blocks
    self.pos_enc = Embedding_with_PosEncoding(n_tokens,d_model,dropout=dropout)
    for i in range(n_blocks):
      vars(self)['_modules']['enc_'+str(i)] = EncoderBlock(d_model, n_heads, dim_feedforward, dropout)
    for i in range(n_blocks):
      vars(self)['_modules']['dec_'+str(i)] = DecoderBlock(d_model, n_heads, dim_feedforward, dropout)
    self.output_layer = nn.Sequential( nn.Linear(d_model, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_model, n_tokens)
        )
  def forward(self,x):
    x_encoded = self.pos_enc(x)
    layers = vars(self)['_modules']
    stack = [x_encoded]
    x = layers['enc_0'](x_encoded)
    for i in range(1,self.n_blocks):
      stack.append(x)
      x = layers['enc_'+str(i)](x)
    stack.append(x)
    x = layers['dec_0'](x,stack.pop(0))
    for i in range(1,self.n_blocks):
      x = layers['dec_'+str(i)](x,stack.pop(0))
    return self.output_layer(x)

    

In [13]:
import math
from typing import Tuple
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset

In [43]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('spacy')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# train_iter was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).contiguous()
    return data.to(device)


train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [44]:
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(bptt, source.shape[1] - 1 - i)
    data = source[:,i:i+seq_len]
    target = source[:,i+1:i+1+seq_len].reshape(-1)
    return data, target
get_batch(train_data,0)[0].shape

torch.Size([20, 256])

In [45]:
ntokens = len(vocab)
print('n_tokens {}'.format(len(vocab)))
model = UnetTransformer(n_blocks=n_blocks,n_tokens=ntokens,\
                        n_heads=n_heads, d_model = d_model,dim_feedforward = h_dims,\
                        dropout=dropout).to(device)
pytorch_total_params = sum(p.numel()
                        for p in model.parameters() if p.requires_grad)
print('-' * 89)
print(
    '#'*12+f" Training model with {pytorch_total_params/1000000:0.2F}M trainable parameters for {epochs:3d} epochs "+'#'*12)
print('-' * 89)

n_tokens 33243
-----------------------------------------------------------------------------------------
############ Training model with 32.97M trainable parameters for 100 epochs ############
-----------------------------------------------------------------------------------------


In [46]:
class BasicOpt:
    def __init__(self, optimizer, schedular):
        self.optimizer = optimizer
        self.schedular = schedular
        self._scalar = 1

    def zero_grad(self):
        self.optimizer.zero_grad()

    @property
    def lr(self):
      return self.optimizer.param_groups[0]['lr']*self._scalar

    @property
    def scalar(self):
        return self._scalar

    @scalar.setter
    def scalar(self, scalar):
        self._scalar = scalar


    def schedule_step(self, val_loss):
        raise NotImplementedError

    def step(self):
        raise NotImplementedError

class linearcycleWarmup(BasicOpt):
    def __init__(self, optimizer, schedular, *args, **kwargs):
        super().__init__(optimizer=optimizer, schedular=schedular)
        self.use_scheduler = True

       
    def step(self):
        lr_s = [p['lr'] for p in self.optimizer.param_groups]
        for p in self.optimizer.param_groups:
            p['lr']  = p['lr']*self._scalar             
        self.optimizer.step()
        for idx, p in enumerate(self.optimizer.param_groups):
            p['lr'] = lr_s[idx]
        try:
          if self.use_scheduler:
            self.schedular.step()
        except Exception as e:
          self.use_scheduler = False
          for idx, p in enumerate(self.optimizer.param_groups):
            p['lr'] = 0.00000088



    def schedule_step(self, *args):
        pass
steps_per_epoch = len(train_data)//bptt+1
total_steps = epochs*(steps_per_epoch)
opt_args = {
    'lr': 0,
    'betas': (0.9, 0.98), 'eps': 1e-9, 'weight_decay': 1e-5
}

linear_args = {
    'total_steps': total_steps,
    'pct_start': 0.3, 'anneal_strategy': 'linear',
    'three_phase': True, 'max_lr': 1e-3
}
opt = torch.optim.RAdam(model.parameters(),
                        **opt_args)
schedular_args = linear_args
schedular = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, **schedular_args)

In [47]:
import copy
import time

criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = linearcycleWarmup(optimizer = opt, schedular=schedular )

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    #src_mask = generate_square_subsequent_mask(bptt).to(device)

    num_batches = train_data.shape[1] // bptt
    for batch, i in enumerate(range(0, train_data.size(1) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        batch_size = data.size(1)
        # if batch_size != bptt:  # only on last batch
        #     src_mask = src_mask[:batch_size, :batch_size]
        output = model(data)
        loss = criterion(output.view(-1, ntokens), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = optimizer.lr
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.6f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    #src_mask = generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, eval_data.size(1) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            batch_size = data.size(1)
            # if batch_size != bptt:
            #     src_mask = src_mask[:batch_size, :batch_size]
            output = model(data)
            output_flat = output.view(-1, ntokens)
            total_loss += batch_size * criterion(output_flat, targets).item()
    return total_loss / (eval_data.size(1) - 1)

In [48]:
best_val_loss = float('inf')
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)
    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)

    #scheduler.step()

| epoch   1 |   200/  434 batches | lr 0.000001 | ms/batch 111.58 | loss  7.64 | ppl  2087.41
| epoch   1 |   400/  434 batches | lr 0.000001 | ms/batch 111.18 | loss  7.08 | ppl  1191.51
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 50.01s | valid loss  6.52 | valid ppl   678.87
-----------------------------------------------------------------------------------------


KeyboardInterrupt: ignored

# Not important

In [None]:
def train(model, device, train_loader, optimizer, criterion):
  model.train()
  num_classes = train_loader.dataset.num_categories
  l = 0
  acc = 0
  pbar = tqdm(total = len(train_loader),position=0,leave=True)
  for batch_idx, (data, target) in enumerate(train_loader):
    target = target.to(device)

    optimizer.zero_grad() 
    x, output_shifted = data
    preds = model(x.to(device), output_shifted.to(device))       
    loss = criterion(preds.view(-1,preds.size(-1)), target.view(-1),reduction="mean")
    loss.backward() 
    #nn.utils.clip_grad_norm(model.parameters(), clip)       
    optimizer.step()
    with torch.no_grad():
      current_loss = loss.item()
      l+= loss.item()
      acc+=(preds.argmax(dim=-1) == target).float().mean().item()
    pbar.set_description('training_step {} loss:{:3f}'.format(batch_idx,current_loss))
    pbar.update()
  acc = 100. * acc / (len(train_loader))
  l = l/len(train_loader)
  print('{0}: loss: {1:.3f} acc {2:.1f}'.format('train',l,acc))


In [None]:
def test( model, device, test_loader,criterion,mode='eval'):
  model.eval()
  test_loss = 0
  correct = 0
  num_classes=test_loader.dataset.num_categories
  with torch.no_grad():
      for data, target in test_loader:
          target = target.to(device)
          x, output_shifted = data
          output = model(x.to(device), output_shifted.to(device))
          test_loss += criterion(output.view(-1,output.size(-1)),\
                                        target.view(-1)).item()          
          correct += (output.argmax(dim=-1) == target).float().mean().item()

  loss = test_loss/len(test_loader)
  acc = 100. * correct / len(test_loader)
  print('{0}: loss: {1:.3f} acc {2:.1f}'.format(mode,loss,acc))


      

In [None]:
class ReverseDataset(data.Dataset):
    def __init__(self, num_categories, seq_len, size):
        super().__init__()
        self.num_categories = num_categories
        self.seq_len = seq_len
        self.size = size

        self.data = torch.randint(low=1, high=self.num_categories, size=(self.size, self.seq_len))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        inp_data = self.data[idx]
        labels = torch.flip(inp_data, dims=(0,))
        labels_shifted = labels.roll(1,0)
        labels_shifted[0] = torch.tensor(0)
        return (inp_data,labels_shifted), labels

In [None]:
def main():
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.determinstic = True
    torch.backends.cudnn.benchmark = False
    dataset = partial(ReverseDataset, 10, 16)
    train_loader = data.DataLoader(dataset(50000), batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
    
    val_loader   = data.DataLoader(dataset(1000), batch_size=test_batch_size)
    test_loader  = data.DataLoader(dataset(10000), batch_size=test_batch_size)
    model = UnetTransformer(n_blocks=n_blocks,input_dim=train_loader.dataset.num_categories,\
                            n_heads=n_heads,d_model = h_dims,num_classes=\
                            train_loader.dataset.num_categories,dim_feedforward = h_dims,\
                            dropout=dropout).to(device)
    optimizer = optim.Adam(model.parameters(), lr= lr)
    criterion =  F.cross_entropy
    


    for epoch in range(1, epochs + 1):
       train(model, device, train_loader, optimizer,criterion)
       test(model, device, val_loader,criterion)
        
    torch.save(model.state_dict(), "model.h5")
    print('------------testing--------------')
    test(model, device, test_loader,criterion,mode='test')


if __name__ == '__main__':
    pass
    #main()