In [1]:
import os
import json
import torch
import numpy as np
import torch.nn.functional as F
import itertools
from tqdm import tqdm
import unidecode
from datetime import datetime
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import DataLoader, TensorDataset, SequentialSampler

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
print('using device:', device)

using device: cuda


In [4]:
language = 'hi'

In [6]:
!mkdir -p /scratch/tabhishek
!scp -r ada:/share1/tushar.abhishek/project_copernicus/datasets/$language /scratch/tabhishek

hi-stage-I-input.jsonl                        100%  170MB  85.1MB/s   00:02    
hi-stage-I-output-meta-data.jsonl             100%  641KB 641.5KB/s   00:00    
hi-stage-II-output.jsonl                      100%   47MB  46.6MB/s   00:00    
hi-stage-I-output.jsonl                       100%  304MB 101.5MB/s   00:03    
hi-stage-II-internal-test.jsonl               100%  800KB 799.9KB/s   00:00    


In [7]:
root_dir='/scratch/tabhishek'

sentence_test_file = os.path.join(root_dir, language, '%s-stage-II-internal-test.jsonl' % language)
sentence_file = os.path.join(root_dir, language, '%s-stage-I-output.jsonl' % language)
# sentence_file = os.path.join(root_dir, 'datasets', 'sampled-first-sent-hi-stage-I-output.jsonl')

In [8]:
# train_data = []
# with open(sentence_file) as dfile:
#     for item in dfile.readlines():
#         train_data.append(json.loads(item.strip()))

test_data = []
with open(sentence_test_file) as dfile:
    for item in dfile.readlines():
        test_data.append(json.loads(item.strip()))

In [9]:
# print('total number of train samples', len(train_data))
print('total number of test samples', len(test_data))

total number of test samples 424


In [10]:
import random
random_ex = random.choice(test_data)

print(random_ex)

{'entity_name': 'Jane Goodall', 'sentence': 'चिम्पांजी के लिये दुनिया की सबसे चर्चित विशेषज्ञ माने जानेवाली गुडाल , गोम्बे स्ट्रीम राष्ट्रीय उद्यान , तंजानिया में पाए जानेवाले जंगली चिम्पांजियों के सामाजिक और पारिवारिक बातचीत पर अपने 55 साल के अध्ययन के लिए प्रसिद्ध हैं ।', 'native_sentence_section': 'introduction', 'translated_sentence': "Known as the world's most famous specialist for chimpanzees, the Goodall, Gombe Stream National Park, is famous for its 55-year study of the social and family interactions of wild chimpanzees found in Tanzania.", 'sent_index': 1, 'facts': [['founded by', 'Jane Goodall Institute', [], True], ['award received', 'Princess of Asturias Award for Technical and Scientific Research', [['point in time', '2003']], False], ['award received', 'Tyler Prize for Environmental Achievement', [['point in time', '1997']], False], ['award received', 'William Procter Prize for Scientific Achievement', [['point in time', '1996']], False], ['award received', 'Officer of th

In [11]:
def store_file(res, file_name):
    with open(file_name, 'w', encoding='utf-8') as dfile:
        for item in res:
            json.dump(item, dfile)
            dfile.write('\n')

In [11]:
project_dir = '/home/tushar.abhishek/ire/research/project_copernicus/datasets/iterative_self_training'
# store_file(train_set, os.path.join(project_dir, 'train.jsonl'))
# store_file(valid_set, os.path.join(project_dir, 'val.jsonl'))
# store_file(test_data, os.path.join(project_dir, 'test.jsonl'))

In [12]:
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM

# config="xlm-roberta-large" 
# config="google/muril-base-cased"
# config="ai4bharat/indic-bert"
# config="sentence-transformers/LaBSE"
config="google/mt5-small"
# config="bert-base-multilingual-uncased"
# config="facebook/mbart-large-cc25" # also need to change

tokenizer = AutoTokenizer.from_pretrained(config)
model = AutoModelForSeq2SeqLM.from_pretrained(config).to(device)

In [13]:
new_tokens = ['<H>', '<R>', '<T>']
new_tokens_vocab = {}
new_tokens_vocab['additional_special_tokens'] = []
for idx, t in enumerate(new_tokens):
    new_tokens_vocab['additional_special_tokens'].append(t)
num_added_toks = tokenizer.add_special_tokens(new_tokens_vocab)
print('We have added %s tokens' % num_added_toks)
model.resize_token_embeddings(len(tokenizer))

We have added 3 tokens


Embedding(250103, 512)

In [197]:
import torch

# loading model weights from checkpoint
def get_checkpoint_file(checkpoint_path):
    file_list = []
    for file_name in os.listdir(checkpoint_path):
        if not file_name.endswith('ckpt'):
            continue
        last_modified_time = os.path.getmtime(
            os.path.join(checkpoint_path, file_name))
        file_list.append([file_name, last_modified_time])

    print(
        'total number of files within checkpoint directory: %d' % len(file_list))
    assert len(file_list) > 0, "no checkpoint file"
    # if multiple files exists then choose the last modified checkpoint path
    sorted(file_list, key=lambda x: x[1], reverse=True)
    return os.path.join(checkpoint_path, file_list[0][0])

checkpoint_path="/scratch/tabhishek/ext_train/hindi-translated-webnlg/mt5-small-1/checkpoint"
checkpoint_file = get_checkpoint_file(checkpoint_path)
print('loading the checkpoint from file : %s' %
                checkpoint_file)
with open(checkpoint_file, 'rb') as tfile:
    checkpoint = torch.load(tfile)
    model_weights = {
        k[6:]: v for k, v in checkpoint['state_dict'].items() if k.startswith('model.')}
    print(model.load_state_dict(model_weights))

total number of files within checkpoint directory: 1
loading the checkpoint from file : /scratch/tabhishek/ext_train/hindi-translated-webnlg/mt5-small-1/checkpoint/checkpoint.ckpt
<All keys matched successfully>


In [14]:
model.eval()

MT5ForConditionalGeneration(
  (shared): Embedding(250103, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(250103, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedGeluDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (w

In [18]:
def get_nodes(n):
    n = n.strip()
    n = n.replace('(', '')
    n = n.replace('\"', '')
    n = n.replace(')', '')
    n = n.replace(',', ' ')
    n = n.replace('_', ' ')

    #n = ' '.join(re.split('(\W)', n))
    n = unidecode.unidecode(n)
    #n = n.lower()

    return n


def get_relation(n):
    n = n.replace('(', '')
    n = n.replace(')', '')
    n = n.strip()
    n = n.split()
    n = "_".join(n)
    return n

def fact_str(fact, enable_qualifiers=False):
    fact_str = ['<R>', get_relation(fact[0]), '<T>', get_nodes(fact[1])]
    qualifier_str = [' '.join(x) for x in fact[2]]
    if enable_qualifiers:
        fact_str += ['<Q>', qualifier_str]
    return fact_str

In [19]:
def pooled_rep(model_output, attention_mask, reduce='cls'):
    if reduce=='cls':
        return model_output[:, 0, :]
    elif reduce == "mean":
        token_embeddings = model_output #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask
    elif reduce == 'sum':
        token_embeddings = model_output #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        return sum_embeddings
    else:
        raise Exception('reduce function not present !!!')

def pooled_rep_v2(model_output, attention_mask, reduce='cls', layers=[8], pool='cat'):
    final_res = []
    for i in layers:
        if reduce=='cls':
            final_res.append(model_output[i][:, 0, :])
        elif reduce == "mean":
            token_embeddings = model_output[i] #First element of model_output contains all token embeddings
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            final_res.append(sum_embeddings / sum_mask)
        elif reduce == 'sum':
            token_embeddings = model_output[i] #First element of model_output contains all token embeddings
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            final_res.append(sum_embeddings)
        else:
            raise Exception('reduce function not present !!!')
    if pool=='sum' and len(final_res)>1:
        return torch.stack(final_res, dim=0).sum(dim=0)
    elif pool=='mean' and len(final_res)>1:
        return torch.stack(final_res, dim=0).mean(dim=0)
    else:
        return torch.cat(final_res, dim=-1)
    final_res
        
def get_candidates(fact_len):
    candidate_list = []
    facts_idx = [x for x in range(fact_len)]
    for L in range(0, len(facts_idx)+1):
        for subset in itertools.combinations(facts_idx, L):
            candidate_list.append(list(subset))
    return candidate_list 

def get_alignment(entity_name, tsentence, facts, reduce='cls', batch_size=512, layers=[8]):
    res = []
    with torch.no_grad():
        enc = tokenizer.encode_plus(tsentence, padding='longest', return_attention_mask=True, return_tensors='pt')
        #taking the [CLS] token
        s_out = model.encoder(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device), output_hidden_states=True).last_hidden_states
        all_word_enc = s_out[0][:torch.sum(enc["attention_mask"])].cpu().tolist()
        sentence_encoding = pooled_rep(s_out, enc["attention_mask"].to(device), reduce=reduce)
        
        candidates = get_candidates(len(facts))
        tcandidates = []
        for c in candidates:
            temp_candidate = ['<H>', entity_name]
            for i in c:
                temp_candidate += fact_str(facts[i])
            tcandidates.append(' '.join(temp_candidate))
#             print(' '.join(temp_candidate))
        
        cenc = tokenizer.batch_encode_plus(tcandidates, padding='longest', return_attention_mask=True, return_tensors='pt')
        
        
        dataset = TensorDataset(cenc['input_ids'], cenc['attention_mask'])
        dataloader = DataLoader(dataset, sampler=SequentialSampler(dataset), batch_size=batch_size)

        temp = []
        for t, batch in enumerate(dataloader):
            temp_c_out_v1 = model.encoder(input_ids=batch[0].to(
                device), attention_mask=batch[1].to(device), output_hidden_states=True).hidden_states
            temp_c_out_v1 = pooled_rep_v2(temp_c_out_v1, batch[1].to(device), reduce=reduce, layers=layers)
            
            temp_c_out_v2 = model.decoder(input_ids=batch[0].to(
                device), attention_mask=batch[1].to(device), output_hidden_states=True).hidden_states
            temp_c_out_v2 = pooled_rep_v2(temp_c_out_v2, batch[1].to(device), reduce=reduce, layers=layers)
            
            temp_c_out = torch.cat((temp_c_out_v1, temp_c_out_v2), dim=-1)
#             temp_c_out = temp_c_out_v1
            temp.append(temp_c_out)
            
        candidate_encoding = torch.vstack(temp)
        
#         c_out = model.encoder(input_ids=cenc["input_ids"].to(device), attention_mask=cenc["attention_mask"].to(device)).last_hidden_state
#         candidate_encoding = pooled_rep(c_out, cenc["attention_mask"].to(device), reduce=reduce)
        
        
        scores = F.cosine_similarity(candidate_encoding, sentence_encoding, 1, 1e-6).cpu().tolist()
        score_map = { ':'.join(list(map(lambda x:str(x).strip(), candidates[i]))):v for i,v in enumerate(scores)}
        
        return score_map

In [90]:
from collections import defaultdict
from sklearn.cluster import KMeans
from numpy.linalg import norm
from sklearn.metrics.pairwise import cosine_similarity
from scipy.special import softmax
import math


def get_similarity_scores(query, documents, reduce='mean', batch_size=10):
    with torch.no_grad():
        enc = tokenizer.encode_plus(query, padding='longest', return_attention_mask=True, return_tensors='pt')
        
        sentence_encoding = pooled_rep(s_out, enc["attention_mask"].to(device), reduce=reduce)
        
        cenc = tokenizer.batch_encode_plus(documents, padding='longest', return_attention_mask=True, return_tensors='pt')
        dataset = TensorDataset(cenc['input_ids'], cenc['attention_mask'])
        dataloader = DataLoader(dataset, sampler=SequentialSampler(dataset), batch_size=batch_size)
        
        temp = []
        for t, batch in enumerate(dataloader):
            temp_c_out = model.encoder(input_ids=batch[0].to(
                device), attention_mask=batch[1].to(device), output_hidden_states=True).last_hidden_state
            temp_c_out = pooled_rep(temp_c_out, batch[1].to(device), reduce=reduce)
            temp.append(temp_c_out)
            
        candidate_encoding = torch.vstack(temp)
        scores = F.cosine_similarity(candidate_encoding, sentence_encoding, 1, 1e-6).cpu().tolist()
    return scores, all_word_enc.cpu().tolist(), candidate_encoding.cpu().tolist()
    
def cal_div_score(sent_sim, fact_sim, c):
    if len(c)==0:
        return 0.0
    fact_sim = softmax(fact_sim[c, :][:, c], axis=0)
    res = []
    for fs in fact_sim:
        temp = np.prod((fs*np.array(sent_sim)[c]))
        res.append(temp)
    #final score_calculation
    return np.sum(res)

def diversity_score(sentence_similarity, fact_emb, reduce='mean'):
    fact_sim_matrix = cosine_similarity(fact_emb, fact_emb)
    score_map = {}
    candidates = get_candidates(len(fact_emb))
    for c in candidates:
        key = ':'.join([str(x).strip() for x in c])
        score_map[key] = cal_div_score(sentence_similarity, fact_sim_matrix, c)
    return score_map

def cal_coverage_score(sim_matrix, fact_sim, rows):
    if len(rows)==0:
        return 0.0
    fact_sim = softmax(fact_sim[rows, :][:, rows], axis=1)
    
    # fact_sim K (number facts) X N (row contains words embeddings in sentence)
    
    local_sim_matrix = sim_matrix[rows,:]
    fact_coverage = local_sim_matrix.mean(axis=0)
#     final_coverage = [max(fact_coverage[i] - (total_fact_coverage[i]), 0) for i in range(len(fact_coverage))]
    return norm(fact_coverage, 2)

def coverage_score(sentence_embs, fact_embs):
    sim_matrix = cosine_similarity(fact_embs, sentence_embs)
    fact_sim_matrix = cosine_similarity(fact_emb, fact_emb)
    score_map = {}
    candidates = get_candidates(len(fact_embs))
    for c in candidates:
        key = ':'.join([str(x).strip() for x in c])
        score_map[key] = cal_coverage_score(sim_matrix, fact_sim_matrix, c)
    return score_map

def get_alignment_new(tsentence, tfacts, reduce='mean', cluster=3):
    m_facts = [' '.join(fact_str(x)) for x in tfacts]
    fact_similarity = defaultdict(lambda: defaultdict(lambda: 0))
    sentence_similarity, sent_emb, fact_emb = get_similarity_scores(tsentence, m_facts)
    # clustering based on similarity
#     dv = diversity_score(sentence_similarity, fact_emb)
    cv = coverage_score(sent_emb, fact_emb)
    final_score = {}
    for k in cv:
        final_score[k] = cv[k]
    return final_score

In [138]:
def cal_prob(output_logits, tgt_enc, pad_token_id):
    assert output_logits.shape[0] == tgt_enc.shape[0] and output_logits.shape[1] == tgt_enc.shape[1]
    prob_score = []
        
    with torch.no_grad():
        batch_size = output_logits.shape[0]
        seq_len = output_logits.shape[1]
        for i in range(batch_size):
            temp_prob_score = 0
            for j in range(seq_len):
                if tgt_enc[i][j]==pad_token_id:
                    break
                temp_prob_score+=output_logits[i][j][tgt_enc[i][j]].item()
            prob_score.append(temp_prob_score/j)
    return prob_score

def get_prob_scores(entity_name, query, facts, batch_size=10):
    random.seed(42)
    candidates = get_candidates(len(facts))
    documents = []
    
    for c in candidates:
        temp_candidate = ['<H>', entity_name]
        random.shuffle(c)
        for i in c:
            temp_candidate += fact_str(facts[i])
        documents.append(' '.join(temp_candidate))
    
    with torch.no_grad():
        # duplicate queries 
        query = [query for _ in documents]
        enc_tgt = tokenizer.batch_encode_plus(query, padding='longest', return_attention_mask=True, return_tensors='pt')
        enc_src = tokenizer.batch_encode_plus(documents, padding='longest', return_attention_mask=True, return_tensors='pt')
        dataset = TensorDataset(enc_src['input_ids'], enc_src['attention_mask'], enc_tgt['input_ids'], enc_tgt['attention_mask'])
        dataloader = DataLoader(dataset, sampler=SequentialSampler(dataset), batch_size=batch_size)
        
        temp = []
        for t, batch in enumerate(dataloader):
            output_logits = model(
                batch[2].to(device),
                attention_mask=batch[3].to(device),
                labels=batch[0].to(device),
                decoder_attention_mask=batch[1].to(device)
            ).logits
            temp.extend(cal_prob(output_logits, batch[0], tokenizer.pad_token_id))
    score_map = {}
    assert len(candidates)==len(temp), "mismatch between length of candidates and model scores"
    for clist, score in zip(candidates, temp):
        sorted_clist = sorted(clist)
        key = ":".join([str(tz) for tz in sorted_clist])
        score_map[key] = score
    return score_map

In [139]:
class AlignmentEvaluation:
    def __init__(self):
        self.true_positive = 0
        self.false_positive = 0
        self.false_negative = 0
        self.correct_count = 0
        self.precision_list = []
        self.recall_list = []
        self.total_count = 0
    
    def add(self, true_y: list, pred_y: list):
        '''
        Arguments
            true_y: list containing gold fact indexes
            pred_y: list containing predicted facts indexes
        '''
        set_a = set(true_y)
        set_b = set(pred_y)
        
        tp = len(set_a.intersection(set_b))
        fp = len(set_b.difference(set_a))
        fn = len(set_a.difference(set_b))

        # storing data for calculating the accuracy
        if fp==0 and fn==0:
            self.correct_count+=1

        # storing data for calculation of global precision & recall
        self.true_positive += tp
        self.false_positive += fp
        self.false_negative += fn
        self.total_count+=1

        # calulating the local precision recall
        precision = float(tp)/(len(set_b)+1e-9)
        recall = float(tp)/(len(set_a)+1e-9)            
        self.precision_list.append(precision)
        self.recall_list.append(recall)
    
    def addlist(self, true_y: list, pred_y: list):
        '''
        Arguments
            true_y: list of lists, each list contains the gold facts
            pred_y: list of lists, eacg list contains the predicted facts indexes
        '''
        assert len(true_y)==len(pred_y), "length mismatch betweent the prediction list and gold label list"

        for x, y in zip(true_y, pred_y):
            self.add(x, y)
    
    def get_scores(self):
        global_precision = float(self.true_positive) / (self.true_positive + self.false_positive + 1e-9)
        global_recall = float(self.true_positive) / (self.true_positive + self.false_negative + 1e-9)
        global_f1 = (2*global_recall*global_precision) / (global_precision+global_recall+1e-9)

        results = {
            'precision': global_precision,
            'recall': global_recall,
            'f1': global_f1,
            'avg_precision': sum(self.precision_list)/float(self.total_count+1e-9),
            'avg_recall': sum(self.recall_list)/float(self.total_count+1e-9),
            'accuracy': self.correct_count / float(self.total_count+1e-9),
            'total_count': self.total_count,
        }

        return results

In [140]:
test_data[0]

{'entity_name': 'Calum MacLeod',
 'sentence': 'उन्होंने वार्विकशायर और डरहम के लिए इंग्लैंड में काउंटी क्रिकेट खेला है ।',
 'native_sentence_section': 'introduction',
 'translated_sentence': 'He has played county cricket in England for Warwickshire and Durham.',
 'sent_index': 3,
 'facts': [['member of sports team',
   'Durham County Cricket Club',
   [['end time', '2016'], ['start time', '2014']],
   False],
  ['member of sports team',
   'Warwickshire County Cricket Club',
   [['end time', '2009'], ['start time', '2008']],
   False],
  ['country for sport', 'Scotland', [], False],
  ['member of sports team', 'Scotland national cricket team', [], False],
  ['occupation', 'cricketer', [], False],
  ['country of citizenship', 'United Kingdom', [], False],
  ['educated at', 'Hillpark Secondary School', [], False],
  ['date of birth', '15 November 1988', [], False],
  ['place of birth', 'Glasgow', [], False]],
 'translated_facts': [['खेल टीम का सदस्य',
   'डरहम काउंटी क्रिकेट क्लब',
   []

In [141]:
# evaluation over test dataset
alignment_evaluation = {
    'empty': AlignmentEvaluation(),
    'partial': AlignmentEvaluation(),
    'complete': AlignmentEvaluation(),
    'combined': AlignmentEvaluation()
}

pooling_type='mean'
cluster_count=3

res = test_data

for item in tqdm([tz for tz in res if tz['coverage']=='complete']):
    pred_alignment = get_prob_scores(item['entity_name'], item['sentence'], item['facts'])#get_alignment_new(item['sentence'], item['facts'], reduce=pooling_type, cluster=cluster_count)
    
    gold_alignment = sorted(item['fact_index'])
    
    sorted_pred_alignment = [(k, v) for k, v in sorted(pred_alignment.items(), key=lambda x: x[1], reverse=True)]
    
    pred = sorted_pred_alignment[0][0]
    actual = ':'.join(list(map(lambda x:str(x).strip(), gold_alignment)))
    
    pred_list = pred.split(':') if pred != '' else []
    actual_list = actual.split(':') if actual != '' else []
    
    alignment_evaluation[item['coverage']].add(actual_list, pred_list)
    alignment_evaluation['combined'].add(actual_list, pred_list)
    
print('evaluating the alignment on test dataset...')
for key in ['empty', 'partial', 'complete', 'combined']:
    eval_res = alignment_evaluation[key].get_scores()
    print('%s alignment [%d] accuracy: %0.3f, global_precision: %0.3f, global_recall: %0.3f, global_f1: %0.3f, avg_precision: %0.3f, avg_recall: %0.3f' % (key, eval_res['total_count'], eval_res['accuracy'],                                                                                                                                              eval_res['precision'], eval_res['recall'], eval_res['f1'], eval_res['avg_precision'], eval_res['avg_recall']))

100%|██████████| 71/71 [08:28<00:00,  7.16s/it]

evaluating the alignment on test dataset...
empty alignment [0] accuracy: 0.000, global_precision: 0.000, global_recall: 0.000, global_f1: 0.000, avg_precision: 0.000, avg_recall: 0.000
partial alignment [0] accuracy: 0.000, global_precision: 0.000, global_recall: 0.000, global_f1: 0.000, avg_precision: 0.000, avg_recall: 0.000
complete alignment [71] accuracy: 0.014, global_precision: 0.209, global_recall: 0.397, global_f1: 0.274, avg_precision: 0.227, avg_recall: 0.381
combined alignment [71] accuracy: 0.014, global_precision: 0.209, global_recall: 0.397, global_f1: 0.274, avg_precision: 0.227, avg_recall: 0.381





In [133]:
import random

facts_threshold = 5

while True:
    random_index = random.choice([i for i in range(len(test_data))])
    if len(test_data[random_index]['facts']) <= facts_threshold and len(test_data[random_index]['fact_index'])>0 and test_data[random_index]['coverage']=='partial':
        break

print(random_index, 'facts count',len(test_data[random_index]['facts']))
print('sentence:', test_data[random_index]['sentence'])
print('translated sentence:', test_data[random_index]['translated_sentence'])
print('--'*30)
for i, fact in enumerate(test_data[random_index]['facts']):
    print(i, fact)
print('--'*30)

pred_alignment = get_prob_scores(test_data[random_index]['entity_name'], test_data[random_index]['sentence'], test_data[random_index]['facts'])
# pred_alignment = get_sentence_specific_fact_alignment(test_data[random_index]['sentence'], test_data[random_index]['translated_sentence'], test_data[random_index]['facts'], test_data[random_index]['translated_facts'], reduce=pooling_type)
gold_alignment = sorted(test_data[random_index]['fact_index'])
gold_alignment_str = ':'.join([str(x) for x in gold_alignment]) 

# index=0
# for cluster, fact in zip(clabels, test_data[random_index]['facts']):
#     print(index, "C-%d"%cluster, fact)
#     index+=1

print('=='*30)

sorted_pred_alignment = [(k, v) for k, v in sorted(pred_alignment.items(), key=lambda x: x[1], reverse=True)]
for item in sorted_pred_alignment:
    string = " %s %s"%(item[0], item[1])
    if item[0]==gold_alignment_str:
        string = "*%s" % string.strip()
    print(string)

52 facts count 5
sentence: मिशेल जोसेफ स्वेपसन ( जन्म 4 अक्टूबर 1993 ) एक ऑस्ट्रेलियाई क्रिकेटर हैं ।
translated sentence: Mitchell Joseph Swepson (born 4 October 1993) is an Australian cricketer.
------------------------------------------------------------
0 ['date of birth', '04 October 1993', [], False]
1 ['occupation', 'cricketer', [], False]
2 ['country of citizenship', 'Australia', [], False]
3 ['member of sports team', 'Queensland cricket team', [['start time', '2015']], False]
4 ['member of sports team', 'Brisbane Heat', [['start time', '2015']], False]
------------------------------------------------------------
 1 -26.750060962191945
  -33.35058168707223
 0 -33.996230857125646
 2:4 -34.5526661379584
 0:2 -38.05343831818679
 1:2:3:4 -38.30819024710819
 1:3:4 -39.73881224928231
 3:4 -40.05091026733662
 0:1:2:3:4 -40.086906630417396
 2 -40.42180354019691
 0:1:2:3 -40.70852154698865
 0:2:3 -40.78264997745382
 0:2:3:4 -41.064815850093446
 0:2:4 -41.229867080162315
 2:3 -41.2867842

In [16]:
test_sample = test_data[1]
print(test_sample)

{'entity_name': 'Henri Bergson', 'sentence': 'उन्हें १९२७ का साहित्य का नोबेल पुरस्कार प्रदान किया गया था ।', 'native_sentence_section': 'introduction', 'translated_sentence': 'He was awarded the Nobel Prize for Literature in 1927.', 'sent_index': 1, 'facts': [['award received', 'Nobel Prize in Literature', [['point in time', '1927'], ['prize money', '126501.0 Swedish krona']], False], ['nominated for', 'Nobel Prize in Literature', [['point in time', '1921']], False], ['nominated for', 'Nobel Prize in Literature', [['point in time', '1918']], False], ['nominated for', 'Nobel Prize in Literature', [['point in time', '1914']], False], ['nominated for', 'Nobel Prize in Literature', [['point in time', '1913']], False], ['nominated for', 'Nobel Prize in Literature', [['point in time', '1912']], False], ['nominated for', 'Nobel Prize in Literature', [['point in time', '1915']], False], ['nominated for', 'Nobel Prize in Literature', [['point in time', '1928']], False], ['member of', 'Lincean 

In [23]:
import random

random.seed(42)
candidates = get_candidates(len(test_sample['facts']))
tcandidates = []
entity_name = test_sample['entity_name']
facts = test_sample['facts']

for c in candidates:
    temp_candidate = ['<H>', entity_name]
    random.shuffle(c)
    for i in c:
        temp_candidate += fact_str(facts[i])
    tcandidates.append(' '.join(temp_candidate))

In [75]:
src_text = tcandidates
tgt_text = test_sample['sentence']

In [68]:
print(len(src_text), len(tgt_text))

1024 1024


In [69]:
src_enc = tokenizer.batch_encode_plus(src_text, padding='longest', return_tensors='pt', return_attention_mask=True)
tgt_enc = tokenizer.batch_encode_plus(tgt_text, padding='longest', return_tensors='pt', return_attention_mask=True)

In [86]:
def cal_prob(output_logits, tgt_enc, pad_token_id):
    assert output_logits.shape[0] == tgt_enc.shape[0] and output_logits.shape[1] == tgt_enc.shape[1]
    prob_score = []
        
    with torch.no_grad():
        batch_size = output_logits.shape[0]
        seq_len = output_logits.shape[1]
        for i in range(batch_size):
            temp_prob_score = 0
            for j in range(seq_len):
                if tgt_enc[i][j]==pad_token_id:
                    break
                temp_prob_score+=output_logits[i][j][tgt_enc[i][j]].item()
            prob_score.append(temp_prob_score/j)
    return prob_score

def get_prob_scores(entity_name, query, facts, batch_size=10):
    random.seed(42)
    candidates = get_candidates(len(facts))
    documents = []
    
    for c in candidates:
        temp_candidate = ['<H>', entity_name]
        random.shuffle(c)
        for i in c:
            temp_candidate += fact_str(facts[i])
        documents.append(' '.join(temp_candidate))
    
    with torch.no_grad():
        # duplicate queries 
        query = [query for _ in documents]
        enc_tgt = tokenizer.batch_encode_plus(query, padding='longest', return_attention_mask=True, return_tensors='pt')
        enc_src = tokenizer.batch_encode_plus(documents, padding='longest', return_attention_mask=True, return_tensors='pt')
        dataset = TensorDataset(enc_src['input_ids'], enc_src['attention_mask'], enc_tgt['input_ids'], enc_tgt['attention_mask'])
        dataloader = DataLoader(dataset, sampler=SequentialSampler(dataset), batch_size=batch_size)
        
        temp = []
        for t, batch in enumerate(dataloader):
            output_logits = model(
                batch[0].to(device),
                attention_mask=batch[1].to(device),
                labels=batch[2].to(device),
                decoder_attention_mask=batch[3].to(device)
            ).logits
            temp.extend(cal_prob(output_logits, batch[2], tokenizer.pad_token_id))
    score_map = {}
    assert len(clist)==len(temp), "mismatch between length of candidates and model scores"
    for clist, score in zip(candidates, temp):
        sorted_clist = sorted(clist)
        key = ":".join(clist)
        score[key] = score
    return score_map

In [84]:
get_prob_scores(tgt_text, src_text)

[-31.478291471799213,
 -23.664619396130245,
 -24.546770294507343,
 -24.546770294507343,
 -24.546770294507343,
 -24.546770294507343,
 -24.546770294507343,
 -24.546770294507343,
 -24.546770294507343,
 -28.248214840888977,
 -28.58089150985082,
 -25.548620223999023,
 -25.548620223999023,
 -27.59205714861552,
 -25.548620223999023,
 -25.548620223999023,
 -25.548620223999023,
 -25.548620223999023,
 -25.545311331748962,
 -33.150628328323364,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -23.403819878896076,
 -24.746236562728882,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -26.285025278727215,
 -24.746236562728882,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -23.403819878896076,
 -23.246151785055797,
 -21.502303620179493,
 -21.502303620179493,
 -21.502303620179493,
 -26.285025278727215,
 -24.7462365

In [80]:
src_text

['<H> Henri Bergson',
 '<H> Henri Bergson <R> award_received <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> nominated_for <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> nominated_for <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> nominated_for <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> nominated_for <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> nominated_for <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> nominated_for <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> nominated_for <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> member_of <T> Lincean Academy',
 '<H> Henri Bergson <R> student_of <T> Jacques Maritain',
 '<H> Henri Bergson <R> nominated_for <T> Nobel Prize in Literature <R> award_received <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> nominated_for <T> Nobel Prize in Literature <R> award_received <T> Nobel Prize in Literature',
 '<H> Henri Bergson <R> award_received <T> Nobel Prize in Literat

In [66]:
cal_prob(output.logits, tgt_enc['input_ids'], tokenizer.pad_token_id)

[-19.08577154159546]

In [85]:
a = [3, 1, 2]
sorted(a)

[1, 2, 3]