In [1]:
from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Tokenizer, M2M100Model
from datasets import load_metric
import pickle

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
import segmentation_models_pytorch as smp
from collections import OrderedDict
import numpy as np
import copy

import os
from collections import OrderedDict
import json
import time

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import torchvision.models as models
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from PIL import Image
from torchinfo import summary
import datasets
import transformers
from transformers.optimization import Adafactor, AdafactorSchedule
import tqdm
import time
import math
import random

In [2]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
        self.weights_backup = copy.deepcopy(self.model.state_dict())
        
        self.tokenizers = {'cs': M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", 
                                                            src_lang='cs',
                                                            tgt_lang="en", 
                                                            padding_side='right', 
                                                            truncation_side='right'), 
                           'de':M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", 
                                                            src_lang='de',
                                                            tgt_lang="en", 
                                                            padding_side='right', 
                                                            truncation_side='right')}

    def forward_eval(self, batch, lang_code):
        return self.model.generate(**batch['x'], forced_bos_token_id=self.tokenizers[lang_code].get_lang_id("en"))
    
    def forward_train(self, batch):
        return self.model(**batch['x'], labels=batch['y']) 
        
    def apply_mask(self, mask, sizing):
        start = 0
        copy_state = copy.deepcopy(self.model.state_dict())
        segments = {}
        for i in copy_state:
            if i in sizing:
                end = start + sizing[i]
                segment = np.round(mask[start:end])
                index = np.where(segment == 0)
                
                final_indices = []
                divisor = int(copy_state[i].shape[0]/256)
                for j in index[0]:
                    final_indices += [*range(j*divisor, (j*divisor)+divisor)]
                copy_state[i].data[np.array(final_indices)] = 0
                segments.update({i:index})
                start = end
        self.model.load_state_dict(copy_state)
        for name, param in self.model.named_parameters():
            if name in segments:
                param.data[segments[name]].requires_grad = False
                start = end

    def return_model(self):
        return self.model

    def return_model_state(self):
        return self.model.state_dict()

    def revert_weights(self):
        self.model.load_state_dict(self.weights_backup)
        for name, param in self.model.named_parameters():
            param.requires_grad = True

    def update_backup(self):
        self.weights_backup = copy.deepcopy(self.model.state_dict())


In [3]:
def size_mask(state_dict):
    total = 0
    mask_sizing = OrderedDict()
    total_params = 0
    total_considered = 0
    for i in list(state_dict.keys()):
        shape = torch.tensor(state_dict[i].shape)
        total_params += torch.prod(shape)
        if 'bias' not in i and 'embed' not in i and 'norm' not in i and 'shared' not in i and 'head' not in i:
            total += 256
            mask_sizing.update({i:256})
    print(total)
    return mask_sizing

In [4]:
class Custom_Dataloader:
    def __init__(self, data, batch_size):
        
        self.data = data 
        self.data_amount = data['x']['input_ids'].shape[0]
        
        self.batch_size = batch_size
        self.length = int(math.ceil(self.data_amount/batch_size))
        self.available = set([*range(self.length)])
        
    def select_subset(self, idxs, data, cuda):
        if cuda:
            return {'x':{'input_ids': data['x']['input_ids'][idxs].cuda(), 
                         'attention_mask': data['x']['attention_mask'][idxs].cuda()}, 
                    'y':data['y'][idxs].cuda()}
        
    def sample_batch(self, cuda=True):
        
        if len(self.available)==0:
            self.available = set([*range(self.length)])
        idx = random.choice(tuple(self.available))
        self.available.remove(idx) 
        start = idx*self.batch_size
        if idx == self.length-1:
            diff = self.data_amount - (idx*self.batch_size)
            batch = self.select_subset([i for i in range(start,start+diff)], self.data, cuda)
            return batch
        else: 
            batch = self.select_subset([i for i in range(start,start+self.batch_size)], self.data, cuda)
            return batch
        
    def reset(self):
        self.history = set()
        self.available = set([*range(self.data_amount)])
        
def dual_sample(ds1_batch, ds2_batch):
    perm = torch.randperm(len(ds1_batch['x']['input_ids'])*2)
    x_input_ids = torch.cat((ds1_batch['x']['input_ids'], ds2_batch['x']['input_ids']), 0)
    x_attention_masks = torch.cat((ds1_batch['x']['attention_mask'], ds2_batch['x']['attention_mask']), 0)
    y = torch.cat((ds1_batch['y'], ds2_batch['y']))
    return {'x':{'input_ids': x_input_ids[perm], 
                 'attention_mask': x_attention_masks[perm]},
            'y':y[perm]}

In [17]:
def train_loop(model, 
              lang1_train_dataloader, 
              lang2_train_dataloader, 
              lang1_test_dataloader, 
              lang2_test_dataloader, 
              epochs, 
               model_save_path,
              steps=300):
    metric = datasets.load_metric('sacrebleu')
    # optim= Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
    optim = torch.optim.AdamW(model.parameters(), lr=0.0005)
    lmbda = lambda epoch: 0.95
    scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optim, lr_lambda=lmbda)
    
    best_loss = np.inf
    saved_state = None
    for epoch in range(epochs):
        model.train()
        for name, param in model.named_parameters():
            if param.requires_grad is False:
                param.requires_grad =  True
        
        train(model, 
              lang1_train_dataloader, 
              lang2_train_dataloader, 
              optim, 
              steps, epoch)
        
        scheduler.step()

        torch.cuda.empty_cache()
        
        cs_loss = test_loss(model, lang1_test_dataloader)
        torch.cuda.empty_cache()
        de_loss = test_loss(model, lang2_test_dataloader)
        torch.cuda.empty_cache()
        
        cs_loss = cs_loss.detach().item()
        de_loss = de_loss.detach().item()
        averaged_loss = ((0.5*cs_loss) + (0.5*de_loss))
        print('\n CS Test Loss: ', cs_loss, ', DE Test Loss: ', de_loss, ', Best Averaged Loss: ', averaged_loss)
        
        
        if averaged_loss < best_loss:
            best_loss = averaged_loss
            torch.save(model.state_dict(), model_save_path)

        
def train(model, 
          lang1_train_dataloader, 
          lang2_train_dataloader, 
          optim, 
          steps, epoch):
    optim.zero_grad()
    for i in range(steps):
        
        batch1 = lang1_train_dataloader.sample_batch()
        batch2 = lang2_train_dataloader.sample_batch()
        batch = dual_sample(batch1, batch2)
        loss = model.forward_train(batch).loss
        loss.backward()
        print('\rEpoch {:.3f}\tBatch: {:.3f}, Loss: {:.3f}'.format(epoch, 
                                                                   i, 
                                                                   loss.detach().item()), 
              end="")   
        if (i+1)%30 == 0:
            optim.step()
            optim.zero_grad()

def convert_to_string(tokens, test_code):
    prohib = [1, 2, 128022, 128020, 128017]
    if test_code:
        return [' '.join([str(i) for i in tokens.tolist() if i not in prohib])]
    else:
        return ' '.join([str(i) for i in tokens.tolist() if i not in prohib])

def stringify(tokens, test_code):
    return [convert_to_string(tokens[i], test_code) for i in range(tokens.shape[0])]

def test_BLEU(model, test_dataloader, metric, lang_code):

    start = time.time()

    model.eval()
    all_preds = []
    all_refs = []
    with torch.no_grad():
        for i in range(test_dataloader.length):
            batch = test_dataloader.sample_batch()
            preds = model.forward_eval(batch, lang_code)
            all_preds += stringify(preds[:, 0:8], False)
            all_refs += stringify(batch['y'], True)
        score = metric.compute(predictions = all_preds, references = all_refs)['score']
    return score

def test_loss(model, test_dataloader):
    model.eval()
    loss = 0
    with torch.no_grad():
        for i in range(test_dataloader.length):
            batch = test_dataloader.sample_batch()
            loss += model.forward_train(batch).loss
    return loss/test_dataloader.length

In [7]:

with open('./train_data/de_test.pkl', 'rb') as handle:
    de_test = pickle.load(handle)
    
with open('./train_data/cs_test.pkl', 'rb') as handle:
    cs_test = pickle.load(handle)
    
with open('./train_data/cs_train.pkl', 'rb') as handle:
    cs_train = pickle.load(handle)
with open('./train_data/de_train.pkl', 'rb') as handle:
    de_train = pickle.load(handle)

In [8]:
cs_test_dataloader = Custom_Dataloader(cs_test, 64)
de_test_dataloader = Custom_Dataloader(de_test, 64)

cs_train_dataloader = Custom_Dataloader(cs_train, 96)
de_train_dataloader = Custom_Dataloader(de_train, 96)


In [18]:
model = Model()
model.load_state_dict(torch.load('./best_model.pth'))
for name, param in model.named_parameters():
    if param.requires_grad is False:
        param.requires_grad =  True
model = model.cuda()

In [14]:
print(cs_train_dataloader.length, de_train_dataloader.length)

75695 78084


In [19]:
torch.cuda.empty_cache()

In [20]:
saved_model_state = train_loop(model, 
                              cs_train_dataloader,
                              de_train_dataloader,
                              cs_test_dataloader,
                              de_test_dataloader,
                              65,
                               './best_model.pth',
                             )

Epoch 0.000	Batch: 299.000, Loss: 3.432
 CS Test Loss:  3.3514182567596436 , DE Test Loss:  3.2881734371185303 , Best Averaged Loss:  3.319795846939087
Epoch 1.000	Batch: 299.000, Loss: 1.063
 CS Test Loss:  2.590630292892456 , DE Test Loss:  2.7482502460479736 , Best Averaged Loss:  2.669440269470215
Epoch 2.000	Batch: 299.000, Loss: 2.356
 CS Test Loss:  2.270009756088257 , DE Test Loss:  2.421527624130249 , Best Averaged Loss:  2.345768690109253
Epoch 3.000	Batch: 299.000, Loss: 1.404
 CS Test Loss:  2.099003314971924 , DE Test Loss:  2.2778680324554443 , Best Averaged Loss:  2.188435673713684
Epoch 4.000	Batch: 299.000, Loss: 1.114
 CS Test Loss:  1.9851433038711548 , DE Test Loss:  2.1842100620269775 , Best Averaged Loss:  2.084676682949066
Epoch 5.000	Batch: 299.000, Loss: 1.224
 CS Test Loss:  1.9092463254928589 , DE Test Loss:  2.119856119155884 , Best Averaged Loss:  2.0145512223243713
Epoch 6.000	Batch: 299.000, Loss: 1.354
 CS Test Loss:  1.8840500116348267 , DE Test Loss:  

KeyboardInterrupt: 