In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv

import talos

from ipywidgets import interact

%load_ext autoreload
%autoreload 2

Torch Hub models can be accessed @ https://pytorch.org/hub/


In [None]:
from Archs import *

In [None]:
vocab_size = 128 + 3 # +3 cuz we use BOS, EOS and PAD tokens extra.
d_model = 64
dim_feedforward = 64
num_encoder_layers=3
num_decoder_layers=3
max_seq_len = 200

In [None]:
cycle_model = create_cycle_transformer(
    vocab_size=vocab_size,
    d_model=d_model,
    nhead=8,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    activation=F.relu,
    dim_feedforward=dim_feedforward,
    max_seq_length=max_seq_len
)
cycle_model.size_info

'1.573M params, 1.573M trainable params (100.0%), 6.291MB total in memory'

In [8]:
import pickle

class TokenBank(torch.utils.data.Dataset):
    def __init__(
            self,
            path : str, max_len : int = max_seq_len, n_samples : int = 50,
            # start_index : int = 0, end_index : int = 10000,
        ):
        super().__init__()
        
        with open(path, 'rb') as file:
            data = pickle.load(file)
        
        self.maxlen = max_len
        self.data = data
        self.ns = n_samples
        # self.data = np.concatenate(list(map(lambda x: np.squeeze(x, 0), self.data_raw)))        
    
    def __len__(self):
        return min(len(self.data), self.ns)

    def __getitem__(self, idx):
        
        sample = torch.tensor(self.data[idx], dtype=torch.int64)
        sample = torch.concatenate([
            torch.tensor([0 for _ in range(3)] + [1]),
            sample[:max_seq_len - 8] + 3,
            torch.tensor([2] + [0 for _ in range(3)]),
        ])
        
        return sample

english_tb = TokenBank('eng_int_tokens.tokens')
swahili_tb = TokenBank('swahili_int_tokens.tokens')

In [None]:
import time

class Trainer:
    
    def __init__(
            self,
            model : CycleTransformer,
            eng_token_bank : torch.utils.data.Dataset,
            swa_token_bank : torch.utils.data.Dataset,
            pad_token : int = None,
        ):
        self.model = model
        self.eng_bank = eng_token_bank
        self.swa_bank = swa_token_bank
        self.ptok = pad_token
        
        self.device = 'cuda' if talos.gpu_exists() else 'cpu'
        
        self.t1_t2opt = torch.optim.Adam(
            list(self.model.ab_transformer.parameters()) +
            list(self.model.ba_transformer.parameters())
        )
        self.t3opt = torch.optim.Adam(self.model.discriminator.parameters())
        
    
    def move_to_device(self):
        self.model = self.model.to(self.device)
    
    
    def train(self, epochs : int = 1, batch_size : int = 16, lam : float = 0.1):
        eng_loader = torch.utils.data.DataLoader(self.eng_bank, batch_size, pin_memory=True)
        swa_loader = torch.utils.data.DataLoader(self.swa_bank, batch_size, shuffle=True, pin_memory=True)
        
        self.move_to_device()
        
        nbatches = len(eng_loader)
        print(f'nbatches: {nbatches}')
        
        losses = {
            'closs' : [],
            'aloss' : [],
            'dloss' : [],
            'lambda' : [],
        }
        
        K = 30
        
        etime = time.time()
        for epoch in range(epochs):
            
            prevtime = time.time()
            for bi, o_eng_ids, o_swa_ids in zip(range(nbatches), eng_loader, swa_loader):
                
                o_eng_ids = o_eng_ids.to(self.device)
                o_swa_ids = o_swa_ids.to(self.device)
                
                # outputs = self.model(o_eng_ids)
                
                p_swahili_ids, p_swahili_logits = self.model.forward_ab(o_eng_ids)
                p_english_ids, p_english_logits = self.model.forward_ba(p_swahili_ids)
                
                d_fake = self.model.discriminator(torch.detach(p_swahili_ids))
                d_real = self.model.discriminator(o_swa_ids)
                
                d_loss = discriminator_loss(d_real, d_fake)
                
                #update Discriminator
                d_loss.backward()
                self.t3opt.step()
                
                c_loss = cyclic_loss(
                    p_english_logits,
                    o_eng_ids,
                    self.ptok
                )
                a_loss = adversarial_loss(self.model.discriminator(
                    p_swahili_ids
                ))
                
                #updating main pipeline
                gen_loss = c_loss + (lam * a_loss)
                gen_loss.backward()
                self.t1_t2opt.step()
                
                losses['closs'].append(c_loss.detach().cpu().item())
                losses['aloss'].append(a_loss.detach().cpu().item())
                losses['dloss'].append(d_loss.detach().cpu().item())
                losses['lambda'].append(lam)
                
                currtime = time.time()
                eta = (currtime - prevtime) * (nbatches - bi)
                eta = time.strftime('%M:%S', time.gmtime(eta))
                f = bi/nbatches
                lb = ('=' * int(f * K)) + '>' + (' ' * (K - int(f * K)))
                print(
                    f'\rEpoch {epoch+1}/{epochs} [{lb}]({f*100:.2f}%) ETA <= {eta}', end=''
                )
                prevtime = time.time()
            
            cetime = time.time()
            cetime = time.strftime("%M:%S", time.gmtime(cetime - etime))
            print(
                f'\rEpoch {epoch+1}/{epochs} [{lb}]({f*100:.2f}%) Took {cetime}'
            )
            etime = time.time()

        return losses


jarvis = create_cycle_transformer(
    vocab_size=vocab_size,
    d_model=d_model,
    nhead=8,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    activation=F.relu,
    dim_feedforward=dim_feedforward,
    max_seq_length=max_seq_len
)
print(jarvis.size_info)
friday = Trainer(jarvis, english_tb, swahili_tb)
hist = friday.train(epochs=10, batch_size=1)

1.573M params, 1.573M trainable params (100.0%), 6.291MB total in memory
[32m[1mGPU(s) exist![0m
nbatches: 50


{'closs': [5.283940315246582,
  5.261687278747559,
  5.1613850593566895,
  5.096944808959961,
  5.002572059631348,
  5.121870517730713,
  4.862157821655273,
  5.0523810386657715,
  4.975968837738037,
  4.983453750610352,
  4.8539862632751465,
  5.231721878051758,
  4.905221462249756,
  4.855556964874268,
  5.064143657684326,
  4.819822788238525,
  5.243703842163086,
  4.935613632202148,
  4.73769998550415,
  4.868610382080078,
  4.6972246170043945,
  4.86647367477417,
  5.003090858459473,
  4.974726676940918,
  4.984422206878662,
  4.932287693023682,
  4.790943622589111,
  4.586535453796387,
  4.599454879760742,
  4.927395343780518,
  4.8277130126953125,
  5.060727596282959,
  4.923603534698486,
  4.884232521057129,
  4.20780611038208,
  4.966693878173828,
  5.046350002288818,
  4.6674370765686035,
  4.622956275939941,
  5.166489124298096,
  5.343782901763916,
  4.698246479034424,
  5.09874153137207,
  4.2636308670043945,
  5.19062614440918,
  4.298125743865967,
  4.3332109451293945,
 