In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from os.path import basename, dirname
from shutil import copyfile

import json
import pickle
from collections import OrderedDict

import torch

In [3]:
os.chdir("..")

In [4]:
from da.fsmt.modeling_fsmt import FSMTForConditionalGeneration
from da.fsmt.tokenization_fsmt import FSMTTokenizer

In [5]:
from types import MethodType

from da.greedy_search_interpret import greedy_search_interpret

In [6]:
def encode_sentence_tokens(src):
    # tok
    src = tokenizer_hf.encode_plus(
        src,
        padding="longest", 
        return_tensors="pt",
        return_token_type_ids=False,
        return_attention_mask=True
    )
    # res
    for k, v in src.items():
        src[k] = v.to(model_hf.device)
    
    res = model_hf.generate(**src,
                       #return_dict=True,
                       output_hidden_states=True,
                       output_attentions=True,
                       do_sample=False,
                       num_beams=1)
    
    he = [r.detach().cpu().numpy() for r in res['encoder_hidden_states']]
    hd = [r.detach().cpu().numpy() for r in res['decoder_hidden_states']]

    return he + hd

In [7]:
NUM_SENTENCES=300

In [8]:
%%time

domain_names = ["Europarl", "OpenSubtitles", "JRC-Acquis", "EMEA"]

for main_name in ['concat1']:
    print()
    print(f"Loading {main_name} model")
    
    hf_dir = f"experiments/en_et_{main_name}/hf"
    tokenizer_hf = FSMTTokenizer.from_pretrained(hf_dir)
    model_hf = FSMTForConditionalGeneration.from_pretrained(hf_dir)
    model_hf = model_hf.cuda()
    model_hf.greedy_search = MethodType(greedy_search_interpret, model_hf)
    
    valid_files = {}

    for domain_name in domain_names:
        fn = f"experiments/fairseq-data-en-et-{domain_name}-ft/valid.en"
        with open(fn) as f:
            valid_files[domain_name] = [l[:-2] for l in f.readlines()]
            
    data_encoded = {}

    for domain, data in valid_files.items():
        print(f"Encoding {domain} data...")
        data_encoded[domain] = [encode_sentence_tokens(s) for s in data[0:NUM_SENTENCES]]
        
    savedir = f"experiments/en_et_{main_name}/internals"
    
    if not os.path.isdir(savedir):
        os.mkdir(savedir)
    
    print(f"Saving to {savedir}/data_encoded{NUM_SENTENCES}.pkl...")
    with open(f'{savedir}/data_encoded{NUM_SENTENCES}.pkl', 'wb') as f:
        pickle.dump(data_encoded, f)


Loading concat1 model
Encoding Europarl data...
Encoding OpenSubtitles data...
Encoding JRC-Acquis data...
Encoding EMEA data...
Saving to experiments/en_et_concat1/internals/data_encoded300.pkl...
CPU times: user 3min 26s, sys: 1.79 s, total: 3min 27s
Wall time: 3min 26s
