In [2]:
import torch
import numpy as np
import pandas as pd
import matchzoo as mz

#import importlib
#importlib.reload(mz)

In [3]:
relation = pd.read_csv("../../../robust04/qrels.rob04.txt", sep=" ",
                      names = ["id_left", "dummy", "id_right", "label"]).drop("dummy",axis=1)
relation['label'] = relation['label'].astype(float)
relation = relation.dropna(axis=0, how='any')
relation.id_left = relation.id_left.astype("str")
relation.id_right = relation.id_right.astype("str")

In [4]:
left = pd.read_csv("../../../robust04/query_trec45_no.txt", sep=":",
                  names = ["id_left", "text_left"])
left = left.dropna(axis=0, how='any')
left.id_left = left.id_left.astype("str")

In [5]:
doc_dict = dict()
with open("../../../robust04/trec45.doc.txt") as f:
    for line in f:
        info = line.split("\t")
        if len(info) == 4:
            text = info[2] + " ".join(info[3].split(" ")[:500])
        else:
            text = " ".join(info[2].split(" ")[:500])
        doc_dict[info[1]] = text

In [6]:
right = pd.DataFrame(list(doc_dict.items()),columns = ['id_right','text_right']) 
right = right.dropna(axis=0, how='any')
right.id_right = right.id_right.astype("str")
right_id_right = right.id_right.drop_duplicates()

In [8]:
train_id_left = pd.read_csv('../../../cedr/data/robust/f1.train.pairs', sep = "\t",
                         names = ["id_left", "id_right"]).id_left.astype("str").drop_duplicates()
vali_id_left = pd.read_csv('../../../cedr/data/robust/f1.valid.run', sep = " ",
                         names = ["id_left", "dummy1", "id_right", "rank", "score", "dummy2"]).id_left.astype("str").drop_duplicates()
test_id_left = pd.read_csv('../../../cedr/data/robust/f1.test.run', sep = " ",
                         names = ["id_left", "dummy1", "id_right", "rank", "score", "dummy2"]).id_left.astype("str").drop_duplicates()

In [9]:
relation_train = pd.merge(pd.merge(relation,train_id_left, on=["id_left"], how="inner"), 
                          right_id_right, how = "inner")[["id_left", "id_right", "label"]]
relation_vali = pd.merge(pd.merge(relation,vali_id_left, on=["id_left"], how="inner"), 
                          right_id_right, how = "inner")[["id_left", "id_right", "label"]]
relation_test = pd.merge(pd.merge(relation,test_id_left, on=["id_left"], how="inner"), 
                          right_id_right, how = "inner")[["id_left", "id_right", "label"]]

left_train = pd.merge(train_id_left, left, how="inner", on="id_left")[["id_left", "text_left"]]
left_train.set_index("id_left", inplace=True)
left_vali = pd.merge(vali_id_left, left, how="inner", on="id_left")[["id_left", "text_left"]]
left_vali.set_index("id_left", inplace=True)
left_test = pd.merge(test_id_left, left, how="inner", on="id_left")[["id_left", "text_left"]]
left_test.set_index("id_left", inplace=True)

right_train = pd.merge(relation_train.id_right.drop_duplicates(), right, how="inner", on="id_right")[["id_right", "text_right"]].drop_duplicates()
right_train.set_index("id_right", inplace=True)
right_vali = pd.merge(relation_vali.id_right.drop_duplicates(), right, how="inner", on="id_right")[["id_right", "text_right"]].drop_duplicates()
right_vali.set_index("id_right", inplace=True)
right_test = pd.merge(relation_test.id_right.drop_duplicates(), right, how="inner", on="id_right")[["id_right", "text_right"]].drop_duplicates()
right_test.set_index("id_right", inplace=True)

In [10]:
print('data loading ...')
train_pack_raw = mz.DataPack(relation=relation_train,left=left_train,right=right_train)
dev_pack_raw = mz.DataPack(relation=relation_vali,left=left_vali,right=right_vali)
test_pack_raw = mz.DataPack(relation=relation_test,left=left_test,right=right_test)
print('data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`')

data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`


In [11]:
ranking_task = mz.tasks.Ranking(losses=mz.losses.RankHingeLoss())
#ranking_task = mz.tasks.Ranking(losses=mz.losses.RankCrossEntropyLoss())
ranking_task.metrics = [
    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),
    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),
    mz.metrics.MeanAveragePrecision()
]
print("`ranking_task` initialized with metrics", ranking_task.metrics)

`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]


In [12]:
preprocessor = mz.preprocessors.BasicPreprocessor(
    truncated_length_left = 10,
    truncated_length_right = 500,
    #filter_low_freq = 2
)

In [13]:
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
dev_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)

Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 150/150 [00:00<00:00, 5063.14it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 94668/94668 [03:14<00:00, 487.56it/s]
Processing text_right with append: 100%|██████████| 94668/94668 [00:00<00:00, 531387.25it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 94668/94668 [00:07<00:00, 12181.38it/s]
Processing text_right with transform: 100%|██████████| 94668/94668 [00:10<00:00, 8705.60it/s] 
Processing text_left with extend: 100%|██████████| 150/150 [00:00<00:00, 157483.25it/s]
Processing text_right with extend: 100%|██████████| 94668/94668 [00:00<00:00, 148657.60it/s]
Building Vocabulary from a datapack.: 100%|██████████| 37583046/37583046 [00:12<00:00, 2990465.42it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 150/150 [00:00<00:00, 8188.16it/s]
Processing text_rig

In [13]:
vocab_dict = dict()
with open("../../../embeddings/bing_vocab") as vocab:
    for line in vocab:
        info = line.split("\t")
        vocab_dict[info[1][:-1]] = info[0] 

with open("../../../embeddings/bing_knrm.txt", "w") as fo, open("../../../embeddings/bing_embedding") as f:
    i = 0
    for line in f:
        if i == 0:
            i += 1
            continue
        
        else:
            info = line.split(" ")
            if info[0] in vocab_dict:
                info[0] = vocab_dict[info[0]]
                fo.write(" ".join(info))
                

In [19]:
# http://boston.lti.cs.cmu.edu/appendices/WSDM2018-ConvKNRM/

raw_embedding = mz.embedding.load_from_file("../../../embeddings/bing_knrm.txt", mode = "glove")
term_index = preprocessor.context['vocab_unit'].state['term_index']
embedding_matrix = raw_embedding.build_matrix(term_index)
l2_norm = np.sqrt((embedding_matrix * embedding_matrix).sum(axis=1))
embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis] 

In [25]:
with open("../../../embeddings/rob04_jinjin_out.txt", "w") as fo,\
    open("../../../embeddings/rob-emd-300-min3-out.txt") as f:
        i = 0
        for line in f:
            if i == 0:
                i += 1
                continue
            fo.write(line)

In [26]:
# jinjin's
raw_embedding = mz.embedding.load_from_file("../../../embeddings/rob04_jinjin_out.txt", mode = "glove")
term_index = preprocessor.context['vocab_unit'].state['term_index']
embedding_matrix = raw_embedding.build_matrix(term_index)
l2_norm = np.sqrt((embedding_matrix * embedding_matrix).sum(axis=1))
embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis] 

In [21]:
glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=100)
term_index = preprocessor.context['vocab_unit'].state['term_index']
embedding_matrix = glove_embedding.build_matrix(term_index)
l2_norm = np.sqrt((embedding_matrix * embedding_matrix).sum(axis=1))
embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis]

In [32]:
trainset = mz.dataloader.Dataset(
    data_pack=train_pack_processed,
    mode='pair',
    num_dup = 750,
    num_neg=1,
    batch_size = 8
)
testset = mz.dataloader.Dataset(
    data_pack=dev_pack_processed
)

301
306
307
308
312
313
320
321
322
324
325
326
327
328
330
332
334
335
337
338
342
343
344
347
348
349
350
351
352
354
355
358
360
361
362
363
364
365
368
369
371
374
376
377
379
380
382
386
387
390
393
396
397
398
402
403
404
405
407
408
412
413
415
417
419
420
421
422
423
424
425
427
430
431
432
434
435
436
437
438
439
440
444
445
446
449
450
602
603
604
605
606
611
614
616
618
620
622
623
624
625
626
627
628
630
631
632
633
636
637
638
639
643
644
648
649
650
651
652
653
655
657
659
661
663
664
666
667
668
671
673
674
675
676
677
678
680
682
683
685
686
687
688
689
691
693
695
697
698


In [33]:
padding_callback = mz.models.KNRM.get_default_padding_callback()

trainloader = mz.dataloader.DataLoader(
    dataset=trainset,
    #batch_size=8,
    stage='train',
    #resample=True,
    #sort=False,
    callback=padding_callback,
)
testloader = mz.dataloader.DataLoader(
    dataset=testset,
    #batch_size=20,
    stage='dev',
    callback=padding_callback
)

In [34]:
model = mz.models.KNRM({"embedding_freeze":True})

model.params['task'] = ranking_task
model.params['embedding'] = embedding_matrix
model.params['kernel_num'] = 21
model.params['sigma'] = 0.1
model.params['exact_sigma'] = 0.001

model.build()

print(model)
print('Trainable params: ', sum(p.numel() for p in model.parameters() if p.requires_grad))

KNRM(
  (embedding): Embedding(254001, 300, padding_idx=0)
  (kernels): ModuleList(
    (0): GaussianKernel()
    (1): GaussianKernel()
    (2): GaussianKernel()
    (3): GaussianKernel()
    (4): GaussianKernel()
    (5): GaussianKernel()
    (6): GaussianKernel()
    (7): GaussianKernel()
    (8): GaussianKernel()
    (9): GaussianKernel()
    (10): GaussianKernel()
    (11): GaussianKernel()
    (12): GaussianKernel()
    (13): GaussianKernel()
    (14): GaussianKernel()
    (15): GaussianKernel()
    (16): GaussianKernel()
    (17): GaussianKernel()
    (18): GaussianKernel()
    (19): GaussianKernel()
    (20): GaussianKernel()
  )
  (out): Linear(in_features=21, out_features=1, bias=True)
)
Trainable params:  22


In [35]:
optimizer = torch.optim.Adadelta(model.parameters())

trainer = mz.trainers.Trainer(
    model=model,
    optimizer=optimizer,
    trainloader=trainloader,
    validloader=testloader,
    validate_interval=None,
    epochs=1
)

In [36]:
trainer.run()

HBox(children=(FloatProgress(value=0.0, max=896625.0), HTML(value='')))

[Iter-896625 Loss-0.781]:
  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.1901 - normalized_discounted_cumulative_gain@5(0.0): 0.166 - mean_average_precision(0.0): 0.1341

Cost time: 22066.634841680527s


In [17]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [50]:
del left_train
del right_train
del relation_train
torch.cuda.empty_cache()

In [73]:
for _, group in groups:
    labels = group.label.unique()
    for label in labels[:-1]:
        pos_samples = group[group.label == label]
        pos_samples = pd.concat([pos_samples] * num_dup)
        neg_samples = group[group.label < label]
        for _, pos_sample in pos_samples.iterrows():
            pos_sample = pd.DataFrame([pos_sample])
            neg_sample = neg_samples.sample(num_neg, replace=True)
            pairs.extend((pos_sample, neg_sample))
new_relation = pd.concat(pairs, ignore_index=True)

Unnamed: 0,id_left,id_right,label
0,301,FBIS3-10169,0.0
1,301,FBIS3-10243,1.0
2,301,FBIS3-10319,0.0
3,301,FBIS3-10397,1.0
4,301,FBIS3-10491,1.0
5,301,FBIS3-10635,0.0
6,301,FBIS3-10721,1.0
7,301,FBIS3-10910,1.0
8,301,FBIS3-10937,1.0
9,301,FBIS3-11028,0.0


In [59]:
for (inputs, target) in trainer._trainloader:
    print(inputs)
            

{'text_left': tensor([[ 7789, 42596, 22650],
        [ 7789, 42596, 22650],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22650],
        [    0, 28493,  7117],
        [    0, 28493,  7117],
        [36280,  7838, 13266],
        [36280,  7838, 13266],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0,  8243, 21206],
        [    0,  8243, 21206],
        [51038, 24367, 41541],
        [51038, 24367, 41541],
        [27360, 50482, 51742],
        [27360, 50482, 51742],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [20286, 48188, 33125],
        [20286, 48188, 33

  x[key] = np.array(val)


{'text_left': tensor([[42861, 19208, 41324],
        [42861, 19208, 41324],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [    0, 33384, 53887],
        [    0, 33384, 53887],
        [27360, 50482, 51742],
        [27360, 50482, 51742],
        [55144, 54803, 12749],
        [55144, 54803, 12749],
        [    0, 28493,  7117],
        [    0, 28493,  7117],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [12974,  7478, 15610],
        [12974,  7478, 15610],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [ 6490, 14634,  9181],
        [ 6490, 14634,  9181],
        [27360, 50482, 51742],
        [27360, 50482, 51742],
        [10604, 54650, 18466],
        [10604, 54650, 18466],
        [    0, 27834, 20771],
        [    0, 27834, 20

{'text_left': tensor([[    0,     0, 40277],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [    0, 42254, 27691],
        [    0, 42254, 27691],
        [    0, 42254, 27691],
        [    0, 42254, 27691],
        [42861, 19208, 41324],
        [42861, 19208, 41324],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [31015, 34759,  9790],
        [31015, 34759,  9790],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22650],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [48116, 48685, 35357],
        [48116, 48685, 35357],
        [    0, 27834, 20771],
        [    0, 27834, 20

{'text_left': tensor([[48116, 48685, 35357],
        [48116, 48685, 35357],
        [ 8272,  9612, 18360],
        [ 8272,  9612, 18360],
        [    0, 33384, 53887],
        [    0, 33384, 53887],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [27360, 50482, 51742],
        [27360, 50482, 51742],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 32673, 18360],
        [    0, 32673, 18360],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [ 6490, 14634,  9181],
        [ 6490, 14634,  9181],
        [    0, 26068,  6778],
        [    0, 26068,  6778],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [    0, 33384, 53887],
        [    0, 33384, 53887],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [31015, 34759,  9790],
        [31015, 34759,  9

{'text_left': tensor([[    0, 15118, 16761, 53569],
        [    0, 15118, 16761, 53569],
        [    0,     0, 27056, 20330],
        [    0,     0, 27056, 20330],
        [    0, 52937, 26210,  6606],
        [    0, 52937, 26210,  6606],
        [    0, 20286, 48188, 33125],
        [    0, 20286, 48188, 33125],
        [    0,     0, 27834, 20771],
        [    0,     0, 27834, 20771],
        [    0,     0, 27834, 20771],
        [    0,     0, 27834, 20771],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834],
        [    0,     0, 27834, 20771],
        [    0,     0, 27834, 20771],
        [    0,     0,     0, 51572],
        [    0,     0,     0, 51572],
        [    0, 34449,  7319, 40694],
        [    0, 34449,  7319, 40694],
        [    0,     0,     0, 51572],
        [    0,     0,     0, 51572],
        [    0, 20286, 48188, 33125],
        [    0, 20286, 48188, 33125],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834]

{'text_left': tensor([[    0, 50697, 31902],
        [    0, 50697, 31902],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [27028, 23128, 13021],
        [27028, 23128, 13021],
        [ 8272,  9612, 18360],
        [ 8272,  9612, 18360],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [27028, 23128, 13021],
        [27028, 23128, 13021],
        [    0, 27056, 20330],
        [    0, 27056, 20330],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [27360, 50482, 51742],
        [27360, 50482, 51742],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [27360, 50482, 51742],
        [27360, 50482, 51

{'text_left': tensor([[    0, 27834, 20771],
        [    0, 27834, 20771],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [    0, 50499, 41164],
        [    0, 50499, 41164],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [38649, 51329, 45206],
        [38649, 51329, 45206],
        [38229, 47869, 20374],
        [38229, 47869, 20374],
        [    0, 28493,  7117],
        [    0, 28493,  7117],
        [36280,  7838, 13266],
        [36280,  7838, 13266],
        [    0, 33384, 53887],
        [    0, 33384, 53887],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [    0, 16989, 53891],
        [    0, 16989, 53891],
        [    0, 19758, 48591],
        [    0, 19758, 48591],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22650],
        [ 8272,  9612, 18360],
        [ 8272,  9612, 18360],
        [    0, 19758, 48591],
        [    0, 19758, 48

{'text_left': tensor([[10685, 43916, 16082],
        [10685, 43916, 16082],
        [34449,  7319, 40694],
        [34449,  7319, 40694],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [55144, 54803, 12749],
        [55144, 54803, 12749],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [10604, 54650, 18466],
        [10604, 54650, 18466],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 42254, 27691],
        [    0, 42254, 27691],
        [10685, 43916, 16082],
        [10685, 43916, 16082],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [27028, 23128, 13021],
        [27028, 23128, 13021],
        [52937, 26210,  6606],
        [52937, 26210,  6606],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [    0,     0, 51572],
        [    0,     0, 51

{'text_left': tensor([[20286, 48188, 33125],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [51193, 21215,  8957],
        [51193, 21215,  8957],
        [    0, 19758, 48591],
        [    0, 19758, 48591],
        [27360, 50482, 51742],
        [27360, 50482, 51742],
        [    0, 19758, 48591],
        [    0, 19758, 48591],
        [35951, 23194, 41128],
        [35951, 23194, 41128],
        [31015, 34759,  9790],
        [31015, 34759,  9790],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [38649, 51329, 45206],
        [38649, 51329, 45206],
        [ 8272,  9612, 18360],
        [ 8272,  9612, 18360],
        [38649, 51329, 45206],
        [38649, 51329, 45206],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 50697, 31902],
        [    0, 50697, 31902],
        [    0, 27056, 20330],
        [    0, 27056, 20

{'text_left': tensor([[36280,  7838, 13266],
        [36280,  7838, 13266],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 28493,  7117],
        [    0, 28493,  7117],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [10604, 54650, 18466],
        [10604, 54650, 18466],
        [31015, 34759,  9790],
        [31015, 34759,  9790],
        [    0, 19758, 48591],
        [    0, 19758, 48591],
        [    0, 19758, 48591],
        [    0, 19758, 48591],
        [    0, 50697, 31902],
        [    0, 50697, 31902],
        [52937, 26210,  6606],
        [52937, 26210,  6606],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [27360, 50482, 51742],
        [27360, 50482, 51742],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [    0, 50697, 31902],
        [    0, 50697, 31902],
        [    0, 50697, 31902],
        [    0, 50697, 31902],
        [    0, 27834, 20771],
        [    0, 27834, 20

{'text_left': tensor([[22634, 34825, 23834],
        [22634, 34825, 23834],
        [55144, 54803, 12749],
        [55144, 54803, 12749],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22650],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [    0, 42254, 27691],
        [    0, 42254, 27691],
        [52937, 26210,  6606],
        [52937, 26210,  6606],
        [    0, 16989, 53891],
        [    0, 16989, 53891],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [    0, 50697, 31902],
        [    0, 50697, 31902],
        [42861, 19208, 41324],
        [42861, 19208, 41324],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [36280,  7838, 13266],
        [36280,  7838, 13

{'text_left': tensor([[    0, 28493,  7117],
        [    0, 28493,  7117],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 26068,  6778],
        [    0, 26068,  6778],
        [42861, 19208, 41324],
        [42861, 19208, 41324],
        [31015, 34759,  9790],
        [31015, 34759,  9790],
        [    0, 42254, 27691],
        [    0, 42254, 27691],
        [42861, 19208, 41324],
        [42861, 19208, 41324],
        [10604, 54650, 18466],
        [10604, 54650, 18466],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [31015, 34759,  9790],
        [31015, 34759,  9790],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22650],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [    0, 16989, 53891],
        [    0, 16989, 53891],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22

{'text_left': tensor([[10685, 43916, 16082],
        [10685, 43916, 16082],
        [38649, 51329, 45206],
        [38649, 51329, 45206],
        [51038, 24367, 41541],
        [51038, 24367, 41541],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0, 28493,  7117],
        [    0, 28493,  7117],
        [27028, 23128, 13021],
        [27028, 23128, 13021],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [    0, 28493,  7117],
        [    0, 28493,  7117],
        [34449,  7319, 40694],
        [34449,  7319, 40694],
        [ 8272,  9612, 18360],
        [ 8272,  9612, 18360],
        [10604, 54650, 18466],
        [10604, 54650, 18466],
        [    0, 26068,  6778],
        [    0, 26068,  6778],
        [    0, 28493,  7117],
        [    0, 28493,  7117],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [ 8272,  9612, 18360],
        [ 8272,  9612, 18360],
        [    0, 15246, 40197],
        [    0, 15246, 40

{'text_left': tensor([[31015, 34759,  9790],
        [31015, 34759,  9790],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22650],
        [    0, 16989, 53891],
        [    0, 16989, 53891],
        [    0, 28493,  7117],
        [    0, 28493,  7117],
        [48116, 48685, 35357],
        [48116, 48685, 35357],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [ 8272,  9612, 18360],
        [ 8272,  9612, 18360],
        [31015, 34759,  9790],
        [31015, 34759,  9790],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [27360, 50482, 51742],
        [27360, 50482, 51742],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [    0, 28493,  7117],
        [    0, 28493,  7

{'text_left': tensor([[    0, 48116, 48685, 35357],
        [    0, 48116, 48685, 35357],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834],
        [    0, 20286, 48188, 33125],
        [    0, 20286, 48188, 33125],
        [    0, 15118, 16761, 53569],
        [    0, 15118, 16761, 53569],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834],
        [    0,     0, 27834, 20771],
        [    0,     0, 27834, 20771],
        [    0,     0, 19758, 48591],
        [    0,     0, 19758, 48591],
        [    0, 31015, 34759,  9790],
        [    0, 31015, 34759,  9790],
        [    0,     0, 50697, 31902],
        [    0,     0, 50697, 31902],
        [    0,  8272,  9612, 18360],
        [    0,  8272,  9612, 18360],
        [    0,  7789, 42596, 22650],
        [    0,  7789, 42596, 22650],
        [    0,     0, 19758, 48591],
        [    0,     0, 19758, 48591]

{'text_left': tensor([[20286, 48188, 33125],
        [20286, 48188, 33125],
        [48116, 48685, 35357],
        [48116, 48685, 35357],
        [    0, 26068,  6778],
        [    0, 26068,  6778],
        [    0, 28493,  7117],
        [    0, 28493,  7117],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [35951, 23194, 41128],
        [35951, 23194, 41128],
        [31015, 34759,  9790],
        [31015, 34759,  9790],
        [10685, 43916, 16082],
        [10685, 43916, 16082],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0, 50697, 31902],
        [    0, 50697, 31902],
        [    0, 26068,  6778],
        [    0, 26068,  6778],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [12974,  7478, 15610],
        [12974,  7478, 15610],
        [42861, 19208, 41324],
        [42861, 19208, 41324],
        [12974,  7478, 15610],
        [12974,  7478, 15610],
        [ 8272,  9612, 18360],
        [ 8272,  9612, 18

{'text_left': tensor([[    0,     0, 16989, 53891],
        [    0,     0, 16989, 53891],
        [    0, 10685, 43916, 16082],
        [    0, 10685, 43916, 16082],
        [    0,     0,     0, 51572],
        [    0,     0,     0, 51572],
        [    0,     0, 19758, 48591],
        [    0,     0, 19758, 48591],
        [    0,     0,     0, 51572],
        [    0,     0,     0, 51572],
        [    0,     0, 50697, 31902],
        [    0,     0, 50697, 31902],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834],
        [    0,     0,     0, 51572],
        [    0,     0,     0, 51572],
        [    0, 34449,  7319, 40694],
        [    0, 34449,  7319, 40694],
        [    0,     0,     0, 51572],
        [    0,     0,     0, 51572],
        [    0, 27028, 23128, 13021],
        [    0, 27028, 23128, 13021],
        [    0, 42861, 19208, 41324],
        [    0, 42861, 19208, 41324],
        [    0,     0, 28493,  7117],
        [    0,     0, 28493,  7117]

{'text_left': tensor([[ 6490, 14634,  9181],
        [ 6490, 14634,  9181],
        [    0,  8243, 21206],
        [    0,  8243, 21206],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22650],
        [    0, 50697, 31902],
        [    0, 50697, 31902],
        [55144, 54803, 12749],
        [55144, 54803, 12749],
        [    0, 19758, 48591],
        [    0, 19758, 48591],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [52937, 26210,  6606],
        [52937, 26210,  6606],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [38229, 47869, 20374],
        [38229, 47869, 20374],
        [ 8272,  9612, 18360],
        [ 8272,  9612, 18360],
        [    0,  8243, 21206],
        [    0,  8243, 21

{'text_left': tensor([[    0, 50697, 31902],
        [    0, 50697, 31902],
        [    0, 42254, 27691],
        [    0, 42254, 27691],
        [    0, 27056, 20330],
        [    0, 27056, 20330],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [35951, 23194, 41128],
        [35951, 23194, 41128],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [15118, 16761, 53569],
        [15118, 16761, 53569],
        [    0, 42254, 27691],
        [    0, 42254, 27691],
        [    0, 28493,  7117],
        [    0, 28493,  7

{'text_left': tensor([[20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0,     0, 40277],
        [    0,     0, 40277],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [36280,  7838, 13266],
        [36280,  7838, 13266],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0, 42254, 27691],
        [    0, 42254, 27691],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [22634, 34825, 23834],
        [    0, 27834, 20771],
        [    0, 27834, 20771],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22650],
        [    0, 26068,  6778],
        [    0, 26068,  6778],
        [20286, 48188, 33125],
        [20286, 48188, 33125],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [    0,     0, 51572],
        [34449,  7319, 40694],
        [34449,  7319, 40694],
        [ 7789, 42596, 22650],
        [ 7789, 42596, 22

{'text_left': tensor([[    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834],
        [    0, 15118, 16761, 53569],
        [    0, 15118, 16761, 53569],
        [    0,     0, 19758, 48591],
        [    0,     0, 19758, 48591],
        [    0, 15118, 16761, 53569],
        [    0, 15118, 16761, 53569],
        [    0, 20286, 48188, 33125],
        [    0, 20286, 48188, 33125],
        [    0,  7789, 42596, 22650],
        [    0,  7789, 42596, 22650],
        [    0, 10604, 54650, 18466],
        [    0, 10604, 54650, 18466],
        [    0,     0, 19758, 48591],
        [    0,     0, 19758, 48591],
        [    0, 42861, 19208, 41324],
        [    0, 42861, 19208, 41324],
        [    0, 10685, 43916, 16082],
        [    0, 10685, 43916, 16082],
        [    0,     0, 27834, 20771],
        [    0,     0, 27834, 20771],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834],
        [    0, 22634, 34825, 23834]

In [66]:
import tqdm
with tqdm(enumerate(trainer._trainloader)) as pbar:
    for step, (inputs, target) in pbar:
        outputs = self._model(inputs)
        

TypeError: 'module' object is not callable

In [37]:
right_vali.loc[relation_vali['id_right'].unique()]

Passing list-likes to .loc or [] with any missing label will raise
KeyError in the future, you can use .reindex() as an alternative.

See the documentation here:
https://pandas.pydata.org/pandas-docs/stable/indexing.html#deprecate-loc-reindex-listlike
  """Entry point for launching an IPython kernel.


Unnamed: 0_level_0,text_right
id_right,Unnamed: 1_level_1
FR940203-0-00059,nmf has specify the manner and location in whi...
FR940617-0-00103,a act means the marine mammal protection act o...
FR940617-0-00104,3 category iii i a there is information indica...
FR940419-2-00009,id 032894e marine mammal agency national marin...
FR940203-0-00069,comment certain impact while insignificant ind...
FR940228-2-00026,id 020994c marine mammal agency national marin...
FR941216-2-00020,id 120194a marine mammal agency national marin...
FR941227-2-00087,fish and wildlife service available of draft r...
FR940127-1-00058,3 saiga antelope saiga tatarica eia submit a d...
FR941006-2-00015,id 091694c marine mammal agency national marin...


In [32]:
num_dup = 100
num_neg = 1
pairs = []
groups = relation.sort_values(
    'label', ascending=False).groupby('id_left')
for leftid, group in groups:
    print(leftid)
    labels = group.label.unique()
    for label in labels[:-1]:
        pos_samples = group[group.label == label]
        pos_samples = pd.concat([pos_samples] * num_dup)
        pos_samples.index = np.arange(0,pos_samples.shape[0] * (num_neg + 1), num_neg + 1)
        neg_samples = group[group.label < label]
        neg_samples_new = neg_samples.sample(num_neg * pos_samples.shape[0], replace=True)
        neg_samples_new.index = np.concatenate([np.arange(i, i + num_neg) for i in np.arange(1,pos_samples.shape[0] * (num_neg + 1) + 1,num_neg + 1)])
        pairs.append(pd.concat([pos_samples, neg_samples_new]).sort_index())
new_relation = pd.concat(pairs, ignore_index=True)

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700


In [33]:
new_relation = pd.concat(pairs, ignore_index=True)

In [34]:
new_relation

Unnamed: 0,id_left,id_right,label
0,301,FBIS3-10082,1.0
1,301,FBIS4-2042,0.0
2,301,FT942-17001,1.0
3,301,FBIS4-68746,0.0
4,301,FT941-10611,1.0
...,...,...,...
3482395,700,LA110790-0028,0.0
3482396,700,LA043089-0044,1.0
3482397,700,LA102290-0023,0.0
3482398,700,LA050290-0015,1.0


In [85]:
pairs

[    id_left    id_right  label
 849     301  FT932-4965    1.0,
      id_left          id_right  label
 1154     301       FT942-10977    0.0
 846      301  FR940727-0-00091    0.0
 1456     301     LA091190-0102    0.0
 509      301        FBIS4-3044    0.0
 1234     301       FT944-11113    0.0
 803      301  FR940503-2-00169    0.0
 1026     301        FT931-7529    0.0
 396      301       FBIS4-10739    0.0
 457      301       FBIS4-21139    0.0
 904      301       FT921-11079    0.0
 297      301       FBIS3-45756    0.0
 566      301       FBIS4-41839    0.0
 343      301       FBIS3-59962    0.0
 1059     301        FT932-6233    0.0
 1477     301     LA100790-0068    0.0
 1049     301        FT932-3338    0.0
 326      301       FBIS3-56182    0.0
 1190     301        FT942-8808    0.0
 906      301       FT921-12538    0.0
 775      301  FR940202-2-00150    0.0
 1031     301        FT931-9535    0.0
 396      301       FBIS4-10739    0.0
 863      301  FR940930-2-00058    0.0