In [2]:
from data_loaders.AUS_dataset import AUSDataset, AUSPytorchDataset

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
from project_settings import EOS_TOK, EOC_TOK
import pdb
from tqdm import tqdm, trange
import numpy as np
from models.Model import MeanModel, TruncatModel, NNModel
from project_settings import ExpConfig, DatasetConfig
from utils import chunkify, encode_chunks, transform_chunk_to_dict

## Metric

In [3]:
from metrics import micro_contrast,macro_contrast

## misc

In [5]:
# def transform_chunk_to_dict(chunk):
#     n,dim=chunk.size()
#     chunk_dict={}
#     chunk_dict["input_ids"]=chunk[0]
#     chunk_dict["token_type_ids"]=chunk[1]
#     chunk_dict["attention_mask"]=chunk[2]
#     return chunk_dict

## Benchmark tranformer

Here we test the transformer via a **retrieval task**. We want to pair case **description** with the right **catchphrases** for each case in our legal dataset.

As seen preivously in the **dataset analysis**, the case description are in general very long, average length is **34k** chars,thus around **6k tokens**. However, transformer relying to squared attention only takes 512 tokens. So in this very naive baseline benchmark, we just truncate the sentence at **512th token** 

In [None]:
# cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

ds = AUSDataset()
exp_config = ExpConfig("MeanModel")
train_dataloader = ds.get_data_loader(split='train', batch_size=2, shuffle=True)
val_dataloader = ds.get_data_loader(split='val', batch_size=2, shuffle=True)
test_dataloader = ds.get_data_loader(split='test', batch_size=2, shuffle=True)

tokenizer = AutoTokenizer.from_pretrained(exp_config.uri)
encoder = AutoModel.from_pretrained(exp_config.uri) 
test_micro_contrast = 0  # running loss
test_macro_contrast = 0
nb_test_steps = 0


for step, batch in tqdm(enumerate(test_dataloader)):
    # Add batch to GPU
    "batch = tuple(t.to(device) for t in batch)"
    # Unpack the inputs from our dataloader
    sentences, catchphrases = batch  # len(sentences)=2, len(catchphrases)=2
    # Clear out the gradients (by default they accumulate)

    sentences_a, catchphrase_a = sentences[0], catchphrases[0]
    sentences_b, catchphrase_b = sentences[1], catchphrases[1]
    batch_catchphrase_a = catchphrase_a.split(EOC_TOK)
    batch_catchphrase_b = catchphrase_b.split(EOC_TOK)

    encoded_batch_catchphrase_a = tokenizer(batch_catchphrase_a, truncation=True, return_tensors="pt",
                                            padding='max_length', max_length=512)
    encoded_batch_catchphrase_b = tokenizer(batch_catchphrase_b, truncation=True, return_tensors="pt",
                                            padding='max_length', max_length=512)

    encoded_sentence_a = tokenizer(sentences_a, truncation=True, return_tensors="pt", padding='max_length',
                                   max_length=512)
    encoded_sentence_b = tokenizer(sentences_b, truncation=True, return_tensors="pt", padding='max_length',
                                   max_length=512)
    print("sentences_a length:", len(sentences_a))
    _, batch_catchphrase_embedding_a = encoder(**encoded_batch_catchphrase_a)  # [7, 768]
    _, batch_catchphrase_embedding_b = encoder(**encoded_batch_catchphrase_b)  # [13,768]

    _, sentence_embedding_a = encoder(**encoded_sentence_a)  # [1, 768]
    _, sentence_embedding_b = encoder(**encoded_sentence_b)  # [1, 768]

    left_left = torch.cdist(sentence_embedding_a, batch_catchphrase_embedding_a, p=2.0)  # [1, 768]*[7, 768]=[1, 7]
    left_right = torch.cdist(sentence_embedding_a, batch_catchphrase_embedding_b,
                             p=2.0)  # [1, 768]*[13, 768]=[1, 13]

    right_right = torch.cdist(sentence_embedding_b, batch_catchphrase_embedding_b,
                              p=2.0)  # [1, 768]*[13, 768]=[1, 13]
    right_left = torch.cdist(sentence_embedding_b, batch_catchphrase_embedding_a, p=2.0)  # [1, 768]*[7, 768]=[1, 7]

    nb_test_steps += 1
    test_macro_contrast += macro_contrast(left_left, left_right, right_right, right_left)
    test_micro_contrast += micro_contrast(left_left, left_right, right_right, right_left)

    print("Test micro contrast: {}".format(test_micro_contrast / nb_test_steps))
    print("Test macro contrast: {}".format(test_macro_contrast / nb_test_steps))

0it [00:00, ?it/s]

sentences_a length: 9849


1it [00:00,  2.46it/s]

Test micro contrast: -0.2915652096271515
Test macro contrast: -0.2915654182434082
sentences_a length: 43829


2it [00:01,  2.08it/s]

Test micro contrast: -0.15208914875984192
Test macro contrast: -0.15208923816680908
sentences_a length: 37594


3it [00:01,  1.99it/s]

Test micro contrast: -0.059666335582733154
Test macro contrast: -0.07713142782449722
sentences_a length: 16457


4it [00:02,  1.43it/s]

Test micro contrast: -0.06612031161785126
Test macro contrast: -0.060889363288879395
sentences_a length: 13591


5it [00:03,  1.58it/s]

Test micro contrast: -0.1408812701702118
Test macro contrast: -0.049582671374082565
sentences_a length: 5725


6it [00:04,  1.40it/s]

Test micro contrast: -0.16578318178653717
Test macro contrast: -0.05182870104908943
sentences_a length: 4359


7it [00:04,  1.66it/s]

Test micro contrast: -0.19738708436489105
Test macro contrast: -0.07518605142831802
sentences_a length: 28842


8it [00:04,  1.84it/s]

Test micro contrast: -0.25972646474838257
Test macro contrast: -0.10576280951499939
sentences_a length: 11473


9it [00:05,  1.96it/s]

Test micro contrast: -0.23750482499599457
Test macro contrast: -0.08275498449802399
sentences_a length: 8456


10it [00:05,  2.23it/s]

Test micro contrast: -0.19420306384563446
Test macro contrast: -0.029058266431093216
sentences_a length: 6620


11it [00:06,  2.21it/s]

Test micro contrast: -0.14246319234371185
Test macro contrast: -0.01392724271863699
sentences_a length: 31855


12it [00:06,  1.94it/s]

Test micro contrast: -0.14163948595523834
Test macro contrast: 0.004203418735414743


13it [00:07,  2.04it/s]

sentences_a length: 36872
Test micro contrast: -0.15020564198493958
Test macro contrast: 0.004489348269999027


14it [00:07,  2.04it/s]

sentences_a length: 29055
Test micro contrast: -0.15139184892177582
Test macro contrast: -0.007746492046862841
sentences_a length: 7935


15it [00:08,  1.84it/s]

Test micro contrast: -0.09898453205823898
Test macro contrast: -0.008049090392887592
sentences_a length: 76543


16it [00:09,  1.44it/s]

Test micro contrast: -0.08227216452360153
Test macro contrast: -0.012295186519622803
sentences_a length: 46535


17it [00:10,  1.48it/s]

Test micro contrast: -0.06599713116884232
Test macro contrast: -0.013866831548511982
sentences_a length: 3686


18it [00:10,  1.74it/s]

Test micro contrast: -0.022096101194620132
Test macro contrast: -0.03038530796766281


19it [00:10,  1.81it/s]

sentences_a length: 42466
Test micro contrast: -0.022785818204283714
Test macro contrast: -0.030638758093118668


20it [00:11,  2.03it/s]

sentences_a length: 14931
Test micro contrast: -0.047829221934080124
Test macro contrast: -0.03186030313372612
sentences_a length: 18023


21it [00:11,  2.03it/s]

Test micro contrast: -0.058846257627010345
Test macro contrast: -0.035361915826797485
sentences_a length: 14096


22it [00:12,  1.65it/s]

Test micro contrast: -0.05002466216683388
Test macro contrast: -0.03876715525984764
sentences_a length: 4481


23it [00:12,  1.92it/s]

Test micro contrast: -0.045403748750686646
Test macro contrast: -0.03463570028543472
sentences_a length: 12980


24it [00:13,  2.06it/s]

Test micro contrast: -0.009089861996471882
Test macro contrast: -0.03550989553332329
sentences_a length: 1951


25it [00:13,  2.10it/s]

Test micro contrast: -0.03269847482442856
Test macro contrast: -0.041228197515010834
sentences_a length: 6401


26it [00:14,  2.09it/s]

Test micro contrast: -0.03033062443137169
Test macro contrast: -0.03752788156270981
sentences_a length: 2845


27it [00:14,  2.48it/s]

Test micro contrast: -0.055818770080804825
Test macro contrast: -0.04925359785556793
sentences_a length: 30020


28it [00:15,  2.05it/s]

Test micro contrast: -0.052247993648052216
Test macro contrast: -0.04458241909742355
sentences_a length: 39861


29it [00:16,  1.51it/s]

Test micro contrast: -0.04358787089586258
Test macro contrast: -0.04476936161518097
sentences_a length: 9912


30it [00:16,  1.78it/s]

Test micro contrast: -0.02267518825829029
Test macro contrast: -0.01812257058918476


31it [00:16,  1.92it/s]

sentences_a length: 5403
Test micro contrast: -0.021871112287044525
Test macro contrast: -0.026183927431702614
sentences_a length: 25180


32it [00:17,  1.67it/s]

Test micro contrast: -0.04150398448109627
Test macro contrast: -0.029904872179031372
sentences_a length: 31675


33it [00:18,  1.37it/s]

Test micro contrast: -0.08122880756855011
Test macro contrast: -0.03610634803771973
sentences_a length: 28761


34it [00:19,  1.42it/s]

Test micro contrast: -0.08646271377801895
Test macro contrast: -0.042667388916015625
sentences_a length: 51217


35it [00:20,  1.23it/s]

Test micro contrast: -0.08477326482534409
Test macro contrast: -0.04952338710427284


36it [00:21,  1.37it/s]

sentences_a length: 7098
Test micro contrast: -0.09187286347150803
Test macro contrast: -0.053943634033203125


37it [00:21,  1.51it/s]

sentences_a length: 8538
Test micro contrast: -0.08995797485113144
Test macro contrast: -0.054261624813079834


38it [00:21,  1.76it/s]

sentences_a length: 12820
Test micro contrast: -0.08996535837650299
Test macro contrast: -0.05359787121415138
sentences_a length: 20795


39it [00:22,  1.60it/s]

Test micro contrast: -0.08718250691890717
Test macro contrast: -0.05174751952290535


40it [00:23,  1.47it/s]

sentences_a length: 66738
Test micro contrast: -0.0874921903014183
Test macro contrast: -0.05294307321310043
sentences_a length: 57633


41it [00:24,  1.26it/s]

Test micro contrast: -0.08664834499359131
Test macro contrast: -0.05376070365309715
sentences_a length: 24273


42it [00:25,  1.38it/s]

Test micro contrast: -0.08495472371578217
Test macro contrast: -0.05285012722015381
sentences_a length: 24988


43it [00:25,  1.64it/s]

Test micro contrast: -0.10063589364290237
Test macro contrast: -0.05217268690466881
sentences_a length: 39933


44it [00:26,  1.66it/s]

Test micro contrast: -0.09374707192182541
Test macro contrast: -0.050816938281059265
sentences_a length: 3738


45it [00:26,  1.92it/s]

Test micro contrast: -0.09634701162576675
Test macro contrast: -0.05354951694607735
sentences_a length: 18375


46it [00:26,  1.83it/s]

Test micro contrast: -0.10679062455892563
Test macro contrast: -0.05282467231154442


47it [00:27,  1.86it/s]

sentences_a length: 9564
Test micro contrast: -0.11388696730136871
Test macro contrast: -0.061069224029779434
sentences_a length: 4147


48it [00:28,  1.50it/s]

Test micro contrast: -0.1165132224559784
Test macro contrast: -0.06158343330025673
sentences_a length: 18508


49it [00:29,  1.49it/s]

Test micro contrast: -0.1248675286769867
Test macro contrast: -0.0673394501209259
sentences_a length: 1028


50it [00:29,  1.47it/s]

Test micro contrast: -0.08492667227983475
Test macro contrast: -0.06336750090122223
sentences_a length: 41055


51it [00:31,  1.15it/s]

Test micro contrast: -0.08275644481182098
Test macro contrast: -0.06206556037068367
sentences_a length: 48924


52it [00:31,  1.19it/s]

Test micro contrast: -0.07971576601266861
Test macro contrast: -0.06414929032325745


53it [00:32,  1.41it/s]

sentences_a length: 29804
Test micro contrast: -0.07617688924074173
Test macro contrast: -0.06090410798788071
sentences_a length: 34544


54it [00:32,  1.46it/s]

Test micro contrast: -0.07111556828022003
Test macro contrast: -0.05955566465854645
sentences_a length: 16492


55it [00:33,  1.30it/s]

Test micro contrast: -0.0747065320611
Test macro contrast: -0.061950623989105225
sentences_a length: 4961


56it [00:34,  1.50it/s]

Test micro contrast: -0.0666022077202797
Test macro contrast: -0.06038342043757439
sentences_a length: 9619


57it [00:34,  1.70it/s]

Test micro contrast: -0.09173166006803513
Test macro contrast: -0.06360092014074326
sentences_a length: 14358


58it [00:35,  1.89it/s]

Test micro contrast: -0.09134522080421448
Test macro contrast: -0.06476467847824097
sentences_a length: 7946


59it [00:35,  1.61it/s]

Test micro contrast: -0.09457195550203323
Test macro contrast: -0.06705283373594284
sentences_a length: 7102


60it [00:36,  1.77it/s]

Test micro contrast: -0.09365064650774002
Test macro contrast: -0.06859808415174484
sentences_a length: 7328


61it [00:36,  2.08it/s]

Test micro contrast: -0.0991426482796669
Test macro contrast: -0.0725482702255249
sentences_a length: 8596


62it [00:37,  2.25it/s]

Test micro contrast: -0.10159669816493988
Test macro contrast: -0.07312946021556854
sentences_a length: 6493


63it [00:38,  1.48it/s]

Test micro contrast: -0.10560236871242523
Test macro contrast: -0.07332772761583328
sentences_a length: 16077


64it [00:38,  1.63it/s]

Test micro contrast: -0.10659968107938766
Test macro contrast: -0.07294259965419769
sentences_a length: 14898


65it [00:39,  1.90it/s]

Test micro contrast: -0.10746711492538452
Test macro contrast: -0.07432783395051956
sentences_a length: 3798


66it [00:40,  1.40it/s]

Test micro contrast: -0.1094883382320404
Test macro contrast: -0.07412353903055191
sentences_a length: 11647


67it [00:40,  1.42it/s]

Test micro contrast: -0.08876711130142212
Test macro contrast: -0.0775754526257515
sentences_a length: 8726


68it [00:41,  1.44it/s]

Test micro contrast: -0.09003031253814697
Test macro contrast: -0.0772327110171318
sentences_a length: 52237


69it [00:42,  1.40it/s]

Test micro contrast: -0.0878654271364212
Test macro contrast: -0.0777420774102211
sentences_a length: 4497


70it [00:43,  1.23it/s]

Test micro contrast: -0.09790365397930145
Test macro contrast: -0.0784536600112915
sentences_a length: 10040


71it [00:44,  1.21it/s]

Test micro contrast: -0.09848601371049881
Test macro contrast: -0.07802968472242355
sentences_a length: 25666


72it [00:44,  1.23it/s]

Test micro contrast: -0.08629627525806427
Test macro contrast: -0.07949912548065186
sentences_a length: 9120


73it [00:45,  1.35it/s]

Test micro contrast: -0.07967273145914078
Test macro contrast: -0.0787789449095726
sentences_a length: 6645


74it [00:45,  1.55it/s]

Test micro contrast: -0.0798962339758873
Test macro contrast: -0.07739415019750595
sentences_a length: 16558


75it [00:46,  1.61it/s]

Test micro contrast: -0.08164047449827194
Test macro contrast: -0.07552285492420197
sentences_a length: 4290


76it [00:47,  1.66it/s]

Test micro contrast: -0.08630166202783585
Test macro contrast: -0.07710706442594528
sentences_a length: 24349


77it [00:47,  1.80it/s]

Test micro contrast: -0.08636971563100815
Test macro contrast: -0.07813010364770889
sentences_a length: 62731


78it [00:48,  1.56it/s]

Test micro contrast: -0.0817670226097107
Test macro contrast: -0.07782872766256332
sentences_a length: 34654


79it [00:48,  1.65it/s]

Test micro contrast: -0.10452708601951599
Test macro contrast: -0.07818122208118439
sentences_a length: 7999


80it [00:49,  1.81it/s]

Test micro contrast: -0.09899888932704926
Test macro contrast: -0.07875292748212814


81it [00:49,  1.79it/s]

sentences_a length: 6986
Test micro contrast: -0.10573559999465942
Test macro contrast: -0.07637592405080795
sentences_a length: 8008


82it [00:50,  2.08it/s]

Test micro contrast: -0.10559378564357758
Test macro contrast: -0.07729228585958481
sentences_a length: 29728


83it [00:50,  1.96it/s]

Test micro contrast: -0.09599396586418152
Test macro contrast: -0.0702342763543129
sentences_a length: 25881


84it [00:51,  1.56it/s]

Test micro contrast: -0.09613732248544693
Test macro contrast: -0.07026087492704391
sentences_a length: 20123


85it [00:52,  1.70it/s]

Test micro contrast: -0.09388694167137146
Test macro contrast: -0.06981449574232101
sentences_a length: 30282


86it [00:53,  1.44it/s]

Test micro contrast: -0.09265435487031937
Test macro contrast: -0.07000884413719177
sentences_a length: 49385


87it [00:54,  1.18it/s]

Test micro contrast: -0.09465117007493973
Test macro contrast: -0.0717281773686409
sentences_a length: 19927


88it [00:55,  1.20it/s]

Test micro contrast: -0.10307541489601135
Test macro contrast: -0.07116243988275528
sentences_a length: 7480


89it [00:55,  1.47it/s]

Test micro contrast: -0.10246115922927856
Test macro contrast: -0.07229117304086685
sentences_a length: 7512


90it [00:55,  1.60it/s]

Test micro contrast: -0.10554265975952148
Test macro contrast: -0.07057129591703415
sentences_a length: 10505


91it [00:56,  1.65it/s]

Test micro contrast: -0.10367537289857864
Test macro contrast: -0.06860993802547455
sentences_a length: 12361


92it [00:57,  1.76it/s]

Test micro contrast: -0.10332786291837692
Test macro contrast: -0.06905641406774521
sentences_a length: 13231


93it [00:57,  1.77it/s]

Test micro contrast: -0.11241645365953445
Test macro contrast: -0.07102367281913757
sentences_a length: 3300


94it [00:58,  1.83it/s]

Test micro contrast: -0.11087185144424438
Test macro contrast: -0.07010876387357712
sentences_a length: 22887


95it [00:58,  1.73it/s]

Test micro contrast: -0.10991932451725006
Test macro contrast: -0.0702228918671608
sentences_a length: 49918


96it [00:59,  1.58it/s]

Test micro contrast: -0.10895314067602158
Test macro contrast: -0.06967020779848099
sentences_a length: 18918


97it [01:00,  1.68it/s]

Test micro contrast: -0.10135995596647263
Test macro contrast: -0.0675693079829216
sentences_a length: 42223


98it [01:00,  1.58it/s]

Test micro contrast: -0.10721321403980255
Test macro contrast: -0.06868034601211548
sentences_a length: 8081


99it [01:00,  1.94it/s]

Test micro contrast: -0.10793555527925491
Test macro contrast: -0.06909677386283875
sentences_a length: 25109


100it [01:01,  2.04it/s]

Test micro contrast: -0.11433444917201996
Test macro contrast: -0.06613782048225403
sentences_a length: 6994


101it [01:01,  2.04it/s]

Test micro contrast: -0.11418984830379486
Test macro contrast: -0.06548433750867844
sentences_a length: 33942
Test micro contrast: -0.11258488148450851


102it [01:02,  1.81it/s]

Test macro contrast: -0.06577669084072113


103it [01:03,  1.47it/s]

sentences_a length: 66030
Test micro contrast: -0.11546555906534195
Test macro contrast: -0.06676119565963745
sentences_a length: 16272


104it [01:04,  1.44it/s]

Test micro contrast: -0.11555671691894531
Test macro contrast: -0.06693897396326065
sentences_a length: 13671


105it [01:04,  1.75it/s]

Test micro contrast: -0.1156507134437561
Test macro contrast: -0.06681861728429794
sentences_a length: 27637


106it [01:05,  1.50it/s]

Test micro contrast: -0.10236270725727081
Test macro contrast: -0.06563343107700348
sentences_a length: 9221


107it [01:06,  1.57it/s]

Test micro contrast: -0.10057570785284042
Test macro contrast: -0.06504914164543152
sentences_a length: 9681


108it [01:06,  1.79it/s]

Test micro contrast: -0.1070118099451065
Test macro contrast: -0.0636371597647667
sentences_a length: 14718


109it [01:07,  1.65it/s]

Test micro contrast: -0.12113679945468903
Test macro contrast: -0.06538040935993195
sentences_a length: 3324


110it [01:07,  1.57it/s]

Test micro contrast: -0.1104222759604454
Test macro contrast: -0.0662970319390297
sentences_a length: 4165


111it [01:08,  1.75it/s]

Test micro contrast: -0.10872430354356766
Test macro contrast: -0.06502936035394669
sentences_a length: 25851


112it [01:09,  1.50it/s]

Test micro contrast: -0.11066184937953949
Test macro contrast: -0.06639426201581955
sentences_a length: 48551


113it [01:10,  1.35it/s]

Test micro contrast: -0.11855035275220871
Test macro contrast: -0.06969314068555832
sentences_a length: 27856


114it [01:10,  1.49it/s]

Test micro contrast: -0.11644253879785538
Test macro contrast: -0.06597577780485153
sentences_a length: 5203


115it [01:11,  1.49it/s]

Test micro contrast: -0.115366131067276
Test macro contrast: -0.06409580260515213
sentences_a length: 55284


116it [01:12,  1.23it/s]

Test micro contrast: -0.10960303246974945
Test macro contrast: -0.06438794732093811
sentences_a length: 16712


117it [01:13,  1.24it/s]

Test micro contrast: -0.1114208772778511
Test macro contrast: -0.06588244438171387
sentences_a length: 19849


118it [01:13,  1.41it/s]

Test micro contrast: -0.11318126320838928
Test macro contrast: -0.0671706274151802
sentences_a length: 3843


119it [01:14,  1.54it/s]

Test micro contrast: -0.11695852130651474
Test macro contrast: -0.06766831874847412
sentences_a length: 8219


120it [01:15,  1.31it/s]

Test micro contrast: -0.1165732815861702
Test macro contrast: -0.06776262074708939
sentences_a length: 33805


121it [01:15,  1.36it/s]

Test micro contrast: -0.11959771811962128
Test macro contrast: -0.07036682963371277
sentences_a length: 8955


122it [01:16,  1.54it/s]

Test micro contrast: -0.12051709741353989
Test macro contrast: -0.0734056904911995
sentences_a length: 26265


123it [01:17,  1.44it/s]

Test micro contrast: -0.11895917356014252
Test macro contrast: -0.0724334716796875
sentences_a length: 17262


124it [01:17,  1.60it/s]

Test micro contrast: -0.12142619490623474
Test macro contrast: -0.07470331341028214
sentences_a length: 28095


125it [01:18,  1.52it/s]

Test micro contrast: -0.1220051720738411
Test macro contrast: -0.07596590369939804
sentences_a length: 5138


126it [01:19,  1.49it/s]

Test micro contrast: -0.1215498223900795
Test macro contrast: -0.07652760297060013
sentences_a length: 33780


127it [01:25,  2.36s/it]

Test micro contrast: -0.121297188103199
Test macro contrast: -0.07638566941022873
sentences_a length: 34246


128it [01:26,  1.89s/it]

Test micro contrast: -0.11722979694604874
Test macro contrast: -0.07640977203845978
sentences_a length: 6879


129it [01:26,  1.41s/it]

Test micro contrast: -0.11741777509450912
Test macro contrast: -0.07423307001590729
sentences_a length: 8301


130it [01:26,  1.10s/it]

Test micro contrast: -0.11407721042633057
Test macro contrast: -0.0764470249414444
sentences_a length: 18782


131it [01:27,  1.08it/s]

Test micro contrast: -0.11593083292245865
Test macro contrast: -0.07909170538187027
sentences_a length: 38665


132it [01:28,  1.06s/it]

Test micro contrast: -0.11300309002399445
Test macro contrast: -0.07897688448429108
sentences_a length: 11450


133it [01:29,  1.13it/s]

Test micro contrast: -0.11153977364301682
Test macro contrast: -0.08068326860666275
sentences_a length: 49874


134it [01:30,  1.08it/s]

Test micro contrast: -0.1119474545121193
Test macro contrast: -0.0821710005402565


135it [01:30,  1.32it/s]

sentences_a length: 30596
Test micro contrast: -0.11274603754281998
Test macro contrast: -0.08319015055894852
sentences_a length: 15751


136it [01:31,  1.46it/s]

Test micro contrast: -0.1141093373298645
Test macro contrast: -0.08347848802804947
sentences_a length: 43072


137it [01:31,  1.33it/s]

Test micro contrast: -0.11311952024698257
Test macro contrast: -0.08271225541830063
sentences_a length: 22614


138it [01:33,  1.10it/s]

Test micro contrast: -0.11347213387489319
Test macro contrast: -0.08248186856508255
sentences_a length: 16158


139it [01:33,  1.21it/s]

Test micro contrast: -0.11296045035123825
Test macro contrast: -0.08183258026838303
sentences_a length: 14286


140it [01:34,  1.45it/s]

Test micro contrast: -0.11331024020910263
Test macro contrast: -0.08150547742843628
sentences_a length: 16360


141it [01:35,  1.39it/s]

Test micro contrast: -0.11585424840450287
Test macro contrast: -0.07819641381502151
sentences_a length: 6635


142it [01:35,  1.71it/s]

Test micro contrast: -0.11586862057447433
Test macro contrast: -0.07847598195075989
sentences_a length: 13230


143it [01:36,  1.54it/s]

Test micro contrast: -0.12027681618928909
Test macro contrast: -0.08314567059278488


144it [01:36,  1.61it/s]

sentences_a length: 14046
Test micro contrast: -0.11964760720729828
Test macro contrast: -0.08277431130409241
sentences_a length: 38394


145it [01:37,  1.53it/s]

Test micro contrast: -0.11898774653673172
Test macro contrast: -0.08201815187931061
sentences_a length: 4135


146it [01:38,  1.35it/s]

Test micro contrast: -0.11029914766550064
Test macro contrast: -0.0804554894566536
sentences_a length: 6101


147it [01:38,  1.48it/s]

Test micro contrast: -0.10852514952421188
Test macro contrast: -0.08059290051460266


148it [01:39,  1.70it/s]

sentences_a length: 32584
Test micro contrast: -0.1125762015581131
Test macro contrast: -0.08167403191328049
sentences_a length: 72910


149it [01:40,  1.38it/s]

Test micro contrast: -0.11333458125591278
Test macro contrast: -0.08246587216854095
sentences_a length: 14555


In [3]:


# def train_contrast_retrieval(data_config, exp_config):


# cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

ds = AUSDataset()
train_dataloader = ds.get_data_loader(split='train', batch_size=2, shuffle=True)
val_dataloader = ds.get_data_loader(split='val', batch_size=2, shuffle=True)
test_dataloader = ds.get_data_loader(split='test', batch_size=2, shuffle=True)

tokenizer = AutoTokenizer.from_pretrained(exp_config.uri)
encoder = AutoModel.from_pretrained(exp_config.uri)

model = NNModel(exp_config)

optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)
# print(encoder.parameters())

# for name, param in encoder.named_parameters():
#     print(name, param.size())
# Store our loss and accuracy for plotting
train_loss_set = []

# Number of training epochs (authors recommend between 2 and 4)
epochs = exp_config.epochs

# trange is a tqdm wrapper around the normal python range
for epoch__ in trange(epochs, desc="Epoch"):

    print("start training")

    # Set our model to training mode (as opposed to evaluation mode)
    encoder.train()

    # Tracking variables
    tr_loss = 0  # running loss
    nb_tr_steps = 0

    # Train the data for one epoch
    for step, batch in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()
        # Unpack the inputs from our dataloader
        sentences, catchphrases = batch  # len(sentences)=2, len(catchphrases)=2
        # Clear out the gradients (by default they accumulate)

        sentences_a, catchphrase_a = sentences[0], catchphrases[0]
        sentences_b, catchphrase_b = sentences[1], catchphrases[1]

        batch_catchphrase_a = catchphrase_a.split(EOC_TOK)
        batch_catchphrase_b = catchphrase_b.split(EOC_TOK)

        encoded_batch_catchphrase_a = tokenizer(batch_catchphrase_a, truncation=True, return_tensors="pt",
                                                padding='max_length', max_length=128)
        encoded_batch_catchphrase_b = tokenizer(batch_catchphrase_b, truncation=True, return_tensors="pt",
                                                padding='max_length', max_length=128)

        sentence_indices_a = tokenizer(sentences_a, truncation=True, return_tensors="pt", padding='max_length',
                                       max_length=512*12)
        sentence_indices_b = tokenizer(sentences_b, truncation=True, return_tensors="pt", padding='max_length',
                                       max_length=512*12)

        _, batch_catchphrase_embedding_a = encoder(**encoded_batch_catchphrase_a)  # [7, 768]
        _, batch_catchphrase_embedding_b = encoder(**encoded_batch_catchphrase_b)  # [13,768]

        
        chunk_indices_a=chunkify(sentence_indices_a)
        
        chunk_indices_b=chunkify(sentence_indices_b)
         

        chunk_embeddings_a=encode_chunks(chunk_indices_a,encoder)
        chunk_embeddings_b=encode_chunks(chunk_indices_b,encoder)
    
        
        #################### Aggregation ######################
        sentence_embedding_a=torch.mean(chunk_embeddings_a,dim=0)
        sentence_embedding_b=torch.mean(chunk_embeddings_b,dim=0)
        

        triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)

        anchor_a = sentence_embedding_a.unsqueeze(0)
        anchor_b = sentence_embedding_b.unsqueeze(0)
        batch_train_loss_set=[]
        for catchphrase_embedding_a in tqdm(batch_catchphrase_embedding_a):
            for catchphrase_embedding_b in batch_catchphrase_embedding_b:
                positive = catchphrase_embedding_a.unsqueeze(0)
                negative = catchphrase_embedding_b.unsqueeze(0)
                loss = triplet_loss(anchor_a, positive, negative)
                batch_train_loss_set.append(loss.unsqueeze(0))
                loss = triplet_loss(anchor_b, negative, positive)
                batch_train_loss_set.append(loss.unsqueeze(0))
        batch_loss=torch.mean(torch.cat(batch_train_loss_set))
        batch_loss.backward()
        optimizer.step()
        # Update tracking variables
        tr_loss += batch_loss.item()
        nb_tr_steps += 1

        print("Train loss: {}".format(tr_loss/nb_tr_steps))





Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

<generator object Module.parameters at 0x11ead27d0>
embeddings.word_embeddings.weight torch.Size([30522, 128])
embeddings.position_embeddings.weight torch.Size([512, 128])
embeddings.token_type_embeddings.weight torch.Size([2, 128])
embeddings.LayerNorm.weight torch.Size([128])
embeddings.LayerNorm.bias torch.Size([128])
encoder.layer.0.attention.self.query.weight torch.Size([128, 128])
encoder.layer.0.attention.self.query.bias torch.Size([128])
encoder.layer.0.attention.self.key.weight torch.Size([128, 128])
encoder.layer.0.attention.self.key.bias torch.Size([128])
encoder.layer.0.attention.self.value.weight torch.Size([128, 128])
encoder.layer.0.attention.self.value.bias torch.Size([128])
encoder.layer.0.attention.output.dense.weight torch.Size([128, 128])
encoder.layer.0.attention.output.dense.bias torch.Size([128])
encoder.layer.0.attention.output.LayerNorm.weight torch.Size([128])
encoder.layer.0.attention.output.LayerNorm.bias torch.Size([128])
encoder.layer.0.intermediate.dense.


0it [00:00, ?it/s][A

100%|██████████| 3/3 [00:00<00:00, 419.23it/s]

1it [00:01,  1.93s/it][A

Train loss: 0.9539670348167419




100%|██████████| 3/3 [00:00<00:00, 1381.22it/s]

2it [00:03,  1.84s/it][A

Train loss: 0.9075373709201813




100%|██████████| 8/8 [00:00<00:00, 977.12it/s]

3it [00:05,  1.80s/it][A

Train loss: 1.1290541291236877




100%|██████████| 9/9 [00:00<00:00, 738.07it/s]

4it [00:07,  1.83s/it][A

Train loss: 1.256189987063408




100%|██████████| 8/8 [00:00<00:00, 1122.63it/s]

5it [00:08,  1.79s/it][A

Train loss: 1.1780067563056946




100%|██████████| 4/4 [00:00<00:00, 618.06it/s]

6it [00:10,  1.75s/it][A

Train loss: 1.146119753519694




100%|██████████| 6/6 [00:00<00:00, 2272.72it/s]

7it [00:12,  1.72s/it][A

Train loss: 1.1346594946725028




100%|██████████| 9/9 [00:00<00:00, 334.79it/s]

8it [00:14,  1.87s/it][A

Train loss: 1.124840572476387




100%|██████████| 16/16 [00:00<00:00, 1346.57it/s]

9it [00:16,  2.01s/it][A

Train loss: 1.1646321217219036




100%|██████████| 9/9 [00:00<00:00, 1404.55it/s]

10it [00:18,  1.95s/it][A

Train loss: 1.1421508550643922




100%|██████████| 2/2 [00:00<00:00, 943.81it/s]

11it [00:20,  1.86s/it][A

Train loss: 1.128929316997528




100%|██████████| 4/4 [00:00<00:00, 673.86it/s]

12it [00:21,  1.83s/it][A

Train loss: 1.1101349145174026




100%|██████████| 8/8 [00:00<00:00, 793.08it/s]

13it [00:23,  1.82s/it][A

Train loss: 1.097806637103741




100%|██████████| 7/7 [00:00<00:00, 805.07it/s]

14it [00:25,  1.80s/it][A

Train loss: 1.0909655775342668




100%|██████████| 5/5 [00:00<00:00, 283.03it/s]

15it [00:27,  1.84s/it][A

Train loss: 1.0879690329233804




100%|██████████| 3/3 [00:00<00:00, 1232.53it/s]

16it [00:29,  1.77s/it][A

Train loss: 1.0805836729705334




100%|██████████| 5/5 [00:00<00:00, 2139.51it/s]

17it [00:30,  1.74s/it][A

Train loss: 1.0767362994306229




100%|██████████| 10/10 [00:00<00:00, 882.70it/s]

18it [00:32,  1.73s/it][A

Train loss: 1.072382089164522




100%|██████████| 7/7 [00:00<00:00, 1602.89it/s]

19it [00:34,  1.73s/it][A

Train loss: 1.0712848493927403




100%|██████████| 5/5 [00:00<00:00, 743.28it/s]

20it [00:35,  1.73s/it][A

Train loss: 1.0669104963541032




100%|██████████| 7/7 [00:00<00:00, 1335.58it/s]

21it [00:37,  1.73s/it][A

Train loss: 1.0615287905647641




100%|██████████| 5/5 [00:00<00:00, 1008.49it/s]

22it [00:39,  1.81s/it][A

Train loss: 1.0627357092770664




100%|██████████| 5/5 [00:00<00:00, 814.30it/s]

23it [00:41,  1.79s/it][A

Train loss: 1.0599717601485874




100%|██████████| 7/7 [00:00<00:00, 1444.46it/s]

24it [00:43,  1.81s/it][A

Train loss: 1.058027721941471




100%|██████████| 7/7 [00:00<00:00, 802.85it/s]

25it [00:45,  1.93s/it][A

Train loss: 1.0584177803993224




100%|██████████| 5/5 [00:00<00:00, 748.64it/s]

26it [00:47,  1.89s/it][A

Train loss: 1.0560370293947368




100%|██████████| 7/7 [00:00<00:00, 489.82it/s]

27it [00:49,  1.86s/it][A

Train loss: 1.054437679273111




100%|██████████| 16/16 [00:00<00:00, 930.66it/s]

28it [00:50,  1.87s/it][A

Train loss: 1.0534850720848357




100%|██████████| 8/8 [00:00<00:00, 1014.74it/s]

29it [00:52,  1.86s/it][A

Train loss: 1.051804947442022




100%|██████████| 4/4 [00:00<00:00, 349.07it/s]

30it [00:55,  2.01s/it][A

Train loss: 1.0502635776996612




100%|██████████| 6/6 [00:00<00:00, 599.34it/s]

31it [00:56,  1.94s/it][A

Train loss: 1.0492189911104017




100%|██████████| 3/3 [00:00<00:00, 946.44it/s]

32it [00:58,  1.93s/it][A

Train loss: 1.0475201085209846




100%|██████████| 5/5 [00:00<00:00, 394.51it/s]

33it [01:00,  1.91s/it][A

Train loss: 1.0464765339186697




100%|██████████| 14/14 [00:00<00:00, 1013.10it/s]

34it [01:02,  1.96s/it][A

Train loss: 1.0449990279534285




100%|██████████| 2/2 [00:00<00:00, 1473.75it/s]

35it [01:04,  1.85s/it][A

Train loss: 1.0440946340560913




100%|██████████| 7/7 [00:00<00:00, 799.94it/s]

36it [01:06,  1.84s/it][A

Train loss: 1.0426333546638489




100%|██████████| 6/6 [00:00<00:00, 500.50it/s]

37it [01:07,  1.83s/it][A

Train loss: 1.041484790879327




100%|██████████| 11/11 [00:00<00:00, 1343.47it/s]

38it [01:10,  1.91s/it][A

Train loss: 1.0403339894194352




100%|██████████| 9/9 [00:00<00:00, 813.76it/s]

39it [01:11,  1.88s/it][A

Train loss: 1.039344127361591




100%|██████████| 4/4 [00:00<00:00, 317.14it/s]

40it [01:13,  1.91s/it][A

Train loss: 1.0383957266807555




100%|██████████| 9/9 [00:00<00:00, 562.78it/s]

41it [01:16,  2.01s/it][A

Train loss: 1.0373242730047645




100%|██████████| 3/3 [00:00<00:00, 448.24it/s]

42it [01:18,  1.99s/it][A

Train loss: 1.0365815545831407




100%|██████████| 3/3 [00:00<00:00, 830.12it/s]

43it [01:20,  2.01s/it][A

Train loss: 1.034249911474627




100%|██████████| 6/6 [00:00<00:00, 768.94it/s]

44it [01:21,  1.90s/it][A

Train loss: 1.033472474325787




100%|██████████| 6/6 [00:00<00:00, 468.64it/s]

45it [01:23,  1.95s/it][A

Train loss: 1.0330443369017708




100%|██████████| 8/8 [00:00<00:00, 481.72it/s]

46it [01:25,  2.03s/it][A

Train loss: 1.032287006792815




100%|██████████| 6/6 [00:00<00:00, 672.92it/s]

47it [01:27,  1.97s/it][A

Train loss: 1.0307795823888575




100%|██████████| 6/6 [00:00<00:00, 982.92it/s]

48it [01:29,  1.95s/it][A

Train loss: 1.0300911230345566




100%|██████████| 12/12 [00:00<00:00, 507.64it/s]

49it [01:31,  1.98s/it][A

Train loss: 1.0303238861414852




100%|██████████| 19/19 [00:00<00:00, 627.47it/s]

50it [01:33,  2.03s/it][A

Train loss: 1.028086905479431




100%|██████████| 2/2 [00:00<00:00, 534.92it/s]

51it [01:35,  1.92s/it][A

Train loss: 1.0282884368709488




100%|██████████| 6/6 [00:00<00:00, 715.73it/s]

52it [01:37,  1.89s/it][A

Train loss: 1.0263760766157737




100%|██████████| 6/6 [00:00<00:00, 1386.01it/s]

53it [01:39,  1.82s/it][A

Train loss: 1.0253380267125256




100%|██████████| 3/3 [00:00<00:00, 887.56it/s]

54it [01:40,  1.76s/it][A

Train loss: 1.0249824877138491




100%|██████████| 4/4 [00:00<00:00, 658.14it/s]

55it [01:42,  1.74s/it][A

Train loss: 1.0258116006851197




100%|██████████| 7/7 [00:00<00:00, 849.49it/s]

56it [01:44,  1.78s/it][A

Train loss: 1.022093957023961




100%|██████████| 3/3 [00:00<00:00, 645.44it/s]

57it [01:46,  1.87s/it][A

Train loss: 1.0187447740320574




100%|██████████| 4/4 [00:00<00:00, 401.45it/s]

58it [01:48,  1.98s/it][A

Train loss: 1.0190570786081512




100%|██████████| 8/8 [00:00<00:00, 1134.25it/s]
58it [01:50,  1.91s/it]
Epoch:   0%|          | 0/10 [01:50<?, ?it/s]


KeyboardInterrupt: 

In [2]:
if __name__ == '__main__':
    ds = AUSDataset()

    # ds.save_processed_splits()

    test_dataloader = ds.get_data_loader(split='test', batch_size=2, shuffle=True)
    # print(test_dl.batch_size)
    # for i in test_dl:
    #     print(len(i[0]),len(i[1]))
    data_config = DatasetConfig("AUS")
    exp_config = ExpConfig("MeanModel")
    train_contrast_retrieval(data_config, exp_config)

NameError: name 'train_contrast_retrieval' is not defined