In [24]:
import random
from pathlib import Path
import numpy as np
import torch.nn.functional as F
import pytorch_lightning as pl
from torch import nn

import pandas as pd


import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from transformers import RobertaTokenizer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from sklearn.metrics import precision_recall_fscore_support, mean_absolute_error

from functools import partial

from aux_relative_text.multilingual_amazon_anchors import MultilingualAmazonAnchors
from typing import *

from modules.relAttention import RelativeAttention
from modules.stitching_module import StitchingModule

from datasets import load_dataset, ClassLabel

# Tensorboard extension (for visualization purposes later)
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = Path("./data")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = Path("./saved_models/rel_multi_vanilla")
RESULT_PATH = Path("./results/rel_multi_vanilla")

PROJECT_ROOT = Path("./")

pd.options.display.max_columns = None
pd.options.display.max_rows = None

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Global seed set to 42


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
Device: cuda:0


# Data

In [2]:
fine_grained: bool = True
target_key: str = "class"
data_key: str = "content"
anchor_dataset_name: str = "amazon_translated"  
ALL_LANGS = ("en", "es", "fr")
num_anchors: int = 768
train_perc: float = 0.25

In [3]:
def get_dataset(lang: str, split: str, perc: float, fine_grained: bool):
    pl.seed_everything(42)
    assert 0 < perc <= 1
    dataset = load_dataset("amazon_reviews_multi", lang)[split]

    if not fine_grained:
        dataset = dataset.filter(lambda sample: sample["stars"] != 3)

    # Select a random subset
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    indices = indices[: int(len(indices) * perc)]
    dataset = dataset.select(indices)

    def clean_sample(sample):
        title: str = sample["review_title"].strip('"').strip(".").strip()
        body: str = sample["review_body"].strip('"').strip(".").strip()

        if body.lower().startswith(title.lower()):
            title = ""

        if len(title) > 0 and title[-1].isalpha():
            title = f"{title}."

        sample["content"] = f"{title} {body}".lstrip(".").strip()
        if fine_grained:
            sample[target_key] = str(sample["stars"] - 1)
        else:
            sample[target_key] = sample["stars"] > 3
        return sample

    dataset = dataset.map(clean_sample)
    dataset = dataset.cast_column(
        target_key,
        ClassLabel(num_classes=5 if fine_grained else 2, names=list(map(str, range(1, 6) if fine_grained else (0, 1)))),
    )

    return dataset

def _amazon_translated_get_samples(lang: str, sample_idxs):
    anchor_dataset = MultilingualAmazonAnchors(split="train", language=lang)
    anchors = []
    for anchor_idx in sample_idxs:
        anchor = anchor_dataset[anchor_idx]
        anchor[data_key] = anchor["data"]
        anchors.append(anchor)
    return anchors

In [4]:
train_datasets = {
    lang: get_dataset(lang=lang, split="train", perc=train_perc, fine_grained=fine_grained) for lang in ALL_LANGS
    }

test_datasets = {
    lang: get_dataset(lang=lang, split="test", perc=1, fine_grained=fine_grained) for lang in ALL_LANGS
    }

val_datasets = {
    lang: get_dataset(lang=lang, split="validation", perc=1, fine_grained=fine_grained) for lang in ALL_LANGS
    }

num_labels = list(train_datasets.values())[0].features[target_key].num_classes

Global seed set to 42
Found cached dataset amazon_reviews_multi (/home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-ec0ea0aad8f98192.arrow
Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-115fb520e0899335.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (/home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-5124f3d24b8cfecb.arrow
Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-d9c8d8f2f813d97f.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (/home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-0985b4a32f5feef9.arrow
Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-39a640cabb5a59c4.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (/home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-8553f71d56c9ba4c.arrow
Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-2c8384112752703f.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (/home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-e3dc951c42308c5b.arrow
Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-bf13bb2c70209559.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (/home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-b471ec6ce2ee1b83.arrow
Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-6f62c30e3bb3f98f.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (/home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-e7ff0dc70b32da22.arrow
Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-586950f56ae31790.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (/home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-fe159db7bc22043d.arrow
Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-0c3248cebcb837fb.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (/home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-c271de877e412081.arrow
Loading cached processed dataset at /home/thepopi300/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609/cache-9106148a706e9181.arrow


In [5]:
train_datasets["es"][5]

{'review_id': 'es_0291786',
 'product_id': 'product_es_0674372',
 'reviewer_id': 'reviewer_es_0553268',
 'stars': 1,
 'review_body': 'Pinzas de malísima calidad. La mayoría vienen astilladas. Resultan hasta peligrosas. A pesar de ser un producto plus, su precio es más caro que las que venden en la calle y de peor calidad. No esperaba esta calidad de un producto vendido por amazon. Totalmente decepcionada.',
 'review_title': 'Malísima calidad',
 'language': 'es',
 'product_category': 'home',
 'content': 'Malísima calidad. Pinzas de malísima calidad. La mayoría vienen astilladas. Resultan hasta peligrosas. A pesar de ser un producto plus, su precio es más caro que las que venden en la calle y de peor calidad. No esperaba esta calidad de un producto vendido por amazon. Totalmente decepcionada',
 'class': 0}

In [6]:
assert len(set(frozenset(train_dataset.features.keys()) for train_dataset in train_datasets.values())) == 1
class2idx = train_datasets["en"].features[target_key].str2int

train_datasets["en"].features

{'review_id': Value(dtype='string', id=None),
 'product_id': Value(dtype='string', id=None),
 'reviewer_id': Value(dtype='string', id=None),
 'stars': Value(dtype='int32', id=None),
 'review_body': Value(dtype='string', id=None),
 'review_title': Value(dtype='string', id=None),
 'language': Value(dtype='string', id=None),
 'product_category': Value(dtype='string', id=None),
 'content': Value(dtype='string', id=None),
 'class': ClassLabel(names=['1', '2', '3', '4', '5'], id=None)}

Get pararel anchors

In [7]:
anchor_dataset2num_samples = 1000
anchor_dataset2first_anchors = [
        776,
        507,
        895,
        922,
        33,
        483,
        85,
        750,
        354,
        523,
        184,
        809,
        418,
        615,
        682,
        501,
        760,
        49,
        732,
        336,
    ]


assert num_anchors <= anchor_dataset2num_samples

pl.seed_everything(42)
anchor_idxs = list(range(anchor_dataset2num_samples))
random.shuffle(anchor_idxs)
anchor_idxs = anchor_idxs[:num_anchors]

assert anchor_idxs[:20] == anchor_dataset2first_anchors  # better safe than sorry
lang2anchors = {
    lang: _amazon_translated_get_samples(lang=lang, sample_idxs=anchor_idxs) for lang in ALL_LANGS
}

Global seed set to 42


This is how we can handdle automatically the tokenizer

In [8]:
def collate_fn(batch, tokenizer, cls=True):
    data = []
    labels = []
    for sample in batch:
        data.append(sample[data_key])
        if cls:
            labels.append(sample[target_key])

    encoding = tokenizer(
        data,
        return_tensors="pt",
        return_special_tokens_mask=True,
        truncation=True,
        max_length=512,
        padding=True,
    )
    del encoding["special_tokens_mask"]

    if cls:
        result = (encoding, torch.tensor(labels))
    else:
        result = encoding

    return  result

# Train

In [9]:
from pl_modules.pl_roberta import LitRelRoberta

In [10]:
lang2transformer_name = {
    "en": "roberta-base",
    "es": "PlanTL-GOB-ES/roberta-base-bne",
    "fr": "ClassCat/roberta-base-french",
    #"ja": "nlp-waseda/roberta-base-japanese",
}
assert set(lang2transformer_name.keys()) == set(ALL_LANGS)

In [11]:
train_lang2dataloader = {}
test_lang2dataloader = {}
val_lang2dataloader = {}
anchors_lang2dataloader = {}

for lang in ALL_LANGS:
    transformer_name = lang2transformer_name[lang]
    print(transformer_name)
    lang_tokenizer = AutoTokenizer.from_pretrained(transformer_name)
    train_lang2dataloader[lang] = DataLoader(train_datasets[lang],
                                       num_workers=4,
                                       collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),
                                       shuffle=True,
                                       pin_memory=True,
                                       drop_last=True,
                                       batch_size=16,
                                       )
    
    test_lang2dataloader[lang] = DataLoader(test_datasets[lang],
                                       num_workers=4,
                                       collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),
                                       batch_size=32,
                                       )
    
    val_lang2dataloader[lang] = DataLoader(val_datasets[lang],
                                       num_workers=4,
                                       collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),
                                       batch_size=32,
                                       )
    
    anchors_lang2dataloader[lang] = DataLoader(lang2anchors[lang],
                                       num_workers=4,
                                       pin_memory=True,
                                       collate_fn=partial(collate_fn, tokenizer=lang_tokenizer, cls=False),
                                       batch_size=48,
                                       )

roberta-base
PlanTL-GOB-ES/roberta-base-bne
ClassCat/roberta-base-french


In [24]:
EPOCHS = 5


def train_network(lang, mode="relative", seed=24, fine_tune=False):
    
    # Create a PyTorch Lightning trainer with the generation callback
    
    if fine_grained:
        title = CHECKPOINT_PATH / 'fine_grained' 
    else:
        title = CHECKPOINT_PATH / 'coarse_grained' 
    
    if fine_tune:
        title = title / f"finetune_{lang}_{mode}_seed{seed}"
    else:
        title = title / f"full_{lang}_{mode}_seed{seed}"
    
    trainer = pl.Trainer(default_root_dir=title, 
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         accumulate_grad_batches=num_labels,
                         max_epochs=EPOCHS, 
                         callbacks=[ModelCheckpoint(save_weights_only=True),
                                    LearningRateMonitor(logging_interval='step')
                                    ])
    
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need
    
    transformer_model = lang2transformer_name[lang]
    
    anchor_loader = None
    if mode == "relative":
        anchor_loader = anchors_lang2dataloader[lang]
    
    
    train_loader = train_lang2dataloader[lang]
    
    if fine_tune:
        freq_anchors = len(train_loader)
    else:
        freq_anchors = 100*num_labels
    
    model = LitRelRoberta(num_labels=num_labels,
                          transformer_model=transformer_model,
                          anchor_dataloader=anchor_loader,
                          hidden_size=num_anchors,
                          normalization_mode="batchnorm",
                          output_normalization_mode=None,
                          dropout_prob=0.1,
                          seed=seed,
                          steps=EPOCHS*len(train_loader),
                          weight_decay=0.01, 
                          head_lr=1e-3/num_labels,
                          encoder_lr=1.75e-4/num_labels,
                          layer_decay=0.65,
                          scheduler_act=True,
                          freq_anchors=freq_anchors,
                          device=device,
                          fine_tune=fine_tune
                          )
    
    val_loader = val_lang2dataloader[lang]
   
    trainer.fit(model, train_loader, val_loader)
        
    return model
    

In [25]:
train_network("es", mode="relative", seed=1, fine_tune=True)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 1
Some weights of the model checkpoint at PlanTL-GOB-ES/roberta-base-bne were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.decoder.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Pa

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


LitRelRoberta(
  (net): RelRoberta(
    (encoder): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(50262, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0): RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
                (dense): Linear(in_features=768, out_feat

In [26]:
train_network("en", mode="relative", seed=1, fine_tune=True)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 1
Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Typ

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


LitRelRoberta(
  (net): RelRoberta(
    (encoder): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(50265, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0): RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
                (dense): Linear(in_features=768, out_feat

In [None]:
SEEDS = list(range(5))

for seed in tqdm(SEEDS, leave=False, desc="seed"):
  for embedding_type in tqdm(["absolute", "relative"], leave=False, desc="embedding_type"):
    for train_lang in tqdm(ALL_LANGS, leave=False, desc="lang"):
      train_network(train_lang, mode=embedding_type, seed=seed)


# Results

In [12]:
def test_model(model, dataloader, title=""):
    preds = []
    model.to(device)
    model.eval()
    with torch.no_grad():
        batch_idx = 0
        for batch, _ in tqdm(dataloader, position=0, leave=True, desc="Computing"+title):
            batch.to(device)
            batch_latents = model(batch_idx=batch_idx, **batch)["prediction"].argmax(-1)
            preds.append(batch_latents)
            batch_idx = 1

    preds = torch.cat(preds, dim=0).cpu()
    test_y = np.array(test_datasets["en"][target_key])

    precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average="weighted")
    mae = mean_absolute_error(y_true=test_y, y_pred=preds)
    return precision, recall, fscore, mae

In [13]:
models = {
    train_mode: {
        seed: {
            embedding_type: {
                train_lang: LitRelRoberta.load_from_checkpoint(
                                              CHECKPOINT_PATH / 
                                              f"{'fine_grained' if fine_grained else 'coarse_grained'}/{train_mode}_{train_lang}_{embedding_type}_seed{seed}" /
                                              "lightning_logs/version_0/checkpoints/epoch=4-step=3125.ckpt" )
                
                for train_lang in ["en", "es"]
            }
            for embedding_type in ["absolute", "relative"]
        }
        for seed in [1]
    }
    for train_mode in tqdm(["finetune", "full"], leave=True, desc="mode")
}

mode:   0%|                                                                              | 0/2 [00:00<?, ?it/s]Global seed set to 1
Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Global seed set to 1
Some weights of the model checkpoint at PlanTL-GOB-ES/roberta-base-bne were not used

In [15]:
numeric_results = {
    "finetune": {
        "seed": [],
        "embed_type": [],
        "enc_lang": [],
        "dec_lang": [],
        "precision": [],
        "recall": [],
        "fscore": [],
        "mae": [],
        "stitched": []
    },
     "full": {
        "seed": [],
        "embed_type": [],
        "enc_lang": [],
        "dec_lang": [],
        "precision": [],
        "recall": [],
        "fscore": [],
        "mae": [],
        "stitched": []
    },
}

for mode in ["finetune", "full"]:
    for seed in [1]:
        for embed_type in ["absolute", "relative"]:
            for enc_lang  in ["en", "es"]:
                for dec_lang  in ["en", "es"]:
                    
                    model = models[mode][seed][embed_type][enc_lang].net
                    if embed_type == "relative":
                        model.anchor_dataloader = anchors_lang2dataloader[enc_lang]
                        
                    if enc_lang != dec_lang:
                        model_dec = models[mode][seed][embed_type][dec_lang].net
                        model = StitchingModule(model, model_dec)
                      
                        
                    # The data is paired with its encoder
                    test_loader = test_lang2dataloader[enc_lang]
                    title = f" {mode}_seed{seed}_{embed_type}_{enc_lang}_{dec_lang}"

                    precision, recall, fscore, mae = test_model(model, test_loader, title)
                    numeric_results[mode]["embed_type"].append(embed_type)
                    numeric_results[mode]["enc_lang"].append(enc_lang)
                    numeric_results[mode]["dec_lang"].append(dec_lang)
                    numeric_results[mode]["precision"].append(precision)
                    numeric_results[mode]["recall"].append(recall)
                    numeric_results[mode]["fscore"].append(fscore)
                    numeric_results[mode]["stitched"].append(enc_lang != dec_lang)
                    numeric_results[mode]["mae"].append(mae)
                    numeric_results[mode]["seed"].append(seed)


Computing finetune_seed1_absolute_en_en: 100%|███████████████████████████████| 157/157 [00:14<00:00, 10.84it/s]
Computing finetune_seed1_absolute_en_es: 100%|███████████████████████████████| 157/157 [00:14<00:00, 10.81it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Computing finetune_seed1_absolute_es_en: 100%|███████████████████████████████| 157/157 [00:11<00:00, 13.74it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Computing finetune_seed1_absolute_es_es: 100%|███████████████████████████████| 157/157 [00:11<00:00, 13.74it/s]
Computing finetune_seed1_relative_en_en: 100%|███████████████████████████████| 157/157 [00:17<00:00,  9.09it/s]
Computing finetune_seed1_relative_en_es: 100%|███████████████████████████████| 157/157 [00:17<00:00,  9.06it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Computing finetune_seed1_relative_es_en: 100%|███████████████████████████████| 157/157 [00:14<00:00, 10.88it/s]
Computing finetune_seed1_relative_es_es: 100%|█████

In [26]:
df = pd.DataFrame(numeric_results["finetune"])


df = df.drop(columns=["stitched", "seed", "precision", "recall"]).groupby(
    ["embed_type", "enc_lang", "dec_lang"]
).agg([np.mean, np.std]).round(3)

df.to_csv(
    RESULT_PATH / f"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-finetune-{train_perc}.tsv",
    sep="\t",
)

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,fscore,fscore,mae,mae
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,std,mean,std
embed_type,enc_lang,dec_lang,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
absolute,en,en,0.64,,0.408,
absolute,en,es,0.067,,1.803,
absolute,es,en,0.233,,1.196,
absolute,es,es,0.623,,0.43,
relative,en,en,0.625,,0.431,
relative,en,es,0.474,,0.881,
relative,es,en,0.527,,0.605,
relative,es,es,0.582,,0.477,


In [25]:
df = pd.DataFrame(numeric_results["full"])

df = df.drop(columns=["stitched", "seed", "precision", "recall"]).groupby(
    ["embed_type", "enc_lang", "dec_lang"]
).agg([np.mean, np.std]).round(3)

df.to_csv(
    RESULT_PATH / f"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-full-{train_perc}.tsv",
    sep="\t",
)

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,fscore,fscore,mae,mae
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,std,mean,std
embed_type,enc_lang,dec_lang,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
absolute,en,en,0.73,,0.292,
absolute,en,es,0.179,,1.67,
absolute,es,en,0.335,,1.124,
absolute,es,es,0.679,,0.343,
relative,en,en,0.724,,0.296,
relative,en,es,0.683,,0.322,
relative,es,en,0.677,,0.346,
relative,es,es,0.686,,0.336,
