In [1]:
import glob
import time
import os
import pandas as pd
import sklearn.metrics
from sklearn.preprocessing import MinMaxScaler
import pickle
from argparse import ArgumentParser, Namespace
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from itertools import chain
from tqdm import tqdm
import copy
import shutil
import pickle

In [2]:
es_url = 'http://ckg07:9200'
es_index = 'wikidatadwd-augmented'

# Input Paths

# GDrive Path: /table-linker-dataset/2019-iswc_challenge_data/t2dv2/canonical-with-context/t2dv2-train-canonical/
train_path = "/home/sriamazingram/USC/Others/ISI/data/t2dv2/train-canonical"

# GDrive Path: /table-linker-dataset/2019-iswc_challenge_data/t2dv2/canonical-with-context/t2dv2-dev-canonical/
dev_path = "/home/sriamazingram/USC/Others/ISI/data/t2dv2/dev-canonical"

# GDrive Path: /table-linker-dataset/2019-iswc_challenge_data/t2dv2/ground_truth/Xinting_GT_csv
ground_truth_files = "/home/sriamazingram/USC/Others/ISI/data/t2dv2/GT"

# OUTPUT PATHS
output_path = "/home/sriamazingram/USC/Others/ISI/data/t2dv2"
train_output_path = f'{output_path}/train1-output'
dev_output_path = f'{output_path}/dev-output'

# increase version to create a new folder for an experiment
VERSION = "12_0"

train_candidate_path = f'{train_output_path}/{VERSION}/candidates'
train_feature_path = f'{train_output_path}/{VERSION}/features'

dev_candidate_path = f'{dev_output_path}/{VERSION}/candidates'
dev_feature_path = f'{dev_output_path}/{VERSION}/features'
dev_output_predictions = f'{dev_output_path}/{VERSION}/dev_predictions'
dev_predictions_top_k = f'{dev_output_path}/{VERSION}/dev_predictions_top_k'
dev_colorized_path = f'{dev_output_path}/{VERSION}/dev_predictions_colorized'
dev_metrics_path = f'{dev_output_path}/{VERSION}/dev_predictions_metrics'
dev_metrics_col_wise = f'{dev_output_path}/{VERSION}/dev_predictions_column_metrics'

aux_field = 'class_count,property_count,context'


train_prop_count = f'{train_output_path}/{VERSION}/train_prop_count' 
train_class_count = f'{train_output_path}/{VERSION}/train_class_count'
train_context_path = f'{train_output_path}/{VERSION}/train_context'

dev_prop_count = f'{dev_output_path}/{VERSION}/dev_prop_count'
dev_class_count = f'{dev_output_path}/{VERSION}/dev_class_count'
dev_context_path = f'{dev_output_path}/{VERSION}/dev_context'

temp_dir = f'{output_path}/temp'

pos_output = f'{temp_dir}/training_data/pos_features.pkl'
neg_output = f'{temp_dir}/training_data/neg_features.pkl'
min_max_scaler_path = f'{temp_dir}/training_data/gt_normalization_factor.pkl'

final_score_column = 'gt_score'
threshold = final_score_column+":median"

model_save_path = f'{dev_output_path}/{VERSION}/saved_models'
best_model_path = ''

In [3]:
!mkdir -p "$temp_dir"

!mkdir -p "$train_prop_count"
!mkdir -p "$dev_prop_count"
!mkdir -p "$train_class_count"
!mkdir -p "$dev_class_count"
!mkdir -p "$train_context_path"
!mkdir -p "$dev_context_path"

!mkdir -p "$train_candidate_path"
!mkdir -p "$dev_candidate_path"

!mkdir -p "$train_feature_path"
!mkdir -p "$dev_feature_path"

!mkdir -p "$temp_dir/training_data"
!mkdir -p "$dev_output_predictions"
!mkdir -p "$model_save_path"
!mkdir -p "$dev_predictions_top_k"
!mkdir -p "$dev_colorized_path"
!mkdir -p "$dev_metrics_path"
!mkdir -p "$dev_metrics_col_wise"

In [4]:
features = ['pgr_rts','monge_elkan','monge_elkan_aliases','des_cont_jaccard',
            'jaro_winkler','levenshtein','singleton', 'context_score', 'smc_class_score', 'smc_property_score']

## Candidate Generation

In [49]:
def candidate_generation(path, gt_path, output_path, class_count_path, prop_count_path, context_path):
    file_list = glob.glob(path + '/*.csv')
    for i, file in enumerate(file_list):
        st = time.time()
        filename = file.split('/')[-1]
        print(f"{filename}: {i+1} of {len(file_list)}")
        gt_file = f"{ground_truth_files}/{filename}"
        output_file = f"{output_path}/{filename}"
        
        !tl clean -c label -o label_clean "$file" / \
        --url $es_url --index $es_index \
        get-fuzzy-augmented-matches -c label_clean \
        --auxiliary-fields {aux_field} \
        --auxiliary-folder "$temp_dir" / \
        --url $es_url --index $es_index \
        get-exact-matches -c label_clean \
        --auxiliary-fields {aux_field} \
        --auxiliary-folder "$temp_dir" / \
        drop-duplicate -c kg_id --score-column retrieval_score --keep-method exact-match / \
        ground-truth-labeler --gt-file "$gt_file" > "$output_file"
        
        for field in aux_field.split(','):
            aux_list = []
            for f in glob.glob(f'{temp_dir}/*{field}.tsv'):
                aux_list.append(pd.read_csv(f, sep='\t', dtype=object))
            aux_df = pd.concat(aux_list).drop_duplicates(subset=['qnode'])
            if field == 'class_count':
                class_count_file = f"{class_count_path}/{filename.strip('.csv')}_class_count.tsv"
                aux_df.to_csv(class_count_file, sep='\t', index=False)
            elif field == 'property_count':
                prop_count_file = f"{prop_count_path}/{filename.strip('.csv')}_prop_count.tsv"
                aux_df.to_csv(prop_count_file, sep='\t', index=False)
            elif field == 'context':
                context_file = f"{context_path}/{filename.strip('.csv')}_context.tsv"
                aux_df.to_csv(context_file, sep='\t', index=False)
        print(time.time() - st)

In [50]:
candidate_generation(train_path, ground_truth_files, train_candidate_path, train_class_count, train_prop_count, train_context_path)

37856682_0_6818907050314633217.csv: 1 of 342
clean Time: 0.026865005493164062s
get-fuzzy-augmented-matches Time: 37.09188103675842s
get-exact-matches Time: 11.392009019851685s
drop-duplicate-kg_id Time: 12.520089864730835s
ground-truth-labeler Time: 1.3781394958496094s
73.70809745788574
B8QWQQAB.csv: 2 of 342
clean Time: 0.009328126907348633s
get-fuzzy-augmented-matches Time: 5.501533269882202s
get-exact-matches Time: 5.759377479553223s
drop-duplicate-kg_id Time: 1.0977163314819336s
ground-truth-labeler Time: 0.13187575340270996s
24.412264108657837
HIFQAGMX.csv: 3 of 342
clean Time: 0.003030538558959961s
get-fuzzy-augmented-matches Time: 8.626246929168701s
get-exact-matches Time: 5.7916340827941895s
drop-duplicate-kg_id Time: 1.4750936031341553s
ground-truth-labeler Time: 0.18973875045776367s
21.15478515625
1LD1MWA8.csv: 4 of 342
clean Time: 0.004044055938720703s
get-fuzzy-augmented-matches Time: 14.138941287994385s
get-exact-matches Time: 6.461132764816284s
drop-duplicate-kg_id Time: 

get-exact-matches Time: 0.5711367130279541s
drop-duplicate-kg_id Time: 0.6388230323791504s
ground-truth-labeler Time: 0.07263422012329102s
15.670592308044434
XSUGP66N.csv: 32 of 342
clean Time: 0.0030591487884521484s
get-fuzzy-augmented-matches Time: 17.06795382499695s
get-exact-matches Time: 6.510787487030029s
drop-duplicate-kg_id Time: 1.9133780002593994s
ground-truth-labeler Time: 0.14757609367370605s
31.125152826309204
6D4OURQN.csv: 33 of 342
clean Time: 0.0034699440002441406s
get-fuzzy-augmented-matches Time: 10.248651504516602s
get-exact-matches Time: 1.1332075595855713s
drop-duplicate-kg_id Time: 0.5008640289306641s
ground-truth-labeler Time: 0.05468606948852539s
16.52730965614319
J3P3ZJZ0.csv: 34 of 342
clean Time: 0.006976604461669922s
get-fuzzy-augmented-matches Time: 16.373947620391846s
get-exact-matches Time: 6.388511419296265s
drop-duplicate-kg_id Time: 1.212627649307251s
ground-truth-labeler Time: 0.1474299430847168s
29.116442918777466
BLUL2XZW.csv: 35 of 342
clean Time: 

21.595680236816406
H22OORUW.csv: 62 of 342
clean Time: 0.0025920867919921875s
get-fuzzy-augmented-matches Time: 10.281410217285156s
get-exact-matches Time: 1.461930513381958s
drop-duplicate-kg_id Time: 0.8326132297515869s
ground-truth-labeler Time: 0.09321427345275879s
17.402437448501587
DPY34RCV.csv: 63 of 342
clean Time: 0.001982450485229492s
get-fuzzy-augmented-matches Time: 5.427448272705078s
get-exact-matches Time: 1.0722920894622803s
drop-duplicate-kg_id Time: 0.388089656829834s
ground-truth-labeler Time: 0.05138444900512695s
11.616313219070435
7LUSAF2U.csv: 64 of 342
clean Time: 0.0032024383544921875s
get-fuzzy-augmented-matches Time: 4.894717216491699s
get-exact-matches Time: 6.1729419231414795s
drop-duplicate-kg_id Time: 0.3780384063720703s
ground-truth-labeler Time: 0.04501962661743164s
15.98365044593811
4UBWEGHX.csv: 65 of 342
clean Time: 0.003270864486694336s
get-fuzzy-augmented-matches Time: 16.69702672958374s
get-exact-matches Time: 6.405397415161133s
drop-duplicate-kg_id

15.76198410987854
DPUA686B.csv: 92 of 342
clean Time: 0.0029213428497314453s
get-fuzzy-augmented-matches Time: 10.888827085494995s
get-exact-matches Time: 6.565765619277954s
drop-duplicate-kg_id Time: 0.7315347194671631s
ground-truth-labeler Time: 0.07511758804321289s
23.04561185836792
4N70FY5Z.csv: 93 of 342
clean Time: 0.0024962425231933594s
get-fuzzy-augmented-matches Time: 11.316134214401245s
get-exact-matches Time: 6.156339645385742s
drop-duplicate-kg_id Time: 1.0982723236083984s
ground-truth-labeler Time: 0.10043978691101074s
23.658005952835083
58891288_0_1117541047012405958.csv: 94 of 342
clean Time: 0.009407758712768555s
get-fuzzy-augmented-matches Time: 18.06826949119568s
get-exact-matches Time: 6.47948694229126s
drop-duplicate-kg_id Time: 2.6641719341278076s
ground-truth-labeler Time: 0.19439005851745605s
33.16319417953491
84548468_0_5955155464119382182.csv: 95 of 342
clean Time: 0.00948333740234375s
get-fuzzy-augmented-matches Time: 18.916303634643555s
get-exact-matches Time

21.497297763824463
0H0U54UZ.csv: 122 of 342
clean Time: 0.0016651153564453125s
get-fuzzy-augmented-matches Time: 5.730553865432739s
get-exact-matches Time: 5.750287294387817s
drop-duplicate-kg_id Time: 0.4080324172973633s
ground-truth-labeler Time: 0.0568540096282959s
16.614811897277832
XNZE1KVH.csv: 123 of 342
clean Time: 0.004827976226806641s
get-fuzzy-augmented-matches Time: 7.502103567123413s
get-exact-matches Time: 1.542799472808838s
drop-duplicate-kg_id Time: 0.9331960678100586s
ground-truth-labeler Time: 0.09245681762695312s
14.96919059753418
1438042989043_35_20150728002309-00287-ip-10-236-191-2_875026214_2.csv: 124 of 342
clean Time: 0.0015172958374023438s
get-fuzzy-augmented-matches Time: 10.761279344558716s
get-exact-matches Time: 1.6189429759979248s
drop-duplicate-kg_id Time: 0.5102486610412598s
ground-truth-labeler Time: 0.06860089302062988s
17.596606254577637
D65TEZWN.csv: 125 of 342
clean Time: 0.00886678695678711s
get-fuzzy-augmented-matches Time: 11.482626914978027s
get

21.3410542011261
VRQZ6O0N.csv: 152 of 342
clean Time: 0.006264686584472656s
get-fuzzy-augmented-matches Time: 13.09453821182251s
get-exact-matches Time: 6.756845951080322s
drop-duplicate-kg_id Time: 2.016061544418335s
ground-truth-labeler Time: 0.19559311866760254s
27.525787830352783
617HJHR2.csv: 153 of 342
clean Time: 0.005648612976074219s
get-fuzzy-augmented-matches Time: 10.64138913154602s
get-exact-matches Time: 6.421900033950806s
drop-duplicate-kg_id Time: 1.7388763427734375s
ground-truth-labeler Time: 0.14825224876403809s
24.4491970539093
84R0FAAB.csv: 154 of 342
clean Time: 0.0026366710662841797s
get-fuzzy-augmented-matches Time: 10.147409677505493s
get-exact-matches Time: 6.427304267883301s
drop-duplicate-kg_id Time: 1.843979835510254s
ground-truth-labeler Time: 0.1348402500152588s
24.319247722625732
J8SUPBQ6.csv: 155 of 342
clean Time: 0.003053426742553711s
get-fuzzy-augmented-matches Time: 10.471451997756958s
get-exact-matches Time: 6.500074148178101s
drop-duplicate-kg_id Ti

clean Time: 0.0030138492584228516s
get-fuzzy-augmented-matches Time: 6.693006753921509s
get-exact-matches Time: 1.6539828777313232s
drop-duplicate-kg_id Time: 0.7384312152862549s
ground-truth-labeler Time: 0.14991497993469238s
14.309862613677979
IVJ00F9K.csv: 183 of 342
clean Time: 0.004628181457519531s
get-fuzzy-augmented-matches Time: 14.940380573272705s
get-exact-matches Time: 6.390960454940796s
drop-duplicate-kg_id Time: 2.3230090141296387s
ground-truth-labeler Time: 0.17940139770507812s
29.65944790840149
QCMUCPC1.csv: 184 of 342
clean Time: 0.011358976364135742s
get-fuzzy-augmented-matches Time: 10.508042573928833s
get-exact-matches Time: 2.0125572681427s
drop-duplicate-kg_id Time: 1.3912079334259033s
ground-truth-labeler Time: 0.1319410800933838s
19.687819719314575
UJNATN8A.csv: 185 of 342
clean Time: 0.002320528030395508s
get-fuzzy-augmented-matches Time: 11.016988277435303s
get-exact-matches Time: 6.636722564697266s
drop-duplicate-kg_id Time: 1.4178276062011719s
ground-truth-la

get-exact-matches Time: 6.606149911880493s
drop-duplicate-kg_id Time: 0.7858362197875977s
ground-truth-labeler Time: 0.08448243141174316s
22.621639728546143
5IXA0RAI.csv: 213 of 342
clean Time: 0.0017099380493164062s
get-fuzzy-augmented-matches Time: 3.3788628578186035s
get-exact-matches Time: 0.9917254447937012s
drop-duplicate-kg_id Time: 0.5080852508544922s
ground-truth-labeler Time: 0.06803750991821289s
9.572571039199829
79EIBGWR.csv: 214 of 342
clean Time: 0.0017402172088623047s
get-fuzzy-augmented-matches Time: 9.235327959060669s
get-exact-matches Time: 0.6492302417755127s
drop-duplicate-kg_id Time: 0.3126852512359619s
ground-truth-labeler Time: 0.05911087989807129s
14.873131275177002
0MZX65PH.csv: 215 of 342
clean Time: 0.002799510955810547s
get-fuzzy-augmented-matches Time: 8.964818954467773s
get-exact-matches Time: 5.816089868545532s
drop-duplicate-kg_id Time: 0.3683278560638428s
ground-truth-labeler Time: 0.05172538757324219s
19.71360206604004
SDQACBPT.csv: 216 of 342
clean Ti

ground-truth-labeler Time: 0.06019306182861328s
11.49648928642273
UEEUASVQ.csv: 243 of 342
clean Time: 0.002842426300048828s
get-fuzzy-augmented-matches Time: 4.579188585281372s
get-exact-matches Time: 6.546862363815308s
drop-duplicate-kg_id Time: 0.4089784622192383s
ground-truth-labeler Time: 0.04707694053649902s
16.216078281402588
VG9104AO.csv: 244 of 342
clean Time: 0.01519465446472168s
get-fuzzy-augmented-matches Time: 10.097188949584961s
get-exact-matches Time: 6.019238710403442s
drop-duplicate-kg_id Time: 1.1858174800872803s
ground-truth-labeler Time: 0.17478680610656738s
22.309381246566772
IUBTQXYO.csv: 245 of 342
clean Time: 0.0030400753021240234s
get-fuzzy-augmented-matches Time: 4.975883960723877s
get-exact-matches Time: 5.635691404342651s
drop-duplicate-kg_id Time: 0.4230632781982422s
ground-truth-labeler Time: 0.05765128135681152s
15.653536796569824
35188621_0_6058553107571275232.csv: 246 of 342
clean Time: 0.022904157638549805s
get-fuzzy-augmented-matches Time: 20.52530384

clean Time: 0.006613016128540039s
get-fuzzy-augmented-matches Time: 12.149035453796387s
get-exact-matches Time: 6.340533971786499s
drop-duplicate-kg_id Time: 2.027259111404419s
ground-truth-labeler Time: 0.16651058197021484s
26.110668659210205
39173938_0_7916056990138658530.csv: 274 of 342
clean Time: 0.004941463470458984s
get-fuzzy-augmented-matches Time: 15.942415475845337s
get-exact-matches Time: 2.725801944732666s
drop-duplicate-kg_id Time: 2.338519334793091s
ground-truth-labeler Time: 0.1811234951019287s
27.056317567825317
10579449_0_1681126353774891032.csv: 275 of 342
clean Time: 0.0030274391174316406s
get-fuzzy-augmented-matches Time: 5.398084878921509s
get-exact-matches Time: 1.089456558227539s
drop-duplicate-kg_id Time: 0.3917117118835449s
ground-truth-labeler Time: 0.12161898612976074s
11.681758165359497
S6RYCPCW.csv: 276 of 342
clean Time: 0.0026292800903320312s
get-fuzzy-augmented-matches Time: 12.11391282081604s
get-exact-matches Time: 6.5922088623046875s
drop-duplicate-kg

clean Time: 0.04909873008728027s
get-fuzzy-augmented-matches Time: 37.781731367111206s
get-exact-matches Time: 11.937984466552734s
drop-duplicate-kg_id Time: 6.671890735626221s
ground-truth-labeler Time: 0.4311509132385254s
64.76940488815308
3J5ABW9E.csv: 304 of 342
clean Time: 0.006597995758056641s
get-fuzzy-augmented-matches Time: 10.157222270965576s
get-exact-matches Time: 6.524601459503174s
drop-duplicate-kg_id Time: 1.4689271450042725s
ground-truth-labeler Time: 0.14702224731445312s
23.386901140213013
96AUP4BE.csv: 305 of 342
clean Time: 0.003099679946899414s
get-fuzzy-augmented-matches Time: 11.960345983505249s
get-exact-matches Time: 6.321634531021118s
drop-duplicate-kg_id Time: 1.5138771533966064s
ground-truth-labeler Time: 0.14323139190673828s
25.13296866416931
RQN3WIO2.csv: 306 of 342
clean Time: 0.001957416534423828s
get-fuzzy-augmented-matches Time: 5.434004545211792s
get-exact-matches Time: 5.810628175735474s
drop-duplicate-kg_id Time: 0.3874521255493164s
ground-truth-labe

get-exact-matches Time: 5.856691360473633s
drop-duplicate-kg_id Time: 0.7781052589416504s
ground-truth-labeler Time: 0.10341405868530273s
22.368530750274658
J4XOF8WJ.csv: 334 of 342
clean Time: 0.0028328895568847656s
get-fuzzy-augmented-matches Time: 6.441299915313721s
get-exact-matches Time: 6.21259331703186s
drop-duplicate-kg_id Time: 0.5021300315856934s
ground-truth-labeler Time: 0.06668496131896973s
17.932737112045288
P11KZF71.csv: 335 of 342
clean Time: 0.003498554229736328s
get-fuzzy-augmented-matches Time: 11.879830837249756s
get-exact-matches Time: 5.724027872085571s
drop-duplicate-kg_id Time: 1.1113173961639404s
ground-truth-labeler Time: 0.10646438598632812s
23.54521870613098
KUN2Y3DX.csv: 336 of 342
clean Time: 0.0038199424743652344s
get-fuzzy-augmented-matches Time: 15.452715158462524s
get-exact-matches Time: 5.8128416538238525s
drop-duplicate-kg_id Time: 1.1979336738586426s
ground-truth-labeler Time: 0.14229559898376465s
27.646260738372803
24036779_0_5608105867560183058.cs

In [51]:
candidate_generation(dev_path, ground_truth_files, dev_candidate_path, dev_class_count, dev_prop_count, dev_context_path)

BOXTVP7V.csv: 1 of 58
clean Time: 0.0022814273834228516s
get-fuzzy-augmented-matches Time: 11.034621477127075s
get-exact-matches Time: 1.5870893001556396s
drop-duplicate-kg_id Time: 1.2067675590515137s
ground-truth-labeler Time: 0.10948967933654785s
19.037237882614136
E5SHJSQZ.csv: 2 of 58
clean Time: 0.0031685829162597656s
get-fuzzy-augmented-matches Time: 8.51126217842102s
get-exact-matches Time: 1.0991947650909424s
drop-duplicate-kg_id Time: 0.538841724395752s
ground-truth-labeler Time: 0.08556389808654785s
14.8853440284729
DBH21J5D.csv: 3 of 58
clean Time: 0.0029687881469726562s
get-fuzzy-augmented-matches Time: 3.725006103515625s
get-exact-matches Time: 5.710453748703003s
drop-duplicate-kg_id Time: 0.4380614757537842s
ground-truth-labeler Time: 0.06373262405395508s
14.415447235107422
84575189_0_6365692015941409487.csv: 4 of 58
clean Time: 0.010727167129516602s
get-fuzzy-augmented-matches Time: 17.339505910873413s
get-exact-matches Time: 6.510434865951538s
drop-duplicate-kg_id Time

get-exact-matches Time: 6.505092620849609s
drop-duplicate-kg_id Time: 0.4042322635650635s
ground-truth-labeler Time: 0.1363234519958496s
20.792482614517212
NE9XVY42.csv: 32 of 58
clean Time: 0.005224466323852539s
get-fuzzy-augmented-matches Time: 10.718997716903687s
get-exact-matches Time: 1.1681997776031494s
drop-duplicate-kg_id Time: 0.870417594909668s
ground-truth-labeler Time: 0.1050724983215332s
17.574124574661255
VB0WL533.csv: 33 of 58
clean Time: 0.0023097991943359375s
get-fuzzy-augmented-matches Time: 10.821867227554321s
get-exact-matches Time: 1.6170899868011475s
drop-duplicate-kg_id Time: 1.2157838344573975s
ground-truth-labeler Time: 0.10077929496765137s
18.736942052841187
XXYFPD8I.csv: 34 of 58
clean Time: 0.002892017364501953s
get-fuzzy-augmented-matches Time: 9.85008692741394s
get-exact-matches Time: 5.734575510025024s
drop-duplicate-kg_id Time: 0.4232504367828369s
ground-truth-labeler Time: 0.060744524002075195s
20.66550922393799
8N4ZTXDV.csv: 35 of 58
clean Time: 0.0016

## Feature Generation

In [80]:
def feature_generation(candidate_dir, class_count_dir, property_count_dir, context_path, output_path):
    file_list = glob.glob(candidate_dir + '/*.csv')
    for i, file in enumerate(file_list):
        filename = file.split('/')[-1]
        print(f"{filename}: {i+1} of {len(file_list)}")
        class_count_file = f"{class_count_dir}/{filename.strip('.csv')}_class_count.tsv"
        property_count_file = f"{property_count_dir}/{filename.strip('.csv')}_prop_count.tsv"
        context_file = f"{context_path}/{filename.strip('.csv')}_context.tsv"
        output_file = f"{output_path}/{filename}"
        !time tl string-similarity "$file" -i --method symmetric_monge_elkan:tokenizer=word -o monge_elkan \
            / string-similarity -i --method symmetric_monge_elkan:tokenizer=word -c label_clean kg_aliases -o monge_elkan_aliases \
            / string-similarity -i --method jaro_winkler -o jaro_winkler \
            / string-similarity -i --method levenshtein -o levenshtein \
            / string-similarity -i --method jaccard:tokenizer=word -c kg_descriptions context -o des_cont_jaccard \
            / create-singleton-feature -o singleton \
            / context-match \
            --context-file "$context_file" \
            -o context_score \
            / pgt-semantic-tf-idf \
            -o smc_class_score \
            --pagerank-column pagerank \
            --retrieval-score-column retrieval_score \
            --feature-file "$class_count_file" \
            --feature-name class_count \
            / pgt-semantic-tf-idf \
            -o smc_property_score \
            --pagerank-column pagerank \
            --retrieval-score-column retrieval_score \
            --feature-file "$property_count_file" \
            --feature-name property_count \
            > $output_file

In [81]:
feature_generation(train_candidate_path, train_class_count, train_prop_count, train_context_path, train_feature_path)

37856682_0_6818907050314633217.csv: 1 of 342
B8QWQQAB.csv: 2 of 342
HIFQAGMX.csv: 3 of 342
1LD1MWA8.csv: 4 of 342
1XNHBBRZ.csv: 5 of 342
38428277_0_1311643810102462607.csv: 6 of 342
1KJ39NFE.csv: 7 of 342
HFNU4Y9W.csv: 8 of 342
MZNZLWYW.csv: 9 of 342
K2V1VODK.csv: 10 of 342
E5DKRW4W.csv: 11 of 342
YCXXPVD2.csv: 12 of 342
ZX8GERJC.csv: 13 of 342
6NO3AH02.csv: 14 of 342
FVFG3EJ2.csv: 15 of 342
26XDNAJB.csv: 16 of 342
NPGBDBM4.csv: 17 of 342
ERPSWFMM.csv: 18 of 342
ZDAZ5PQ5.csv: 19 of 342
2LM6W2JV.csv: 20 of 342
0KL64BZL.csv: 21 of 342
2JN1R1VW.csv: 22 of 342
X0TEEJCK.csv: 23 of 342
UL2BYXAR.csv: 24 of 342
QTYEU8F5.csv: 25 of 342
9XF3SP0B.csv: 26 of 342
EL9S7KDR.csv: 27 of 342
HTUXRVUC.csv: 28 of 342
29414811_12_251152470253168163.csv: 29 of 342
F487BS0V.csv: 30 of 342
8DOZTMTY.csv: 31 of 342
XSUGP66N.csv: 32 of 342
6D4OURQN.csv: 33 of 342
J3P3ZJZ0.csv: 34 of 342
BLUL2XZW.csv: 35 of 342
64ZFZ4K2.csv: 36 of 342
9834884_0_3871985887467090123.csv: 37 of 342
NUTCUXCN.csv: 38 of 342
YXYVNO79.c

QH8JJV75.csv: 295 of 342
77694908_0_6083291340991074532.csv: 296 of 342
I3MK4TC6.csv: 297 of 342
G0QTILKH.csv: 298 of 342
9PYE5TKS.csv: 299 of 342
JZ9RW99R.csv: 300 of 342
UPNVIIDW.csv: 301 of 342
AJI584YU.csv: 302 of 342
50245608_0_871275842592178099.csv: 303 of 342
3J5ABW9E.csv: 304 of 342
96AUP4BE.csv: 305 of 342
RQN3WIO2.csv: 306 of 342
I7JP3F91.csv: 307 of 342
LFET4LDT.csv: 308 of 342
RA12VETD.csv: 309 of 342
CUC5JLK3.csv: 310 of 342
CKRLO13X.csv: 311 of 342
N6S2LRDK.csv: 312 of 342
Y3XLQHGC.csv: 313 of 342
UU8Q91MG.csv: 314 of 342
EHOBJNIC.csv: 315 of 342
PQN3CY7B.csv: 316 of 342
40534006_0_4617468856744635526.csv: 317 of 342
22864497_0_8632623712684511496.csv: 318 of 342
NYXNRZCM.csv: 319 of 342
JUFYSXYP.csv: 320 of 342
ZT4Q61TK.csv: 321 of 342
OJXA73X7.csv: 322 of 342
OMJX8TT6.csv: 323 of 342
T8SL8HGK.csv: 324 of 342
XVMPJ993.csv: 325 of 342
W0ZNF869.csv: 326 of 342
MBM31U4C.csv: 327 of 342
VE3T1LHT.csv: 328 of 342
X7U2V5YE.csv: 329 of 342
DSK6KM67.csv: 330 of 342
EPFTNCPD.csv:


real	0m9.303s
user	0m9.684s
sys	0m0.697s
NUTCUXCN.csv: 38 of 342
pgt-semantic-tf-idf-class_count Time: 0.619443416595459s
pgt-semantic-tf-idf-property_count Time: 1.1702473163604736s

real	0m3.953s
user	0m4.982s
sys	0m0.630s
YXYVNO79.csv: 39 of 342
pgt-semantic-tf-idf-class_count Time: 0.3605515956878662s
pgt-semantic-tf-idf-property_count Time: 0.6901242733001709s

real	0m3.438s
user	0m4.549s
sys	0m0.531s
O668CSQ3.csv: 40 of 342
pgt-semantic-tf-idf-class_count Time: 0.38735103607177734s
pgt-semantic-tf-idf-property_count Time: 0.758237361907959s

real	0m4.002s
user	0m5.729s
sys	0m0.633s
E0LR4TZL.csv: 41 of 342
pgt-semantic-tf-idf-class_count Time: 0.6746814250946045s
pgt-semantic-tf-idf-property_count Time: 1.2197165489196777s

real	0m3.984s
user	0m5.126s
sys	0m0.549s
R4K6322V.csv: 42 of 342
pgt-semantic-tf-idf-class_count Time: 0.18490862846374512s
pgt-semantic-tf-idf-property_count Time: 0.4878237247467041s

real	0m3.240s
user	0m4.408s
sys	0m0.618s
L2WQ1RA3.csv: 43 of 342
pgt-seman

pgt-semantic-tf-idf-property_count Time: 0.727837085723877s

real	0m3.442s
user	0m4.508s
sys	0m0.549s
1T91CHXV.csv: 82 of 342
pgt-semantic-tf-idf-class_count Time: 0.6964280605316162s
pgt-semantic-tf-idf-property_count Time: 0.9476034641265869s

real	0m4.004s
user	0m5.351s
sys	0m0.635s
EXZ2ZWNZ.csv: 83 of 342
pgt-semantic-tf-idf-class_count Time: 1.4202773571014404s
pgt-semantic-tf-idf-property_count Time: 2.436164617538452s

real	0m5.821s
user	0m6.657s
sys	0m0.609s
OAWHF5BM.csv: 84 of 342
pgt-semantic-tf-idf-class_count Time: 0.24118638038635254s
pgt-semantic-tf-idf-property_count Time: 0.500537633895874s

real	0m5.273s
user	0m4.935s
sys	0m0.549s
20135078_0_7570343137119682530.csv: 85 of 342
pgt-semantic-tf-idf-class_count Time: 1.5437250137329102s
pgt-semantic-tf-idf-property_count Time: 2.670609474182129s

real	0m5.757s
user	0m6.911s
sys	0m0.633s
39107734_2_2329160387535788734.csv: 86 of 342
pgt-semantic-tf-idf-class_count Time: 0.44888830184936523s
pgt-semantic-tf-idf-property_coun

pgt-semantic-tf-idf-property_count Time: 0.5576448440551758s

real	0m3.899s
user	0m5.457s
sys	0m0.632s
D65TEZWN.csv: 125 of 342
pgt-semantic-tf-idf-class_count Time: 0.668710470199585s
pgt-semantic-tf-idf-property_count Time: 1.228830337524414s

real	0m4.086s
user	0m5.212s
sys	0m0.586s
9SERGNIZ.csv: 126 of 342
pgt-semantic-tf-idf-class_count Time: 0.7587018013000488s
pgt-semantic-tf-idf-property_count Time: 1.3082220554351807s

real	0m4.224s
user	0m5.296s
sys	0m0.617s
ICA6GXTG.csv: 127 of 342
pgt-semantic-tf-idf-class_count Time: 0.36213088035583496s
pgt-semantic-tf-idf-property_count Time: 0.6955149173736572s

real	0m3.470s
user	0m4.639s
sys	0m0.546s
4NXQPKFO.csv: 128 of 342
pgt-semantic-tf-idf-class_count Time: 0.6018812656402588s
pgt-semantic-tf-idf-property_count Time: 1.1857807636260986s

real	0m4.059s
user	0m5.160s
sys	0m0.598s
WNKF57RH.csv: 129 of 342
pgt-semantic-tf-idf-class_count Time: 0.13661432266235352s
pgt-semantic-tf-idf-property_count Time: 0.35295724868774414s

real	0m


real	0m3.233s
user	0m4.442s
sys	0m0.484s
QFDEZDAG.csv: 168 of 342
pgt-semantic-tf-idf-class_count Time: 0.3006305694580078s
pgt-semantic-tf-idf-property_count Time: 0.5890653133392334s

real	0m3.313s
user	0m4.593s
sys	0m0.520s
G2K4GSYB.csv: 169 of 342
pgt-semantic-tf-idf-class_count Time: 0.4976511001586914s
pgt-semantic-tf-idf-property_count Time: 1.0335614681243896s

real	0m3.819s
user	0m4.978s
sys	0m0.510s
6XCOGRWM.csv: 170 of 342
pgt-semantic-tf-idf-class_count Time: 0.6175937652587891s
pgt-semantic-tf-idf-property_count Time: 1.129859209060669s

real	0m4.009s
user	0m5.144s
sys	0m0.585s
8V9GOLD1.csv: 171 of 342
pgt-semantic-tf-idf-class_count Time: 0.7654869556427002s
pgt-semantic-tf-idf-property_count Time: 1.28281569480896s

real	0m4.033s
user	0m5.008s
sys	0m0.644s
43237185_1_3636357855502246981.csv: 172 of 342
pgt-semantic-tf-idf-class_count Time: 0.3441948890686035s
pgt-semantic-tf-idf-property_count Time: 0.6364734172821045s

real	0m3.302s
user	0m4.502s
sys	0m0.529s
52UWNGIM.

pgt-semantic-tf-idf-property_count Time: 0.6457023620605469s

real	0m3.441s
user	0m4.530s
sys	0m0.648s
HMULE47M.csv: 212 of 342
pgt-semantic-tf-idf-class_count Time: 0.879798412322998s
pgt-semantic-tf-idf-property_count Time: 1.3383195400238037s

real	0m4.141s
user	0m5.220s
sys	0m0.621s
5IXA0RAI.csv: 213 of 342
pgt-semantic-tf-idf-class_count Time: 0.13654255867004395s
pgt-semantic-tf-idf-property_count Time: 0.37849926948547363s

real	0m3.217s
user	0m4.518s
sys	0m0.665s
79EIBGWR.csv: 214 of 342
pgt-semantic-tf-idf-class_count Time: 0.23220396041870117s
pgt-semantic-tf-idf-property_count Time: 0.4887962341308594s

real	0m3.172s
user	0m4.321s
sys	0m0.613s
0MZX65PH.csv: 215 of 342
pgt-semantic-tf-idf-class_count Time: 0.20994257926940918s
pgt-semantic-tf-idf-property_count Time: 0.46290063858032227s

real	0m3.095s
user	0m4.325s
sys	0m0.502s
SDQACBPT.csv: 216 of 342
pgt-semantic-tf-idf-class_count Time: 0.534358024597168s
pgt-semantic-tf-idf-property_count Time: 0.942436695098877s

real	0

pgt-semantic-tf-idf-class_count Time: 0.8489255905151367s
pgt-semantic-tf-idf-property_count Time: 1.6484010219573975s

real	0m4.655s
user	0m5.693s
sys	0m0.647s
7T0JJJVE.csv: 256 of 342
pgt-semantic-tf-idf-class_count Time: 0.6400225162506104s
pgt-semantic-tf-idf-property_count Time: 1.0947329998016357s

real	0m4.257s
user	0m5.748s
sys	0m0.628s
6VLKFW8J.csv: 257 of 342
pgt-semantic-tf-idf-class_count Time: 0.24343442916870117s
pgt-semantic-tf-idf-property_count Time: 0.4637312889099121s

real	0m3.209s
user	0m4.241s
sys	0m0.544s
0AQOU1Z2.csv: 258 of 342
pgt-semantic-tf-idf-class_count Time: 0.21094059944152832s
pgt-semantic-tf-idf-property_count Time: 0.45990896224975586s

real	0m3.103s
user	0m4.326s
sys	0m0.512s
8EVYXGLE.csv: 259 of 342
pgt-semantic-tf-idf-class_count Time: 0.9376451969146729s
pgt-semantic-tf-idf-property_count Time: 1.6788995265960693s

real	0m4.745s
user	0m6.049s
sys	0m0.702s
BCNHZUB2.csv: 260 of 342
pgt-semantic-tf-idf-class_count Time: 0.7096526622772217s
pgt-seman

pgt-semantic-tf-idf-property_count Time: 0.6668548583984375s

real	0m3.383s
user	0m4.497s
sys	0m0.574s
9PYE5TKS.csv: 299 of 342
pgt-semantic-tf-idf-class_count Time: 0.5004491806030273s
pgt-semantic-tf-idf-property_count Time: 0.9352800846099854s

real	0m3.865s
user	0m4.996s
sys	0m0.571s
JZ9RW99R.csv: 300 of 342
pgt-semantic-tf-idf-class_count Time: 0.7179999351501465s
pgt-semantic-tf-idf-property_count Time: 1.403041124343872s

real	0m4.303s
user	0m5.445s
sys	0m0.661s
UPNVIIDW.csv: 301 of 342
pgt-semantic-tf-idf-class_count Time: 0.7748560905456543s
pgt-semantic-tf-idf-property_count Time: 1.3720498085021973s

real	0m4.076s
user	0m5.114s
sys	0m0.582s
AJI584YU.csv: 302 of 342
pgt-semantic-tf-idf-class_count Time: 0.46900081634521484s
pgt-semantic-tf-idf-property_count Time: 0.9038770198822021s

real	0m3.687s
user	0m4.766s
sys	0m0.625s
50245608_0_871275842592178099.csv: 303 of 342
pgt-semantic-tf-idf-class_count Time: 2.208080768585205s
pgt-semantic-tf-idf-property_count Time: 4.1710903

pgt-semantic-tf-idf-class_count Time: 0.2685372829437256s
pgt-semantic-tf-idf-property_count Time: 0.5350325107574463s

real	0m3.284s
user	0m4.519s
sys	0m0.583s


In [82]:
feature_generation(dev_candidate_path, dev_class_count, dev_prop_count, dev_context_path, dev_feature_path)

BOXTVP7V.csv: 1 of 58
E5SHJSQZ.csv: 2 of 58
DBH21J5D.csv: 3 of 58
84575189_0_6365692015941409487.csv: 4 of 58
3OX1PGQD.csv: 5 of 58
4HYT5D2J.csv: 6 of 58
28086084_0_3127660530989916727.csv: 7 of 58
U4430LA9.csv: 8 of 58
50270082_0_444360818941411589.csv: 9 of 58
PT0GTLGV.csv: 10 of 58
FV3PPNAQ.csv: 11 of 58
JYC6D9MU.csv: 12 of 58
VADKVBSJ.csv: 13 of 58
FU7P6GOF.csv: 14 of 58
6T4QNE30.csv: 15 of 58
TLAL3B63.csv: 16 of 58
VNSUNG1M.csv: 17 of 58
FDOC6GMJ.csv: 18 of 58
29414811_2_4773219892816395776.csv: 19 of 58
CR0Q0GDE.csv: 20 of 58
RCL5LZUM.csv: 21 of 58
9V2P69CI.csv: 22 of 58
RF6RSJ5W.csv: 23 of 58
SYRX0I75.csv: 24 of 58
54SEC9F3.csv: 25 of 58
MBCHQ4TC.csv: 26 of 58
RWEJTWBK.csv: 27 of 58
34LOX8E9.csv: 28 of 58
OYFD9B7F.csv: 29 of 58
ZR25NVUN.csv: 30 of 58
JTWZYYBU.csv: 31 of 58
NE9XVY42.csv: 32 of 58
VB0WL533.csv: 33 of 58
XXYFPD8I.csv: 34 of 58
8N4ZTXDV.csv: 35 of 58
YV0V8O3A.csv: 36 of 58
39759273_0_1427898308030295194.csv: 37 of 58
OEMDOUBY.csv: 38 of 58
14380604_4_332923570574676

39759273_0_1427898308030295194.csv: 37 of 58
pgt-semantic-tf-idf-class_count Time: 1.4313597679138184s
pgt-semantic-tf-idf-property_count Time: 2.672819137573242s

real	0m5.604s
user	0m6.792s
sys	0m0.647s
OEMDOUBY.csv: 38 of 58
pgt-semantic-tf-idf-class_count Time: 0.35991334915161133s
pgt-semantic-tf-idf-property_count Time: 0.709599494934082s

real	0m3.499s
user	0m4.662s
sys	0m0.599s
14380604_4_3329235705746762392.csv: 39 of 58
pgt-semantic-tf-idf-class_count Time: 0.23534512519836426s
pgt-semantic-tf-idf-property_count Time: 0.500577449798584s

real	0m3.186s
user	0m4.411s
sys	0m0.452s
VEKB4XZC.csv: 40 of 58
pgt-semantic-tf-idf-class_count Time: 1.1003592014312744s
pgt-semantic-tf-idf-property_count Time: 1.9371049404144287s

real	0m4.893s
user	0m6.058s
sys	0m0.627s
1UEUW7EP.csv: 41 of 58
pgt-semantic-tf-idf-class_count Time: 0.4443821907043457s
pgt-semantic-tf-idf-property_count Time: 0.9211266040802002s

real	0m3.863s
user	0m5.199s
sys	0m0.610s
14067031_0_559833072073397908.csv: 42

## Generate training data

In [83]:
def merge_files(args):
    datapath = args.train_path
    eval_file_names = []
    for (dirpath, dirnames, filenames) in os.walk(datapath):
        for fn in filenames:
            if "csv" not in fn:
                continue
            abs_fn = f"{dirpath}/{fn}"
            assert os.path.isfile(abs_fn)
            if os.path.getsize(abs_fn) == 0:
                continue
            eval_file_names.append(abs_fn)
    df_list = []
    for fn in eval_file_names:
        fid = fn.split('/')[-1].split('.csv')[0]
        df = pd.read_csv(fn)
        df['table_id'] = fid
        df_list.append(df)
    return pd.concat(df_list) 

def compute_normalization_factor(args, all_data):
    min_max_scaler_path = args.min_max_scaler_path
    all_data_features = all_data[features]
    scaler = MinMaxScaler()
    scaler.fit(all_data_features)
    pickle.dump(scaler, open(min_max_scaler_path, 'wb'))
    return scaler

def generate_train_data(args):
    scaler_path = args.min_max_scaler_path
    scaler = pickle.load(open(scaler_path, 'rb'))
    final_list = []
    sfeatures = copy.deepcopy(features) + ['evaluation_label']
    print(sfeatures)
    normalize_features = features
    evaluation_label = ['evaluation_label']
    positive_features_final = []
    negative_features_final = []
    for i,file in enumerate(glob.glob(args.train_path + '/*.csv')):
        file_name = file.split('/')[-1]
        print(file_name)
        if os.path.getsize(file) == 0:
                continue
        d_sample = pd.read_csv(file)
        d_sample = d_sample[(d_sample["smc_class_score"]>0)].reset_index(drop=True)
        grouped_obj = d_sample.groupby(['column', 'row'])
        for cell in grouped_obj:
            cell[1][normalize_features] = scaler.transform(cell[1][normalize_features])
            pos_features = []
            neg_features = []
            a = cell[1][cell[1]['evaluation_label'] == 1]
            if a.empty:
                continue
            pos_rows = cell[1][cell[1]['evaluation_label'] == 1][features].to_numpy()
            for i in range(len(pos_rows)):
                pos_features.append(pos_rows[i])
            neg_rows = cell[1][cell[1]['evaluation_label'] == -1][features].to_numpy()
            for i in range(len(neg_rows)):
                neg_features.append(neg_rows[i])
            random.shuffle(pos_features)
            random.shuffle(pos_features)
            positive_features_final.append(pos_features)
            negative_features_final.append(neg_features)
            
    print(len(positive_features_final), len(positive_features_final[0]))
    print(len(negative_features_final), len(negative_features_final[0]))
    pickle.dump(positive_features_final,open(args.pos_output,'wb'))
    pickle.dump(negative_features_final,open(args.neg_output,'wb'))


In [84]:
gen_training_data_args = Namespace(train_path=train_feature_path, pos_output=pos_output, neg_output=neg_output, 
                 min_max_scaler_path=min_max_scaler_path)
all_data = merge_files(gen_training_data_args)
scaler = compute_normalization_factor(gen_training_data_args, all_data)
generate_train_data(gen_training_data_args)

['pgr_rts', 'monge_elkan', 'monge_elkan_aliases', 'des_cont_jaccard', 'jaro_winkler', 'levenshtein', 'singleton', 'context_score', 'smc_class_score', 'smc_property_score', 'evaluation_label']
37856682_0_6818907050314633217.csv
B8QWQQAB.csv
HIFQAGMX.csv
1LD1MWA8.csv
1XNHBBRZ.csv
38428277_0_1311643810102462607.csv
1KJ39NFE.csv
HFNU4Y9W.csv
MZNZLWYW.csv
K2V1VODK.csv
E5DKRW4W.csv
YCXXPVD2.csv
ZX8GERJC.csv
6NO3AH02.csv
FVFG3EJ2.csv
26XDNAJB.csv
NPGBDBM4.csv
ERPSWFMM.csv
ZDAZ5PQ5.csv
2LM6W2JV.csv
0KL64BZL.csv
2JN1R1VW.csv
X0TEEJCK.csv
UL2BYXAR.csv
QTYEU8F5.csv
9XF3SP0B.csv
EL9S7KDR.csv
HTUXRVUC.csv
29414811_12_251152470253168163.csv
F487BS0V.csv
8DOZTMTY.csv
XSUGP66N.csv
6D4OURQN.csv
J3P3ZJZ0.csv
BLUL2XZW.csv
64ZFZ4K2.csv
9834884_0_3871985887467090123.csv
NUTCUXCN.csv
YXYVNO79.csv
O668CSQ3.csv
E0LR4TZL.csv
R4K6322V.csv
L2WQ1RA3.csv
BBN3425A.csv
FCG0YNIZ.csv
2IEVUWPV.csv
6FGUGZF9.csv
57681CMM.csv
CLW28GXT.csv
0ZH7HCT0.csv
5INQ2HVE.csv
L94PAXBK.csv
21245481_0_8730460088443117515.csv
3QIWU8Z7.c

## Model definition

In [85]:
# Dataset
class T2DV2Dataset(Dataset):
    def __init__(self, pos_features, neg_features):
        self.pos_features = pos_features
        self.neg_features = neg_features
    
    def __len__(self):
        return len(self.pos_features)
    
    def __getitem__(self, idx):
        return self.pos_features[idx], self.neg_features[idx]

# Model
class PairwiseNetwork(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        #original 10x20, 20x10, 10x10, 10x1
        self.fc1 = nn.Linear(hidden_size, 2*hidden_size)
        self.fc2 = nn.Linear(2*hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, 1)
    
    def forward(self, pos_features, neg_features):
        # Positive pass
        x = F.relu(self.fc1(pos_features))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        pos_out = torch.sigmoid(self.fc4(x))
        
        # Negative Pass
        x = F.relu(self.fc1(neg_features))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        neg_out = torch.sigmoid(self.fc4(x))
        
        return pos_out, neg_out
    
    def predict(self, test_feat):
        x = F.relu(self.fc1(test_feat))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        test_out = torch.sigmoid(self.fc4(x))
        return test_out

# Pairwise Loss
class PairwiseLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = 0
    
    def forward(self, pos_out, neg_out):
        distance = (1 - pos_out) + neg_out
        loss = torch.mean(torch.max(torch.tensor(0), distance))
        return loss

## Training

In [86]:
def generate_dataloader(positive_feat_path, negative_feat_path):
    pos_features = pickle.load(open(positive_feat_path, 'rb'))
    neg_features = pickle.load(open(negative_feat_path, 'rb'))

    pos_features_flatten = list(chain.from_iterable(pos_features))
    neg_features_flatten = list(chain.from_iterable(neg_features))

    train_dataset = T2DV2Dataset(pos_features_flatten, neg_features_flatten)
    train_dataloader = DataLoader(train_dataset, batch_size=64)
    return train_dataloader

def infer_scores(min_max_scaler_path, input_table_path, output_table_path, model):
    scaler = pickle.load(open(min_max_scaler_path, 'rb'))
    normalize_features = features
    for file in glob.glob(input_table_path + '/*.csv'):
        file_name = file.split('/')[-1]
        if os.path.getsize(file) == 0:
                continue
                
        print(file_name)
        d_sample = pd.read_csv(file)
#         d_sample = d_sample[d_sample["smc_class_score"]>0]
        grouped_obj = d_sample.groupby(['column', 'row'])
        new_df_list = []
        pred = []
        for cell in grouped_obj:
            cell[1][normalize_features] = scaler.transform(cell[1][normalize_features])
            sorted_df = cell[1].sort_values('smc_class_score',ascending=False)
            sorted_df_features = sorted_df[normalize_features]
            new_df_list.append(sorted_df)
            arr = sorted_df_features.to_numpy()
            test_inp = []
            for a in arr:
                test_inp.append(a)
            test_tensor = torch.tensor(test_inp).float()
            scores = model.predict(test_tensor)
            scores_list = torch.squeeze(scores).tolist()
            if not type(scores_list) is list:
                pred.append(scores_list)
            else:
                pred.extend(scores_list)
        test_df = pd.concat(new_df_list)
        test_df[final_score_column] = pred
        test_df.to_csv(f"{output_table_path}/{file_name}", index=False)

def train(args):
    if torch.cuda.is_available():
        device = torch.device('cuda')
    
    else:
        device = torch.device('cpu')
    train_dataloader = generate_dataloader(args.positive_feat_path, args.negative_feat_path)
    criterion = PairwiseLoss()
    EPOCHS = args.num_epochs
    model = PairwiseNetwork(len(features)).to(device=device)
    optimizer = Adam(model.parameters(), lr=args.lr)
    top1_max_prec = 0
    for epoch in range(EPOCHS):
        train_epoch_loss = 0
        avg_loss = 0
        model.train()
        for bid, batch in tqdm(enumerate(train_dataloader), position=0, leave=True):
            positive_feat = torch.tensor(batch[0].float())
            negative_feat = torch.tensor(batch[1].float())
            optimizer.zero_grad()
            pos_out, neg_out = model(positive_feat, negative_feat)
            loss = criterion(pos_out, neg_out)
            loss.backward()
            optimizer.step()
            train_epoch_loss += loss
        avg_loss = train_epoch_loss / bid

        # Evaluation
        model.eval()
        infer_scores(args.min_max_scaler_path, args.dev_path, args.dev_output, model)
        eval_data = merge_eval_files(args.dev_output)
        res, candidate_eval_data = parse_eval_files_stats(eval_data, final_score_column)
        top1_precision = res['num_tasks_with_model_score_top_one_accurate']/res['num_tasks_with_gt']
        if top1_precision > top1_max_prec:
            top1_max_prec = top1_precision
            model_save_name = 'epoch_{}_loss_{}_top1_{}.pth'.format(epoch, avg_loss, top1_max_prec)
            best_model_path = os.path.join(args.model_save_path, model_save_name)
            torch.save(model.state_dict(), best_model_path)
        
        print("Epoch {}, Avg Loss is {}, epoch top1 {}, max top1 {}".format(epoch, avg_loss, top1_precision, top1_max_prec))
    return best_model_path

In [87]:
def merge_eval_files(final_score_path):
    eval_file_names = []
    df_list = []
    for (dirpath, dirnames, filenames) in os.walk(final_score_path):
        for fn in filenames:
            if "csv" not in fn:
                continue
            abs_fn = os.path.join(dirpath, fn)
            assert os.path.isfile(abs_fn)
            if os.path.getsize(abs_fn) == 0:
                continue
            eval_file_names.append(abs_fn)
    
    for fn in eval_file_names:
        fid = fn.split('/')[-1].split('.csv')[0]
        df = pd.read_csv(fn)
        df['table_id'] = fid
        df_list.append(df)
    return pd.concat(df_list)

def parse_eval_files_stats(eval_data, method):
    res = {}
    candidate_eval_data = eval_data.groupby(['table_id', 'column', 'row'])['table_id'].count().reset_index(name="count")
    res['num_tasks_with_gt'] = len(eval_data[pd.notna(eval_data['GT_kg_id'])].groupby(['table_id', 'column', 'row']))
    num_tasks_with_model_score_top_one_accurate = []
    num_tasks_with_model_score_top_five_accurate = []
    num_tasks_with_model_score_top_ten_accurate = []
    has_gt_list = []
    has_gt_in_candidate = []
    for i, row in candidate_eval_data.iterrows():
        table_id, row_idx, col_idx = row['table_id'], row['row'], row['column']
        c_e_data = eval_data[(eval_data['table_id'] == table_id) & (eval_data['row'] == row_idx) & (eval_data['column'] == col_idx)]
        assert len(c_e_data) > 0
        if np.nan not in set(c_e_data['GT_kg_id']):
            has_gt_list.append(1)
        else:
            has_gt_list.append(0)
        if 1 in set(c_e_data['evaluation_label']):
            has_gt_in_candidate.append(1)
        else:
            has_gt_in_candidate.append(0)
                    
        #rank on model score
        s_data = c_e_data.sort_values(by=[method], ascending=False)
        if s_data.iloc[0]['evaluation_label'] == 1:
            num_tasks_with_model_score_top_one_accurate.append(1)
        else:
            num_tasks_with_model_score_top_one_accurate.append(0)
        if 1 in set(s_data.iloc[0:5]['evaluation_label']):
            num_tasks_with_model_score_top_five_accurate.append(1)
        else:
            num_tasks_with_model_score_top_five_accurate.append(0)
        if 1 in set(s_data.iloc[0:10]['evaluation_label']):
            num_tasks_with_model_score_top_ten_accurate.append(1)
        else:
            num_tasks_with_model_score_top_ten_accurate.append(0)
            
    res['num_tasks_with_model_score_top_one_accurate'] = sum(num_tasks_with_model_score_top_one_accurate)
    res['num_tasks_with_model_score_top_five_accurate'] = sum(num_tasks_with_model_score_top_five_accurate)
    res['num_tasks_with_model_score_top_ten_accurate'] = sum(num_tasks_with_model_score_top_ten_accurate)
    return res, candidate_eval_data

In [88]:
training_args = Namespace(num_epochs=20, lr=0.001, positive_feat_path=pos_output, negative_feat_path=neg_output,
                         dev_path=dev_feature_path, dev_output=dev_output_predictions,
                         model_save_path=model_save_path, min_max_scaler_path=min_max_scaler_path)

In [89]:
## Call Training
best_model_path = train(training_args)

  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
276it [00:02, 118.25it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
53it [00:00, 522.21it/s]

Epoch 0, Avg Loss is 0.6838570237159729, epoch top1 0.8274165202108963, max top1 0.8274165202108963


276it [00:00, 561.89it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
60it [00:00, 591.63it/s]

Epoch 1, Avg Loss is 0.127713143825531, epoch top1 0.8126537785588752, max top1 0.8274165202108963


276it [00:00, 586.26it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
57it [00:00, 569.78it/s]

Epoch 2, Avg Loss is 0.08688374608755112, epoch top1 0.7943760984182777, max top1 0.8274165202108963


276it [00:00, 581.83it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
58it [00:00, 577.86it/s]

Epoch 3, Avg Loss is 0.07403706014156342, epoch top1 0.7834797891036906, max top1 0.8274165202108963


276it [00:00, 588.25it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 582.52it/s]

Epoch 4, Avg Loss is 0.0676841139793396, epoch top1 0.7666080843585237, max top1 0.8274165202108963


276it [00:00, 578.79it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 586.70it/s]

Epoch 5, Avg Loss is 0.06369482725858688, epoch top1 0.7398945518453427, max top1 0.8274165202108963


276it [00:00, 580.82it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 587.76it/s]

Epoch 6, Avg Loss is 0.06072162464261055, epoch top1 0.7247803163444639, max top1 0.8274165202108963


276it [00:00, 569.69it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 588.87it/s]

Epoch 7, Avg Loss is 0.058280590921640396, epoch top1 0.6977152899824253, max top1 0.8274165202108963


276it [00:00, 569.15it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 581.98it/s]

Epoch 8, Avg Loss is 0.05617678910493851, epoch top1 0.6766256590509666, max top1 0.8274165202108963


276it [00:00, 561.02it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
58it [00:00, 578.88it/s]

Epoch 9, Avg Loss is 0.05429631099104881, epoch top1 0.6569420035149385, max top1 0.8274165202108963


276it [00:00, 583.04it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
58it [00:00, 577.24it/s]

Epoch 10, Avg Loss is 0.052619464695453644, epoch top1 0.6355008787346221, max top1 0.8274165202108963


276it [00:00, 580.65it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 583.81it/s]

Epoch 11, Avg Loss is 0.051144812256097794, epoch top1 0.6193321616871704, max top1 0.8274165202108963


276it [00:00, 586.23it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 580.95it/s]

Epoch 12, Avg Loss is 0.049924541264772415, epoch top1 0.604920913884007, max top1 0.8274165202108963


276it [00:00, 586.91it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
57it [00:00, 567.12it/s]

Epoch 13, Avg Loss is 0.048934824764728546, epoch top1 0.5905096660808435, max top1 0.8274165202108963


276it [00:00, 585.97it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
48it [00:00, 477.76it/s]

Epoch 14, Avg Loss is 0.048074476420879364, epoch top1 0.5789103690685413, max top1 0.8274165202108963


276it [00:00, 516.14it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 584.17it/s]

Epoch 15, Avg Loss is 0.04729203134775162, epoch top1 0.5708260105448155, max top1 0.8274165202108963


276it [00:00, 587.11it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
56it [00:00, 557.99it/s]

Epoch 16, Avg Loss is 0.04657056927680969, epoch top1 0.5609841827768014, max top1 0.8274165202108963


276it [00:00, 583.01it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 587.15it/s]

Epoch 17, Avg Loss is 0.04589607194066048, epoch top1 0.5511423550087874, max top1 0.8274165202108963


276it [00:00, 563.10it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


  positive_feat = torch.tensor(batch[0].float())
  negative_feat = torch.tensor(batch[1].float())
59it [00:00, 583.76it/s]

Epoch 18, Avg Loss is 0.04530151188373566, epoch top1 0.5476274165202109, max top1 0.8274165202108963


276it [00:00, 586.50it/s]


BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv
Epoch 19, Avg Loss is 0.04478145390748978, epoch t

In [90]:
best_model_path

'/home/sriamazingram/USC/Others/ISI/data/t2dv2/dev-output/12_1/saved_models/epoch_0_loss_0.6838570237159729_top1_0.8274165202108963.pth'

## Dev Prediction

In [5]:
def dev_prediction(dev_feature_path, dev_predictions_top_k, dev_output_predictions, saved_model, output_column, min_max_scaler_path, k=5):
    for file in glob.glob(dev_feature_path + '/*.csv'):
        filename = file.split("/")[-1]
        print(filename)
        feature_str =  ",".join(features)
        if os.path.getsize(file) == 0:
            continue
        # location where the output generated by the predictions wil be stored.
        dev_output = f"{dev_predictions_top_k}/{filename}"
        !tl predict-using-model $file -o $output_column \
            --features $feature_str \
            --ranking-model $saved_model \
            --normalization-factor $min_max_scaler_path \
            / create-pseudo-gt \
            --column-thresholds $threshold \
            --filter smc_class_score:0 \
            / get-kg-links -c $output_column -k $k --k-rows \
            > $dev_output

In [6]:
def add_color(dev_predictions_top_k, dev_colorized_path, score_column, k=5):
    for file in glob.glob(dev_predictions_top_k + '/*.csv'):
        filename = file.split("/")[-1]
        print(filename)
        if os.path.getsize(file) == 0:
                    continue
                
        dev_color_file = f"{dev_colorized_path}/{filename.strip('.csv')}.xlsx"
        !tl add-color $file -c "$score_column,evaluation_label" -k $k --output $dev_color_file

In [7]:
def compute_metrics(dev_predictions_top_k, dev_predictions_metrics, score_column, k=5):
    df_list = []
    for file in glob.glob(dev_predictions_top_k + '/*.csv'):
        filename = file.split("/")[-1]
        print(filename)
        if os.path.getsize(file) == 0:
                    continue
        dev_metrics_file = f"{dev_predictions_metrics}/{filename}"
        df = pd.read_csv(file)
        col_df = []
        for col, coldf in df.groupby(by=["column"]):
            smc_rows = 0
            rows = 0
            smc_recall = 0
            smc_candidates = 0
            pgt_rows = 0
            pgt_recall = 0
            for row, rowdf in coldf.groupby(by=["row"]):
                rows += 1
                count = rowdf[(rowdf["smc_class_score"] > 0)].shape[0]
                if count > 0:
                    smc_rows += 1
                smc_candidates += count
                recall = rowdf[((rowdf["smc_class_score"] > 0) & (rowdf["evaluation_label"] == 1))].shape[0]
                smc_recall += recall
                p_count = rowdf[(rowdf["pseudo_gt"] == 1)].shape[0]
                if p_count > 0:
                    pgt_rows += 1
                p_recall = rowdf[((rowdf["pseudo_gt"] == 1) & (rowdf["evaluation_label"] == 1))].shape[0]
                pgt_recall += p_recall
            col_df.append(pd.DataFrame([{"filename":filename, "column": col, "rows": rows, "smc_rows": smc_rows, "smc_recall_rows": smc_recall, "smc_candidates": smc_candidates, "pgt_rows": pgt_rows, "pgt_recall": pgt_recall, "smc_recall": smc_recall/smc_rows if smc_rows!=0 else 0, "pgt_accuracy": pgt_recall/pgt_rows if pgt_rows!=0 else 0}]))
        df_list.append(pd.concat(col_df))
    return pd.concat(df_list)

In [8]:
# best_model_path = '/home/sriamazingram/USC/Others/ISI/data/t2dv2/dev-output/12_1/saved_models/epoch_0_loss_0.6838570237159729_top1_0.8274165202108963.pth'
best_model_path = '/home/sriamazingram/USC/Others/ISI/data/t2dv2/dev-output/12_0/saved_models/epoch_19_loss_0.1209137961268425_top1_0.8340949033391916.pth'
dev_prediction(dev_feature_path, dev_predictions_top_k, dev_output_predictions, best_model_path, final_score_column, min_max_scaler_path, k=1)

BOXTVP7V.csv
predict-using-model Time: 5.418988466262817s
create-pseudo-gt Time: 0.09844112396240234s
get-kg-links-gt_score Time: 0.42121458053588867s
E5SHJSQZ.csv
predict-using-model Time: 0.46292948722839355s
create-pseudo-gt Time: 0.047797441482543945s
get-kg-links-gt_score Time: 0.1605374813079834s
DBH21J5D.csv
predict-using-model Time: 0.45867323875427246s
create-pseudo-gt Time: 0.049750328063964844s
get-kg-links-gt_score Time: 0.15622425079345703s
84575189_0_6365692015941409487.csv
predict-using-model Time: 0.9911816120147705s
create-pseudo-gt Time: 0.17081546783447266s
get-kg-links-gt_score Time: 0.6941359043121338s
3OX1PGQD.csv
predict-using-model Time: 0.4571664333343506s
create-pseudo-gt Time: 0.04681682586669922s
get-kg-links-gt_score Time: 0.15894222259521484s
4HYT5D2J.csv
predict-using-model Time: 0.5606391429901123s
create-pseudo-gt Time: 0.07158446311950684s
get-kg-links-gt_score Time: 0.2722949981689453s
28086084_0_3127660530989916727.csv
predict-using-model Time: 1.809

292E016E.csv
predict-using-model Time: 0.7273874282836914s
create-pseudo-gt Time: 0.10996675491333008s
get-kg-links-gt_score Time: 0.4365420341491699s
V1MLK9TP.csv
predict-using-model Time: 0.7399024963378906s
create-pseudo-gt Time: 0.11232447624206543s
get-kg-links-gt_score Time: 0.42385339736938477s
NBYU3S9Y.csv
predict-using-model Time: 0.466763973236084s
create-pseudo-gt Time: 0.04730725288391113s
get-kg-links-gt_score Time: 0.1652841567993164s
52299421_0_4473286348258170200.csv
predict-using-model Time: 1.0503613948822021s
create-pseudo-gt Time: 0.36276865005493164s
get-kg-links-gt_score Time: 0.7175037860870361s
PMSAYLPC.csv
predict-using-model Time: 0.731025218963623s
create-pseudo-gt Time: 0.1127777099609375s
get-kg-links-gt_score Time: 0.4479975700378418s


In [9]:
metrics_df = compute_metrics(dev_predictions_top_k, dev_metrics_path, final_score_column, k=1)

BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
VADKVBSJ.csv
FU7P6GOF.csv
6T4QNE30.csv
TLAL3B63.csv
VNSUNG1M.csv
FDOC6GMJ.csv
29414811_2_4773219892816395776.csv
CR0Q0GDE.csv
RCL5LZUM.csv
9V2P69CI.csv
RF6RSJ5W.csv
SYRX0I75.csv
54SEC9F3.csv
MBCHQ4TC.csv
RWEJTWBK.csv
34LOX8E9.csv
OYFD9B7F.csv
ZR25NVUN.csv
JTWZYYBU.csv
NE9XVY42.csv
VB0WL533.csv
XXYFPD8I.csv
8N4ZTXDV.csv
YV0V8O3A.csv
39759273_0_1427898308030295194.csv
OEMDOUBY.csv
14380604_4_3329235705746762392.csv
VEKB4XZC.csv
1UEUW7EP.csv
14067031_0_559833072073397908.csv
093BPOP2.csv
J5WTHYK6.csv
IYEDUUIU.csv
JAV53EZQ.csv
DHDSWRU2.csv
RPS3P53T.csv
U7PSL9LZ.csv
45073662_0_3179937335063201739.csv
DKR353LM.csv
KL3RUA2V.csv
EJMFROMS.csv
292E016E.csv
V1MLK9TP.csv
NBYU3S9Y.csv
52299421_0_4473286348258170200.csv
PMSAYLPC.csv


In [10]:
metrics_df

Unnamed: 0,filename,column,rows,smc_rows,smc_recall_rows,smc_candidates,pgt_rows,pgt_recall,smc_recall,pgt_accuracy
0,BOXTVP7V.csv,0,20,20,20,20,16,16,1.000000,1.000000
0,BOXTVP7V.csv,2,20,20,20,20,0,0,1.000000,0.000000
0,E5SHJSQZ.csv,0,20,20,11,20,10,7,0.550000,0.700000
0,DBH21J5D.csv,0,20,20,20,20,20,20,1.000000,1.000000
0,84575189_0_6365692015941409487.csv,2,99,92,80,92,64,60,0.869565,0.937500
...,...,...,...,...,...,...,...,...,...,...
0,NBYU3S9Y.csv,0,20,20,19,20,19,19,0.950000,1.000000
0,52299421_0_4473286348258170200.csv,1,92,92,88,92,83,81,0.956522,0.975904
0,PMSAYLPC.csv,0,20,20,19,20,0,0,0.950000,0.000000
0,PMSAYLPC.csv,1,20,20,20,20,0,0,1.000000,0.000000


In [11]:
metrics_df['smc_recall'].mean()

0.8093272029938183

In [12]:
metrics_df['pgt_accuracy'].mean()

0.60950998914938

In [13]:
add_color(dev_predictions_top_k, dev_colorized_path, final_score_column, k=1)

BOXTVP7V.csv
add-color Time: 0.3345224857330322s
E5SHJSQZ.csv
add-color Time: 0.04653811454772949s
DBH21J5D.csv
add-color Time: 0.040328264236450195s
84575189_0_6365692015941409487.csv
add-color Time: 0.09169578552246094s
3OX1PGQD.csv
add-color Time: 0.03899812698364258s
4HYT5D2J.csv
add-color Time: 0.043579816818237305s
28086084_0_3127660530989916727.csv
add-color Time: 0.1759474277496338s
U4430LA9.csv
add-color Time: 0.04608297348022461s
50270082_0_444360818941411589.csv
add-color Time: 0.12504291534423828s
PT0GTLGV.csv
add-color Time: 0.08594560623168945s
FV3PPNAQ.csv
add-color Time: 0.050455331802368164s
JYC6D9MU.csv
add-color Time: 0.06686568260192871s
VADKVBSJ.csv
add-color Time: 0.06985068321228027s
FU7P6GOF.csv
add-color Time: 0.05729341506958008s
6T4QNE30.csv
add-color Time: 0.07089996337890625s
TLAL3B63.csv
add-color Time: 0.05772519111633301s
VNSUNG1M.csv
add-color Time: 0.057912588119506836s
FDOC6GMJ.csv
add-color Time: 0.06678223609924316s
29414811_2_4773219892816395776.cs

In [14]:
metrics_df.to_csv(f"{dev_metrics_path}/metrics1.csv", index=False)

In [139]:
dev_prediction(dev_feature_path, dev_predictions_top_k, dev_output_predictions, best_model_path, final_score_column, min_max_scaler_path, k=5)
metrics_df = compute_metrics(dev_predictions_top_k, dev_metrics_path, final_score_column, k=5)
print(metrics_df['smc_recall'].mean(), metrics_df['pgt_accuracy'].mean())
add_color(dev_predictions_top_k, dev_colorized_path, final_score_column, k=5)
metrics_df.to_csv(f"{dev_metrics_path}/metrics5.csv", index=False)

BOXTVP7V.csv
predict-using-model Time: 0.5665733814239502s
create-pseudo-gt Time: 0.08074164390563965s
get-kg-links-gt_score Time: 0.2836918830871582s
E5SHJSQZ.csv
predict-using-model Time: 0.43309974670410156s
create-pseudo-gt Time: 0.047197818756103516s
get-kg-links-gt_score Time: 0.15593266487121582s
DBH21J5D.csv
predict-using-model Time: 0.4545016288757324s
create-pseudo-gt Time: 0.044663429260253906s
get-kg-links-gt_score Time: 0.15007567405700684s
84575189_0_6365692015941409487.csv
predict-using-model Time: 0.9130144119262695s
create-pseudo-gt Time: 0.16237974166870117s
get-kg-links-gt_score Time: 0.667809009552002s
3OX1PGQD.csv
predict-using-model Time: 0.4298079013824463s
create-pseudo-gt Time: 0.04340696334838867s
get-kg-links-gt_score Time: 0.14835476875305176s
4HYT5D2J.csv
predict-using-model Time: 0.5688464641571045s
create-pseudo-gt Time: 0.07027649879455566s
get-kg-links-gt_score Time: 0.25863051414489746s
28086084_0_3127660530989916727.csv
predict-using-model Time: 1.700

292E016E.csv
predict-using-model Time: 0.6815779209136963s
create-pseudo-gt Time: 0.10749387741088867s
get-kg-links-gt_score Time: 0.4225749969482422s
V1MLK9TP.csv
predict-using-model Time: 0.9217581748962402s
create-pseudo-gt Time: 0.11023688316345215s
get-kg-links-gt_score Time: 0.4081439971923828s
NBYU3S9Y.csv
predict-using-model Time: 0.5111043453216553s
create-pseudo-gt Time: 0.056909799575805664s
get-kg-links-gt_score Time: 0.22411894798278809s
52299421_0_4473286348258170200.csv
predict-using-model Time: 0.9369409084320068s
create-pseudo-gt Time: 0.1533513069152832s
get-kg-links-gt_score Time: 0.6594874858856201s
PMSAYLPC.csv
predict-using-model Time: 0.6817476749420166s
create-pseudo-gt Time: 0.10379171371459961s
get-kg-links-gt_score Time: 0.4075276851654053s
BOXTVP7V.csv
E5SHJSQZ.csv
DBH21J5D.csv
84575189_0_6365692015941409487.csv
3OX1PGQD.csv
4HYT5D2J.csv
28086084_0_3127660530989916727.csv
U4430LA9.csv
50270082_0_444360818941411589.csv
PT0GTLGV.csv
FV3PPNAQ.csv
JYC6D9MU.csv
V

In [44]:
d = pd.read_csv("/home/sriamazingram/USC/Others/ISI/data/t2dv2/dev-output/12_1/smc_context_score/median_final/dev_predictions_metrics/metrics.csv")
d.sort_values(by=["pgt_accuracy"]).head(20)

Unnamed: 0,filename,column,rows,smc_rows,smc_recall_rows,smc_candidates,pgt_rows,pgt_recall,smc_recall,pgt_accuracy
72,VEKB4XZC.csv,0,20,20,20,1536,10,0,1.0,0.0
51,54SEC9F3.csv,2,20,20,0,1443,1,0,0.0,0.0
49,SYRX0I75.csv,4,20,20,0,20,3,0,0.0,0.0
23,FU7P6GOF.csv,2,20,20,0,155,17,0,0.0,0.0
28,TLAL3B63.csv,1,20,20,19,1561,20,0,0.95,0.0
25,6T4QNE30.csv,1,20,20,20,577,19,2,1.0,0.105263
52,MBCHQ4TC.csv,0,20,20,4,38,7,1,0.2,0.142857
26,6T4QNE30.csv,2,20,20,20,1422,10,3,1.0,0.3
65,VB0WL533.csv,2,20,20,18,351,10,5,0.9,0.5
47,SYRX0I75.csv,2,20,20,19,174,9,5,0.95,0.555556
