# Use NLI model on Archival dataset

In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML
import sys
import os
from os import path

sys.path.append("./../src")
tmp_path = path.join('.cache', '2023-04-04')
os.makedirs(tmp_path,exist_ok=True)

In [2]:
#SERVER_DIR = '/Users/dunguyen/Projects/IGRIDA/RUNS'
SERVER_DIR = '/srv/tempdd/dunguyen/RUNS'
LOG_PATH = path.join(SERVER_DIR, 'logs')
DATA_CACHE = path.join(SERVER_DIR, 'dataset')
MODEL_CACHE = path.join(SERVER_DIR, 'models')

MONGO_CACHE = path.join(DATA_CACHE, 'archival', 'mongo_db') # Generated from Mongodb
AUTOGESTION_CACHE = path.join(DATA_CACHE, 'archival', 'autogestion') # generated from autogestion repository
INFERENCE_CACHE = path.join(DATA_CACHE, 'archival', 'inference') # model generation

In [3]:
from modules.logger import init_logging
from modules.logger import log

init_logging(color=True)

# Evaluation in generated test set

In [4]:
from data_module.archival_module import ArchivalNLIDM
from model_module.lstm.archival_lstm_module import ArchivalLstmModule
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

###############
# PREPARE DATA
###############
dm = ArchivalNLIDM(cache_path=DATA_CACHE, batch_size=16, num_workers=8)
dm.prepare_data()
dm.setup('test')

###############
# MODEL MODULE
###############

# model's parameters (could be changed in other versions)
MODEL_NAME = 'run=7_vector=0_lentropy=0.02'
m_kwargs = dict(
    n_context=1, 
    d_embedding=300
)

# leave the default ones
model = ArchivalLstmModule(cache_path=MODEL_CACHE,
    mode='dev',
    vocab=dm.vocab,
    concat_context=True,
    data='archival',
    num_class=dm.num_class, **m_kwargs)

ckpt_path = path.join(LOG_PATH, 'archival', MODEL_NAME, 'checkpoints', 'best.ckpt')
if torch.cuda.is_available():
    checkpoint = torch.load(ckpt_path)
else:
    checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
    
model.to(device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()

print('Model is in cuda: ',next(model.parameters()).is_cuda)

14-04-2023 01:21:28 | [34m    INFO[0m [1m [4m archival_module.py:prepare_data:82 [0m [34mLoaded vocab at /srv/tempdd/dunguyen/RUNS/dataset/archival/vocab.pt[0m
14-04-2023 01:21:28 | [34m    INFO[0m [1m [4m archival_module.py:prepare_data:84 [0m [34mVocab size: 16792[0m
14-04-2023 01:21:28 | [34m    INFO[0m [1m [4m dataset.py:__init__:83 [0m [34mLoad dataset from /srv/tempdd/dunguyen/RUNS/dataset/archival/test.json[0m
14-04-2023 01:21:29 | [32;1m   DEBUG[0m [1m [4m dual_lstm_attention.py:__init__:33 [0m [32;1mInitialize embedding from random[0m
Model is in cuda:  True


# Evaluation in a subset

# Evaluation in pre-established links

## QA Links

In [None]:
for side, frequency in attended_word_frequency.items():
    sorted_frequency = dict(sorted(frequency.items(), key=lambda x: x[1], reverse=True))
    attended_word_frequency[side] = sorted_frequency

## Similarity Links

## TF-IDF

In [119]:
n = 20

for side, frequency in attended_word_frequency.items():
    print(str(n)+' words most attended in '+ side)
    print()
    top_n = dict(list(frequency.items())[:n])
    df = pd.DataFrame(list(frequency.items()), columns=['Word', 'Count'])
    df = df[df['Count'] > 1]
    display(df.transpose())
    print('=20')

20 words most attended in premise



Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2684,2685,2686,2687,2688,2689,2690,2691,2692,2693
Word,le,de,",",.,être,et,un,ce,à,que,...,participant,dépense,sortie,1957,varga,épisode,phénomène,pourvoir,franc,octroyer
Count,5562,4936,2366,1300,1016,976,970,786,766,624,...,2,2,2,2,2,2,2,2,2,2


=20
20 words most attended in hypothesis



Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,4267,4268,4269,4270,4271,4272,4273,4274,4275,4276
Word,le,de,",",.,et,un,être,à,que,ce,...,archive,prôner,workers,motor,copropriétaire,1851,dépourvu,p.a.l.,s.e.c.a.m.,popularité
Count,10767,9707,6117,2902,2515,2509,2314,2087,1612,1427,...,2,2,2,2,2,2,2,2,2,2


=20


In [124]:
for side, frequency in attended_word_frequency.items():
    top_n = dict(list(frequency.items())[:n])
    df = pd.DataFrame(list(frequency.items()), columns=['Word', 'Count'])
    df = df[df['Count'] > 1]
    display(df[50:].transpose())
    print('=20')

Unnamed: 0,50,51,52,53,54,55,56,57,58,59,...,2684,2685,2686,2687,2688,2689,2690,2691,2692,2693
Word,conseil,travailleur,entreprise,entre,politique,si,société,révolution,celui,autre,...,participant,dépense,sortie,1957,varga,épisode,phénomène,pourvoir,franc,octroyer
Count,106,104,103,101,100,98,96,93,92,90,...,2,2,2,2,2,2,2,2,2,2


=20


Unnamed: 0,50,51,52,53,54,55,56,57,58,59,...,4267,4268,4269,4270,4271,4272,4273,4274,4275,4276
Word,avec,conseil,devoir,société,entreprise,travailleur,politique,dire,celui,économique,...,archive,prôner,workers,motor,copropriétaire,1851,dépourvu,p.a.l.,s.e.c.a.m.,popularité
Count,228,225,224,220,219,212,208,203,200,188,...,2,2,2,2,2,2,2,2,2,2


=20


In [None]:
for side, frequency in attended_word_frequency.items():
    top_n = dict(list(frequency.items())[:n])
    df = pd.DataFrame(list(frequency.items()), columns=['Word', 'Count'])
    df = df[df['Count'] > 1]
    display(df[50:].transpose())
    print('=20')

#### Test on Archival study subset : on work



In [22]:
## Server parameters
ON_IGRIDA = True

if ON_IGRIDA:
    SERVER_DIR = '/srv/tempdd/dunguyen/RUNS'
else:
    SERVER_DIR = '/Users/dunguyen/Projects/IGRIDA/historic/2023-02-23/archival'

LOG_PATH = path.join(SERVER_DIR, 'logs')
DATA_CACHE = path.join(SERVER_DIR, 'dataset')
#MODEL_CACHE = path.join(SERVER_DIR, 'models')

## Model parameters (the model inference has already been done in IGRIDA. In this notebook we retrieve only its inference)
MODEL_NAME = 'run=6_vector=0_lentropy=0.05'

In [25]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm

POST_PROCESS_INFERENCE_PATH = path.join(LOG_PATH, 'archival', MODEL_NAME, 'predictions','clean_inference_sentence_pairs_v2.json')
INFERENCE_PATH = path.join(LOG_PATH, 'archival', MODEL_NAME, 'predictions','batch_inference_sentence_pairs_v2.json')

# load inference file
if path.exists(POST_PROCESS_INFERENCE_PATH):
    inference_sentences = pd.read_json(POST_PROCESS_INFERENCE_PATH, encoding='utf-8')
    print(f'Load inference_sentences from {POST_PROCESS_INFERENCE_PATH}')
else:
    
    inference_sentences = pd.read_json(INFERENCE_PATH, orient='records', encoding='utf-8')

    def remove_mask(row):
        for side in ['premise', 'hypothesis']:
            padding_mask = np.array(row[f'padding_mask.{side}'])

            a_hat = np.array(row[f'a_hat.{side}'])
            a_hat = a_hat[~padding_mask]
            row[f'a_hat.{side}'] = a_hat.tolist()

            ids = np.array(row[f'{side}_ids'])
            ids = ids[~padding_mask]
            row[f'{side}_ids'] = ids.tolist()
        return row

    if 'padding_mask.premise' in inference_sentences.columns:
        tqdm.pandas(desc='remove_mask')
        inference_sentences = inference_sentences.progress_apply(remove_mask,axis=1)
        inference_sentences.drop(columns=['padding_mask.premise', 'padding_mask.hypothesis'], inplace=True)
    
    with open(POST_PROCESS_INFERENCE_PATH, 'w', encoding='utf-8') as f:
        inference_sentences.to_json(f, force_ascii=False)
        print(f'Save inference_sentences to {POST_PROCESS_INFERENCE_PATH}')
        
inference_sentences

remove_mask:   0%|          | 0/223666 [00:00<?, ?it/s]

Save inference_sentences to /srv/tempdd/dunguyen/RUNS/logs/archival/run=6_vector=0_lentropy=0.05/predictions/clean_inference_sentence_pairs_v2.json


Unnamed: 0,y_hat,y_score,a_hat.premise,a_hat.hypothesis,source.bloc.id,target.bloc.id,source.sentence.id,target.sentence.id,source.bloc.uid,target.bloc.uid,...,hypothesis_ids,source.bloc.index,target.bloc.index,source.sentence.index,target.sentence.index,source.tokens,target.tokens,source.article.title,premise.tokens,hypothesis.tokens
0,1,0.816464,"[0.007651310400000001, 0.013778822500000001, 0...","[0.0627513155, 0.0203241948, 0.0320522711, 0.0...",80.0,35.0,0.0,2.0,FMSH_PB188a_18-19_201_07,FMSH_PB188a_11-12_125_02,...,"[12624, 12632, 11727, 12637, 11508, 12630, 126...",,,,,,,,,
1,0,0.382935,"[0.0495468192, 0.0609955639, 0.0146217896, 0.0...","[0.028592417000000002, 0.022165052600000002, 0...",80.0,35.0,1.0,0.0,FMSH_PB188a_18-19_201_07,FMSH_PB188a_11-12_125_02,...,"[12638, 11707, 12446, 11577, 11482, 12638, 124...",,,,,,,,,
2,1,0.767685,"[0.0353953615, 0.034988597, 0.01225645, 0.0257...","[0.0183737185, 0.0168948267, 0.0108001148, 0.0...",80.0,35.0,1.0,1.0,FMSH_PB188a_18-19_201_07,FMSH_PB188a_11-12_125_02,...,"[12629, 12446, 12632, 12518, 11202, 12631, 879...",,,,,,,,,
3,1,0.736607,"[0.0379820652, 0.033533476300000004, 0.0103123...","[0.0449870862, 0.0260418579, 0.0365362316, 0.0...",80.0,35.0,1.0,2.0,FMSH_PB188a_18-19_201_07,FMSH_PB188a_11-12_125_02,...,"[12624, 12632, 11727, 12637, 11508, 12630, 126...",,,,,,,,,
4,0,0.009523,"[0.5871393681, 0.4128606319]","[0.0293221492, 0.0265875515, 0.0202338286, 0.0...",80.0,35.0,2.0,0.0,FMSH_PB188a_18-19_201_07,FMSH_PB188a_11-12_125_02,...,"[12638, 11707, 12446, 11577, 11482, 12638, 124...",,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
223661,1,0.990256,"[0.09885212780000001, 0.059863925000000005, 0....","[0.0161787346, 0.0162396468, 0.0161599331, 0.0...",137.0,142.0,2.0,1.0,FMSH_PB188a_20-21_064_02,FMSH_PB188a_20-21_067_02,...,"[12638, 11638, 12637, 12629, 12487, 12328, 113...",,,,,,,,,
223662,1,0.960653,"[0.0758343861, 0.0483686328, 0.1165912524, 0.4...","[0.0896834284, 0.0997016653, 0.1304979771, 0.0...",137.0,142.0,2.0,2.0,FMSH_PB188a_20-21_064_02,FMSH_PB188a_20-21_067_02,...,"[12638, 1563, 10637, 12619, 1, 12611, 12423, 1...",,,,,,,,,
223663,0,0.029427,"[0.0108872922, 0.0054137427, 0.0020751467, 0.0...","[0.0018269485, 0.0016038899000000001, 0.003085...",137.0,143.0,0.0,0.0,FMSH_PB188a_20-21_064_02,FMSH_PB188a_20-21_069_03,...,"[12625, 12613, 6217, 12592, 12606, 12459, 1262...",,,,,,,,,
223664,1,0.501854,"[0.0066407137, 0.0056945598, 0.0025590872, 0.0...","[0.3468779922, 0.17439484600000002, 0.11852170...",137.0,143.0,0.0,1.0,FMSH_PB188a_20-21_064_02,FMSH_PB188a_20-21_069_03,...,"[12626, 12621, 12370, 12623, 12635]",,,,,,,,,
