In [None]:
# Small notebook to examine the token-level probabilities

SENT_SRC = "I like to eat pizza."
SENT_TGT_BAD = "Ich mag essen Pizza."
SENT_TGT_GOOD = "Ich esse gerne Pizza."

from functools import partial

In [None]:
import torch
DEVICE = torch.device('cpu')
from metrics_domain_adaptation import utils
from metrics_domain_adaptation.metrics.prism2 import PRISMModel

model_1 = PRISMModel(
    lang1="en", lang2="de", 
    device=DEVICE, model_type="nllb", 
    model_name="facebook/nllb-200-distilled-600M"
)
model_1.nllb_finetuned = False
model_2 = PRISMModel(
    lang1="en", lang2="de", 
    device=DEVICE, model_type="nllb", 
    model_name=f"{utils.ROOT}/models/trained/nllb/ende/600M_bio_lr1e-6_bs4/checkpoint-75000/"
)
model_2.nllb_finetuned = True

In [None]:
# overwrite implementation for fast loop

def _score_1way_nllb(self, input_text, output_text, input_lang_id, input_lang):
    # need to import locally
    import torch
    # this shouldn't have an effect because we are not using tokenizer's special tokens but just for consistency
    self.tokenizer.src_lang = utils.LANG_TO_NLLB[input_lang]

    src_ids_list = [
        [input_lang_id, ] +
        self.tokenizer.encode(input_text, add_special_tokens=False)[:self.max_src_len_tokens] +
        [self.bos_id, ]
    ]
    src_ids = torch.tensor(src_ids_list, dtype=torch.int64)
    tgt_ids_list = [
        [self.bos_id, self.lang2nllb_id] +
        self.tokenizer.encode(output_text, add_special_tokens=False)[:self.max_tgt_len_tokens] +
        [self.bos_id, ]
    ]
    tgt_ids = torch.tensor(tgt_ids_list, dtype=torch.int64)
    logits = self.model.forward(
        input_ids=src_ids.to(self.device),
        decoder_input_ids=tgt_ids.to(self.device)
    )['logits']
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    log_probs2 = log_probs.squeeze().cpu().detach().numpy()

    # best for finetuned
    if self.nllb_finetuned:
        log_probs2 = log_probs2[0:-1, :]
        # lose nothing
        idxs = tgt_ids.flatten().detach().numpy()[1:]
        scores = [log_probs2[ii, jj] for ii, jj in enumerate(idxs)][2:]
    else:
        # lose the language code prediction and the garbage prediction at the end
        log_probs2 = log_probs2[1:-1, :]
        # lose </s> and language code (keep eos)
        idxs = tgt_ids.flatten().detach().numpy()[2:]
        scores = [log_probs2[ii, jj] for ii, jj in enumerate(idxs)]
        
    return scores

model_1._score_1way_nllb = partial(_score_1way_nllb, self=model_1)
model_2._score_1way_nllb = partial(_score_1way_nllb, self=model_2)

In [None]:
# Ich mag essen Pizza. <EOS>
# Ich esse gerne Pizza. <EOS>

(
    model_1.score_w_src(SENT_SRC, SENT_TGT_BAD),
    model_1.score_w_src(SENT_SRC, SENT_TGT_GOOD)
)

In [None]:
(
    model_2.score_w_src(SENT_SRC, SENT_TGT_BAD),
    model_2.score_w_src(SENT_SRC, SENT_TGT_GOOD)
)

In [None]:
import numpy as np


def format_sequence(sent_src, sent_tgt, model, min, max):
    scores = model.score_w_src(sent_src, sent_tgt)
    tokens_enc = model.tokenizer.encode(sent_tgt)
    sent_tgt = sent_tgt.replace('.', ' .').split()+["</s>"]
    
    for word, score in zip(sent_tgt, scores):
        score_new = (score-min)/(max-min)
        print(f"   #color_cell({score_new:.2f}, {score:.3f})[{word}]")
    
    score = np.average(scores)
    score_new = (score-min)/(max-min)
    print(f"   #h(1fr)")
    print(f"   #color_cell({score_new:.2f}, {score:.2f})[=avg]")
    print()

format_sequence(SENT_SRC, SENT_TGT_BAD, model_1, min=-11, max=0)
format_sequence(SENT_SRC, SENT_TGT_GOOD, model_1, min=-11, max=0)
format_sequence(SENT_SRC, SENT_TGT_BAD, model_2, min=-11, max=0)
format_sequence(SENT_SRC, SENT_TGT_GOOD, model_2, min=-11, max=0)

In [None]:
x = model_2.tokenizer.encode(SENT_TGT_BAD, add_special_tokens=False)
x, model_2.tokenizer.decode(x)

In [None]:
# ./metrics_domain_adaptation/run_metric.py --metric prism2-src --model-name "${ADAPTATION_ROOT}/models/trained/nllb/ende/600M_bio_lr1e-6_bs4/checkpoint-75000/" --domain bio --count 100 --lang en-de 2>/dev/null | grep JSON;\
# ./metrics_domain_adaptation/run_metric.py --metric prism2-src --model-name "facebook/nllb-200-distilled-600M" --domain bio --count 100 --lang en-de 2>/dev/null | grep JSON


# ./metrics_domain_adaptation/run_metric.py --metric prism2-src --model-name "${ADAPTATION_ROOT}/models/trained/nllb/ende/600M_bio_lr1e-6_bs4/checkpoint-75000/" --domain bio --lang en-de 2>/dev/null | grep JSON;\
# ./metrics_domain_adaptation/run_metric.py --metric prism2-src --model-name "facebook/nllb-200-distilled-600M" --domain bio --lang en-de 2>/dev/null | grep JSON


# ./metrics_domain_adaptation/run_metric.py --metric prism2-ref --model-name "${ADAPTATION_ROOT}/models/trained/nllb/ende/600M_bio_lr1e-6_bs4/checkpoint-75000/" --domain bio --lang en-de 2>/dev/null | grep JSON;\
# ./metrics_domain_adaptation/run_metric.py --metric prism2-ref --model-name "facebook/nllb-200-distilled-600M" --domain bio --lang en-de 2>/dev/null | grep JSON


# ./metrics_domain_adaptation/run_metric.py --metric prism2-src --model-name "facebook/nllb-200-distilled-600M" --domain bio --lang en-de 2>/dev/null | grep JSON 