In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import faiss

import logging
from rich.logging import RichHandler

import torch
from torch import nn
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader

from sentence_transformers import SentenceTransformer
import datasets
from datasets import load_dataset, Dataset, DatasetDict

from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory
from accelerate import Accelerator

import bitsandbytes as bnb

from transformers import AutoTokenizer, MT5EncoderModel
from transformers import DataCollatorWithPadding
from transformers import get_scheduler

import gc
from functools import partial
from collections import Counter
from tqdm.notebook import tqdm_notebook

In [None]:
! mkdir -p data

In [None]:
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
SEED = 101
MAX_LENGTH = 128
DATA_DIR = './data'

In [None]:
logging.basicConfig(
    level="ERROR",
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True)]
)
logger = logging.getLogger(__name__)

In [None]:
target_lang_pairs = ['en-de', 'zh-en', 'he-en']
lang_pairs = target_lang_pairs + ['de-en', 'en-zh']

id2lang = {
    'en' : 'english',
    'de' : 'german',
    'zh' : 'chinese',
    'he' : 'hebrew'
}

score_type2id = {
    'da': 1,
    'mqm': 2,
    'sqm': 3
}

In [None]:
wmt_dsets = load_dataset('nllg/wmt-metrics-data', token='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').filter(lambda x: x['lp'] in lang_pairs)
train_dset = wmt_dsets['train'].rename_column('score_type', 'score_type_str')
test_dset = wmt_dsets['test'].rename_column('score_type', 'score_type_str').map(lambda x: {'score_type': score_type2id[x['score_type_str']]})

train_dset = train_dset.map(lambda x: {'score_type': score_type2id[x['score_type_str']]}).remove_columns(['score_type_str'])
test_dset = test_dset.map(lambda x: {'score_type': score_type2id[x['score_type_str']]}).remove_columns(['score_type_str'])

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bigscience/mt0-large')
model = MT5EncoderModel.from_pretrained('bigscience/mt0-large')

## Prepare FAISS for references and translations

Note that for normalized vectors: $\|x-y\|^2_2=2-2x^Ty \, \rightarrow \, \min\left(x^Ty\right)=\max\left((-x)^Ty\right)=\min\left(\|(-x)-y\|^2_2\right)$

In [None]:
def free_():
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
def prepare_faiss(
    model,
    pool,
    sentences,
    nlist=100, # voronoi cells for ANN
    m=16, # number of centroid IDs in final compressed vectors
    bits=8,  # number of bits in each centroid
    nprobe=10, # number of cells to search during inference,
    save_filepath=None
):
    embs = model.encode_multi_process(sentences, pool, batch_size=128)
    embs = -embs / np.linalg.norm(embs, axis=1, keepdims=True)

    d = embs.shape[1]
    quantizer = faiss.IndexFlatL2(d)
    index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits)
    index.train(embs)
    index.add(embs)
    index.nprobe = nprobe
    if save_filepath is not None:
        faiss.write_index(index, save_filepath)
    return index

In [None]:
lp2train_dset = {}
lp2ref_index = {}
lp2mt_index = {}

st = SentenceTransformer('sentence-transformers/sentence-t5-large')
pool = st.start_multi_process_pool()
for lp in (pbar := tqdm_notebook(lang_pairs)):
    pbar.set_description(lp)

    dset_path = f'{DATA_DIR}/{lp}_dset.pt'
    if os.path.exists(dset_path):
        lp_dset = torch.load(dset_path)
    else:
        lp_dset = train_dset.filter(lambda x: x['lp'] == lp).shuffle(seed=SEED)
        src_embs_dset = Dataset.from_dict({'src_emb': st.encode_multi_process(lp_dset['src'], pool, batch_size=128)})
        lp_dset = datasets.concatenate_datasets([lp_dset, src_embs_dset], axis=1)
        torch.save(lp_dset, dset_path)
    lp2train_dset[lp] = lp_dset

    ref_index_path = f'{DATA_DIR}/{lp}_ref_faiss.idx'
    mt_index_path = f'{DATA_DIR}/{lp}_mt_faiss.idx'

    lp2ref_index[lp] = faiss.read_index(ref_index_path) if os.path.exists(ref_index_path) else \
            prepare_faiss(st, pool, lp_dset['ref'], save_filepath=ref_index_path)
    lp2mt_index[lp] = faiss.read_index(mt_index_path) if os.path.exists(mt_index_path) else \
            prepare_faiss(st, pool, lp_dset['mt'], save_filepath=mt_index_path)

st.stop_multi_process_pool(pool)

## Prepare dataset

In [None]:
def preprocess(text, lang=None):
    if text is None:
        return ''
    text = text.lower().strip()
    if lang is not None:
        text = f'{lang}: {text}'
    return ' '.join(text.split())

def tokenize_(data, tokenizer, max_length=MAX_LENGTH, add_lang=False):
    src_lang, tgt_lang = map(lambda x: id2lang[x], data['lp'].split('-'))
    output = {}
    for field in ['src', 'ref', 'mt']:
        result = tokenizer(preprocess(data[field], src_lang if field == 'src' else tgt_lang),
                           truncation=True, max_length=max_length, padding=False)
        if max_length is not None and result['input_ids'][-1] != tokenizer.eos_token_id \
            and len(result['input_ids']) < max_length:
            result['input_ids'].append(tokenizer.eos_token_id)
            result['attention_mask'].append(1)
        output[f'{field}_input_ids'] = result['input_ids']
        output[f'{field}_attention_mask'] = result['attention_mask']
    return output

In [None]:
lp2train_dset = DatasetDict(lp2train_dset).map(partial(tokenize_, tokenizer=tokenizer))

In [None]:
test_dset = test_dset.map(partial(tokenize_, tokenizer=tokenizer))

In [None]:
class DatasetWMTCL(Dataset):

    def __init__(
        self,
        input,
        inference=False,
        train_batch_size=32
    ):
        self.inference = inference
        assert isinstance(input, list) and (not inference or len(input) == 1)
        if inference:
            self.datasets = input
            self.total_len = len(input[0])
            return

        self.total_len = 0
        self.cumsum_lens = []
        self.datasets = []
        self.ref_indexes = []
        self.mt_indexes = []
        self.train_batch_size = train_batch_size
        self.n_neighbors = (train_batch_size - 3) // 2

        for part in input:
            dataset, ref_index, mt_index = part
            self.ref_indexes.append(ref_index)
            self.mt_indexes.append(mt_index)
            self.datasets.append(dataset)
            self.total_len += len(dataset)
            self.cumsum_lens.append(self.total_len)

    def __len__(self):
        return self.total_len

    def determine_data_index(self, idx):
        prev_cumsum_len = 0
        for i, cumsum_len in enumerate(self.cumsum_lens):
            if idx < cumsum_len:
                return i, idx - prev_cumsum_len
            prev_cumsum_len = cumsum_len
        raise ValueError(f'Index {idx} is not in valid range')

    def __getitem_for_train(self, idx):
        assert len(idx) == 1
        idx = idx[0]

        output = {}
        i, idx = self.determine_data_index(idx)
        dataset = self.datasets[i]
        point = dataset[idx]
        output = {'score_type': [point['score_type']]*self.train_batch_size, 'score': [point['score']]*self.train_batch_size}
        for key in ['input_ids', 'attention_mask']:
            output[key] = [point[f'{field}_{key}'] for field in ['src', 'ref', 'mt']]

        src_emb = np.asarray(point['src_emb'])[None, :]
        ref_index = self.ref_indexes[i]
        mt_index = self.mt_indexes[i]

        _, ref_I = ref_index.search(src_emb, k=self.n_neighbors+1+1)
        _, mt_I = mt_index.search(src_emb, k=self.n_neighbors+1)

        ref_I = ref_I.ravel()
        mt_I = mt_I.ravel()

        for j in range(self.n_neighbors+1):
            if ref_I[j] == idx:
                continue
            far_point = dataset[int(ref_I[j])]
            for key in ['input_ids', 'attention_mask']:
                output[key].append(far_point[f'ref_{key}'])
        for j in range(self.n_neighbors):
            if mt_I[j] == idx:
                continue
            far_point = dataset[int(mt_I[j])]
            for key in ['input_ids', 'attention_mask']:
                output[key].append(far_point[f'mt_{key}'])

        return output

    def __getiten_for_test(self, idx):
        output = {}
        dataset = self.datasets[0]
        points = dataset[idx]
        for key in ['input_ids', 'attention_mask']:
            output[key] = []
            for field in ['src', 'ref']:
                output[key] += points[f'{field}_{key}'] 
        return output

    def __getitem__(self, idx):
        return self.__getitem_for_train(idx) if not self.inference else self.__getiten_for_test(idx)


In [None]:
class DataCollatorWithPaddingAndScore:

    def __init__(
        self,
        tokenizer,
        padding=True,
        max_length=None,
        pad_to_multiple_of=None,
        return_tensors="pt"
    ):
        self.tokenizer = tokenizer
        self.padding = padding
        self.max_length = max_length
        self.pad_to_multiple_of = pad_to_multiple_of
        self.return_tensors = return_tensors 

    def __call__(self, features):
        scores = []
        score_types = []
        for feature in features:
            scores.append(feature.pop('score'))
            score_types.append(feature.pop('score_type'))
        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        batch['score'] = scores
        batch['score_type'] = score_types
        return batch

In [None]:
train_batch_size = 4
wmtcl_train_dset = DatasetWMTCL(
    [(lp2train_dset[lp], lp2ref_index[lp], lp2mt_index[lp]) for lp in lang_pairs],
    train_batch_size=train_batch_size
)

In [None]:
wmtcl_test_dset = DatasetWMTCL(
    [test_dset],
    inference=True
)

In [None]:
train_collator = DataCollatorWithPaddingAndScore(tokenizer, pad_to_multiple_of=4, max_length=MAX_LENGTH) 
test_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=4, max_length=MAX_LENGTH)

In [None]:
train_dataloader = DataLoader(wmtcl_train_dset, batch_size=1, shuffle=True, pin_memory=True, collate_fn=train_collator)
test_dataloader = DataLoader(wmtcl_test_dset, batch_size=16, shuffle=False, pin_memory=True, collate_fn=test_collator)

### Finetuning

In [None]:
class ContrastiveLossWMT(nn.Module):
    """
    Full credit to https://zablo.net/blog/post/understanding-implementing-simclr-guide-eli5-pytorch/.
    """
    def __init__(
        self,
        negative_n, # n_neighbors from dataset initialization
        score_type_weights=None, # prioritization of score types
        temperature=0.05,
        device='cpu'
    ):
        super().__init__()
        self.negative_n = negative_n
        self.score_type_weights = score_type_weights
        self.register_buffer("temperature", torch.tensor(temperature, device=device))
            
    def forward(self, embs: torch.Tensor, score: float, score_type: str):
        embs = torch.nn.functional.normalize(embs, dim=1)
        src_emb = embs[0:1]
        tgt_embs = embs[1:]
        similarity_vector = src_emb @ tgt_embs.T
        similarity_vector = similarity_vector.squeeze()

        ref_nom  = torch.exp(similarity_vector[0] / self.temperature)
        ref_denom = ref_nom + torch.exp(similarity_vector[2 : 2+self.negative_n+1] / self.temperature).sum()
        mt_nom  = self.score_type_weights[score_type] * score * torch.exp(similarity_vector[0] / self.temperature)
        mt_denom = mt_nom + torch.exp(similarity_vector[2+self.negative_n+1 : 2+2*self.negative_n+1] / self.temperature).sum()
        loss = -torch.log(ref_nom / ref_denom) - torch.log(mt_nom / mt_denom)
        return loss
     

In [None]:
def mean_pooling(token_embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

```python
max_memory = get_balanced_memory(
     model,
    max_memory=None,
    no_split_module_classes=['MT5LayerSelfAttention', 'MT5LayerFF'],
    dtype='float16',
    low_zero=True,
)

device_map = infer_auto_device_map(
    model,
    max_memory=max_memory,
    no_split_module_classes=['MT5LayerSelfAttention', 'MT5LayerFF'],
    dtype='float16'
)

model = dispatch_model(model, device_map=device_map)
```

In [None]:
model.gradient_checkpointing_enable()

In [None]:
n_epochs = 5
n_steps = n_epochs * len(wmtcl_train_dset)

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters()
)
scheduler = get_scheduler('linear', optimizer, num_warmup_steps=1000, num_training_steps=n_steps)

In [None]:
loss = ContrastiveLossWMT(wmtcl_train_dset.n_neighbors,
                          score_type_weights={score_type2id['da']: 1, score_type2id['mqm']: 2, score_type2id['sqm']: 1})

In [None]:
accelerator = Accelerator(gradient_accumulation_steps=4)
train_dataloader, test_dataloader, model, optimizer, scheduler = accelerator.prepare(
    train_dataloader, test_dataloader, model, optimizer, scheduler
)

In [None]:
for epoch in tqdm_notebook(range(n_epochs)):
    model.train()
    for batch in (pbar := tqdm_notebook(train_dataloader)):
        with accelerator.accumulate(model):
            score = batch.pop('score')[0]
            score_type = batch.pop('score_type')[0]
    
            outputs = model(**batch)
            outputs = mean_pooling(outputs.last_hidden_state, batch['attention_mask'])
    
            loss_ = loss(outputs, score, score_type)
            free_()
    
            loss_.backward()
            free_()
    
            optimizer.step()
            free_()
    
            scheduler.step()
            free_()
    
            optimizer.zero_grad()
            free_()

            pbar.set_description(f'loss: {loss_.item()}')

    model.eval()
    total_correlation = 0.0
    total_ = 0.0
    if accelerator.is_local_main_process:
        for batch in (pbar := tqdm_notebook(test_dataloader)):
            with torch.no_grad():
                outputs = model(**batch)
                outputs = mean_pooling(outputs.last_hidden_state, batch['attention_mask'])
                outputs = torch.nn.functional.normalize(outputs, dim=1)
                embs_src = outputs[0::2]
                embs_ref = outputs[1::2]

                correlation_ = (embs_src @ embs_ref.T).diag().sum().item()
                total_correlation += correlation_
                total_ += (len(batch) / 2)

                pbar.set_description(f'src & ref correlation: {total_correlation / total_}')

        print(f'src & ref correlation: {total_correlation / total_}')


---