In [1]:
import torch
import transformers

#device = 'cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from transformers import MarianMTModel, MarianTokenizer
en_ROMANCE_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
en_ROMANCE_tokenizer = MarianTokenizer.from_pretrained(en_ROMANCE_model_name)
', '.join(en_ROMANCE_tokenizer.supported_language_codes)

'>>fr<<, >>es<<, >>it<<, >>pt<<, >>pt_br<<, >>ro<<, >>ca<<, >>gl<<, >>pt_BR<<, >>la<<, >>wa<<, >>fur<<, >>oc<<, >>fr_CA<<, >>sc<<, >>es_ES<<, >>es_MX<<, >>es_AR<<, >>es_PR<<, >>es_UY<<, >>es_CL<<, >>es_CO<<, >>es_CR<<, >>es_GT<<, >>es_HN<<, >>es_NI<<, >>es_PA<<, >>es_PE<<, >>es_VE<<, >>es_DO<<, >>es_EC<<, >>es_SV<<, >>an<<, >>pt_PT<<, >>frp<<, >>lad<<, >>vec<<, >>fr_FR<<, >>co<<, >>it_IT<<, >>lld<<, >>lij<<, >>lmo<<, >>nap<<, >>rm<<, >>scn<<, >>mwl<<'

In [3]:
en_ROMANCE = MarianMTModel.from_pretrained(en_ROMANCE_model_name).to(device)

In [4]:
ROMANCE_en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
ROMANCE_en_tokenizer = MarianTokenizer.from_pretrained(ROMANCE_en_model_name)

In [5]:
#ROMANCE_en = MarianMTModel.from_pretrained(ROMANCE_en_model_name).to(device)

# Batch translation

In [6]:
def monkey_patch(model, new_postproc_fn):
    cls = model.__class__
    func_name = "postprocess_next_token_scores"
    orig_name = "_orig_" + func_name
    if not hasattr(cls, orig_name):
        setattr(cls, orig_name, getattr(cls, func_name))
    setattr(cls, func_name, new_postproc_fn)

In [7]:
def postprocess_next_token_scores(self, scores, input_ids, *a, **kw):
    print(input_ids.shape, scores.shape)
    batch_size, vocab_size = scores.shape
    cur_len = input_ids.shape[1]
    for hypothesis_idx in range(batch_size):
        cur_hypothesis = input_ids[hypothesis_idx]
        print(en_ROMANCE_tokenizer.convert_ids_to_tokens(cur_hypothesis))

    # Hack the beam
    if cur_len == 2:
        force_token_id = 1886 # cor
        #force_token_id = 3675 # sal
        #print(scores[:, force_token_id])
        self._force_token_ids_generation(scores, token_ids=[force_token_id])

    print(scores[:, self.config.eos_token_id])
    return self._orig_postprocess_next_token_scores(scores, input_ids, *a, **kw)

monkey_patch(en_ROMANCE, postprocess_next_token_scores)

In [8]:
def translate(tokenizer, model, text, num_outputs):
    """Use beam search to get a reasonable translation of 'text'"""
    # Tokenize the source text
    tokenizer.current_spm = tokenizer.spm_source ### HACK!
    batch = tokenizer.prepare_translation_batch([text]).to(model.device)
    
    # Run model
    num_beams = num_outputs
    translated = model.generate(**batch, num_beams=num_beams, num_return_sequences=num_outputs, max_length=128)
    
    # Untokenize the output text.
    tokenizer.current_spm = tokenizer.spm_target
    return [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=False) for t in translated]

translate(en_ROMANCE_tokenizer, en_ROMANCE, ">>es<< I ran and I jumped.", 10)
    

torch.Size([10, 1]) torch.Size([10, 65001])
['<pad>']
['<pad>']
['<pad>']
['<pad>']
['<pad>']
['<pad>']
['<pad>']
['<pad>']
['<pad>']
['<pad>']
tensor([-7.2095, -7.2095, -7.2095, -7.2095, -7.2095, -7.2095, -7.2095, -7.2095,
        -7.2095, -7.2095], device='cuda:0')
torch.Size([10, 2]) torch.Size([10, 65001])
['<pad>', '▁Corr']
['<pad>', '▁Corri']
['<pad>', '▁Corre']
['<pad>', '▁corri']
['<pad>', '▁Cor']
['<pad>', '▁He']
['<pad>', '▁Yo']
['<pad>', '▁Me']
['<pad>', '▁Fu']
['<pad>', '▁Hu']
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       device='cuda:0')
torch.Size([10, 3]) torch.Size([10, 65001])
['<pad>', '▁Yo', '▁cor']
['<pad>', '▁Me', '▁cor']
['<pad>', '▁He', '▁cor']
['<pad>', '▁Corre', '▁cor']
['<pad>', '▁Corri', '▁cor']
['<pad>', '▁corri', '▁cor']
['<pad>', '▁Corr', '▁cor']
['<pad>', '▁Cor', '▁cor']
['<pad>', '▁Fu', '▁cor']
['<pad>', '▁Hu', '▁cor']
tensor([-11.8194, -13.3916, -10.4630, -12.4498, -12.1174, -12.2610, -12.9104,
        -12.5495, -12.8968, -

['Yo corrí y salté.',
 'Me corrí y salté.',
 'Yo corría y salté.',
 'Yo corrí y me salté.',
 'Yo corrí y he saltado.',
 'Yo corrí y salí.',
 'Yo corrí y Salté.',
 'Yo corrí y pulé.',
 'Yo corrí y salto.',
 'Yo corrí y salté a saltar.']

In [9]:
en_ROMANCE.config

BartConfig {
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      65000
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 512,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 2048,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 65000,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 2048,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 0,
  "extra_pos_embeddings": 65001,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 512,
  "max_position_embeddings": 512,
  "model_type": "bart",
  "normalize_before": false,
  "normalize_embedding": false,
  "

In [10]:
en_ROMANCE.adjust_logits_during_generation??

[0;31mSignature:[0m [0men_ROMANCE[0m[0;34m.[0m[0madjust_logits_during_generation[0m[0;34m([0m[0mlogits[0m[0;34m,[0m [0mcur_len[0m[0;34m,[0m [0mmax_length[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
    [0;32mdef[0m [0madjust_logits_during_generation[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mlogits[0m[0;34m,[0m [0mcur_len[0m[0;34m,[0m [0mmax_length[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mlogits[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mpad_token_id[0m[0;34m][0m [0;34m=[0m [0mfloat[0m[0;34m([0m[0;34m"-inf"[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0;32mif[0m [0mcur_len[0m [0;34m==[0m [0mmax_length[0m [0;34m-[0m [0;36m1[0m [0;32mand[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0meos_token_id[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m
[0;34m[0m            [0ms

In [11]:
en_ROMANCE._force_token_ids_generation??

[0;31mSignature:[0m [0men_ROMANCE[0m[0;34m.[0m[0m_force_token_ids_generation[0m[0;34m([0m[0mscores[0m[0;34m,[0m [0mtoken_ids[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0m_force_token_ids_generation[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mscores[0m[0;34m,[0m [0mtoken_ids[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;34m"""force one of token_ids to be generated by setting prob of all other tokens to 0"""[0m[0;34m[0m
[0;34m[0m        [0;32mif[0m [0misinstance[0m[0;34m([0m[0mtoken_ids[0m[0;34m,[0m [0mint[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m            [0mtoken_ids[0m [0;34m=[0m [0;34m[[0m[0mtoken_ids[0m[0;34m][0m[0;34m[0m
[0;34m[0m        [0mall_but_token_ids_mask[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0;34m[0m
[0;34m[0m            [0;34m[[0m[0mx[0m [0;32mfor[0m [0mx

In [None]:
en_ROMANCE_tokenizer.current_spm = en_ROMANCE_tokenizer.spm_target
tokens = en_ROMANCE_tokenizer.tokenize("Yo salté.")
list(zip(en_ROMANCE_tokenizer.convert_tokens_to_ids(tokens), tokens))


In [None]:
en_ROMANCE.postprocess_next_token_scores??