## Imports

In [25]:
#!g1.1
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from transformers import pipeline
from tqdm import tqdm

from torch.utils.data import DataLoader

def nice_df(df, axis=None, reverse=False, **kwargs):
    cm = sns.light_palette("green", as_cmap=True, reverse=reverse)
    return df.style.background_gradient(cmap=cm, axis=axis, **kwargs)

device = torch.device("cuda")



## Loading Data

In [26]:
#!g1.1
from datasets import concatenate_datasets, load_from_disk

BS = 32
lang_list = ['en', 'fr', 'de', 'es']
split_list = ['train', 'validation', 'test']


data = {
    lang: load_from_disk(f'handle_amazon/amazon_ok_tr_{lang}')
    for lang in lang_list
}

dataloader = {
    lang: {
        split: DataLoader(data[lang][split], batch_size=BS, shuffle=(split == 'train'))
        for split in split_list
    }
    for lang in lang_list
}



## Model

In [27]:
#!g1.1
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

# tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id, output_hidden_states=True)
model.to(device)
for param in model.base_model.parameters():
    param.requires_grad = False

with torch.no_grad():
    for batch in dataloader['en']['test']:
        i_d = batch["input_ids"].to(device)
        a_m = batch["attention_mask"].to(device)
        batch_hs = model(
                input_ids=i_d,
                attention_mask=a_m,
            ).hidden_states[-1].mean(dim=1)    
        print(batch_hs)

        logits = model(
                input_ids=i_d,
                attention_mask=a_m,
            ).logits    
        print(logits)

        print(torch.argmax(
            model(
                input_ids=i_d,
                attention_mask=a_m,
            ).logits,
            axis=-1
        ))
        break


Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'pre_classi

tensor([[ 0.3267,  0.0176,  0.1214,  ...,  0.0836, -0.0200, -0.1360],
        [ 0.1976,  0.1497,  0.1829,  ...,  0.0148,  0.0260, -0.0101],
        [ 0.1492,  0.0464,  0.1645,  ..., -0.0861, -0.0884,  0.0062],
        ...,
        [ 0.1911,  0.1054, -0.0288,  ..., -0.0056,  0.0567, -0.1306],
        [ 0.1369,  0.0744,  0.1030,  ...,  0.0234, -0.1261, -0.0686],
        [ 0.1460,  0.2451,  0.0183,  ...,  0.1160, -0.0157, -0.1809]],
       device='cuda:0')
tensor([[ 0.1394, -0.1306],
        [ 0.1361, -0.1175],
        [ 0.1149, -0.0921],
        [ 0.0970, -0.1148],
        [ 0.0732, -0.0850],
        [ 0.0743, -0.0979],
        [ 0.1071, -0.1082],
        [ 0.1054, -0.1558],
        [ 0.0569, -0.0800],
        [ 0.1023, -0.0974],
        [ 0.1134, -0.0870],
        [ 0.1020, -0.0854],
        [ 0.1184, -0.0616],
        [ 0.1070, -0.0862],
        [ 0.1203, -0.1260],
        [ 0.0716, -0.0795],
        [ 0.1182, -0.1321],
        [ 0.0721, -0.1105],
        [ 0.1005, -0.1020],
        [ 

In [28]:
#!g1.1
for name, param in model.named_parameters():
    if 'clas' in name:
        print(name, param.shape, param.requires_grad)
    else:
        assert not param.requires_grad


pre_classifier.weight torch.Size([768, 768]) True
pre_classifier.bias torch.Size([768]) True
classifier.weight torch.Size([2, 768]) True
classifier.bias torch.Size([2]) True


## Training

In [33]:
#!g1.1
# translators = {
#     lang: pipeline("translation", model=f"Helsinki-NLP/opus-mt-{lang}-en", batch_size=8, max_length=150)
#     for lang in ['fr', 'de', 'es']
# }

translators_back = {
    lang: pipeline("translation", model=f"Helsinki-NLP/opus-mt-en-{lang}", max_length=150)
    for lang in ['fr', 'de', 'es']
}


In [34]:
#!g1.1
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
def tokenization(example):
    return tokenizer(example, truncation=True, padding=True, pad_to_multiple_of=512)

extreme_words = [
    'negative',
    'terrible',
    'horrible',
    'awful',
    'dreadful',
    'lousy',
    'abysmal',
    'dismal',
    'unpleasant',
    'repulsive',
    'completely devastated',
    'an utter disaster',
    'can\'t stand it anymore',
    'a total nightmare',
    'feels completely hopeless',
    'at my breaking point',
    'the worst thing ever',
    'can\'t take it anymore',
    'I\'m completely miserable',
    'I\'m absolutely crushed',

    'positive',
    'wonderful',
    'excellent',
    'fantastic',
    'amazing',
    'great',
    'superb',
    'outstanding',
    'perfect',
    'fabulous',
    'absolutely fantastic',
    'feels over the moon',
    'a dream coming true',
    'I\'m so thrilled',
    'amazing news',
    'can\'t believe how wonderful this is',
    'the best thing ever',
    'I\'m ecstatic about this',
    'feels on top of the world',
    'I\'m overjoyed',
]

extreme_dict = {
    'en': extreme_words
}

for lang in ['fr', 'de', 'es']:
    extreme_dict[lang] = [x['translation_text'] for x in translators_back[lang](extreme_words)]


In [35]:
#!g1.1
extreme_data_dict = {
    lang: [tokenization(word) for word in extreme_dict[lang]]
    for lang in lang_list
}

i_d = {
    lang: torch.tensor([word["input_ids"] for word in extreme_data_dict[lang]]).to(device)
    for lang in lang_list
}
a_m = {
    lang: torch.tensor([word["attention_mask"] for word in extreme_data_dict[lang]]).to(device)
    for lang in lang_list
}

batch_hs = {
    lang: model(
        input_ids=i_d[lang],
        attention_mask=a_m[lang],
    ).hidden_states[-1]  
    for lang in lang_list
}

masked_hs = {
    lang: batch_hs[lang] * a_m[lang][..., None]
    for lang in lang_list
}
extreme_embeds = {
    lang: masked_hs[lang].sum(axis=1) / a_m[lang].sum(axis=-1)[..., None]
    for lang in lang_list
}


for lang in lang_list:
    print(extreme_embeds[lang].shape)


torch.Size([40, 768])
torch.Size([40, 768])
torch.Size([40, 768])
torch.Size([40, 768])


An example of extreme embeds to be used:

In [36]:
#!g1.1
l_lang = 'fr'
ex_ex = extreme_embeds[l_lang] @ extreme_embeds[l_lang].T

num_nn = 5
values, indices = torch.topk(ex_ex, k=num_nn, dim=1)
num_extremes = extreme_embeds[l_lang].shape[0] // 2
((indices >= num_extremes).sum(dim=-1) > num_nn / 2).int()



tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0',
       dtype=torch.int32)

## Running

In [37]:
#!g1.1
def run_with_head(dl, model):
    acc = 0
    for batch in tqdm(dl):
        i_d = batch["input_ids"].to(model.device)
        a_m = batch["attention_mask"].to(model.device)
        logits = model(
                input_ids=i_d,
                attention_mask=a_m,
            ).logits.to('cpu')

        preds = torch.argmax(logits, axis=-1)
        labels = batch['bin_label']
        acc += (preds == labels).float().mean()
    acc /= len(dl)
    return acc

def run_knn_handmade(dl, model, extreme_embeds, num_nn=5):
    acc = 0
    for batch in tqdm(dl):
        i_d = batch["input_ids"].to(model.device)
        a_m = batch["attention_mask"].to(model.device)

        batch_hs = model(
                input_ids=i_d,
                attention_mask=a_m,
            ).hidden_states[-1]

        masked_hs = batch_hs * a_m[..., None]
        embeds = masked_hs.sum(axis=1) / a_m.sum(axis=-1)[..., None]

        corr = embeds @ extreme_embeds.T

        _, indices = torch.topk(corr, k=num_nn, dim=1)
        num_extremes = extreme_embeds.shape[0] // 2
        preds = ((indices >= num_extremes).sum(dim=-1) > num_nn / 2).int().to('cpu')

        labels = batch['bin_label']
        acc += (preds == labels).float().mean()
    acc /= len(dl)
    return acc



In [38]:
#!g1.1
# run_with_head(dataloader['en']['test'], model)


In [40]:
#!g1.1
for lang in lang_list:
    ans_to_print = f"LOOK at {lang}\n"

    for num_nn in range(1, 29, 2):
        res = run_knn_handmade(dataloader['en']['test'], model, extreme_embeds[lang], num_nn)
        ans_to_print += f'knn({num_nn}) = {res}\n'
        # print(num_nn, "done")

    print(ans_to_print)



LOOK at en
knn(1) = 0.5882499814033508
knn(3) = 0.6462500095367432
knn(5) = 0.625249981880188
knn(7) = 0.5364999771118164
knn(9) = 0.5195000171661377
knn(11) = 0.5122500061988831
knn(13) = 0.5147500038146973
knn(15) = 0.5172500014305115
knn(17) = 0.6112499833106995
knn(19) = 0.5917500257492065
knn(21) = 0.5214999914169312
knn(23) = 0.5217499732971191
knn(25) = 0.5230000019073486
knn(27) = 0.5295000076293945

LOOK at fr
knn(1) = 0.47850000858306885
knn(3) = 0.48750001192092896
knn(5) = 0.49950000643730164
knn(7) = 0.5007500052452087
knn(9) = 0.5005000233650208
knn(11) = 0.49950000643730164
knn(13) = 0.5019999742507935
knn(15) = 0.503000020980835
knn(17) = 0.5044999718666077
knn(19) = 0.5082499980926514
knn(21) = 0.5082499980926514
knn(23) = 0.5047500133514404
knn(25) = 0.5027499794960022
knn(27) = 0.5027499794960022

LOOK at de
knn(1) = 0.4987500011920929
knn(3) = 0.5049999952316284
knn(5) = 0.5199999809265137
knn(7) = 0.5017499923706055
knn(9) = 0.5027499794960022
knn(11) = 0.495249986

100%|██████████| 125/125 [00:22<00:00,  5.50it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]
100%|██████████| 125/125 [00:22<00:00,  5.50it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]
100%|██████████| 125/125 [00:22<00:00,  5.49it/s]
100%|██████████| 125/125 [00:22<00:00,  5.50it/s]
100%|██████████| 125/125 [00:22<00:00,  5.49it/s]
100%|██████████| 125/125 [00:22<00:00,  5.49it/s]
100%|██████████| 125/125 [00:22<00:00,  5.49it/s]
100%|██████████| 125/125 [00:22<00:00,  5.49it/s]
100%|██████████| 125/125 [00:22<00:00,  5.49it/s]
100%|██████████| 125/125 [00:22<00:00,  5.50it/s]
100%|██████████| 125/125 [00:23<00:00,  5.42it/s]
100%|██████████| 125/125 [00:22<00:00,  5.49it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]
100%|██████████| 125/125 [00:22<00:00,  5.47it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]
100%|██████████| 125/125 [00:22<00:00,  5.48it/s]


In [None]:
#!g1.1
