## Imports

In [72]:
#!g1.1
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import pickle


from torch.utils.data import Dataset, DataLoader

def nice_df(df, axis=None, reverse=False, **kwargs):
    cm = sns.light_palette("green", as_cmap=True, reverse=reverse)
    return df.style.background_gradient(cmap=cm, axis=axis, **kwargs)

device = torch.device("cuda")



## Data

In [75]:
#!g1.1

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
def tokenization(examples: np.ndarray):
    return tokenizer(examples.tolist(), truncation=True, padding=True, pad_to_multiple_of=40, max_length=40, return_tensors='pt')

class CustomDataset(Dataset):
    def __init__(self, data):
        self.en_id = data["en_id"]
        self.en_am = data["en_am"]
        self.fr_id = data["fr_id"]
        self.fr_am = data["fr_am"]
    
    def __len__(self):
        return len(self.en_id)
    
    def __getitem__(self, index):
        return {
            "en_id": self.en_id[index], 
            "en_am": self.en_am[index],
            "fr_id": self.fr_id[index],
            "fr_am": self.fr_am[index],
        }

def get_dataloaer(lang, batch_size = 32):
    done = 100 * 1000
    file_path = f'saved_tr_lists/list_{lang}_{done}.pkl'

    with open(file_path, 'rb') as file:
        tr_pairs = pickle.load(file)
        file.close()

    np_pairs = np.array(tr_pairs)
    fr_torch  = tokenization(np_pairs[:, 0])
    en_torch = tokenization(np_pairs[:, 1])

    data = {
        "en_id": en_torch['input_ids'],
        "en_am": en_torch['attention_mask'],
        "fr_id": fr_torch['input_ids'],
        "fr_am": fr_torch['attention_mask'],
    }

    dataset = CustomDataset(data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

lang_list_tr = ['fr', 'de', 'es']
dataloaders_tr = {
    lang: get_dataloaer(lang)
    for lang in lang_list_tr
}

for lang in lang_list_tr:
    sum = 0
    # Iterate over the dataloader
    for batch in dataloaders_tr[lang]:
        en_id = batch["en_id"]
        en_am = batch["en_am"]
        fr_id = batch["fr_id"]
        fr_am = batch["fr_am"]

        # Use the batched data for further processing
        if False:
            print("Input IDs:", en_id.shape)
            print("Attention Mask:", en_am.shape)
            print("Input IDs:", fr_id.shape)
            print("Attention Mask:", fr_am.shape)
            print("Input IDs:", en_id[:2, :5])
            print("Input IDs:", fr_id[:2, :5])
            break

        sum += en_id.shape[0]

    print(sum)


100000
100000
100000


In [76]:
#!g1.1
from datasets import concatenate_datasets, load_from_disk

BS = 32
lang_list = ['en', 'fr', 'de', 'es']
split_list = ['train', 'validation', 'test']


# data = {
#     lang: load_from_disk(f'handle_amazon/amazon_{lang}')
#     for lang in lang_list
# }

tr_data = {
    lang: load_from_disk(f'handle_amazon/amazon_ok_tr_{lang}')
    for lang in lang_list
}

dataloader = {
    lang: {
        split: DataLoader(tr_data[lang][split], batch_size=BS, shuffle=(split == 'train'))
        for split in split_list
    }
    for lang in lang_list
}



## Models

In [94]:
#!g1.1
# from transformers import AutoModelForSequenceClassification

id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

models = dict()
for lang in lang_list:
    models[lang] = DistilBertForSequenceClassification.from_pretrained(f'models/ft_ht_{lang}')
    models[lang].to(device)

for lang in lang_list:
    for param in models[lang].base_model.parameters():
        param.requires_grad = True


In [95]:
#!g1.1

for lang in lang_list:
    for name, param in models[lang].named_parameters():
        if 'clas' in name:
            print(name, param.shape, param.requires_grad)
        assert param.requires_grad
    print(f'{lang} is ok')


pre_classifier.weight torch.Size([768, 768]) True
pre_classifier.bias torch.Size([768]) True
classifier.weight torch.Size([2, 768]) True
classifier.bias torch.Size([2]) True
en is ok
pre_classifier.weight torch.Size([768, 768]) True
pre_classifier.bias torch.Size([768]) True
classifier.weight torch.Size([2, 768]) True
classifier.bias torch.Size([2]) True
fr is ok
pre_classifier.weight torch.Size([768, 768]) True
pre_classifier.bias torch.Size([768]) True
classifier.weight torch.Size([2, 768]) True
classifier.bias torch.Size([2]) True
de is ok
pre_classifier.weight torch.Size([768, 768]) True
pre_classifier.bias torch.Size([768]) True
classifier.weight torch.Size([2, 768]) True
classifier.bias torch.Size([2]) True
es is ok


## Training


In [79]:
#!g1.1
def eval(model, dls, lang, test_split):
    # put model in eval mode
    model.eval()

    # get needful data slice
    dl_to_test = dls[lang][test_split]
    
    test_loss = 0
    test_acc = 0
    
    with torch.no_grad():
        for batch in tqdm(dl_to_test):
            # move batch to device
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            labels = batch['bin_label'].to(model.device)

            # forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # calculate loss and accuracy
            preds = logits.argmax(dim=1)
            test_acc += (preds == labels).sum().item()

    test_acc /= BS * len(dl_to_test)
    print(f'\teval {lang}: {test_acc}')
    return test_acc


In [98]:
#!g1.1

import random
def epoch_sentiment(model, dls, dl, lang, validation_split, optimizer, sample_size=-1, debug=False):
    model.train()
    loss_fn = torch.nn.CrossEntropyLoss()
    train_loss = 0
    train_acc = 0
    
    if debug:
        if sample_size == -1:
            sample_size = 100
        else:
            sample_size = sample_size // 10
    if sample_size != -1:
        sampled_batches = random.sample(list(dl), sample_size)
    else:
        sampled_batches = dl

    batch_num = 0
    for batch in tqdm(sampled_batches):
        # move batch to device
        input_ids = batch['input_ids'].to(model.device)
        attention_mask = batch['attention_mask'].to(model.device)
        labels = batch['bin_label'].to(model.device)

        # zero out gradients
        optimizer.zero_grad()

        # forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # calculate loss and accuracy
        loss = loss_fn(logits, labels)
        train_loss += loss.item()
        preds = logits.argmax(dim=1)
        train_acc += (preds == labels).sum().item()

        loss.backward()
        optimizer.step()
        
        batch_num += 1
        if batch_num > 200 and debug:
            break

    train_acc /= BS * batch_num
    valid_acc = eval(model, dls, lang, validation_split)
    print(f'train {lang}: {train_acc} (val {valid_acc})')
    return (train_acc, valid_acc)

def epoch_translation(model, dls, dl, lang, validation_split, optimizer, debug=False):
    model.train()
    loss_fn = torch.nn.MSELoss()
    train_loss = 0
    

    batch_num = 0
    for batch in tqdm(dl):
        # move batch to device
        en_id = batch["en_id"].to(model.device)
        en_am = batch["en_am"].to(model.device)
        en_batch_hs = model(input_ids=en_id, attention_mask=en_am).hidden_states[1:]
        fr_id = batch["fr_id"].to(model.device)
        fr_am = batch["fr_am"].to(model.device)
        fr_batch_hs = model(input_ids=fr_id, attention_mask=fr_am).hidden_states[1:]

        get_embeds = lambda batch_hs, a_m: (batch_hs * a_m[..., None]).sum(axis=1) / a_m.sum(axis=-1)[..., None]
        
        # zero out gradients
        optimizer.zero_grad()

        loss = 0
        for en_hs, fr_hs in zip(en_batch_hs, fr_batch_hs):
            en_embeds = get_embeds(en_hs, en_am)
            fr_embeds = get_embeds(fr_hs, fr_am)
            loss += loss_fn(en_embeds, fr_embeds)
        
        # calculate loss and accuracy
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        
        batch_num += 1
        if batch_num > 100 and debug:
            break

    train_loss /= BS * batch_num
    valid_acc = eval(model, dls, lang, validation_split)
    print(f'TRANSLATE {lang}: {train_loss} (val {valid_acc})')
    return (train_loss, valid_acc)

def train_translation(model, dls, dl_tr, lang, train_split, validation_split, num_epochs=2, need_tr=True, device='mps', debug=False):
    # put model on mps device
    model.to(device)
    
    # get needful data slice
    dl_sentiment = dls[lang][train_split]
    dl_translation = dl_tr[lang]

    # define our optimizer and loss function
    learning_rate_bert = 1e-6
    learning_rate_classifier = 2e-5

    # Set up parameter groups for different parts of the model
    optimizer_grouped_parameters = [
        {"params": model.distilbert.parameters(), "lr": learning_rate_bert},
        {"params": model.classifier.parameters(), "lr": learning_rate_classifier},
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters)

    collected_data = list()
    for epoch in range(num_epochs):
        # train loops
        tr_res = epoch_translation(model, dls, dl_translation, lang, validation_split, optimizer, debug=debug)
        collected_data.append(tr_res)

        s_res = epoch_sentiment(model, dls, dl_sentiment, lang, validation_split, optimizer, 1000, debug=debug)
        collected_data.append(s_res)

    return collected_data

def train_sentiment(model, dls, dl_tr, lang, train_split, validation_split, num_epochs=2, need_tr=True, device='mps', debug=False):
    # put model on mps device
    model.to(device)
    
    dl_sentiment = dls[lang][train_split]
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

    collected_data = list()
    for epoch in range(num_epochs):
        # train loop
        s_res = epoch_sentiment(model, dls, dl_sentiment, lang, validation_split, optimizer, debug=debug)
        collected_data.append(s_res)

    return collected_data



In [99]:
#!g1.1

def assert_use_grad(model):          
    for name, param in model.named_parameters():
        if 'clas' in name:
            assert param.requires_grad
        else:
            assert param.requires_grad
    print(f'model is ok (use_grad)')

def assert_no_grad(model):       
    for name, param in model.named_parameters():
        if 'clas' in name:
            assert param.requires_grad
        else:
            assert not param.requires_grad
    print(f'model is ok (no_grad)')

def use_grad(model):
    for lang in lang_list:
        for param in model.base_model.parameters():
            param.requires_grad = True
    assert_use_grad(model)

def no_grad(model):
    for lang in lang_list:
        for param in model.base_model.parameters():
            param.requires_grad = False
    assert_no_grad(model)

def training_pipeline(lang, epochs=2, debug=False):
    for lang in [lang]:
        for i in range(epochs):
            use_grad(models[lang])
            cd_1 = train_translation(models[lang], dataloader, dataloaders_tr, lang, 'train', 'validation',
                                     num_epochs=(3 if debug else 5), need_tr=True, device=device, debug=debug)
            print(f'{i}: ', cd_1)

            no_grad(models[lang])
            cd_2 = train_sentiment(models[lang], dataloader, dataloaders_tr, lang, 'train', 'validation', 
                                   num_epochs=(1 if debug else 2), need_tr=True, device=device, debug=debug)    
            print(f'{i}: ', cd_2)



In [92]:
#!g1.1

for lang in lang_list:
    if lang != 'en':
        training_pipeline(lang, epochs=2, debug=False)
        
        

model is ok (use_grad)
	eval fr: 0.821
TRANSLATE fr: 0.004487292901013453 (val 0.821)
	eval fr: 0.82975
train fr: 0.8165625 (val 0.82975)
	eval fr: 0.68975
TRANSLATE fr: 0.004004931234118372 (val 0.68975)
	eval fr: 0.8315
train fr: 0.8246875 (val 0.8315)
	eval fr: 0.838
TRANSLATE fr: 0.00397068147529101 (val 0.838)
	eval fr: 0.844
train fr: 0.8359375 (val 0.844)
0:  [(0.004487292901013453, 0.821), (0.8165625, 0.82975), (0.004004931234118372, 0.68975), (0.8246875, 0.8315), (0.00397068147529101, 0.838), (0.8359375, 0.844)]
model is ok (no_grad)
	eval fr: 0.8435
train fr: 0.8376865671641791 (val 0.8435)
0:  [(0.8376865671641791, 0.8435)]
model is ok (use_grad)
	eval fr: 0.8285
TRANSLATE fr: 0.003365031314442063 (val 0.8285)
	eval fr: 0.8445
train fr: 0.814375 (val 0.8445)
	eval fr: 0.83575
TRANSLATE fr: 0.003034878704272727 (val 0.83575)
	eval fr: 0.84875
train fr: 0.8253125 (val 0.84875)
	eval fr: 0.837
TRANSLATE fr: 0.0030063006515144417 (val 0.837)
	eval fr: 0.852
train fr: 0.839375 (v

  3%|▎         | 100/3125 [00:08<04:12, 11.99it/s]
100%|██████████| 125/125 [00:22<00:00,  5.49it/s]
100%|██████████| 100/100 [00:49<00:00,  2.04it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]
  3%|▎         | 100/3125 [00:08<04:11, 12.03it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]
100%|██████████| 100/100 [00:49<00:00,  2.04it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]
  3%|▎         | 100/3125 [00:08<04:17, 11.75it/s]
100%|██████████| 125/125 [00:22<00:00,  5.46it/s]
100%|██████████| 100/100 [00:49<00:00,  2.04it/s]
100%|██████████| 125/125 [00:22<00:00,  5.47it/s]
  4%|▍         | 200/5000 [00:41<16:30,  4.85it/s]
100%|██████████| 125/125 [00:22<00:00,  5.46it/s]
  3%|▎         | 100/3125 [00:08<04:12, 12.00it/s]
100%|██████████| 125/125 [00:22<00:00,  5.47it/s]
100%|██████████| 100/100 [00:49<00:00,  2.03it/s]
100%|██████████| 125/125 [00:22<00:00,  5.46it/s]
  3%|▎         | 100/3125 [00:08<04:13, 11.93it/s]
100%|██████████| 125/125 [00:22<00:00,  5.46

In [100]:
#!g1.1

training_pipeline('fr', epochs=2, debug=False)
        


model is ok (use_grad)
	eval fr: 0.5
TRANSLATE fr: 0.0018882102750334888 (val 0.5)
	eval fr: 0.81175
train fr: 0.73271875 (val 0.81175)
	eval fr: 0.79775
TRANSLATE fr: 0.0003935722163692117 (val 0.79775)
	eval fr: 0.83725
train fr: 0.815625 (val 0.83725)
	eval fr: 0.8285
TRANSLATE fr: 0.0003471001837681979 (val 0.8285)
	eval fr: 0.852
train fr: 0.83534375 (val 0.852)
	eval fr: 0.83725
TRANSLATE fr: 0.0003154149005515501 (val 0.83725)
	eval fr: 0.859
train fr: 0.84553125 (val 0.859)
	eval fr: 0.8375
TRANSLATE fr: 0.00029923732010181994 (val 0.8375)
	eval fr: 0.8645
train fr: 0.8541875 (val 0.8645)
0:  [(0.0018882102750334888, 0.5), (0.73271875, 0.81175), (0.0003935722163692117, 0.79775), (0.815625, 0.83725), (0.0003471001837681979, 0.8285), (0.83534375, 0.852), (0.0003154149005515501, 0.83725), (0.84553125, 0.859), (0.00029923732010181994, 0.8375), (0.8541875, 0.8645)]
model is ok (no_grad)
	eval fr: 0.869
train fr: 0.8637 (val 0.869)
	eval fr: 0.8685
train fr: 0.86400625 (val 0.8685)
0

100%|██████████| 5000/5000 [17:15<00:00,  4.83it/s]
100%|██████████| 125/125 [00:22<00:00,  5.45it/s]
100%|██████████| 5000/5000 [17:09<00:00,  4.85it/s]
100%|██████████| 125/125 [00:22<00:00,  5.47it/s]


In [102]:
#!g1.1

training_pipeline('de', epochs=2, debug=False)


model is ok (use_grad)
	eval de: 0.49975
TRANSLATE de: 0.0022574793294258414 (val 0.49975)
	eval de: 0.7875
train de: 0.70653125 (val 0.7875)
	eval de: 0.779
TRANSLATE de: 0.0003889742057584226 (val 0.779)
	eval de: 0.80975
train de: 0.79275 (val 0.80975)
	eval de: 0.75325
TRANSLATE de: 0.00032856806331314146 (val 0.75325)
	eval de: 0.8245
train de: 0.8146875 (val 0.8245)
	eval de: 0.8025
TRANSLATE de: 0.00029973368672188373 (val 0.8025)
	eval de: 0.831
train de: 0.82315625 (val 0.831)
	eval de: 0.79175
TRANSLATE de: 0.0002834173326473683 (val 0.79175)
	eval de: 0.84525
train de: 0.8360625 (val 0.84525)
0:  [(0.0022574793294258414, 0.49975), (0.70653125, 0.7875), (0.0003889742057584226, 0.779), (0.79275, 0.80975), (0.00032856806331314146, 0.75325), (0.8146875, 0.8245), (0.00029973368672188373, 0.8025), (0.82315625, 0.831), (0.0002834173326473683, 0.79175), (0.8360625, 0.84525)]
model is ok (no_grad)
	eval de: 0.84775
train de: 0.845225 (val 0.84775)
	eval de: 0.84975
train de: 0.844081

100%|██████████| 5000/5000 [16:56<00:00,  4.92it/s]
100%|██████████| 125/125 [00:22<00:00,  5.53it/s]
100%|██████████| 5000/5000 [16:57<00:00,  4.92it/s]
100%|██████████| 125/125 [00:22<00:00,  5.53it/s]


In [103]:
#!g1.1

training_pipeline('es', epochs=2, debug=False)


model is ok (use_grad)
	eval es: 0.50375
TRANSLATE es: 0.0020478245209529997 (val 0.50375)
	eval es: 0.792
train es: 0.72803125 (val 0.792)
	eval es: 0.67525
TRANSLATE es: 0.00039636514275334775 (val 0.67525)
	eval es: 0.8215
train es: 0.80309375 (val 0.8215)
	eval es: 0.65775
TRANSLATE es: 0.0003367180655570701 (val 0.65775)
	eval es: 0.83375
train es: 0.82740625 (val 0.83375)
	eval es: 0.65825
TRANSLATE es: 0.0003086267004720867 (val 0.65825)
	eval es: 0.8475
train es: 0.83746875 (val 0.8475)
	eval es: 0.6525
TRANSLATE es: 0.0002886756160762161 (val 0.6525)
	eval es: 0.85475
train es: 0.84875 (val 0.85475)
0:  [(0.0020478245209529997, 0.50375), (0.72803125, 0.792), (0.00039636514275334775, 0.67525), (0.80309375, 0.8215), (0.0003367180655570701, 0.65775), (0.82740625, 0.83375), (0.0003086267004720867, 0.65825), (0.83746875, 0.8475), (0.0002886756160762161, 0.6525), (0.84875, 0.85475)]
model is ok (no_grad)
	eval es: 0.85575
train es: 0.8551 (val 0.85575)
	eval es: 0.854
train es: 0.85

100%|██████████| 5000/5000 [16:57<00:00,  4.92it/s]
100%|██████████| 125/125 [00:22<00:00,  5.54it/s]
100%|██████████| 5000/5000 [16:57<00:00,  4.91it/s]
100%|██████████| 125/125 [00:22<00:00,  5.52it/s]


In [104]:
#!g1.1

eval_res = pd.DataFrame(data = np.zeros((4, 1)), columns = ['full'], index=lang_list)

# for lang in lang_list:
for lang in lang_list:
    test_res = eval(models[lang], dataloader, lang, 'test')
    eval_res.at[lang, 'full'] = test_res

nice_df(eval_res)


100%|██████████| 125/125 [00:22<00:00,  5.51it/s]
100%|██████████| 125/125 [00:22<00:00,  5.52it/s]
100%|██████████| 125/125 [00:22<00:00,  5.49it/s]
100%|██████████| 125/125 [00:22<00:00,  5.52it/s]


	eval en: 0.87675
	eval fr: 0.87575
	eval de: 0.864
	eval es: 0.87675


Unnamed: 0,full
en,0.87675
fr,0.87575
de,0.864
es,0.87675


## Saving Models


In [None]:
#!g1.1

for lang in lang_list:
    if lang != 'en':
        models[lang].save_pretrained(f'models/ft_best_{lang}')
