In [2]:
import random
from pathlib import Path
import numpy as np
import torch.nn.functional as F
import pytorch_lightning as pl
from torch import nn

import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from transformers import RobertaTokenizer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint


from functools import partial

from aux_relative_text.multilingual_amazon_anchors import MultilingualAmazonAnchors
from typing import *

from modules.relAttention import RelativeAttention

from datasets import load_dataset, ClassLabel

# Tensorboard extension (for visualization purposes later)
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = Path("./data")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = Path("./saved_models/rel_multi_vanilla")

PROJECT_ROOT = Path("./")

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device=torch.device("cpu")
print("Device:", device)

Global seed set to 42


Device: cpu


# Data

In [3]:
fine_grained: bool = True
target_key: str = "class"
data_key: str = "content"
anchor_dataset_name: str = "amazon_translated"  
ALL_LANGS = ("en", "es", "fr")
num_anchors: int = 768
train_perc: float = 0.25

In [4]:
def get_dataset(lang: str, split: str, perc: float, fine_grained: bool):
    pl.seed_everything(42)
    assert 0 < perc <= 1
    dataset = load_dataset("amazon_reviews_multi", lang)[split]

    if not fine_grained:
        dataset = dataset.filter(lambda sample: sample["stars"] != 3)

    # Select a random subset
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    indices = indices[: int(len(indices) * perc)]
    dataset = dataset.select(indices)

    def clean_sample(sample):
        title: str = sample["review_title"].strip('"').strip(".").strip()
        body: str = sample["review_body"].strip('"').strip(".").strip()

        if body.lower().startswith(title.lower()):
            title = ""

        if len(title) > 0 and title[-1].isalpha():
            title = f"{title}."

        sample["content"] = f"{title} {body}".lstrip(".").strip()
        if fine_grained:
            sample[target_key] = str(sample["stars"] - 1)
        else:
            sample[target_key] = sample["stars"] > 3
        return sample

    dataset = dataset.map(clean_sample)
    dataset = dataset.cast_column(
        target_key,
        ClassLabel(num_classes=5 if fine_grained else 2, names=list(map(str, range(1, 6) if fine_grained else (0, 1)))),
    )

    return dataset

def _amazon_translated_get_samples(lang: str, sample_idxs):
    anchor_dataset = MultilingualAmazonAnchors(split="train", language=lang)
    anchors = []
    for anchor_idx in sample_idxs:
        anchor = anchor_dataset[anchor_idx]
        anchor[data_key] = anchor["data"]
        anchors.append(anchor)
    return anchors

In [5]:
train_datasets = {
    lang: get_dataset(lang=lang, split="train", perc=train_perc, fine_grained=fine_grained) for lang in ALL_LANGS
    }

test_datasets = {
    lang: get_dataset(lang=lang, split="test", perc=1, fine_grained=fine_grained) for lang in ALL_LANGS
    }

num_labels = list(train_datasets.values())[0].features[target_key].num_classes

Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\en\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-3476f9b441626d5a.arrow
Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\en\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-d574878ad23ad156.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\es\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-f863bc7387640b81.arrow
Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\es\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-ee0b7780b329a439.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\fr\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-e1bd444b80bb0045.arrow
Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\fr\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-c79029d9a5b26980.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\en\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-fd9b2139e905e2ea.arrow
Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\en\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-73acd18e8f4cf9e6.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\es\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-5feabbc631f1f845.arrow
Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\es\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-efd031a678918e43.arrow
Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\fr\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-ee48d5caa215e866.arrow
Loading cached processed dataset at C:\Users\alexg\.cache\huggingface\datasets\amazon_reviews_multi\fr\1.0.0\724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609\cache-c678b00bf840bf9e.arrow


In [6]:
assert len(set(frozenset(train_dataset.features.keys()) for train_dataset in train_datasets.values())) == 1
class2idx = train_datasets["en"].features[target_key].str2int

train_datasets["en"].features

{'review_id': Value(dtype='string', id=None),
 'product_id': Value(dtype='string', id=None),
 'reviewer_id': Value(dtype='string', id=None),
 'stars': Value(dtype='int32', id=None),
 'review_body': Value(dtype='string', id=None),
 'review_title': Value(dtype='string', id=None),
 'language': Value(dtype='string', id=None),
 'product_category': Value(dtype='string', id=None),
 'content': Value(dtype='string', id=None),
 'class': ClassLabel(names=['1', '2', '3', '4', '5'], id=None)}

Get pararel anchors

In [7]:
anchor_dataset2num_samples = 1000
anchor_dataset2first_anchors = [
        776,
        507,
        895,
        922,
        33,
        483,
        85,
        750,
        354,
        523,
        184,
        809,
        418,
        615,
        682,
        501,
        760,
        49,
        732,
        336,
    ]


assert num_anchors <= anchor_dataset2num_samples

pl.seed_everything(42)
anchor_idxs = list(range(anchor_dataset2num_samples))
random.shuffle(anchor_idxs)
anchor_idxs = anchor_idxs[:num_anchors]

assert anchor_idxs[:20] == anchor_dataset2first_anchors  # better safe than sorry
lang2anchors = {
    lang: _amazon_translated_get_samples(lang=lang, sample_idxs=anchor_idxs) for lang in ALL_LANGS
}

Global seed set to 42


This is how we can handdle automatically the tokenizer

In [8]:
def collate_fn(batch, tokenizer):
    encoding = tokenizer(
        [sample[data_key] for sample in batch],
        return_tensors="pt",
        return_special_tokens_mask=True,
        truncation=True,
        max_length=512,
        padding=True,
    )
    del encoding["special_tokens_mask"]
    return  encoding

# Train

In [9]:
from pl_modules.pl_roberta import LitRelRoberta

In [10]:
lang2transformer_name = {
    "en": "roberta-base",
    "es": "PlanTL-GOB-ES/roberta-base-bne",
    "fr": "ClassCat/roberta-base-french",
    #"ja": "nlp-waseda/roberta-base-japanese",
}
assert set(lang2transformer_name.keys()) == set(ALL_LANGS)

In [11]:
train_lang2dataloader = {}
test_lang2dataloader = {}
anchors_lang2dataloader = {}

for lang in ALL_LANGS:
    transformer_name = lang2transformer_name[lang]
    print(transformer_name)
    lang_tokenizer = AutoTokenizer.from_pretrained(transformer_name)
    train_lang2dataloader[lang] = DataLoader(train_datasets[lang],
                                       num_workers=4,
                                       collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),
                                       shuffle=True,
                                       batch_size=32,
                                       )
    
    test_lang2dataloader[lang] = DataLoader(test_datasets[lang],
                                       num_workers=4,
                                       collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),
                                       batch_size=32,
                                       )
    
    anchors_lang2dataloader[lang] = DataLoader(lang2anchors[lang],
                                       num_workers=4,
                                       collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),
                                       shuffle=False,
                                       batch_size=32,
                                       )

roberta-base
PlanTL-GOB-ES/roberta-base-bne
ClassCat/roberta-base-french


In [11]:
anchors =  AutoTokenizer.from_pretrained(lang2transformer_name["en"])(
        [sample[data_key] for sample in lang2anchors["en"]],
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding=True,
    ).to(device)

In [12]:
transformer_model = lang2transformer_name["en"]

a = LitRelRoberta(num_labels=num_labels,
                 transformer_model=transformer_model,
                 anchors=anchors,
                 hidden_size=768,
                 similarity_mode="inner",
                 normalization_mode="l2",
                 output_normalization_mode=None,
                 dropout_prob=0.1,
                 seed=42,
                 epochs=5,
                 weight_decay=0.01, 
                 lr_init=1e-3,
                 layer_decay=0.9,
                 warmup_steps=500,
                 device=device
                 )

Global seed set to 42
Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'roberta.pooler.dense.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
a.net.anchors_latent[:, 0, :].shape

torch.Size([768, 768])

In [None]:
EPOCHS = 5

def train_network(lang, seed=24):
    
    # Create a PyTorch Lightning trainer with the generation callback
    
    trainer = pl.Trainer(default_root_dir=CHECKPOINT_PATH / f"{lang}_seed{seed}", 
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=EPOCHS, 
                         callbacks=[ModelCheckpoint(save_weights_only=True),
                                    LearningRateMonitor("step")])
    
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need
    
    transformer_model = lang2transformer_name[lang]
    anchor_loader = anchors_lang2dataloader[lang]
    
    model = LitRelRoberta(num_labels=num_labels,
                          transformer_model=transformer_model,
                          anchor_loader=anchor_loader,
                          hidden_size=768,
                          similarity_mode="inner",
                          normalization_mode="l2",
                          output_normalization_mode=None,
                          dropout_prob=0.1,
                          seed=seed,
                          epochs=EPOCHS,
                          weight_decay=0.01, 
                          lr_init=1e-3,
                          layer_decay=0.9,
                          warmup_steps=500
                          )
    
    train_loader = train_lang2dataloader[lang]
    test_loader = test_lang2dataloader[lang]
   
    trainer.fit(model, train_loader, val_loader)
    
    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result}
    
    return model, result
    

In [None]:
# latent_normalize: bool = True ??????

In [None]:
SEEDS = list(range(5))
train_classifiers = {
    seed: {
        embedding_type: {
            train_lang: fit(
                lang2train_latents[train_lang][embedding_type],
                train_dataset[target_key],
                seed=seed,
                normalize=latent_normalize,
            )
            for train_lang, train_dataset in tqdm(train_datasets.items(), leave=False, desc="lang")
        }
        for embedding_type in tqdm(["absolute", "relative"], leave=False, desc="embedding_type")
    }
    for seed in tqdm(SEEDS, leave=False, desc="seed")
}

# Results

In [None]:
from sklearn.metrics import precision_recall_fscore_support, mean_absolute_error

numeric_results = {
    "seed": [],
    "embed_type": [],
    "train_lang": [],
    "test_lang": [],
    "precision": [],
    "recall": [],
    "fscore": [],
    "mae": [],
    "stitched": [],
}
for seed, embed_type2train_lang2classifier in train_classifiers.items():
    for embed_type, train_lang2classifier in embed_type2train_lang2classifier.items():
        for train_lang, classifier in train_lang2classifier.items():
            for test_lang, test_latents in langt2test_latents.items():
                test_latents = test_latents[embed_type]
                if latent_normalize:
                    test_latents = F.normalize(test_latents, p=2, dim=-1)
                preds = classifier(test_latents)
                test_y = np.array(test_datasets[test_lang][target_key])

                precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average="weighted")
                mae = mean_absolute_error(y_true=test_y, y_pred=preds)
                numeric_results["embed_type"].append(embed_type)
                numeric_results["train_lang"].append(train_lang)
                numeric_results["test_lang"].append(test_lang)
                numeric_results["precision"].append(precision)
                numeric_results["recall"].append(recall)
                numeric_results["fscore"].append(fscore)
                numeric_results["stitched"].append(train_lang != test_lang)
                numeric_results["mae"].append(mae)
                numeric_results["seed"].append(seed)


import pandas as pd

pd.options.display.max_columns = None
pd.options.display.max_rows = None
df = pd.DataFrame(numeric_results)
df.to_csv(
    f"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv",
    sep="\t",
)

df = df.groupby(
    [
        "embed_type",
        "stitched",
        "train_lang",
        "test_lang",
    ]
).agg([np.mean])
df

In [None]:
f"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv"

In [None]:
import pandas as pd
import numpy as np

# fine_grained: bool = False
# anchor_dataset_name: str = "amazon_translated" # wikimatrix, amazon_translated
# train_perc: float = 0.25

# full_df = pd.read_csv(
#     f"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv",
#     sep="\t",
#     index_col=0,
# )

df = df.groupby(
    [
        "embed_type",
        "stitched",
        "train_lang",
        "test_lang",
    ]
).agg([np.mean, "count"])
df

In [None]:
df.drop(columns=["stitched", "seed", "precision", "recall"])[full_df.train_lang == "en"].groupby(
    ["embed_type", "train_lang", "test_lang"]
).agg([np.mean, np.std]).round(3)