In [1]:
from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Tokenizer, M2M100Model
from datasets import load_metric
import pickle

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
import segmentation_models_pytorch as smp
from collections import OrderedDict
import numpy as np
import copy

import os
from collections import OrderedDict
import json
import time

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import torchvision.models as models
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from PIL import Image
from torchinfo import summary
import datasets
import transformers
from transformers.optimization import Adafactor, AdafactorSchedule
import tqdm
import time
import math
import random

In [2]:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
        self.weights_backup = copy.deepcopy(self.model.state_dict())
        
        self.tokenizers = {'cs': M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", 
                                                            src_lang='cs',
                                                            tgt_lang="en", 
                                                            padding_side='right', 
                                                            truncation_side='right'), 
                           'de':M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", 
                                                            src_lang='de',
                                                            tgt_lang="en", 
                                                            padding_side='right', 
                                                            truncation_side='right')}

    def forward_eval(self, batch, lang_code):
        return self.model.generate(**batch['x'], forced_bos_token_id=self.tokenizers[lang_code].get_lang_id("en"))
    
    def forward_train(self, batch):
        return self.model(**batch['x'], labels=batch['y']) 
        
    def apply_mask(self, mask, sizing):
        start = 0
        copy_state = copy.deepcopy(self.model.state_dict())
        segments = {}
        for i in copy_state:
            if i in sizing:
                end = start + sizing[i]
                segment = np.round(mask[start:end])
                index = np.where(segment == 0)
                
                final_indices = []
                divisor = int(copy_state[i].shape[0]/sizing[i])
                for j in index[0]:
                    final_indices += [*range(j*divisor, (j*divisor)+divisor)]
                # print(final_indices)
                copy_state[i].data[np.array(final_indices)] = 0
                segments.update({i:index})
                start = end
        self.model.load_state_dict(copy_state)
        # for name, param in self.model.named_parameters():
        #     if name in segments:
        #         param.data[segments[name]].requires_grad = False
        #         start = end

    def return_model(self):
        return self.model

    def return_model_state(self):
        return self.model.state_dict()

    def revert_weights(self):
        self.model.load_state_dict(self.weights_backup)
        for name, param in self.model.named_parameters():
            param.requires_grad = True

    def update_backup(self):
        self.weights_backup = copy.deepcopy(self.model.state_dict())


In [3]:
def size_mask(state_dict):
    total = 0
    mask_sizing = OrderedDict()
    total_params = 0
    total_considered = 0
    uniques = set()
    size = 0
    for i in list(state_dict.keys()):
        shape = torch.tensor(state_dict[i].shape)
        total_params += torch.prod(shape)
        uniques.add(state_dict[i])
        # print(state_dict[i][0].type())
        if 'bias' not in i and 'embed' not in i and 'norm' not in i and 'shared' not in i and 'head' not in i:
            # print('------')
            # print(i)
            # print(shape)
            if shape[0] == 4096:
                total += 1024
                mask_sizing.update({i:1024})
            else:
                total += 256
                mask_sizing.update({i:256})
            total_considered += torch.prod(shape)
        # else:
        #     if 'embed' in i or 'shared' in i or 'head' in i:
        #         print('------')
        #         print(i)
        #         print(shape)
        #         print(state_dict[i])
    print(total_params)
    print(total_considered)
    print(total)
    return mask_sizing

In [4]:
def count_active_params(state_dict):
    total = 0
    for i in state_dict:
        flattened = torch.flatten(state_dict[i])
        total += torch.count_nonzero(flattened)
    return total.detach().item()

class Custom_Dataloader:
    def __init__(self, data, batch_size):
        
        self.data = data 
        self.data_amount = data['x']['input_ids'].shape[0]
        
        self.batch_size = batch_size
        self.length = int(math.ceil(self.data_amount/batch_size))
        self.available = set([*range(self.length)])
        
        
    def select_subset(self, idxs, data, cuda):
        if cuda:
            return {'x':{'input_ids': data['x']['input_ids'][idxs].cuda(), 
                         'attention_mask': data['x']['attention_mask'][idxs].cuda()}, 
                    'y':data['y'][idxs].cuda()}
        
    def sample_batch(self, cuda=True):
        
        if len(self.available)==0:
            self.available = set([*range(self.length)])
        idx = random.choice(tuple(self.available))
        self.available.remove(idx) 
        start = idx*self.batch_size
        if idx == self.length-1:
            diff = self.data_amount - (idx*self.batch_size)
            batch = self.select_subset([i for i in range(start,start+diff)], self.data, cuda)
            return batch
        else: 
            batch = self.select_subset([i for i in range(start,start+self.batch_size)], self.data, cuda)
            return batch
        
    def reset(self):
        self.history = set()
        self.available = set([*range(self.data_amount)])
        
def dual_sample(ds1_batch, ds2_batch):
    perm = torch.randperm(len(ds1_batch['x']['input_ids'])*2)
    x_input_ids = torch.cat((ds1_batch['x']['input_ids'], ds2_batch['x']['input_ids']), 0)
    x_attention_masks = torch.cat((ds1_batch['x']['attention_mask'], ds2_batch['x']['attention_mask']), 0)
    y = torch.cat((ds1_batch['y'], ds2_batch['y']))
    
    return {'x':{'input_ids': x_input_ids[perm], 
                 'attention_mask': x_attention_masks[perm]},
            'y':y[perm]}

In [5]:
from transformers import get_scheduler
def train_loop(model, 
               mask,
               mask_sizing,
              lang1_train_dataloader, 
              lang1_test_dataloader, 
              epochs, lang_code, model_idx,
              steps=300):

    optim= Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
#     optim = torch.optim.AdamW(model.parameters(), lr=0.0005)

#     lmbda = lambda epoch: 0.99
#     scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optim, lr_lambda=lmbda)

    # optim = torch.optim.AdamW(model.parameters(), lr=0.00002)
    # lmbda = lambda epoch: 0.99
    # scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optim, lr_lambda=lmbda)

    best_loss = np.inf
    saved_state = None
    
    delay = 0
    for epoch in range(epochs):
        model.train()
        for name, param in model.named_parameters():
            if param.requires_grad is False:
                param.requires_grad =  True
                
        loss = test_loss(model, lang1_test_dataloader)
        torch.cuda.empty_cache()

        loss = loss.detach().item()
        best_loss = loss
        # scheduler.step()

        print('\n Starting Loss: ', loss, ', Lang Code: ', lang_code, ', Model: ', str(model_idx))
        
        train(model, 
              mask, mask_sizing,
              lang1_train_dataloader, 
              optim, 
              steps, epoch)
        
        # scheduler.step()

        torch.cuda.empty_cache()
        
        loss = test_loss(model, lang1_test_dataloader)
        torch.cuda.empty_cache()

        loss = loss.detach().item()
        
        # scheduler.step()

        print('\n Best Averaged Loss: ', loss, ', Lang Code: ', lang_code, ', Model: ', str(model_idx))
        
        
        if loss < best_loss:
            delay = 0
            best_loss = loss
            torch.save(model.state_dict(), './MFEA/final_models/'+lang_code+'_'+str(model_idx)+'.pth')
            
        delay += 1
        
        if delay == 10:
            print('Training ended early: 10 Epochs: No improvement')
            break

        
def train(model, 
          mask, mask_sizing,
          lang1_train_dataloader, 
          optim, 
          steps, epoch):
    optim.zero_grad()
    for i in range(steps):
        batch = lang1_train_dataloader.sample_batch()
        loss = model.forward_train(batch).loss
        loss.backward()
        print('\rEpoch {:.3f}\tBatch: {:.3f}, Loss: {:.3f}'.format(epoch, 
                                                                   i, 
                                                                   loss.detach().item()), 
              end="")
        if (i+1)%30 == 0:
            optim.step()
            optim.zero_grad()
            model.apply_mask(mask, mask_sizing)

def convert_to_string(tokens, test_code):
    prohib = [1, 2, 128022, 128020, 128017]
    if test_code:
        return [' '.join([str(i) for i in tokens.tolist() if i not in prohib])]
    else:
        return ' '.join([str(i) for i in tokens.tolist() if i not in prohib])

def stringify(tokens, test_code):
    return [convert_to_string(tokens[i], test_code) for i in range(tokens.shape[0])]

def test_BLEU(model, test_dataloader, metric, lang_code):
    start = time.time()

    model.eval()
    all_preds = []
    all_refs = []
    with torch.no_grad():
        for i in range(test_dataloader.length):
            batch = test_dataloader.sample_batch()
            preds = model.forward_eval(batch, lang_code)
            all_preds += stringify(preds, False)
            all_refs += stringify(batch['y'], True)
        score = metric.compute(predictions = all_preds, references = all_refs)['score']
    return score

def test_loss(model, test_dataloader):
    model.eval()
    loss = 0
    with torch.no_grad():
        for i in range(test_dataloader.length):
            batch = test_dataloader.sample_batch()
            loss += model.forward_train(batch).loss
    return loss/test_dataloader.length

In [6]:
# checkpoint_size = 16
# with open('de_dataset_test_full_'+str(checkpoint_size)+'.pkl', 'rb') as handle:
#     de_test = pickle.load(handle)
    
# with open('cs_dataset_test_full_'+str(checkpoint_size)+'.pkl', 'rb') as handle:
#     cs_test = pickle.load(handle)
    
# with open('cs_dataset_train_full_'+str(checkpoint_size)+'.pkl', 'rb') as handle:
#     cs_train = pickle.load(handle)
# with open('de_dataset_train_full_'+str(checkpoint_size)+'.pkl', 'rb') as handle:
#     de_train = pickle.load(handle)

# with open('./MFEA/data/cs_finetune.pkl', 'rb') as handle:
#     cs_train = pickle.load(handle)
# with open('./MFEA/data/de_finetune.pkl', 'rb') as handle:
#     de_train = pickle.load(handle)

with open('./train_data/de_test.pkl', 'rb') as handle:
    de_test = pickle.load(handle)
    
with open('./train_data/cs_test.pkl', 'rb') as handle:
    cs_test = pickle.load(handle)
    
with open('./train_data/cs_train.pkl', 'rb') as handle:
    cs_train = pickle.load(handle)
with open('./train_data/de_train.pkl', 'rb') as handle:
    de_train = pickle.load(handle)

In [7]:
cs_test_dataloader = Custom_Dataloader(cs_test, 64)
de_test_dataloader = Custom_Dataloader(de_test, 64)

cs_train_dataloader = Custom_Dataloader(cs_train, 128)
de_train_dataloader = Custom_Dataloader(de_train, 128)


In [8]:
model = Model()
model.load_state_dict(torch.load('./MFEA/models/base.pth'))
model = model.cuda()
model.update_backup()

In [9]:
with open('./MFEA/results/MT/Test_Run2/mask_checkpoint.pkl', 'rb') as f:
    masks = pickle.load(f)

In [10]:
torch.cuda.empty_cache()

In [11]:
cs_selected_models = [0, 1, 2, 3, 4, 6, 10, 11, 18]

In [12]:
de_selected_models = [0, 1, 2, 3, 4, 5, 6, 8, 9]

In [13]:
[(i, masks[1][i]['objs_T1']) for i in range(0, 20) if i in cs_selected_models]

[(0, [21.68546485900879, 627631104]),
 (1, [1.4215970039367676, 859735040]),
 (2, [6.353278923034668, 726033408]),
 (3, [11.179492092132568, 689091584]),
 (4, [8.54193468093872, 708367360]),
 (6, [14.507681369781494, 660734976]),
 (10, [5.093519449234009, 742331392]),
 (11, [2.751480221748352, 763823104]),
 (18, [2.1428685188293457, 781595648])]

In [14]:
[(i, masks[2][i]['objs_T2']) for i in range(0, 20) if i in de_selected_models]

[(0, [21.964793968200684, 626664448]),
 (1, [1.6490003705024718, 858719232]),
 (2, [14.56126823425293, 659530752]),
 (3, [8.766735458374024, 705111040]),
 (4, [10.607507133483887, 683234304]),
 (5, [1.865936553478241, 810816512]),
 (6, [3.129116940498352, 761234432]),
 (8, [1.7327796578407288, 833205248]),
 (9, [4.683143568038941, 737731584])]

In [18]:
model.revert_weights()

In [34]:
# for key in model.return_model_state():
#     print(key, model.return_model_state()[key].shape)

In [24]:
count_active_params(model.return_model_state()) - (128112 * 1024 * 3) - (1026 * 1024 * 2)

483902464

In [27]:
((483902464*32) / (1024*8))/1024

1845.94140625

In [15]:
mask_sizing = size_mask(model.return_model_state())

for skill_factor in masks:
    for i, candidate in enumerate(masks[skill_factor]):
        
        
        mask = candidate['rnvec']
        model.apply_mask(mask, mask_sizing)
        
        for name, param in model.named_parameters():
            param.requires_grad = True
            
        if skill_factor == 1:
            if i in cs_selected_models:
                lang_code = 'cs'
                if os.path.isfile('./MFEA/final_models/'+lang_code+'_'+str(i)+'.pth'):
                    print("Starting from Checkpoint")
                    model.load_state_dict(torch.load('./MFEA/final_models/'+lang_code+'_'+str(i)+'.pth'))

                    for name, param in model.named_parameters():
                        param.requires_grad = True

                print('----------')
                print('Lang Code: ', lang_code, ', Candidate: ', str(i), ', Model Size: ', str(count_active_params(model.return_model_state()) - (128112 * 1024 * 3) - (1026 * 1024 * 2)))
                train_loop(model, 
                           mask,
                           mask_sizing,
                          cs_train_dataloader, 
                          cs_test_dataloader, 
                          3, lang_code, i)
        else:
            if i in de_selected_models:
                lang_code = 'de'
                if os.path.isfile('./MFEA/final_models/'+lang_code+'_'+str(i)+'.pth'):
                    print("Starting from Checkpoint")
                    model.load_state_dict(torch.load('./MFEA/final_models/'+lang_code+'_'+str(i)+'.pth'))

                    for name, param in model.named_parameters():
                        param.requires_grad = True

                print('----------')
                print('Lang Code: ', lang_code, ', Candidate: ', str(i), ', Model Size: ', str(count_active_params(model.return_model_state()) - (128112 * 1024 * 3) - (1026 * 1024 * 2)))
                train_loop(model, 
                           mask,
                           mask_sizing,
                          de_train_dataloader, 
                          de_test_dataloader, 
                          3, lang_code, i)            

        model.revert_weights()

tensor(879566848)
tensor(352321536)
67584
Starting from Checkpoint
----------
Lang Code:  cs , Candidate:  0 , Model Size:  231969792

 Starting Loss:  1.8520716428756714 , Lang Code:  cs , Model:  0
Epoch 0.000	Batch: 1199.000, Loss: 0.455
 Best Averaged Loss:  1.8515033721923828 , Lang Code:  cs , Model:  0

 Starting Loss:  1.8515034914016724 , Lang Code:  cs , Model:  0
Epoch 1.000	Batch: 1199.000, Loss: 3.142
 Best Averaged Loss:  1.8501688241958618 , Lang Code:  cs , Model:  0

 Starting Loss:  1.8501687049865723 , Lang Code:  cs , Model:  0
Epoch 2.000	Batch: 1199.000, Loss: 3.658
 Best Averaged Loss:  1.847960352897644 , Lang Code:  cs , Model:  0
Starting from Checkpoint
----------
Lang Code:  cs , Candidate:  1 , Model Size:  464073728

 Starting Loss:  1.2566149234771729 , Lang Code:  cs , Model:  1
Epoch 0.000	Batch: 1199.000, Loss: 1.054
 Best Averaged Loss:  1.2566672563552856 , Lang Code:  cs , Model:  1

 Starting Loss:  1.2566674947738647 , Lang Code:  cs , Model:  1
E

KeyboardInterrupt: 

In [16]:
metric = datasets.load_metric('sacrebleu')


In [30]:
model.revert_weights()

In [31]:
score = test_BLEU(model, cs_test_dataloader, metric, 'cs')
mem = count_active_params(model.return_model_state()) - (128112 * 1024 * 3) - (1026 * 1024 * 2)
mem = ((mem*32) / (1024*8))/1024

print(score, mem)

31.658459055339222 1845.94140625


In [32]:
score = test_BLEU(model, de_test_dataloader, metric, 'de')
mem = count_active_params(model.return_model_state()) - (128112 * 1024 * 3) - (1026 * 1024 * 2)
mem = ((mem*32) / (1024*8))/1024

print(score, mem)

23.571392328999416 1845.94140625


In [33]:
mask_sizing = size_mask(model.return_model_state())

results = {}

for skill_factor in masks:
    for i, candidate in enumerate(masks[skill_factor]):
        
        mask = candidate['rnvec']
        model.apply_mask(mask, mask_sizing)
        
        for name, param in model.named_parameters():
            param.requires_grad = True
            
        if skill_factor == 1:
            if i in cs_selected_models:
                lang_code = 'cs'
                if os.path.isfile('./MFEA/final_models/'+lang_code+'_'+str(i)+'.pth'):
                    print("Starting from Checkpoint")
                    model.load_state_dict(torch.load('./MFEA/final_models/'+lang_code+'_'+str(i)+'.pth'))

                    for name, param in model.named_parameters():
                        param.requires_grad = True

                print('----------')
                task_name = 'cs_only'
                score = test_BLEU(model, cs_test_dataloader, metric, 'cs')
                mem = count_active_params(model.return_model_state()) - (128112 * 1024 * 3) - (1026 * 1024 * 2)
                mem = ((mem*32) / (1024*8))/1024
                print('Score: ', score, ', Lang Code: ', lang_code, ', Candidate: ', str(i), ', Model Size: ', mem)
                if task_name not in results:
                    results.update({task_name:[(score, mem)]})
                else:
                    curr = results[task_name]
                    curr.append((score, mem))
                    results.update({task_name:curr})
        else:
            if i in de_selected_models:
                lang_code = 'de'
                if os.path.isfile('./MFEA/final_models/'+lang_code+'_'+str(i)+'.pth'):
                    print("Starting from Checkpoint")
                    model.load_state_dict(torch.load('./MFEA/final_models/'+lang_code+'_'+str(i)+'.pth'))

                    for name, param in model.named_parameters():
                        param.requires_grad = True
                task_name = 'de_only'
                print('----------')
                score = test_BLEU(model, de_test_dataloader, metric, 'de')
                mem = count_active_params(model.return_model_state()) - (128112 * 1024 * 3) - (1026 * 1024 * 2)
                mem = ((mem*32) / (1024*8))/1024
                print('Score: ', score, ', Lang Code: ', lang_code, ', Candidate: ', str(i), ', Model Size: ', mem)
                if task_name not in results:
                    results.update({task_name:[(score, mem)]})
                else:
                    curr = results[task_name]
                    curr.append((score, mem))
                    results.update({task_name:curr})
        model.revert_weights()

tensor(879566848)
tensor(352321536)
67584
Starting from Checkpoint
----------
Score:  23.854002187191682 , Lang Code:  cs , Candidate:  0 , Model Size:  884.89453125
Starting from Checkpoint
----------
Score:  32.22747264598261 , Lang Code:  cs , Candidate:  1 , Model Size:  1770.30078125
Starting from Checkpoint
----------
Score:  26.898700478784054 , Lang Code:  cs , Candidate:  2 , Model Size:  1260.26953125
Starting from Checkpoint
----------
Score:  31.91246970741459 , Lang Code:  cs , Candidate:  3 , Model Size:  1119.34765625
Starting from Checkpoint
----------
Score:  32.1336120762841 , Lang Code:  cs , Candidate:  4 , Model Size:  1192.87890625
Starting from Checkpoint
----------
Score:  24.685699230399294 , Lang Code:  cs , Candidate:  6 , Model Size:  1011.17578125
Starting from Checkpoint
----------
Score:  28.457910455709765 , Lang Code:  cs , Candidate:  10 , Model Size:  1322.44140625
Starting from Checkpoint
----------
Score:  30.152389736476827 , Lang Code:  cs , Candi

In [35]:
results

{'cs_only': [(23.854002187191682, 884.89453125),
  (32.22747264598261, 1770.30078125),
  (26.898700478784054, 1260.26953125),
  (31.91246970741459, 1119.34765625),
  (32.1336120762841, 1192.87890625),
  (24.685699230399294, 1011.17578125),
  (28.457910455709765, 1322.44140625),
  (30.152389736476827, 1404.42578125),
  (30.321455461369865, 1472.22265625)],
 'de_only': [(21.29226708687752, 881.20703125),
  (23.99233217870648, 1766.42578125),
  (22.21373807430573, 1006.58203125),
  (23.333850317280728, 1180.45703125),
  (22.738922753998107, 1097.00390625),
  (24.064876510828327, 1583.69140625),
  (23.896947833654597, 1394.55078125),
  (24.331468331673822, 1669.09765625),
  (24.131666338433572, 1304.89453125)]}

In [38]:
cs = results['cs_only']
cs = [i for i in cs if i[1] not in [1119.34765625, 1192.87890625]]

In [39]:
cs

[(23.854002187191682, 884.89453125),
 (32.22747264598261, 1770.30078125),
 (26.898700478784054, 1260.26953125),
 (24.685699230399294, 1011.17578125),
 (28.457910455709765, 1322.44140625),
 (30.152389736476827, 1404.42578125),
 (30.321455461369865, 1472.22265625)]

In [45]:
de = results['de_only']
de = [i for i in de if i[1] not in [1766.42578125, 1304.89453125]]

In [46]:
final_results = {'cs_only':cs, 'de_only':de}

In [47]:
with open('./MFEA/results/MT/Test_Run2/finetuned_results.pkl', 'wb') as f:
    pickle.dump(final_results, f)