In [1]:
%load_ext autoreload
%autoreload 2

In [45]:
import torch

In [6]:
import os

os.chdir("..")

In [7]:
from da.fsmt.modeling_fsmt import FSMTForConditionalGeneration
from da.fsmt.tokenization_fsmt import FSMTTokenizer

In [89]:
source_data_path = "experiments/doc-indices/sp-cl-Europarl.en-et.docs.test-cl.both"
target_data_path = "experiments/doc-indices/sp-cl-Europarl.en-et.docs.test-cl.both"
cluster_ids_path = "experiments/en_et_concat60/nmt-clusters-sent/Europarl.test-cl.clust.nmt.sent"

experts_id2path = {
    "0": "experiments/en_et_concat60/hf",
    "1": "experiments/en_et_concat60/hf",
    "2": "experiments/en_et_concat60/hf",
    "3": "experiments/en_et_concat60/hf"
}

In [90]:
pwd

'/home/maksym/research/da'

In [91]:
id2model = {}
id2tokenizer = {}
for id, path in experts_id2path.items():
    tokenizer_hf = FSMTTokenizer.from_pretrained(path)
    model_hf = FSMTForConditionalGeneration.from_pretrained(path)
    model_hf.cuda()
    model_hf.eval()
    
    id2tokenizer[id] = tokenizer_hf 
    id2model[id] = model_hf

In [92]:
sources = []
targets = []
doc_ids = []
with open(source_data_path) as both: 
    for line in both:
        doc_id, src, tgt = line.strip().split('\t')
        doc_ids.append(doc_id)
        sources.append(src)
        targets.append(tgt) 
        
cluster_ids = []
with open(cluster_ids_path) as f: 
    for line in f:
        cluster_ids.append(line.strip())


In [95]:
%%time

translations = []
i = 0
for id, src in zip(cluster_ids, sources):
    
    tokenizer, model = id2tokenizer[id], id2model[id] 

    src = tokenizer.encode_plus(
        src,
        padding="longest", 
        return_tensors="pt",
        return_token_type_ids=False,
        return_attention_mask=True
    )

    for k, v in src.items():
        src[k] = v.to(model.device)

    with torch.no_grad():
        hyp = model.generate(
            **src,
            #return_dict=True,
            #do_sample=False,
            num_beams=5
        )[0]

    hyp = tokenizer._decode(hyp, skip_special_tokens=True)

    translations.append(hyp)

    if i % 100 == 0:
        print(f"{i}/{len(sources)} sentences translated")

    i += 1

0/3107 sentences translated
100/3107 sentences translated
200/3107 sentences translated
300/3107 sentences translated
400/3107 sentences translated
500/3107 sentences translated
600/3107 sentences translated
700/3107 sentences translated
800/3107 sentences translated
900/3107 sentences translated
1000/3107 sentences translated
1100/3107 sentences translated
1200/3107 sentences translated
1300/3107 sentences translated
1400/3107 sentences translated
1500/3107 sentences translated
1600/3107 sentences translated
1700/3107 sentences translated
1800/3107 sentences translated
1900/3107 sentences translated
2000/3107 sentences translated
2100/3107 sentences translated
2200/3107 sentences translated
2300/3107 sentences translated
2400/3107 sentences translated
2500/3107 sentences translated
2600/3107 sentences translated
2700/3107 sentences translated
2800/3107 sentences translated
2900/3107 sentences translated
3000/3107 sentences translated
3100/3107 sentences translated
CPU times: user 7min

In [96]:
import sentencepiece as spm

In [97]:
from sacrebleu import corpus_bleu as _corpus_bleu

In [98]:
def corpus_bleu(sys_stream, ref_streams):
    bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
    return bleu.score

In [99]:
spm_model = spm.SentencePieceProcessor(model_file='data-prep/preproc-models/syscl-en-et.model')

In [111]:
def detok(s):
    return ''.join(s.split()).replace('▁', ' ').strip()

In [115]:
targets_detok = [detok(t) for t in targets]
hyps_detok = [detok(t) for t in translations]

In [119]:
hyps_detok[2]

'Kuna me hilineme, püüame kõigepealt käsitleda lisaküsimusi fraktsioonina iga kord, et saaksime täna pärastlõunal rohkem küsimusi lahendada, ning nagu tavaliselt, kutsuksin ma parlamendiliikmeid esiuksele, et nad infotunnis täpsemalt osaleksid.'

In [121]:
corpus_bleu(targets_detok, [hyps_detok])

23.26039808040434

In [122]:
corpus_bleu(targets, [translations])

20.037047422435965

In [120]:
hyps_detok

['Järgmine päevakorrapunkt on infotund (B6-0384/2007).',
 'Me andsime sellele arutelule selle tähtsuse tõttu pika aja ja loodame, et saame oma infotunniga jätkata.',
 'Kuna me hilineme, püüame kõigepealt käsitleda lisaküsimusi fraktsioonina iga kord, et saaksime täna pärastlõunal rohkem küsimusi lahendada, ning nagu tavaliselt, kutsuksin ma parlamendiliikmeid esiuksele, et nad infotunnis täpsemalt osaleksid.',
 'Volinik, mul on kahju, et te pidite ootama, kuid see oli tähtis arutelu, nagu te kindlasti tunnistate.',
 'Komisjonile on esitatud järgmised küsimused.',
 'Teema: Energeetika - Maailma Kaubandusorganisatsioon (WTO)',
 'Kuigi Maailma Kaubandusorganisatsiooni (WTO) eeskirju ei ole konkreetselt energiatoodete ja -teenustega koostatud, kohaldatakse neid ka nende suhtes ja kaitstakse investeeringuid energiasektorisse.',
 'Sellest tulenevalt on ekspordipiirangud ja toodete diskrimineerimine keelatud ning tuleb tagada vabakaubandus, kuigi ohutuse huvides on olemas erandid meetmetest.'

In [52]:
basename = source_data_path.split('/')[-1]
with open(f"hypothesis/hyp-{basename}.bpe.et", 'w') as f:
    for s in translated:
        f.write(f"{s}\n")