In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='7'

In [2]:
# model_name = "albert-xxlarge-v2"

# model_name = "roberta_large"
# mask_name = "<mask>"

model_name = "bert-large-cased"
mask_name = "[MASK]"

In [18]:
lang_pair = "eng_ron"

In [24]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import random
import numpy as np
import torch
from tqdm import tqdm
import json

In [25]:
gold_path = f"/mounts/work/akoksal/word_alignment_silver/{lang_pair}/{lang_pair}.txt"
parallel_path = f"/mounts/work/akoksal/word_alignment_silver/{lang_pair}/parallel.txt"

In [26]:
with open(gold_path, "r") as f:
    gold_content = [x.split("|||")[0].strip() for x in f.read().splitlines()]

with open(parallel_path, "r") as f:
    parallel_content = [x.split("|||")[0].strip() for x in f.read().splitlines()]


In [27]:
random.seed(42)
parallel_sample_content = random.sample(parallel_content, 1000)

In [28]:
len(gold_content), len(parallel_sample_content), len(parallel_content)

(199, 1000, 39101)

## Augmentation

In [29]:
device = "cuda"
aug_tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/mounts/work/akoksal/hf_cache")
aug_model = AutoModelForMaskedLM.from_pretrained(model_name, cache_dir="/mounts/work/akoksal/hf_cache").to(device)
aug_model.eval();

Some weights of the model checkpoint at bert-large-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [30]:
def augmented_sentences(target_data):
    final_sentences = []
    
    for sentence_id, sentence in tqdm(enumerate(target_data)):
        words = sentence.split()
        for word_id in range(len(words)):
            res = {"sentence_id": sentence_id, "word_id":word_id, "word":words[word_id]}
            new_words = words[:word_id]+[aug_tokenizer.mask_token]+words[word_id+1:]
            masked_sentence = " ".join(new_words)
            res["masked_sentence"] = masked_sentence

            inputs = aug_tokenizer(res["masked_sentence"], return_tensors="pt").to(device)
            outputs = aug_model(**inputs)

            masked_id_loc = [idx for idx, el in enumerate(inputs["input_ids"][0]) if el==aug_tokenizer.mask_token_id][0]
            vals = outputs["logits"].detach().to("cpu")[0][masked_id_loc]
            probs = torch.nn.Softmax(dim=0)(vals)
            word = " "+res["word"]

            prob = probs[aug_tokenizer.encode(word, add_special_tokens=False)]
            if len(prob)==1:
                res["one_token"] = True
                res["prob"] = round(float(probs[aug_tokenizer.encode(word, add_special_tokens=False)]), 3)
            else:
                res["one_token"] = False
            possible_vals = []
            counter = 0
            added_words = set()
            for el in reversed(probs.argsort()[-10:]):
                tword = aug_tokenizer.decode([el]).strip()
#                 print(tword, word)
                if tword.startswith("##"):
                    continue
                if tword.lower() in added_words:
                    continue
                if tword == word.strip():
                    possible_vals.append({"prob":round(float(probs[el]), 3), "word":tword,
                                         "token":int(el)})
                    added_words.add(tword.lower())
                    counter += 1
                elif tword.lower()==word.strip().lower():
                    continue
                else:
                    possible_vals.append({"prob":round(float(probs[el]), 3), "word":tword,
                                         "token":int(el)})
                    added_words.add(tword.lower())
                    counter += 1
                    
                if counter == 5:
                    break
            res["replacement"] = possible_vals

            final_sentences.append(res)
    return final_sentences


In [31]:
final_sentences_gold = augmented_sentences(gold_content)

199it [02:02,  1.63it/s]


In [32]:
with open(f"/mounts/work/akoksal/word_alignment_silver/{lang_pair}/gold_augmented.json", "w") as f:
    json.dump(final_sentences_gold, f)

In [33]:
final_sentences_parallel = augmented_sentences(parallel_sample_content)

1000it [09:59,  1.67it/s]


In [34]:
with open(f"/mounts/work/akoksal/word_alignment_silver/{lang_pair}/parallel_sample_augmented.json", "w") as f:
    json.dump(final_sentences_parallel, f)