Colab related:

In [110]:
#!g1.1
# from google.colab import drive
# drive.mount('/content/drive')

# !cp -r "/content/drive/MyDrive/Colab Notebooks/Diploma/handle_amazon/amazon_en" .
# !cp -r "/content/drive/MyDrive/Colab Notebooks/Diploma/handle_amazon/amazon_fr" .
# !cp -r "/content/drive/MyDrive/Colab Notebooks/Diploma/handle_amazon/amazon_de" .
# !cp -r "/content/drive/MyDrive/Colab Notebooks/Diploma/handle_amazon/amazon_es" .

# !pip install transformers datasets


DS related:

In [111]:
#!g1.1
# %pip install seaborn
# %pip install transformers datasets



## Imports

In [112]:
#!g1.1
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm


from torch.utils.data import DataLoader

def nice_df(df, axis=None, reverse=False, **kwargs):
    cm = sns.light_palette("green", as_cmap=True, reverse=reverse)
    return df.style.background_gradient(cmap=cm, axis=axis, **kwargs)

device = torch.device("cuda")



## Loading Data

In [113]:
#!g1.1

# !unzip handle_amazon


In [114]:
#!g1.1
from datasets import concatenate_datasets, load_from_disk

BS = 32
lang_list = ['en', 'fr', 'de', 'es']
split_list = ['train', 'validation', 'test']


# data = {
#     lang: load_from_disk(f'handle_amazon/amazon_{lang}')
#     for lang in lang_list
# }

tr_data = {
    lang: load_from_disk(f'handle_amazon/amazon_ok_tr_{lang}')
    for lang in lang_list
}

dataloader = {
    lang: {
        split: DataLoader(tr_data[lang][split], batch_size=BS, shuffle=(split == 'train'))
        for split in split_list
    }
    for lang in lang_list
}



## Model

In [129]:
#!g1.1
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

model = DistilBertForSequenceClassification.from_pretrained("./models/ft_no_tr_2_en", num_labels=2, id2label=id2label, label2id=label2id, output_hidden_states=True)
model.to(device);



## Eval and Training

In [127]:
#!g1.1
def eval(model, dls, lang, test_split, pref=''):
    # put model in eval mode
    model.eval()

    # get needful data slice
    dl_to_test = dls[lang][test_split]
    
    test_loss = 0
    test_acc = 0
    
    with torch.no_grad():
        for batch in tqdm(dl_to_test):
            # move batch to device
            input_ids = batch[pref+'input_ids'].to(model.device)
            attention_mask = batch[pref+'attention_mask'].to(model.device)
            labels = batch['bin_label'].to(model.device)

            # forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # calculate loss and accuracy
            preds = logits.argmax(dim=1)
            test_acc += (preds == labels).sum().item()

    test_acc /= BS * len(dl_to_test)
    print(f'\teval {lang}: {test_acc}')
    return test_acc


## All Lang Models

In [130]:
#!g1.1
eval_res = pd.DataFrame(data = np.zeros((4, 1)), columns = ['finetune_translation'], index=lang_list)

for lang in lang_list:
    test_res = eval(model, dataloader, lang, 'test', pref='en_')
    eval_res.at[lang, 'finetune_translation'] = test_res

nice_df(eval_res)


100%|██████████| 125/125 [00:22<00:00,  5.53it/s]


	eval en: 0.8775


100%|██████████| 125/125 [00:22<00:00,  5.55it/s]


	eval fr: 0.8775


100%|██████████| 125/125 [00:22<00:00,  5.50it/s]


	eval de: 0.86625


100%|██████████| 125/125 [00:22<00:00,  5.53it/s]


	eval es: 0.8735


Unnamed: 0,finetune_translation
en,0.8775
fr,0.8775
de,0.86625
es,0.8735


In [None]:
#!g1.1

model_full = DistilBertForSequenceClassification.from_pretrained("./models/ft_no_tr_2_en", num_labels=2, id2label=id2label, label2id=label2id, output_hidden_states=True)
model_full.to(device);


In [None]:
#!g1.1

#!g1.1
eval_res_full = pd.DataFrame(data = np.zeros((4, 1)), columns = ['finetune_translation'], index=lang_list)

for lang in lang_list:
    test_res_full = eval(model_full, dataloader, lang, 'test', pref='en_')
    eval_res_full.at[lang, 'finetune_translation'] = test_res_full

nice_df(eval_res_full)
