In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
import random
from pathlib import Path
import numpy as np
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch import nn

In [5]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

In [12]:
from attention import RelativeAttention, AttentionOutput

In [13]:
device: str = "cuda"
fine_grained: bool = True
target_key: str = "class"
data_key: str = "content"
anchor_dataset_name: str = "amazon_translated"  # wikimatrix, amazon_translated
ALL_LANGS = ("en", "es", "fr", "ja")
num_anchors: int = 768
train_perc: float = 0.25

In [14]:
from datasets import load_dataset, ClassLabel


def get_dataset(lang: str, split: str, perc: float, fine_grained: bool):
    seed_everything(42)
    assert 0 < perc <= 1
    dataset = load_dataset("amazon_reviews_multi", lang)[split]

    if not fine_grained:
        dataset = dataset.filter(lambda sample: sample["stars"] != 3)

    # Select a random subset
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    indices = indices[: int(len(indices) * perc)]
    dataset = dataset.select(indices)

    def clean_sample(sample):
        title: str = sample["review_title"].strip('"').strip(".").strip()
        body: str = sample["review_body"].strip('"').strip(".").strip()

        if body.lower().startswith(title.lower()):
            title = ""

        if len(title) > 0 and title[-1].isalpha():
            title = f"{title}."

        sample["content"] = f"{title} {body}".lstrip(".").strip()
        if fine_grained:
            sample[target_key] = str(sample["stars"] - 1)
        else:
            sample[target_key] = sample["stars"] > 3
        return sample

    dataset = dataset.map(clean_sample)
    dataset = dataset.cast_column(
        target_key,
        ClassLabel(num_classes=5 if fine_grained else 2, names=list(map(str, range(1, 6) if fine_grained else (0, 1)))),
    )

    return dataset

In [15]:
train_datasets = {
    lang: get_dataset(lang=lang, split="train", perc=train_perc, fine_grained=fine_grained) for lang in ALL_LANGS
}
train_datasets["en"].features

Global seed set to 42


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading metadata: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading and preparing dataset amazon_reviews_multi/en to C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/82.0M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/2.06M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/2.05M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/200000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Dataset amazon_reviews_multi downloaded and prepared to C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/50000 [00:00<?, ? examples/s]

Global seed set to 42


Downloading and preparing dataset amazon_reviews_multi/es to C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/77.5M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.93M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/200000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Dataset amazon_reviews_multi downloaded and prepared to C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/50000 [00:00<?, ? examples/s]

Global seed set to 42


Downloading and preparing dataset amazon_reviews_multi/fr to C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/81.9M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/2.02M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/2.04M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/200000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Dataset amazon_reviews_multi downloaded and prepared to C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/50000 [00:00<?, ? examples/s]

Global seed set to 42


Downloading and preparing dataset amazon_reviews_multi/ja to C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/ja/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/169M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/4.19M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/4.21M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/200000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Dataset amazon_reviews_multi downloaded and prepared to C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/ja/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/50000 [00:00<?, ? examples/s]

{'review_id': Value(dtype='string', id=None),
 'product_id': Value(dtype='string', id=None),
 'reviewer_id': Value(dtype='string', id=None),
 'stars': Value(dtype='int32', id=None),
 'review_body': Value(dtype='string', id=None),
 'review_title': Value(dtype='string', id=None),
 'language': Value(dtype='string', id=None),
 'product_category': Value(dtype='string', id=None),
 'content': Value(dtype='string', id=None),
 'class': ClassLabel(names=['1', '2', '3', '4', '5'], id=None)}

In [16]:
assert len(set(frozenset(train_dataset.features.keys()) for train_dataset in train_datasets.values())) == 1
class2idx = train_datasets["en"].features[target_key].str2int
train_datasets["en"].features[target_key], class2idx

(ClassLabel(names=['1', '2', '3', '4', '5'], id=None),
 <bound method ClassLabel.str2int of ClassLabel(names=['1', '2', '3', '4', '5'], id=None)>)

In [19]:
def load_transformer(transformer_name):
    transformer = AutoModel.from_pretrained(transformer_name, output_hidden_states=True, return_dict=True)
    transformer.requires_grad_(False).eval()
    return transformer, AutoTokenizer.from_pretrained(transformer_name)

In [20]:
test_datasets = {lang: get_dataset(lang=lang, split="test", perc=1, fine_grained=fine_grained) for lang in ALL_LANGS}

Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/en/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/es/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/fr/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

Global seed set to 42
Found cached dataset amazon_reviews_multi (C:/Users/alexg/.cache/huggingface/datasets/amazon_reviews_multi/ja/1.0.0/724e94f4b0c6c405ce7e476a6c5ef4f87db30799ad49f765094cf9770e0f7609)


  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [21]:
@torch.no_grad()
def call_transformer(batch, transformer):
    encoding = batch["encoding"].to(device)
    sample_encodings = transformer(**encoding)["hidden_states"][-1]
    # TODO: aggregation mode
    # result = []
    # for sample_encoding, sample_mask in zip(sample_encodings, batch["mask"]):
    #     result.append(sample_encoding[sample_mask].mean(dim=0))

    # return torch.stack(result, dim=0)
    return sample_encodings[:, 0, :]  # CLS

In [26]:
from multilingual_amazon_anchors import MultilingualAmazonAnchors
from typing import *

anchor_dataset2num_samples = {"wikimatrix": 3338, "amazon_translated": 1000}
anchor_dataset2first_anchors = {
    "wikimatrix": [
        361,
        2192,
        1855,
        1163,
        1434,
        3065,
        1329,
        2381,
        2366,
        466,
        1488,
        3007,
        1749,
        2332,
        2463,
        2180,
        1790,
        3328,
        2865,
        1457,
    ],
    "amazon_translated": [
        776,
        507,
        895,
        922,
        33,
        483,
        85,
        750,
        354,
        523,
        184,
        809,
        418,
        615,
        682,
        501,
        760,
        49,
        732,
        336,
    ],
}


def _amazon_translated_get_samples(lang: str, sample_idxs) -> Sequence:
    anchor_dataset = MultilingualAmazonAnchors(split="train", language=lang)
    anchors = []
    for anchor_idx in sample_idxs:
        anchor = anchor_dataset[anchor_idx]
        anchor[data_key] = anchor["data"]
        anchors.append(anchor)
    return anchors




anchor_dataset2sampling = {"amazon_translated": _amazon_translated_get_samples}

assert num_anchors <= anchor_dataset2num_samples[anchor_dataset_name]

seed_everything(42)
anchor_idxs = list(range(anchor_dataset2num_samples[anchor_dataset_name]))
random.shuffle(anchor_idxs)
anchor_idxs = anchor_idxs[:num_anchors]

assert anchor_idxs[:20] == anchor_dataset2first_anchors[anchor_dataset_name]  # better safe than sorry
lang2anchors: Mapping[str, Sequence] = {
    lang: anchor_dataset2sampling[anchor_dataset_name](lang=lang, sample_idxs=anchor_idxs) for lang in ALL_LANGS
}

Global seed set to 42


In [None]:
lang2transformer_name = {
    "en": "roberta-base",
    "es": "PlanTL-GOB-ES/roberta-base-bne",
    "fr": "ClassCat/roberta-base-french",
    "ja": "nlp-waseda/roberta-base-japanese",
}
assert set(lang2transformer_name.keys()) == set(ALL_LANGS)

In [None]:
relative_projection = RelativeAttention(
    n_anchors=num_anchors,
    normalization_mode="l2",
    similarity_mode="inner",
    values_mode="similarities",
    n_classes=train_datasets["en"].features[target_key].num_classes,
    output_normalization_mode=None,
).to(device)

In [None]:
def collate_fn(batch, tokenizer):
    encoding = tokenizer(
        [sample[data_key] for sample in batch],
        return_tensors="pt",
        return_special_tokens_mask=True,
        truncation=True,
        max_length=512,
        padding=True,
    )
    # mask = encoding["attention_mask"] * encoding["special_tokens_mask"].bool().logical_not()
    del encoding["special_tokens_mask"]
    # return {"encoding": encoding, "mask": mask.bool()}
    return {"encoding": encoding}

In [None]:
def get_latents(dataloader, anchors, split: str, transformer) -> Dict[str, torch.Tensor]:
    absolute_latents: List = []
    relative_latents: List = []

    transformer = transformer.to(device)
    for batch in tqdm(dataloader, desc=f"[{split}] Computing latents"):
        with torch.no_grad():
            batch_latents = call_transformer(batch=batch, transformer=transformer)

            absolute_latents.append(batch_latents.cpu())

            if anchors is not None:
                batch_rel_latents = relative_projection.encode(x=batch_latents, anchors=anchors)[
                    AttentionOutput.SIMILARITIES
                ]
                relative_latents.append(batch_rel_latents.cpu())

    absolute_latents: torch.Tensor = torch.cat(absolute_latents, dim=0)
    relative_latents: torch.Tensor = (
        torch.cat(relative_latents, dim=0).cpu() if len(relative_latents) > 0 else relative_latents
    )

    transformer = transformer.cpu()
    return {
        "absolute": absolute_latents,
        "relative": relative_latents,
    }

In [None]:
anchor_dataset_name

In [None]:
from rae import PROJECT_ROOT

LATENTS_DIR: Path = (
    PROJECT_ROOT
    / "data"
    / "latents"
    / "multilingual_amazon"
    / str(train_perc)
    / anchor_dataset_name
    / ("fine_grained" if fine_grained else "coarse_grained")
)
LATENTS_DIR.mkdir(exist_ok=True, parents=True)
LATENTS_DIR

In [None]:
def load_latents(split: str, langs: Sequence[str]):
    lang2latents = {}

    for lang in langs:
        transformer_name = lang2transformer_name[lang]
        transformer_path = LATENTS_DIR / split / lang / f"{transformer_name.replace('/', '-')}.pt"
        if transformer_path.exists():
            lang2latents[lang] = torch.load(transformer_path)

    return lang2latents

In [None]:
from functools import partial


def encode_latents(langs, lang2dataset, lang2latents, split: str):
    for lang in langs:
        transformer_name: str = lang2transformer_name[lang]
        lang_transformer, lang_tokenizer = load_transformer(transformer_name=transformer_name)
        lang2latents[lang] = {
            "anchors_latents": (
                anchors_latents := get_latents(
                    dataloader=DataLoader(
                        lang2anchors[lang],
                        num_workers=4,
                        pin_memory=True,
                        collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),
                        batch_size=32,
                    ),
                    split=f"{transformer_name}, anchor, {split}",
                    anchors=None,
                    transformer=lang_transformer,
                )["absolute"]
            ),
            **get_latents(
                dataloader=DataLoader(
                    lang2dataset[lang],
                    num_workers=4,
                    pin_memory=True,
                    collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),
                    batch_size=32,
                ),
                split=f"{split}/{lang}",
                anchors=anchors_latents.to(device),
                transformer=lang_transformer,
            ),
        }
        # Save latents
        if CACHE_LATENTS:
            transformer_path = LATENTS_DIR / split / lang / f"{transformer_name.replace('/', '-')}.pt"
            transformer_path.parent.mkdir(exist_ok=True, parents=True)
            torch.save(lang2latents[lang], transformer_path)

In [None]:
# Compute test latents

FORCE_RECOMPUTE: bool = False
CACHE_LATENTS: bool = True

langt2test_latents: Dict[str, Mapping[str, torch.Tensor]] = load_latents(split="test", langs=ALL_LANGS)
missing_langs = ALL_LANGS if FORCE_RECOMPUTE else [lang for lang in ALL_LANGS if lang not in langt2test_latents]
encode_latents(langs=missing_langs, lang2dataset=test_datasets, lang2latents=langt2test_latents, split="test")

In [None]:
# Compute train latents

FORCE_RECOMPUTE: bool = False
CACHE_LATENTS: bool = True

lang2train_latents: Dict[str, Mapping[str, torch.Tensor]] = load_latents(split="train", langs=train_datasets.keys())
missing_langs = (
    train_datasets.keys()
    if FORCE_RECOMPUTE
    else [lang for lang in train_datasets.keys() if lang not in lang2train_latents]
)
encode_latents(langs=missing_langs, lang2dataset=train_datasets, lang2latents=lang2train_latents, split="train")

In [None]:
latent_normalize: bool = True

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam


# def fit(X, y, seed, **kwargs):
#     classifier = make_pipeline(
#         Normalizer(), StandardScaler(), SVC(gamma="auto", kernel="linear", max_iter=200, random_state=seed)
#     )  # , class_weight="balanced"))
#     classifier.fit(X, y)
#     return lambda x: classifier.predict(x)


class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


def fit(X: torch.Tensor, y, seed, normalize: bool):
    seed_everything(seed)
    if normalize:
        X = F.normalize(X, p=2, dim=-1)
    dataset = TensorDataset(X, torch.as_tensor(y))
    loader = DataLoader(dataset, batch_size=32, pin_memory=True, shuffle=True, num_workers=4)

    model = nn.Sequential(
        nn.LayerNorm(normalized_shape=num_anchors),
        nn.Linear(in_features=num_anchors, out_features=num_anchors),
        nn.SiLU(),
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=num_anchors),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(in_features=num_anchors, out_features=num_anchors),
        nn.SiLU(),
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=num_anchors),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(
            in_features=num_anchors, out_features=list(train_datasets.values())[0].features[target_key].num_classes
        ),
        nn.ReLU(),
    ).to(device)
    opt = Adam(model.parameters(), lr=1e-3)
    loss_fn = CrossEntropyLoss()
    for epoch in tqdm(range(5 if fine_grained else 3), leave=False, desc="epoch"):
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            pred_y = model(batch_x)
            loss = loss_fn(pred_y, batch_y)
            loss.backward()
            opt.step()
            opt.zero_grad()
    model = model.cpu().eval()
    return lambda x: model(x).argmax(-1).detach().cpu()

In [None]:
SEEDS = list(range(5))
train_classifiers = {
    seed: {
        embedding_type: {
            train_lang: fit(
                lang2train_latents[train_lang][embedding_type],
                train_dataset[target_key],
                seed=seed,
                normalize=latent_normalize,
            )
            for train_lang, train_dataset in tqdm(train_datasets.items(), leave=False, desc="lang")
        }
        for embedding_type in tqdm(["absolute", "relative"], leave=False, desc="embedding_type")
    }
    for seed in tqdm(SEEDS, leave=False, desc="seed")
}

In [None]:
train_classifiers

In [None]:
from sklearn.metrics import precision_recall_fscore_support, mean_absolute_error

numeric_results = {
    "seed": [],
    "embed_type": [],
    "train_lang": [],
    "test_lang": [],
    "precision": [],
    "recall": [],
    "fscore": [],
    "mae": [],
    "stitched": [],
}
for seed, embed_type2train_lang2classifier in train_classifiers.items():
    for embed_type, train_lang2classifier in embed_type2train_lang2classifier.items():
        for train_lang, classifier in train_lang2classifier.items():
            for test_lang, test_latents in langt2test_latents.items():
                test_latents = test_latents[embed_type]
                if latent_normalize:
                    test_latents = F.normalize(test_latents, p=2, dim=-1)
                preds = classifier(test_latents)
                test_y = np.array(test_datasets[test_lang][target_key])

                precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average="weighted")
                mae = mean_absolute_error(y_true=test_y, y_pred=preds)
                numeric_results["embed_type"].append(embed_type)
                numeric_results["train_lang"].append(train_lang)
                numeric_results["test_lang"].append(test_lang)
                numeric_results["precision"].append(precision)
                numeric_results["recall"].append(recall)
                numeric_results["fscore"].append(fscore)
                numeric_results["stitched"].append(train_lang != test_lang)
                numeric_results["mae"].append(mae)
                numeric_results["seed"].append(seed)


import pandas as pd

pd.options.display.max_columns = None
pd.options.display.max_rows = None
df = pd.DataFrame(numeric_results)
df.to_csv(
    f"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv",
    sep="\t",
)

df = df.groupby(
    [
        "embed_type",
        "stitched",
        "train_lang",
        "test_lang",
    ]
).agg([np.mean])
df

In [None]:
f"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv"

In [None]:
import pandas as pd
import numpy as np

# fine_grained: bool = False
# anchor_dataset_name: str = "amazon_translated" # wikimatrix, amazon_translated
# train_perc: float = 0.25

# full_df = pd.read_csv(
#     f"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv",
#     sep="\t",
#     index_col=0,
# )

df = df.groupby(
    [
        "embed_type",
        "stitched",
        "train_lang",
        "test_lang",
    ]
).agg([np.mean, "count"])
df

In [None]:
df.drop(columns=["stitched", "seed", "precision", "recall"])[full_df.train_lang == "en"].groupby(
    ["embed_type", "train_lang", "test_lang"]
).agg([np.mean, np.std]).round(3)

In [None]:
# it_dataset = get_samples(lang="it", sample_idxs=list(range(1000)))
# it_transformer_name: str = "dbmdz/bert-base-italian-cased"
# transformer, tokenizer = load_transformer(transformer_name=it_transformer_name)
# it_anchor_latents = get_latents(
#     dataloader=DataLoader(
#         get_samples("it", sample_idxs=anchor_idxs),
#         num_workers=16,
#         pin_memory=True,
#         collate_fn=partial(collate_fn, tokenizer=tokenizer),
#         batch_size=32,
#     ),
#     split=f"{it_transformer_name}",
#     anchors=None,
#     transformer=transformer,
# )
# it_latents = get_latents(
#     dataloader=DataLoader(
#         it_dataset,
#         num_workers=16,
#         pin_memory=True,
#         collate_fn=partial(collate_fn, tokenizer=tokenizer),
#         batch_size=32,
#     ),
#     split=f"{it_transformer_name}",
#     anchors=it_anchor_latents["absolute"].to(device),
#     transformer=transformer,
# )
# subsample_anchors = it_latents["relative"][:31, :]
# for i_sample, sample in enumerate(it_samples):
#     if sample["target"] == 3:
#         continue
#     for embed_type in ("relative", "absolute"):
#         latents = it_latents[embed_type]
#         latents = torch.cat([latents[i_sample, :].unsqueeze(0), subsample_anchors], dim=0)
#         classifier = train_classifiers[SEEDS[0]][embed_type]["en"]
#         print(
#             embed_type,
#             classifier(latents)[0].item(),
#             sample["class"],
#         )
#     print()
#     if i_sample > 100:
#         break