In [2]:
# ! python -m spacy download en_core_web_trf

In [3]:
! rm -rf /tmp/dask-worker-space
! rm -rf ./dask-worker-space

# Dataset preparation

In [4]:
import os
import re
import copy
import spacy
import string
import pickle
import random
import logging
import unicodedata

import pandas as pd

from tqdm.auto import tqdm
from cai_common.data import ParallelTMXLoader
from dask.distributed import Client, LocalCluster

In [5]:
dask_logger = logging.getLogger("distributed.utils_perf")
dask_logger.setLevel(logging.ERROR)

dask_client = Client(LocalCluster(
    n_workers=20,
    threads_per_worker=1
))

In [6]:
parallel_df = ParallelTMXLoader().apply_markup().clean_bad_chars().dataframe.compute()
parallel_df

Unnamed: 0,filename,tohoku,folio,position,tibetan,english
0,Toh_384-Glorious_King_of_Tantras_That_Resolves...,384,F.187.a,1,དཔལ་རྡོ་རྗེ་སེམས་དཔའ་ལ་ཕྱག་འཚལ་ལོ། །,I pay homage to Glorious Vajrasattva!
1,Toh_384-Glorious_King_of_Tantras_That_Resolves...,384,F.187.a,39,འདི་སྐད་བདག་གིས་ཐོས་པའི་དུས་གཅིག་ན། །,Thus have I heard at one time.
2,Toh_384-Glorious_King_of_Tantras_That_Resolves...,384,F.187.a,204,དེ་ནས་བྱང་ཆུབ་སེམས་དཔའ་རྡོ་རྗེ་སྙིང་པོ་ལ་སོགས་...,"Then, the entourage, including bodhisattva Vaj..."
3,Toh_384-Glorious_King_of_Tantras_That_Resolves...,384,F.187.a,322,ཕྱི་ནང་གསང་བའི་མཆོད་པས་མཆོད་ནས་འདི་སྐད་ཅེས་གསོ...,"made outer, inner, and secret offerings, and a..."
4,Toh_384-Glorious_King_of_Tantras_That_Resolves...,384,F.187.a,388,ཀྱེ་ཧོ་བཅོམ་ལྡན་རྡོ་རྗེ་འཛིན། །,O Blessed Vajra Holder!
...,...,...,...,...,...,...
19,Toh_309-The_Sutra_on_Impermanence-v1.tmx,309,F.155.b,650,ནད་མེད་མི་རྟག་ལང་ཚོ་རྟག་མ་ཡིན། །འབྱོར་པ་མི་རྟག...,"“Good health is impermanent, Youth does not la..."
20,Toh_309-The_Sutra_on_Impermanence-v1.tmx,309,F.155.b,757,སྐྱེ་བོ་མི་རྟག་ཉིད་ཀྱིས་ཉེན་གྱུར་ན། །འདོད་པའི་...,"How can beings, afflicted as they are by imper..."
21,Toh_309-The_Sutra_on_Impermanence-v1.tmx,309,F.155.b,858,བཅོམ་ལྡན་འདས་ཀྱིས་དེ་སྐད་ཅེས་བཀའ་སྩལ་ནས། དགེ་ས...,"When the Bhagavān had thus spoken, the monks r..."
22,Toh_309-The_Sutra_on_Impermanence-v1.tmx,309,F.155.b,935,མི་རྟག་པ་ཉིད་ཀྱི་མདོ་རྫོགས་སོ།། །།,This completes “The Sūtra on Impermanence.”


In [7]:
random.seed(42)

In [8]:
spacy.prefer_gpu()
nlp = spacy.load("en_core_web_trf")

## Test

In [None]:
i = 0

In [None]:
example = parallel_df.english.iloc[i]
parsed = nlp(example)

print(example)
print()
for token in parsed:
    print(f"{token.text:<20} {token.lemma_:<20} {token.pos_}")

i += 1

## Run

In [9]:
eligible_pos = {'VERB', 'NOUN', 'ADJ', 'PROPN'}
num_positives = 3500
negative_ratio = 3
validation_frac = 0.1

In [10]:
positive_records = []

records = random.choices(parallel_df.to_dict(orient="records"), k=num_positives)
final_allowed = set(string.ascii_lowercase)

for record in tqdm(records):
    parsed = nlp(record['english'])
    positive_record = []
    for token in filter(lambda token: token.pos_ in eligible_pos, parsed):
        lemma = "".join([c for c in unicodedata.normalize('NFKD', token.lemma_) if not unicodedata.combining(c)]).lower()
        lemma = ''.join([c for c in lemma if c in final_allowed])
        if lemma == '':
            continue
        lemma = re.sub(r'\s+', ' ', lemma).strip()
        positive_record.append({
            'source': record['tibetan'],
            'target': lemma,
            'pos': token.pos_,
            'label': 1
        })
    positive_records.append(positive_record)

positive_examples = [ex for rec in positive_records for ex in rec]

  0%|          | 0/3500 [00:00<?, ?it/s]

In [11]:
from pattern.en import NOUN, VERB, ADJECTIVE
from pattern.en.wordnet import synsets

pos_map = {
    'NOUN': NOUN,
    'PROPN': NOUN,
    'VERB': VERB,
    'ADJ': ADJECTIVE
}

In [12]:
by_tag = {tag: [] for tag in eligible_pos}

for example in tqdm(positive_examples):
    synonyms = synsets(example['target'], pos_map[example['pos']])
    if len(synonyms) == 0:
        by_tag[example['pos']].append([example['target'].lower()])
    else:
        all_synonyms = sorted(list(set([synonym.replace('_', ' ') for synonym in synonyms[0].synonyms])))
        by_tag[example['pos']].append(all_synonyms)

  0%|          | 0/24386 [00:00<?, ?it/s]

In [13]:
negative_records = []

for pos_record in tqdm(positive_records):
    neg_record = []
    for example in pos_record:
        blacked_out_targets = {example['target']}
        for _ in range(negative_ratio):
            example = copy.deepcopy(example)
            while True:
                target = random.choice(
                    random.choice(
                        by_tag[example['pos']]
                    )    # First choose a meaning
                )        #     ...then choose a synonym
                if not target in blacked_out_targets:
                    break
            example['target'] = target
            example['label'] = 0
            blacked_out_targets.add(target)
            neg_record.append(example)
    negative_records.append(neg_record)

  0%|          | 0/3500 [00:00<?, ?it/s]

In [14]:
pos_train_size = int(len(positive_records) * (1 - validation_frac))

pos_train_dataset = [ex for i in range(pos_train_size) for ex in positive_records[i]]
pos_val_dataset = [ex for i in range(pos_train_size, len(positive_records)) for ex in positive_records[i]]
f"Positive part: {len(pos_train_dataset)}+{len(pos_val_dataset)} = {len(pos_train_dataset) + len(pos_val_dataset)} = {len(positive_examples)}"

'Positive part: 21971+2415 = 24386 = 24386'

In [15]:
neg_len = sum([1 for rec in negative_records for ex in rec])
neg_train_size = int(len(negative_records) * (1 - validation_frac))

neg_train_dataset = [ex for i in range(neg_train_size) for ex in negative_records[i]]
neg_val_dataset = [ex for i in range(neg_train_size, len(negative_records)) for ex in negative_records[i]]
f"Negative part: {len(neg_train_dataset)}+{len(neg_val_dataset)} = {len(neg_train_dataset) + len(neg_val_dataset)} = {neg_len}"

'Negative part: 65913+7245 = 73158 = 73158'

In [16]:
dataset_fn = os.path.join(os.environ['CAI_TEMP_PATH'], "temp_data/aligner_dataset")
os.makedirs(dataset_fn, exist_ok=True)
dataset_fn

'/home/eeisenst/workspace/temp/temp_data/aligner_dataset'

In [17]:
train_dataset = pos_train_dataset + neg_train_dataset
val_dataset = pos_val_dataset + neg_val_dataset

pd.DataFrame(train_dataset).to_csv(os.path.join(dataset_fn, "train.csv"))
pd.DataFrame(val_dataset).to_csv(os.path.join(dataset_fn, "validation.csv"))

# Post training analysis

## Load the checkpoint

In [1]:
from cai_common.models.utils import get_local_ckpt, get_cai_config
from cai_garland.models.cai_encoder_decoder_seq_class import CAIEncoderDecoderForSequenceClassification
from cai_garland.models.factory import make_bilingual_tokenizer

In [2]:
local_ckpt = get_local_ckpt("experiments/aligner/olive-cormorant-bart", model_dir=True)
aligner = CAIEncoderDecoderForSequenceClassification.from_pretrained(local_ckpt)

In [3]:
_ = aligner.cuda()

In [4]:
cai_base_config = get_cai_config(local_ckpt)
encoder_name = cai_base_config['encoder_model_name']
encoder_length = cai_base_config['encoder_max_length']
decoder_name = cai_base_config['decoder_model_name']
decoder_length = cai_base_config['decoder_max_length']

tokenizer = make_bilingual_tokenizer(encoder_name, decoder_name)

In [5]:
# source, target, word, pos = '།སྐྱེ་བོ་ཐམས་ཅད་དགའ་བར་འགྱུར་རོ།', 'one will be loved by everyone.', 'love', 'VERB'
# source, target, word, pos = '།སྐྱེ་བོ་ཐམས་ཅད་དགའ་བར་འགྱུར་རོ།', 'one will be loved by everyone.', 'everyone', 'NOUN'

# source, target, word, pos = "ནད་ཐམས་ཅད་སོས་པར་གྱུར་ཏེ།", "he will cure all diseases", 'cure', 'VERB'
# source, target, word, pos = "ནད་ཐམས་ཅད་སོས་པར་གྱུར་ཏེ།", "he will cure all diseases", 'disease', 'NOUN'

source, target, word, pos = '།འཁྲུལ་འཁོར་ཐམས་ཅད་འགེམས་པར་བྱེད་དོ།', 'one will burn all the diagrams.', 'burn', 'VERB'
# source, target, word, pos = '།འཁྲུལ་འཁོར་ཐམས་ཅད་འགེམས་པར་བྱེད་དོ།', 'one will burn all the diagrams.', 'diagram', 'NOUN'

In [6]:
inputs = tokenizer(source, return_tensors="pt").to(aligner.device)
with tokenizer.as_target_tokenizer():
    inputs_tgt = tokenizer(f"{pos}{tokenizer.target_tokenizer.mask_token}{word}", return_tensors="pt").to(aligner.device)
inputs["decoder_input_ids"] = inputs_tgt["input_ids"]
inputs["decoder_attention_mask"] = inputs_tgt["attention_mask"]



In [7]:
int(aligner(**inputs).logits.argmax())

1

## Captum

In [8]:
import torch

from captum.attr import TokenReferenceBase, visualization, LayerIntegratedGradients, LayerDeepLift

In [9]:
def make_viz_record(attribs, inputs, pred, pred_ind, label, delta, tokenizer):
    attribs = attribs / torch.norm(attribs)
    attribs = attribs.cpu().detach().numpy()
    
    with tokenizer.as_target_tokenizer():
        label = tokenizer.decode(label)
    
    return visualization.VisualizationDataRecord(
        attribs,
        0.5,
        pred_ind,
        label,
        label,
        attribs.sum(),
        tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist()),
        delta
    )

def lig_attribs(source, word, pos, aligner, tokenizer):
    aligner.zero_grad()

    inputs = tokenizer(source, return_tensors="pt").to(aligner.device)
    with tokenizer.as_target_tokenizer():
        inputs_tgt = tokenizer(f"{pos}{tokenizer.target_tokenizer.mask_token}{word}", return_tensors="pt").to(aligner.device)
    inputs["decoder_input_ids"] = inputs_tgt["input_ids"]
    inputs["decoder_attention_mask"] = inputs_tgt["attention_mask"]
    del inputs_tgt
    
    token_reference = TokenReferenceBase(reference_token_idx=tokenizer.source_tokenizer.pad_token_id)
    seq_length = inputs['input_ids'].shape[1]
    input_indices = inputs['input_ids']
    reference_indices = token_reference.generate_reference(seq_length, device=aligner.device).unsqueeze(0)
    reference_indices[0][0], reference_indices[0][-1] = tokenizer.bos_token_id, tokenizer.eos_token_id

    def forward_func(input_ids):
        new_inputs = {}
        new_inputs['input_ids'] = input_ids
        new_inputs['attention_mask'] = torch.ones_like(input_ids)
        new_inputs['decoder_input_ids'] = inputs['decoder_input_ids'].repeat(input_ids.shape[0], 1)
        new_inputs['decoder_attention_mask'] = inputs['decoder_attention_mask'].repeat(input_ids.shape[0], 1)
        logits = aligner(**new_inputs).logits[:,1]#.unsqueeze(1)
        # print(new_inputs)
        # print(logits)
        # print()
        return logits

    lig = LayerIntegratedGradients(forward_func, aligner.encoder.embeddings.word_embeddings)
    igs, delta = lig.attribute(input_indices, reference_indices, n_steps=100, return_convergence_delta=True)
    attribs = igs.sum(dim=2).squeeze(0)

    certainty_score = attribs.detach().cpu().numpy()
    certainty_score.sort()
    certainty_score = abs(certainty_score[-1] / certainty_score[:-1].mean())
    
    return inputs, attribs, certainty_score, float(abs(delta[0]))

In [10]:
inputs, attribs, certainty_score, delta = lig_attribs(source, word, pos, aligner, tokenizer)



In [11]:
attribs, certainty_score, delta

(tensor([ 0.0000, -0.0943,  0.0506,  0.2051,  0.4350,  0.8224,  0.1388,  0.0669,
          0.2731,  0.1122,  0.0000], device='cuda:0', dtype=torch.float64),
 6.925712566676644,
 0.06474727249194512)

In [12]:
viz_records = [make_viz_record(attribs, inputs, 1, 1, 1, delta, tokenizer)]
_ = visualization.visualize_text(viz_records)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,1 (0.50),,1.98,[CLS] ▁ ། འཁྲུལ་འཁོར་ ཐམས་ཅད་ འགེམས་ པར་བྱེད་ ད ོ ། [SEP]
,,,,


## Training bitext

In [13]:
import os

import pandas as pd

In [32]:
df = pd.read_csv(os.path.join(os.environ['CAI_DATA_BASE_PATH'], "experiments/aligner/validation.csv"))
df

Unnamed: 0.1,Unnamed: 0,source,target,pos,label
0,0,འཇིག་རྟེན་འཛིན་བྱང་ཆུབ་སེམས་དཔའ་སེམས་དཔའ་ཆེན་པ...,lokadhara,PROPN,1
1,1,འཇིག་རྟེན་འཛིན་བྱང་ཆུབ་སེམས་དཔའ་སེམས་དཔའ་ཆེན་པ...,bodhisattva,PROPN,1
2,2,འཇིག་རྟེན་འཛིན་བྱང་ཆུབ་སེམས་དཔའ་སེམས་དཔའ་ཆེན་པ...,great,ADJ,1
3,3,འཇིག་རྟེན་འཛིན་བྱང་ཆུབ་སེམས་དཔའ་སེམས་དཔའ་ཆེན་པ...,being,NOUN,1
4,4,འཇིག་རྟེན་འཛིན་བྱང་ཆུབ་སེམས་དཔའ་སེམས་དཔའ་ཆེན་པ...,contemplate,VERB,1
...,...,...,...,...,...
9655,9655,གལ་ཏེ་ཡུལ་འདིར་གྲུ་གུ་ལ་སོགས་པ་གཏུམ་པོ་དད་པ་མ་...,unity,PROPN,0
9656,9656,གལ་ཏེ་ཡུལ་འདིར་གྲུ་གུ་ལ་སོགས་པ་གཏུམ་པོ་དད་པ་མ་...,sugata,PROPN,0
9657,9657,གལ་ཏེ་ཡུལ་འདིར་གྲུ་གུ་ལ་སོགས་པ་གཏུམ་པོ་དད་པ་མ་...,be,VERB,0
9658,9658,གལ་ཏེ་ཡུལ་འདིར་གྲུ་གུ་ལ་སོགས་པ་གཏུམ་པོ་དད་པ་མ་...,respond,VERB,0


In [249]:
row = df[df.label==1].sample(n=1).iloc[0]
print(row.source, row.target, row.pos)

མི་དམན་པ་བདག་དེ་བཞིན་གཤེགས་པ་བདུད་རྩིའི་ཐིགས་པའི་རྒྱལ་པོའི་དྲུང་དུ་སོང་སྟེ། དོན་འདི་ཡོངས་སུ་ཞུ་བར་བྱའོ། ། ask VERB


In [250]:
inputs, attribs, certainty_score, delta = lig_attribs(row.source, row.target, row.pos, aligner, tokenizer)

print(int(aligner(**inputs).logits.argmax()))

print(tokenizer.decode(inputs['input_ids'][0][attribs.argmax()]))
# source_word = inputs['input_ids'][]

print(attribs, certainty_score, delta)
viz_records = [make_viz_record(attribs, inputs, 1, 1, 1, delta, tokenizer)]
_ = visualization.visualize_text(viz_records)

1
ཞུ་
tensor([ 0.0000,  0.3236, -0.1151, -0.0703,  0.1324,  0.0529,  0.1171, -0.0954,
         0.0089,  0.0051, -0.0166, -0.0530,  0.0338, -0.0448, -0.2757, -0.1540,
        -0.0506, -0.0341,  0.2437,  0.0561,  0.3646,  0.1372,  1.4243, -0.0934,
        -0.0045, -0.1467, -0.2210, -0.1689, -0.2439, -0.2170,  0.0000],
       device='cuda:0', dtype=torch.float64) 80.68306842515587 1.5229128805359429


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,1 (0.50),,0.54,[CLS] ▁མི་ དམན་པ་ བདག་ དེ་ བཞིན་ གཤེགས་པ་ བདུད་ རྩིའི་ ཐིགས་ པའི་ རྒྱལ་ པོའི་ དྲུང་དུ་ སོང་ སྟ ེ ། ▁དོན་ འདི་ ཡོངས་ སུ་ ཞུ་ བར་ བྱ འ ོ ། ▁ ། [SEP]
,,,,
