## Import libs

In [1]:
%cd ..

/home/sasha/effective-inference


In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from sklearn.metrics import classification_report
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
import torch
#from progressbar import progressbar
from tqdm.auto import tqdm
from utils.prepare_dataset import load_datasets, cut_datasets
from collections import defaultdict
import numpy as np

## Define hyperparams

In [14]:
# Define datasets
#['mrpc', 'sst2', 'cola', 'rte', 'qnli']
model_name = 'gpt2'

tokenizer = AutoTokenizer.from_pretrained(model_name, max_length=1024)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()

translation_datasets_artifacts = {
    "fr-en": ('Translate from English to French: ', 'fr', 'French'),
    "ru-en": ('Translate from English to Russian: ', 'ru', 'Russian')
}

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == 'cpu':
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model.to(device)

DEBUG_FLAG = True
CUT_SIZE = None if not DEBUG_FLAG else 6

## Load datasets

In [4]:
translation_datasets = load_datasets('wmt14', ["fr-en", "ru-en"], CUT_SIZE)

all_datasets = {'translation': translation_datasets}

Downloading builder script:   0%|          | 0.00/2.97k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/15.3k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.37k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/41.2k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/7 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/658M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/919M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.37G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/80.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.60G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/38.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/7 [00:00<?, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split:   0%|          | 0/40836715 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3003 [00:00<?, ? examples/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.49M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/5 [00:00<?, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split:   0%|          | 0/1486965 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3003 [00:00<?, ? examples/s]

In [6]:
translation_datasets['fr-en']['train'][0], translation_datasets['fr-en']['train'][1]

({'translation': {'en': 'Resumption of the session',
   'fr': 'Reprise de la session'}},
 {'translation': {'en': 'I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.',
   'fr': 'Je déclare reprise la session du Parlement européen qui avait été interrompue le vendredi 17 décembre dernier et je vous renouvelle tous mes vux en espérant que vous avez passé de bonnes vacances.'}})

In [16]:
tqdm_pbar = lambda x, y: tqdm(x, leave=True, position=0, total=len(x), desc=f'{y}')
def get_translations_for_dataset(
    dataset_name, dataset, 
    artifacts, model, tokenizer, 
    pbar_func=tqdm_pbar, device=device, CUT_SIZE=CUT_SIZE
):
    collected_translations = defaultdict(list)
    
    for split, data in dataset.items():
        if split != 'validation':
            continue
        
        pbar = pbar_func(data, f"{split} {dataset_name}") if pbar_func is not None else data
        for example in pbar:
            # Encode the input sentences
            ex1, g1 = dataset['train'][0]["translation"]["en"], dataset['train'][0]["translation"][artifacts[1]] 
            ex2, g2 = dataset['train'][1]["translation"]["en"], dataset['train'][1]["translation"][artifacts[1]]
            ex3, g3 = dataset['train'][2]["translation"]["en"], dataset['train'][2]["translation"][artifacts[1]]
            ex4, g4 = dataset['train'][3]["translation"]["en"], dataset['train'][3]["translation"][artifacts[1]]
            ex5, g5 = dataset['train'][4]["translation"]["en"], dataset['train'][4]["translation"][artifacts[1]]
            target_name = artifacts[2]
            sample_src = example["translation"]["en"]
            prompt = f"English: {ex1}\n{target_name}: {g1}" + \
                     f"\n\nEnglish: {ex2}\n{target_name}: {g2}" + \
                     f"\n\nEnglish: {ex3}\n{target_name}: {g3}" + \
                     f"\n\nEnglish: {ex4}\n{target_name}: {g4}" + \
                     f"\n\nEnglish: {ex5}\n{target_name}: {g5}" + \
                     f"\n\nEnglish: {sample_src}\n{target_name}: "
            
            input_ids = tokenizer.encode(prompt, 
                                         truncation=True,  return_tensors="pt").to(device)
            attention_mask = torch.ones_like(input_ids).to(device)  # Create attention mask with ones

            # Perform the generation
            with torch.no_grad():
                translation = model.generate(inputs=input_ids.to(device), attention_mask=attention_mask.to(device),
                                             max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)

            # Decode the generated output
            translation = tokenizer.decode(translation[0], skip_special_tokens=True)
            #print(translation)
            if f'English: {sample_src}' in translation:
                translation = translation.split(f'English: {sample_src}')[1].strip()
            #print('----->', '\n', translation)
            if 'English: ' in translation:
                translation = translation.split('English: ')[0].strip()
        
            if f'{target_name}: ' in translation:
                translation = translation.split(f'{target_name}: ')[1].strip()
            #print(translation)

            collected_translations[split].append(translation)
         
    return collected_translations

In [12]:
def get_scores(src, prediction, gold, give_example=True):
    prediction = np.array(prediction)
    gold = np.array(gold)
    # Calculate BLEU score
    smoothie = SmoothingFunction().method3
    bleu_score = corpus_bleu(gold, prediction, smoothing_function=smoothie)
    print("BLEU score:", bleu_score)

    # Calculate exact match (EM) score
    exact_match = np.mean(prediction == gold)
    print("Exact Match (EM) score:", exact_match)
    
    print(f'Example:\nSrc:  {src[0]}\nTgt:  {gold[0]}\nPred: {prediction[0]}\n\n')
    return bleu_score, exact_match

In [18]:
for dn, datasets in all_datasets.items():
    for dataset_name, dataset in datasets.items():
        print(f"{dn.upper()} / {dataset_name}\n")
        dataset_translations = get_translations_for_dataset(
            dataset_name,
            dataset, 
            translation_datasets_artifacts[dataset_name], 
            model, 
            tokenizer)
        
        to_language = translation_datasets_artifacts[dataset_name][1]
        
        val_gold_translation = [el['translation'][to_language].strip() for el in dataset['validation']]
        src_val = [el['translation']['en'].strip() for el in dataset['validation']]
        pred = [el.strip() for  el in dataset_translations['validation']]
        
        get_scores(src_val, pred, val_gold_translation)

TRANSLATION / fr-en



validation fr-en:   0%|          | 0/6 [00:00<?, ?it/s]

BLEU score: 0.0004335440242797244
Exact Match (EM) score: 0.0
Example:
Src:  A Republican strategy to counter the re-election of Obama
Tgt:  Une stratégie républicaine pour contrer la réélection d'Obama
Pred: était en ce n'est pas, comme un certain nombre de collègues me l'ont demandé, que nous observions une minute de silence pour toutes les victimes, des tempêtes notamment, dans les différents pays de l'Union européenne qui ont été touchés.


TRANSLATION / ru-en



validation ru-en:   0%|          | 0/6 [00:00<?, ?it/s]

BLEU score: 0.0013973130863582958
Exact Match (EM) score: 0.0
Example:
Src:  A Republican strategy to counter the re-election of Obama
Tgt:  Республиканская стратегия сопротивления повторному избранию Обамы
Pred: это готовая к использованию паста, которая наносится шпателем или пальцами в виде закругленного перехода в углы сталелитейного кокиль от горячего абразивного стального литья.


