In [1]:
import numpy as np
import pandas as pd
import io
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from collections import Counter
import pickle as pkl
import random
import pdb
from sklearn.model_selection import train_test_split



random.seed(1337)

PAD_IDX = 0
UNK_IDX = 1
BATCH_SIZE = 32
NUM_CLASSES = 3
BIDIRECTIONAL = True
NUM_DIRECTIONS = 1 if not BIDIRECTIONAL else 2
RNN_HIDDEN_SIZE = 200
LIN_HIDDEN_SIZE = 256


In [2]:
train_data = pd.read_csv('hw2_data/snli_train.tsv',sep='\t')
label_map = {'contradiction':0, 'entailment':2, 'neutral':1}
train_data.replace({'label':label_map},inplace=True)

In [3]:
val_data = pd.read_csv('hw2_data/snli_val.tsv',sep='\t')
label_map = {'contradiction':0, 'entailment':2, 'neutral':1}
val_data.replace({'label':label_map},inplace=True)

In [4]:
#train_data,val_data = train_test_split(data,test_size=0.2,random_state=1337)

In [5]:
def prepare_data(df):
    df['sentence1'] = df['sentence1'].str.split()
    df['sentence2'] = df['sentence2'].str.split()
    return df

In [6]:
train_data = prepare_data(train_data)
val_data = prepare_data(val_data)
#emb_weights = pkl.load(open('emb_weights.pkl','rb'))
#token2id = pkl.load(open('token2id.pkl','rb'))
#id2token = pkl.load(open('id2token.pkl','rb'))

In [7]:
MAX_SENTENCE_LENGTH = int(train_data['sentence1'].str.len().quantile(0.95))

In [8]:
MAX_SENTENCE_LENGTH

25

In [57]:
def load_vectors(fname,vocab):  
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    n, d = map(int, fin.readline().split())
    print(n)
    print(d)
    data = {}
    count = 0
    for line in fin:
        tokens = line.rstrip().split(' ')
        if tokens[0] in vocab:
            #convert map object to numpy ndarray
            count+=1
            print(count)
            data[tokens[0]] = np.fromiter(map(float, tokens[1:]),dtype=np.float)
    return data

In [58]:
def build_vocab(train_data,max_vocab_size=50000):
    # Returns:
    # id2token: list of tokens, where id2token[i] returns token that corresponds to token i
    # token2id: dictionary where keys represent tokens and corresponding values represent indices
    all_tokens = []
    for i in range(len(train_data)):
        all_tokens.extend(train_data.iloc[i]['sentence1'])
        all_tokens.extend(train_data.iloc[i]['sentence2'])
    token_counter = Counter(all_tokens)
    vocab, count = zip(*token_counter.most_common(max_vocab_size))
    print('Done building vocab')
    vocab = list(vocab)
    emb_weights = load_vectors('wiki-news-300d-1M.vec',vocab)
    print('Done getting embedding weights')
    vocab = [word for word in vocab if word in emb_weights] 
    id2token = list(vocab)
    token2id = dict(zip(vocab, range(2,2+len(vocab)))) 
    id2token = ['<pad>', '<unk>'] + id2token
    token2id['<pad>'] = PAD_IDX 
    token2id['<unk>'] = UNK_IDX
    return token2id, id2token, emb_weights

In [59]:
token2id,id2token,emb_weights = build_vocab(train_data)

Done building vocab
999994
300
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
2

2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389


4130
4131
4132
4133
4134
4135
4136
4137
4138
4139
4140
4141
4142
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
4153
4154
4155
4156
4157
4158
4159
4160
4161
4162
4163
4164
4165
4166
4167
4168
4169
4170
4171
4172
4173
4174
4175
4176
4177
4178
4179
4180
4181
4182
4183
4184
4185
4186
4187
4188
4189
4190
4191
4192
4193
4194
4195
4196
4197
4198
4199
4200
4201
4202
4203
4204
4205
4206
4207
4208
4209
4210
4211
4212
4213
4214
4215
4216
4217
4218
4219
4220
4221
4222
4223
4224
4225
4226
4227
4228
4229
4230
4231
4232
4233
4234
4235
4236
4237
4238
4239
4240
4241
4242
4243
4244
4245
4246
4247
4248
4249
4250
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
4265
4266
4267
4268
4269
4270
4271
4272
4273
4274
4275
4276
4277
4278
4279
4280
4281
4282
4283
4284
4285
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
4298
4299
4300
4301
4302
4303
4304
4305
4306
4307
4308
4309
4310
4311
4312
4313
4314
4315
4316
4317
4318
4319
4320
4321
4322
4323
4324
4325
4326
4327
4328
4329


5989
5990
5991
5992
5993
5994
5995
5996
5997
5998
5999
6000
6001
6002
6003
6004
6005
6006
6007
6008
6009
6010
6011
6012
6013
6014
6015
6016
6017
6018
6019
6020
6021
6022
6023
6024
6025
6026
6027
6028
6029
6030
6031
6032
6033
6034
6035
6036
6037
6038
6039
6040
6041
6042
6043
6044
6045
6046
6047
6048
6049
6050
6051
6052
6053
6054
6055
6056
6057
6058
6059
6060
6061
6062
6063
6064
6065
6066
6067
6068
6069
6070
6071
6072
6073
6074
6075
6076
6077
6078
6079
6080
6081
6082
6083
6084
6085
6086
6087
6088
6089
6090
6091
6092
6093
6094
6095
6096
6097
6098
6099
6100
6101
6102
6103
6104
6105
6106
6107
6108
6109
6110
6111
6112
6113
6114
6115
6116
6117
6118
6119
6120
6121
6122
6123
6124
6125
6126
6127
6128
6129
6130
6131
6132
6133
6134
6135
6136
6137
6138
6139
6140
6141
6142
6143
6144
6145
6146
6147
6148
6149
6150
6151
6152
6153
6154
6155
6156
6157
6158
6159
6160
6161
6162
6163
6164
6165
6166
6167
6168
6169
6170
6171
6172
6173
6174
6175
6176
6177
6178
6179
6180
6181
6182
6183
6184
6185
6186
6187
6188


7732
7733
7734
7735
7736
7737
7738
7739
7740
7741
7742
7743
7744
7745
7746
7747
7748
7749
7750
7751
7752
7753
7754
7755
7756
7757
7758
7759
7760
7761
7762
7763
7764
7765
7766
7767
7768
7769
7770
7771
7772
7773
7774
7775
7776
7777
7778
7779
7780
7781
7782
7783
7784
7785
7786
7787
7788
7789
7790
7791
7792
7793
7794
7795
7796
7797
7798
7799
7800
7801
7802
7803
7804
7805
7806
7807
7808
7809
7810
7811
7812
7813
7814
7815
7816
7817
7818
7819
7820
7821
7822
7823
7824
7825
7826
7827
7828
7829
7830
7831
7832
7833
7834
7835
7836
7837
7838
7839
7840
7841
7842
7843
7844
7845
7846
7847
7848
7849
7850
7851
7852
7853
7854
7855
7856
7857
7858
7859
7860
7861
7862
7863
7864
7865
7866
7867
7868
7869
7870
7871
7872
7873
7874
7875
7876
7877
7878
7879
7880
7881
7882
7883
7884
7885
7886
7887
7888
7889
7890
7891
7892
7893
7894
7895
7896
7897
7898
7899
7900
7901
7902
7903
7904
7905
7906
7907
7908
7909
7910
7911
7912
7913
7914
7915
7916
7917
7918
7919
7920
7921
7922
7923
7924
7925
7926
7927
7928
7929
7930
7931


9462
9463
9464
9465
9466
9467
9468
9469
9470
9471
9472
9473
9474
9475
9476
9477
9478
9479
9480
9481
9482
9483
9484
9485
9486
9487
9488
9489
9490
9491
9492
9493
9494
9495
9496
9497
9498
9499
9500
9501
9502
9503
9504
9505
9506
9507
9508
9509
9510
9511
9512
9513
9514
9515
9516
9517
9518
9519
9520
9521
9522
9523
9524
9525
9526
9527
9528
9529
9530
9531
9532
9533
9534
9535
9536
9537
9538
9539
9540
9541
9542
9543
9544
9545
9546
9547
9548
9549
9550
9551
9552
9553
9554
9555
9556
9557
9558
9559
9560
9561
9562
9563
9564
9565
9566
9567
9568
9569
9570
9571
9572
9573
9574
9575
9576
9577
9578
9579
9580
9581
9582
9583
9584
9585
9586
9587
9588
9589
9590
9591
9592
9593
9594
9595
9596
9597
9598
9599
9600
9601
9602
9603
9604
9605
9606
9607
9608
9609
9610
9611
9612
9613
9614
9615
9616
9617
9618
9619
9620
9621
9622
9623
9624
9625
9626
9627
9628
9629
9630
9631
9632
9633
9634
9635
9636
9637
9638
9639
9640
9641
9642
9643
9644
9645
9646
9647
9648
9649
9650
9651
9652
9653
9654
9655
9656
9657
9658
9659
9660
9661


10976
10977
10978
10979
10980
10981
10982
10983
10984
10985
10986
10987
10988
10989
10990
10991
10992
10993
10994
10995
10996
10997
10998
10999
11000
11001
11002
11003
11004
11005
11006
11007
11008
11009
11010
11011
11012
11013
11014
11015
11016
11017
11018
11019
11020
11021
11022
11023
11024
11025
11026
11027
11028
11029
11030
11031
11032
11033
11034
11035
11036
11037
11038
11039
11040
11041
11042
11043
11044
11045
11046
11047
11048
11049
11050
11051
11052
11053
11054
11055
11056
11057
11058
11059
11060
11061
11062
11063
11064
11065
11066
11067
11068
11069
11070
11071
11072
11073
11074
11075
11076
11077
11078
11079
11080
11081
11082
11083
11084
11085
11086
11087
11088
11089
11090
11091
11092
11093
11094
11095
11096
11097
11098
11099
11100
11101
11102
11103
11104
11105
11106
11107
11108
11109
11110
11111
11112
11113
11114
11115
11116
11117
11118
11119
11120
11121
11122
11123
11124
11125
11126
11127
11128
11129
11130
11131
11132
11133
11134
11135
11136
11137
11138
11139
11140
11141
1114

12551
12552
12553
12554
12555
12556
12557
12558
12559
12560
12561
12562
12563
12564
12565
12566
12567
12568
12569
12570
12571
12572
12573
12574
12575
12576
12577
12578
12579
12580
12581
12582
12583
12584
12585
12586
12587
12588
12589
12590
12591
12592
12593
12594
12595
12596
12597
12598
12599
12600
12601
12602
12603
12604
12605
12606
12607
12608
12609
12610
12611
12612
12613
12614
12615
12616
12617
12618
12619
12620
12621
12622
12623
12624
12625
12626
12627
12628
12629
12630
12631
12632
12633
12634
12635
12636
12637
12638
12639
12640
12641
12642
12643
12644
12645
12646
12647
12648
12649
12650
12651
12652
12653
12654
12655
12656
12657
12658
12659
12660
12661
12662
12663
12664
12665
12666
12667
12668
12669
12670
12671
12672
12673
12674
12675
12676
12677
12678
12679
12680
12681
12682
12683
12684
12685
12686
12687
12688
12689
12690
12691
12692
12693
12694
12695
12696
12697
12698
12699
12700
12701
12702
12703
12704
12705
12706
12707
12708
12709
12710
12711
12712
12713
12714
12715
12716
1271

14019
14020
14021
14022
14023
14024
14025
14026
14027
14028
14029
14030
14031
14032
14033
14034
14035
14036
14037
14038
14039
14040
14041
14042
14043
14044
14045
14046
14047
14048
14049
14050
14051
14052
14053
14054
14055
14056
14057
14058
14059
14060
14061
14062
14063
14064
14065
14066
14067
14068
14069
14070
14071
14072
14073
14074
14075
14076
14077
14078
14079
14080
14081
14082
14083
14084
14085
14086
14087
14088
14089
14090
14091
14092
14093
14094
14095
14096
14097
14098
14099
14100
14101
14102
14103
14104
14105
14106
14107
14108
14109
14110
14111
14112
14113
14114
14115
14116
14117
14118
14119
14120
14121
14122
14123
14124
14125
14126
14127
14128
14129
14130
14131
14132
14133
14134
14135
14136
14137
14138
14139
14140
14141
14142
14143
14144
14145
14146
14147
14148
14149
14150
14151
14152
14153
14154
14155
14156
14157
14158
14159
14160
14161
14162
14163
14164
14165
14166
14167
14168
14169
14170
14171
14172
14173
14174
14175
14176
14177
14178
14179
14180
14181
14182
14183
14184
1418

15545
15546
15547
15548
15549
15550
15551
15552
15553
15554
15555
15556
15557
15558
15559
15560
15561
15562
15563
15564
15565
15566
15567
15568
15569
15570
15571
15572
15573
15574
15575
15576
15577
15578
15579
15580
15581
15582
15583
15584
15585
15586
15587
15588
15589
15590
15591
15592
15593
15594
15595
15596
15597
15598
15599
15600
15601
15602
15603
15604
15605
15606
15607
15608
15609
15610
15611
15612
15613
15614
15615
15616
15617
15618
15619
15620
15621
15622
15623
15624
15625
15626
15627
15628
15629
15630
15631
15632
15633
15634
15635
15636
15637
15638
15639
15640
15641
15642
15643
15644
15645
15646
15647
15648
15649
15650
15651
15652
15653
15654
15655
15656
15657
15658
15659
15660
15661
15662
15663
15664
15665
15666
15667
15668
15669
15670
15671
15672
15673
15674
15675
15676
15677
15678
15679
15680
15681
15682
15683
15684
15685
15686
15687
15688
15689
15690
15691
15692
15693
15694
15695
15696
15697
15698
15699
15700
15701
15702
15703
15704
15705
15706
15707
15708
15709
15710
1571

17150
17151
17152
17153
17154
17155
17156
17157
17158
17159
17160
17161
17162
17163
17164
17165
17166
17167
17168
17169
17170
17171
17172
17173
17174
17175
17176
17177
17178
17179
17180
17181
17182
17183
17184
17185
17186
17187
17188
17189
17190
17191
17192
17193
17194
17195
17196
17197
17198
17199
17200
17201
17202
17203
17204
17205
17206
17207
17208
17209
17210
17211
17212
17213
17214
17215
17216
17217
17218
17219
17220
17221
17222
17223
17224
17225
17226
17227
17228
17229
17230
17231
17232
17233
17234
17235
17236
17237
17238
17239
17240
17241
17242
17243
17244
17245
17246
17247
17248
17249
17250
17251
17252
17253
17254
17255
17256
17257
17258
17259
17260
17261
17262
17263
17264
17265
17266
17267
17268
17269
17270
17271
17272
17273
17274
17275
17276
17277
17278
17279
17280
17281
17282
17283
17284
17285
17286
17287
17288
17289
17290
17291
17292
17293
17294
17295
17296
17297
17298
17299
17300
17301
17302
17303
17304
17305
17306
17307
17308
17309
17310
17311
17312
17313
17314
17315
1731

18658
18659
18660
18661
18662
18663
18664
18665
18666
18667
18668
18669
18670
18671
18672
18673
18674
18675
18676
18677
18678
18679
18680
18681
18682
18683
18684
18685
18686
18687
18688
18689
18690
18691
18692
18693
18694
18695
18696
18697
18698
18699
18700
18701
18702
18703
18704
18705
18706
18707
18708
18709
18710
18711
18712
18713
18714
18715
18716
18717
18718
18719
18720
18721
18722
18723
18724
18725
18726
18727
18728
18729
18730
18731
18732
18733
18734
18735
18736
18737
18738
18739
18740
18741
18742
18743
18744
18745
18746
18747
18748
18749
18750
18751
18752
18753
18754
18755
18756
18757
18758
18759
18760
18761
18762
18763
18764
18765
18766
18767
18768
18769
18770
18771
18772
18773
18774
18775
18776
18777
18778
18779
18780
18781
18782
18783
18784
18785
18786
18787
18788
18789
18790
18791
18792
18793
18794
18795
18796
18797
18798
18799
18800
18801
18802
18803
18804
18805
18806
18807
18808
18809
18810
18811
18812
18813
18814
18815
18816
18817
18818
18819
18820
18821
18822
18823
1882

20086
20087
20088
20089
20090
20091
20092
20093
20094
20095
20096
20097
20098
20099
20100
20101
20102
20103
20104
20105
20106
20107
20108
20109
20110
20111
20112
20113
20114
20115
20116
20117
20118
20119
20120
20121
20122
20123
20124
20125
20126
20127
20128
20129
20130
20131
20132
20133
20134
20135
20136
20137
20138
20139
20140
20141
20142
20143
20144
20145
20146
20147
20148
20149
20150
20151
20152
20153
20154
20155
20156
20157
20158
20159
20160
20161
20162
20163
20164
20165
20166
20167
20168
20169
20170
20171
20172
20173
20174
20175
20176
20177
20178
20179
20180
20181
20182
20183
20184
20185
20186
20187
20188
20189
20190
20191
20192
20193
20194
20195
20196
20197
20198
20199
20200
20201
20202
20203
20204
20205
20206
20207
20208
20209
20210
20211
20212
20213
20214
20215
20216
20217
20218
20219
20220
20221
20222
20223
20224
20225
20226
20227
20228
20229
20230
20231
20232
20233
20234
20235
20236
20237
20238
20239
20240
20241
20242
20243
20244
20245
20246
20247
20248
20249
20250
20251
2025

21614
21615
21616
21617
21618
21619
21620
21621
21622
21623
21624
21625
21626
21627
21628
21629
21630
21631
21632
21633
21634
21635
21636
21637
21638
21639
21640
21641
21642
21643
21644
21645
21646
21647
21648
21649
21650
21651
21652
21653
21654
21655
21656
21657
21658
21659
21660
21661
21662
21663
21664
21665
21666
21667
21668
21669
21670
21671
21672
21673
21674
21675
21676
21677
21678
21679
21680
21681
21682
21683
21684
21685
21686
21687
21688
21689
21690
21691
21692
21693
21694
21695
21696
21697
21698
21699
21700
21701
21702
21703
21704
21705
21706
21707
21708
21709
21710
21711
21712
21713
21714
21715
21716
21717
21718
21719
21720
21721
21722
21723
21724
21725
21726
21727
21728
21729
21730
21731
21732
21733
21734
21735
21736
21737
21738
21739
21740
21741
21742
21743
21744
21745
21746
21747
21748
21749
21750
21751
21752
21753
21754
21755
21756
21757
21758
21759
21760
21761
21762
21763
21764
21765
21766
21767
21768
21769
21770
21771
21772
21773
21774
21775
21776
21777
21778
21779
2178

In [60]:
len(token2id)

22059

In [109]:
pkl.dump(emb_weights,open('emb_weights.pkl','wb'))

In [110]:
pkl.dump(token2id,open('token2id.pkl','wb'))

In [111]:
pkl.dump(id2token,open('id2token.pkl','wb'))

In [44]:
len(emb_weights)

20434

In [126]:
a = torch.from_numpy(np.arange(0,2*3*4).reshape(2,3,4))
b = np.zeros((2,3,1),dtype=int)
b[0,2,:] = 1
b[1,1,:] = 1
b = torch.from_numpy(b)

In [9]:
token2id = pkl.load(open('token2id.pkl','rb'))
id2token = pkl.load(open('id2token.pkl','rb'))
emb_weights = pkl.load(open('emb_weights.pkl','rb'))

In [10]:
class SNLIDataset(Dataset):
    """
    Class that represents a train/validation/test dataset that's readable for PyTorch
    Note that this class inherits torch.utils.data.Dataset
    """

    def __init__(self, df, token2id):
        """
        @param data_list: list of character
        @param target_list: list of targets

        """
        self.sentence1,self.sentence2,self.target_list = df['sentence1'].values,df['sentence2'].values,df['label'].values
        assert (len(self.sentence1) == len(self.target_list))
        self.token2id = token2id

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, i):
        """
        Triggered when you call dataset[i]
        """
        x1_mask,x2_mask = [],[]
        x1_word_idx,x2_word_idx = [],[]
        for word in self.sentence1[i][:MAX_SENTENCE_LENGTH]:
            if word in self.token2id.keys():
                x1_word_idx.append(self.token2id[word])
                x1_mask.append(0)
            else:
                x1_word_idx.append(UNK_IDX)
                x1_mask.append(1)
        
        for word in self.sentence2[i][:MAX_SENTENCE_LENGTH]:
            if word in self.token2id.keys():
                x2_word_idx.append(self.token2id[word])
                x2_mask.append(0)
            else:
                x2_word_idx.append(UNK_IDX)
                x2_mask.append(1)
                
        label = self.target_list[i]
        return [x1_word_idx,x2_word_idx, len(x1_word_idx),len(x2_word_idx),x1_mask,x2_mask, label]


In [11]:
def snli_collate_func(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all
    data have the same length
    """
    x1_list = []
    x2_list = []
    length_x1_list = []
    length_x2_list = []
    label_list = []
    x1_mask_list,x2_mask_list = [],[]

    for datum in batch:
        label_list.append(datum[6])
        length_x1_list.append(datum[2])
        length_x2_list.append(datum[3])
        x1_padded_vec = np.pad(np.array(datum[0]),
                                pad_width=((0,MAX_SENTENCE_LENGTH-datum[2])),
                                mode="constant", constant_values=0)
        x2_padded_vec = np.pad(np.array(datum[1]),
                                pad_width=((0,MAX_SENTENCE_LENGTH-datum[3])),
                                mode="constant", constant_values=0)
        x1_mask_padded_vec = np.pad(np.array(datum[4]),
                                pad_width=((0,MAX_SENTENCE_LENGTH-datum[2])),
                                mode="constant", constant_values=0)
        x2_mask_padded_vec = np.pad(np.array(datum[5]),
                                pad_width=((0,MAX_SENTENCE_LENGTH-datum[3])),
                                mode="constant", constant_values=0)
        x1_list.append(x1_padded_vec)
        x2_list.append(x2_padded_vec)
        x1_mask_list.append(x1_mask_padded_vec)
        x2_mask_list.append(x2_mask_padded_vec)
    # padding
    
    ind_dec_order = np.argsort(length_x1_list)[::-1]
    x1_list = np.array(x1_list)[ind_dec_order]
    x2_list = np.array(x2_list)[ind_dec_order]
    length_x1_list = np.array(length_x1_list)[ind_dec_order]
    length_x2_list = np.array(length_x2_list)[ind_dec_order]
    label_list = np.array(label_list)[ind_dec_order]
    x1_mask_list = np.array(x1_mask_list)[ind_dec_order].reshape(len(batch),-1,1)
    x2_mask_list = np.array(x2_mask_list)[ind_dec_order].reshape(len(batch),-1,1)
    #print(x1_mask_list)
    #print(x1_mask_list.shape)
    return [torch.from_numpy(x1_list), torch.from_numpy(x2_list), length_x1_list, length_x2_list, torch.from_numpy(x1_mask_list).float(), torch.from_numpy(x2_mask_list).float(), torch.from_numpy(label_list)]

In [12]:
train_dataset = SNLIDataset(train_data, token2id)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=snli_collate_func,
                                           shuffle=True)

val_dataset = SNLIDataset(val_data, token2id)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=snli_collate_func,
                                           shuffle=True)

# test_dataset = SNLIDataset(test_data, token2id)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                            batch_size=BATCH_SIZE,
#                                            collate_fn=snli_collate_func,
#                                            shuffle=True)

In [13]:
weights_mat = np.zeros((len(token2id),300))
for i in range(2,len(emb_weights)):
    weights_mat[i] = emb_weights[id2token[i]]

weights_mat[1] = np.random.randn(300)

In [14]:
class RNN(nn.Module):
    def __init__(self, hidden_size, num_layers, vocab_size,weights, bidirectional = True):
        # RNN Accepts the following hyperparams:
        # hidden_size: Hidden Size of layer in RNN
        # num_layers: number of layers in RNN
        # num_classes: number of output classes
        # vocab_size: vocabulary size
        # emb_weights = pretrained embedding weights
        super(RNN, self).__init__()

        self.num_layers, self.hidden_size = num_layers, hidden_size
        self.embedding = nn.Embedding(vocab_size, 300, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(300,hidden_size, num_layers, batch_first=True,bidirectional=bidirectional)
        self.gru = nn.GRU(300,hidden_size, num_layers, batch_first=True,bidirectional=bidirectional)
        self.num_directions = 1 if not bidirectional else 2
        self.embedding.weight.data.copy_(torch.from_numpy(weights))

    def init_hidden_gru(self, batch_size):
        # Function initializes the activation of recurrent neural net at timestep 0
        # Needs to be in format (num_layers, batch_size, hidden_size)
        hidden = torch.randn(self.num_directions*self.num_layers, batch_size, self.hidden_size).to(DEVICE)
        return hidden
    
    def init_hidden_lstm(self, batch_size):
        # Function initializes the activation of recurrent neural net at timestep 0
        # Needs to be in format (num_directions*num_layers, batch_size, hidden_size)
        hidden = torch.randn(self.num_directions*self.num_layers, batch_size, self.hidden_size).to(DEVICE)
        c_0 = torch.randn(self.num_directions*self.num_layers, batch_size, self.hidden_size).to(DEVICE)
        return hidden, c_0

    def forward(self, x, lengths,masks):
        # reset hidden state
        
        true2sorted = sorted(range(len(lengths)), key=lambda i: -lengths[i])
        sorted2true = sorted(range(len(lengths)), key=lambda i: true2sorted[i])
        #enc_input = torch.stack([enc_input[i, :] for i in true2sorted], dim=1)
        x = x[true2sorted]
        lengths = lengths[true2sorted]

        batch_size, seq_len = x.size()

        #self.hidden, self.c = self.init_hidden_lstm(batch_size)
        self.hidden = self.init_hidden_gru(batch_size)
    
        # get embedding of characters
        embed = self.embedding(x)
        #mask out all others except <UNK> token to freeze their weights
        #print(embed)
        #print(embed.size())
        #print(masks)
        #print(masks.size())
        embed = masks*embed + (1-masks)*embed.clone().detach()
        # pack padded sequence
        embed = torch.nn.utils.rnn.pack_padded_sequence(embed, lengths, batch_first=True)
        # fprop though RNN
        #rnn_out, self.hidden = self.rnn(embed, self.hidden)
        
        #rnn_out, (self.hidden, self.c) = self.lstm(embed, (self.hidden,self.c))
        rnn_out,self.hidden = self.gru(embed,self.hidden)
        # undo packing
        rnn_out, _ = torch.nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
        rnn_out = rnn_out.view(batch_size,-1,self.num_directions,self.hidden_size)
        # sum hidden activations of RNN across time
        rnn_out = torch.sum(rnn_out, dim=1)
        #concat both directions
        if(self.num_directions == 2):
            out_concat = torch.cat((rnn_out[:,-1,:],rnn_out[:,-2,:]),dim=1)
        else:
            out_concat = torch.cat((rnn_out[:,-1,:]),dim=1)
        #print(out_concat.size())
        out_concat = out_concat[sorted2true]
        return out_concat


In [15]:
class ClassificationNetwork(nn.Module):
    def __init__(self, num_inputs, hidden_size, num_outputs,num_directions=2,interact_type = 'concat'):
        super(ClassificationNetwork, self).__init__()
        # Fully connected and ReLU layers
        if(interact_type == 'concat'):
            self.fc1 = nn.Linear(num_directions*num_inputs*2, hidden_size)
        else:
            self.fc1 = nn.Linear(num_directions*num_inputs, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_outputs)
        self.relu = nn.ReLU(inplace=True)
        #self.dropout = nn.Dropout(p=0.5)
        self.type = interact_type

        # Initialize weights
        self._init_weights()

    def forward(self, embedding_output1, embedding_output2):
        if(self.type == 'concat'):
            input = torch.cat((embedding_output1,embedding_output2),dim=1)
        elif(self.type == 'mul'):
            input = embedding_output1 * embedding_output2
        
        input = input.view(input.size(0), -1) # Reshape input to batch_size x num_inputs
        output = self.fc1(input)
        output = self.relu(output)
        #output = self.dropout(output)
        output = self.fc2(output)
        return output

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.uniform_(m.bias)

In [16]:
DEVICE = 'cuda:3' if torch.cuda.is_available() else 'cpu'

In [22]:
def test_model(loader, model,classification_network):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    total = 0
    model.eval()
    for x1,x2,length_x1,length_x2,x1_mask,x2_mask,label in loader:
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        outputs_x1 = model(x1, length_x1,x1_mask)
        outputs_x2 = model(x2,length_x2,x2_mask)
        outputs = F.softmax(classification_network(outputs_x1,outputs_x2),dim=1)
        predicted = outputs.max(1, keepdim=True)[1]

        total += label.size(0)
        correct += predicted.eq(label.view_as(predicted)).sum().item()
    return (100 * correct / total)


model = RNN(hidden_size=RNN_HIDDEN_SIZE, num_layers=1, vocab_size=len(token2id),weights=weights_mat, bidirectional = True).to(DEVICE)
classification_network = ClassificationNetwork(num_inputs=RNN_HIDDEN_SIZE, hidden_size=LIN_HIDDEN_SIZE, num_outputs=NUM_CLASSES,num_directions=2,interact_type='mul').to(DEVICE)
learning_rate = 3e-4
num_epochs = 5 # number epoch to train

# Criterion and Optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters())+list(classification_network.parameters()), lr=learning_rate)

# Train the model
total_step = len(train_loader)

for epoch in range(num_epochs):
    for i, (x1,x2,length_x1,length_x2,x1_mask,x2_mask,label) in enumerate(train_loader):
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        model.train()
        optimizer.zero_grad()
        # Forward pass
        outputs_x1 = model(x1, length_x1,x1_mask)
        outputs_x2 = model(x2,length_x2,x2_mask)
        outputs = classification_network(outputs_x1,outputs_x2)
        loss = criterion(outputs, label)

        # Backward and optimize
        loss.backward()
        optimizer.step()
        # validate every 100 iterations
        if i > 0 and i % 100 == 0:
            # validate
            val_acc = test_model(val_loader, model,classification_network=classification_network)
            print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                       epoch+1, num_epochs, i+1, len(train_loader), val_acc))


Epoch: [1/5], Step: [101/3125], Validation Acc: 30.5
Epoch: [1/5], Step: [201/3125], Validation Acc: 35.2
Epoch: [1/5], Step: [301/3125], Validation Acc: 38.2
Epoch: [1/5], Step: [401/3125], Validation Acc: 34.9
Epoch: [1/5], Step: [501/3125], Validation Acc: 36.0
Epoch: [1/5], Step: [601/3125], Validation Acc: 35.5
Epoch: [1/5], Step: [701/3125], Validation Acc: 39.6
Epoch: [1/5], Step: [801/3125], Validation Acc: 34.2
Epoch: [1/5], Step: [901/3125], Validation Acc: 37.6
Epoch: [1/5], Step: [1001/3125], Validation Acc: 35.8
Epoch: [1/5], Step: [1101/3125], Validation Acc: 39.1
Epoch: [1/5], Step: [1201/3125], Validation Acc: 40.7
Epoch: [1/5], Step: [1301/3125], Validation Acc: 38.3
Epoch: [1/5], Step: [1401/3125], Validation Acc: 39.8
Epoch: [1/5], Step: [1501/3125], Validation Acc: 39.2
Epoch: [1/5], Step: [1601/3125], Validation Acc: 39.6
Epoch: [1/5], Step: [1701/3125], Validation Acc: 40.8
Epoch: [1/5], Step: [1801/3125], Validation Acc: 40.9
Epoch: [1/5], Step: [1901/3125], Vali

KeyboardInterrupt: 

In [50]:
test_acc = test_model(test_loader, model,classification_network=classification_network)

In [51]:
test_acc

67.2

In [53]:
class CNN(nn.Module):
    def __init__(self, hidden_size, num_layers, vocab_size,weights,kernel_size = 3):

        super(CNN, self).__init__()

        self.num_layers, self.hidden_size = num_layers, hidden_size
        self.embedding = nn.Embedding(vocab_size, 300, padding_idx=PAD_IDX)
        self.embedding.weight.data.copy_(torch.from_numpy(weights))
        self.kernel_size = kernel_size
        self.conv1 = nn.Conv1d(300, hidden_size, kernel_size=self.kernel_size, padding=1, bias = True)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.conv2 = nn.Conv1d(hidden_size, hidden_size, kernel_size=self.kernel_size, padding=1, bias = True)
        self.bn2 = nn.BatchNorm1d(hidden_size)
        
        self.relu = nn.ReLU()

        #self.linear = nn.Linear(hidden_size, num_classes)
        #self._init_weights()

    def forward(self, x, lengths,masks):
        batch_size, seq_len = x.size()

        embed = self.embedding(x)
        embed = masks*embed + (1-masks)*embed.clone().detach()
        embed = embed.transpose(1,2)
        #hidden = self.conv1(embed.transpose(1,2)).transpose(1,2)
        #hidden = self.relu(hidden.contiguous().view(-1, hidden.size(-1))).view(batch_size, seq_len, hidden.size(-1))      
#         hidden = self.conv2(hidden.transpose(1,2)).transpose(1,2)
#         hidden = self.relu(hidden.contiguous().view(-1, hidden.size(-1))).view(batch_size, seq_len, hidden.size(-1))
        
        #changed to work like a normal conv layer where the final hidden size is:(batch_size,num_filters,output_size)
        #and finally maxpooling over the final output size
        hidden = self.conv1(embed)
        #hidden = self.bn1(hidden)
        hidden = self.relu(hidden)
        hidden = self.conv2(hidden)
        #hidden = self.bn2(hidden)
        hidden = self.relu(hidden)
        #print(hidden.size())
        hidden = torch.max(hidden, dim=2)
        #print(hidden[0].size())
        #logits = self.linear(hidden)
        return hidden[0]
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.xavier_normal_(m.weight)
                nn.init.uniform_(m.bias)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.uniform_(m.bias)

In [57]:
def test_model(loader, model,classification_network):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    total = 0
    model.eval()
    for x1,x2,length_x1,length_x2,x1_mask,x2_mask,label in loader:
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        outputs_x1 = model(x1, length_x1,x1_mask)
        outputs_x2 = model(x2,length_x2,x2_mask)
        outputs = F.softmax(classification_network(outputs_x1,outputs_x2),dim=1)
        predicted = outputs.max(1, keepdim=True)[1]

        total += label.size(0)
        correct += predicted.eq(label.view_as(predicted)).sum().item()
    return (100 * correct / total)

CNN_HIDDEN_SIZE = 512

model = CNN(hidden_size=CNN_HIDDEN_SIZE, num_layers=1, vocab_size=len(token2id),weights=weights_mat).to(DEVICE)
classification_network = ClassificationNetwork(num_inputs=CNN_HIDDEN_SIZE, hidden_size=LIN_HIDDEN_SIZE, num_outputs=NUM_CLASSES,num_directions=1,interact_type='concat').to(DEVICE)
learning_rate = 3e-4
num_epochs = 10 # number epoch to train

# Criterion and Optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters())+list(classification_network.parameters()), lr=learning_rate)

# Train the model
total_step = len(train_loader)
train_loss_hist = []
val_acc_hist = []

for epoch in range(num_epochs):
    for i, (x1,x2,length_x1,length_x2,x1_mask,x2_mask,label) in enumerate(train_loader):
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        model.train()
        optimizer.zero_grad()
        # Forward pass
        outputs_x1 = model(x1, length_x1,x1_mask)
        outputs_x2 = model(x2,length_x2,x2_mask)
        print(outputs_x1.size())
        outputs = classification_network(outputs_x1,outputs_x2)
        loss = criterion(outputs, label)

        # Backward and optimize
        loss.backward()
        optimizer.step()
        train_loss_hist.append(loss.item())
        # validate every 100 iterations
        if i > 0 and i % 100 == 0:
            val_acc = test_model(val_loader, model,classification_network=classification_network)
            val_acc_hist.append(val_acc)
            print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                       epoch+1, num_epochs, i+1, len(train_loader), val_acc))
            # validate
    
            

val_acc_hist = np.array(val_acc_hist)
max_val_acc = np.max(val_acc_hist)
max_val_acc_epoch = np.argmax(val_acc_hist)
print(max_val_acc)
print(max_val_acc_epoch)

torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size([32, 512])
torch.Size

KeyboardInterrupt: 

In [55]:
RNN_HIDDEN_SIZE

512

In [34]:
#based on the original CNN models presented in CNN for text classification paper
#try to do multiple kernels with different kernel sizes and concat their output(results in using just one CNN layer)
#inspired from allennlp repo for seq2vecencoder
class MultipleCNN(nn.Module):
    def __init__(self, hidden_size, num_layers, vocab_size,weights):

        super(MultipleCNN, self).__init__()

        self.num_layers, self.hidden_size = num_layers, hidden_size
        self.embedding = nn.Embedding(vocab_size, 300, padding_idx=PAD_IDX)
        self.embedding.weight.data.copy_(torch.from_numpy(weights))
        self.filter_sizes = [2,3,4,5]
#         self.conv_filter_1 = nn.Conv1d(300, hidden_size, kernel_size=2)
#         self.conv_filter_2 = nn.Conv1d(300, hidden_size, kernel_size=3)
#         self.conv_filter_3 = nn.Conv1d(300, hidden_size, kernel_size=2)
#         self.conv_filter_4 = nn.Conv1d(300, hidden_size, kernel_size=3)

        self.conv_layers = [nn.Conv1d(300, hidden_size, kernel_size=i, padding=1) for i in self.filter_sizes]
        
        #add modules for later using it as list rather than adding each one individually 
        
        for i,conv_layer in enumerate(self.conv_layers):
            self.add_module('conv_layer_{}'.format(i),conv_layer)
        #self.conv1 = nn.Conv1d(300, hidden_size, kernel_size=3, padding=1)
        #self.conv2 = nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1)
        
        self.relu = nn.ReLU()

        #self.linear = nn.Linear(hidden_size, num_classes)

    def forward(self, x, lengths,masks):
        batch_size, seq_len = x.size()

        embed = self.embedding(x)
        embed = masks*embed + (1-masks)*embed.clone().detach()
        embed = embed.transpose(1,2)
        #hidden = self.conv1(embed.transpose(1,2)).transpose(1,2)
        #hidden = self.relu(hidden.contiguous().view(-1, hidden.size(-1))).view(batch_size, seq_len, hidden.size(-1))      
#         hidden = self.conv2(hidden.transpose(1,2)).transpose(1,2)
#         hidden = self.relu(hidden.contiguous().view(-1, hidden.size(-1))).view(batch_size, seq_len, hidden.size(-1))
        
        #changed to work like a normal conv layer where the final hidden size is:(batch_size,num_filters,output_size)
        #and finally maxpooling over the final output size
        outputs_conv = []
        for i in range(len(self.filter_sizes)):
            conv_layer = getattr(self,'conv_layer_{}'.format(i))
            outputs_conv.append(self.relu(conv_layer(embed)).max(dim=2)[0])
        
        final_out = torch.cat(outputs_conv,dim=1)
        #print(final_out.size())
        #print(hidden[0].size())
        #logits = self.linear(hidden)
        return final_out

In [36]:
def test_model(loader, model,classification_network):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    total = 0
    model.eval()
    for x1,x2,length_x1,length_x2,x1_mask,x2_mask,label in loader:
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        outputs_x1 = model(x1, length_x1,x1_mask)
        outputs_x2 = model(x2,length_x2,x2_mask)
        outputs = F.softmax(classification_network(outputs_x1,outputs_x2),dim=1)
        predicted = outputs.max(1, keepdim=True)[1]

        total += label.size(0)
        correct += predicted.eq(label.view_as(predicted)).sum().item()
    return (100 * correct / total)

CNN_HIDDEN_SIZE = 512
model = MultipleCNN(hidden_size=CNN_HIDDEN_SIZE, num_layers=1, vocab_size=len(token2id),weights=weights_mat).to(DEVICE)
classification_network = ClassificationNetwork(num_inputs=4*CNN_HIDDEN_SIZE, hidden_size=LIN_HIDDEN_SIZE, num_outputs=NUM_CLASSES,num_directions=1).to(DEVICE)
learning_rate = 3e-4
num_epochs = 10 # number epoch to train

# Criterion and Optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters())+list(classification_network.parameters()), lr=learning_rate)

# Train the model
total_step = len(train_loader)
train_loss_hist = []
val_acc_hist = []
for epoch in range(num_epochs):
    for i, (x1,x2,length_x1,length_x2,x1_mask,x2_mask,label) in enumerate(train_loader):
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        model.train()
        optimizer.zero_grad()
        # Forward pass
        outputs_x1 = model(x1, length_x1,x1_mask)
        outputs_x2 = model(x2,length_x2,x2_mask)
        outputs = classification_network(outputs_x1,outputs_x2)
        loss = criterion(outputs, label)

        # Backward and optimize
        loss.backward()
        optimizer.step()
        train_loss_hist.append(loss.item())
        # validate every 100 iterations
        if i > 0 and i % 100 == 0:
            # validate
            val_acc = test_model(val_loader, model,classification_network=classification_network)
            print('Epoch: [{}/{}], Step: [{}/{}], Training Loss: {}'.format(epoch+1,num_epochs,i+1,len(train_loader),loss.item()))
            print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                       epoch+1, num_epochs, i+1, len(train_loader), val_acc))
            val_acc_hist.append(val_acc)
            

Epoch: [1/10], Step: [101/3125], Training Loss: 1.1229647397994995
Epoch: [1/10], Step: [101/3125], Validation Acc: 46.2
Epoch: [1/10], Step: [201/3125], Training Loss: 1.0760072469711304
Epoch: [1/10], Step: [201/3125], Validation Acc: 55.0
Epoch: [1/10], Step: [301/3125], Training Loss: 0.8719912767410278
Epoch: [1/10], Step: [301/3125], Validation Acc: 56.2
Epoch: [1/10], Step: [401/3125], Training Loss: 0.8041960000991821
Epoch: [1/10], Step: [401/3125], Validation Acc: 55.8
Epoch: [1/10], Step: [501/3125], Training Loss: 0.7246642112731934
Epoch: [1/10], Step: [501/3125], Validation Acc: 60.5
Epoch: [1/10], Step: [601/3125], Training Loss: 0.7938550114631653
Epoch: [1/10], Step: [601/3125], Validation Acc: 60.5
Epoch: [1/10], Step: [701/3125], Training Loss: 0.9209326505661011
Epoch: [1/10], Step: [701/3125], Validation Acc: 56.7
Epoch: [1/10], Step: [801/3125], Training Loss: 0.7432420253753662
Epoch: [1/10], Step: [801/3125], Validation Acc: 59.3
Epoch: [1/10], Step: [901/3125],

Epoch: [3/10], Step: [601/3125], Training Loss: 0.463113397359848
Epoch: [3/10], Step: [601/3125], Validation Acc: 66.2
Epoch: [3/10], Step: [701/3125], Training Loss: 0.5352632999420166
Epoch: [3/10], Step: [701/3125], Validation Acc: 65.6
Epoch: [3/10], Step: [801/3125], Training Loss: 0.6947352886199951
Epoch: [3/10], Step: [801/3125], Validation Acc: 66.5
Epoch: [3/10], Step: [901/3125], Training Loss: 0.4957398772239685
Epoch: [3/10], Step: [901/3125], Validation Acc: 67.0
Epoch: [3/10], Step: [1001/3125], Training Loss: 0.8345785737037659
Epoch: [3/10], Step: [1001/3125], Validation Acc: 66.0
Epoch: [3/10], Step: [1101/3125], Training Loss: 0.5729821920394897
Epoch: [3/10], Step: [1101/3125], Validation Acc: 64.1
Epoch: [3/10], Step: [1201/3125], Training Loss: 0.5527711510658264
Epoch: [3/10], Step: [1201/3125], Validation Acc: 65.1
Epoch: [3/10], Step: [1301/3125], Training Loss: 0.4693891108036041
Epoch: [3/10], Step: [1301/3125], Validation Acc: 67.1
Epoch: [3/10], Step: [140

Epoch: [5/10], Step: [1101/3125], Training Loss: 0.4246789813041687
Epoch: [5/10], Step: [1101/3125], Validation Acc: 65.9
Epoch: [5/10], Step: [1201/3125], Training Loss: 0.09909610450267792
Epoch: [5/10], Step: [1201/3125], Validation Acc: 67.5
Epoch: [5/10], Step: [1301/3125], Training Loss: 0.30465900897979736
Epoch: [5/10], Step: [1301/3125], Validation Acc: 67.9
Epoch: [5/10], Step: [1401/3125], Training Loss: 0.246515691280365
Epoch: [5/10], Step: [1401/3125], Validation Acc: 68.3
Epoch: [5/10], Step: [1501/3125], Training Loss: 0.31940966844558716
Epoch: [5/10], Step: [1501/3125], Validation Acc: 66.5
Epoch: [5/10], Step: [1601/3125], Training Loss: 0.2033720165491104
Epoch: [5/10], Step: [1601/3125], Validation Acc: 67.0
Epoch: [5/10], Step: [1701/3125], Training Loss: 0.5279890298843384
Epoch: [5/10], Step: [1701/3125], Validation Acc: 66.4
Epoch: [5/10], Step: [1801/3125], Training Loss: 0.5281568765640259
Epoch: [5/10], Step: [1801/3125], Validation Acc: 67.7
Epoch: [5/10],

Epoch: [7/10], Step: [1601/3125], Training Loss: 0.08262232691049576
Epoch: [7/10], Step: [1601/3125], Validation Acc: 67.1
Epoch: [7/10], Step: [1701/3125], Training Loss: 0.2034592628479004
Epoch: [7/10], Step: [1701/3125], Validation Acc: 65.2
Epoch: [7/10], Step: [1801/3125], Training Loss: 0.22055676579475403
Epoch: [7/10], Step: [1801/3125], Validation Acc: 67.0
Epoch: [7/10], Step: [1901/3125], Training Loss: 0.2956238090991974
Epoch: [7/10], Step: [1901/3125], Validation Acc: 66.4
Epoch: [7/10], Step: [2001/3125], Training Loss: 0.33237606287002563
Epoch: [7/10], Step: [2001/3125], Validation Acc: 67.1
Epoch: [7/10], Step: [2101/3125], Training Loss: 0.21481171250343323
Epoch: [7/10], Step: [2101/3125], Validation Acc: 66.8
Epoch: [7/10], Step: [2201/3125], Training Loss: 0.3193310499191284
Epoch: [7/10], Step: [2201/3125], Validation Acc: 68.5
Epoch: [7/10], Step: [2301/3125], Training Loss: 0.17313474416732788
Epoch: [7/10], Step: [2301/3125], Validation Acc: 66.6
Epoch: [7/1

Epoch: [9/10], Step: [2101/3125], Training Loss: 0.1275109350681305
Epoch: [9/10], Step: [2101/3125], Validation Acc: 67.9
Epoch: [9/10], Step: [2201/3125], Training Loss: 0.07499410212039948
Epoch: [9/10], Step: [2201/3125], Validation Acc: 67.2
Epoch: [9/10], Step: [2301/3125], Training Loss: 0.21343280375003815
Epoch: [9/10], Step: [2301/3125], Validation Acc: 67.2
Epoch: [9/10], Step: [2401/3125], Training Loss: 0.08270110934972763
Epoch: [9/10], Step: [2401/3125], Validation Acc: 67.5
Epoch: [9/10], Step: [2501/3125], Training Loss: 0.03841625154018402
Epoch: [9/10], Step: [2501/3125], Validation Acc: 67.0
Epoch: [9/10], Step: [2601/3125], Training Loss: 0.061952102929353714
Epoch: [9/10], Step: [2601/3125], Validation Acc: 68.1
Epoch: [9/10], Step: [2701/3125], Training Loss: 0.20915617048740387
Epoch: [9/10], Step: [2701/3125], Validation Acc: 68.9
Epoch: [9/10], Step: [2801/3125], Training Loss: 0.08882326632738113
Epoch: [9/10], Step: [2801/3125], Validation Acc: 66.1
Epoch: [

In [39]:
max_val_acc = np.max(val_acc_hist)
max_val_acc_epoch = np.argmax(val_acc_hist)
print(max_val_acc)

69.6


In [45]:
pkl.dump(train_loss_hist,open('train_hist_mult_cnn.pkl','wb'))
pkl.dump(val_acc_hist,open('val_hist_mult_cnn.pkl','wb'))

In [137]:
CNN_HIDDEN_SIZES = [256,512]
INTERACT_TYPES = ['concat','mul']
KERNEL_SIZES = [3,5]
somelists = [CNN_HIDDEN_SIZES,KERNEL_SIZES,INTERACT_TYPES]

result = list(itertools.product(*somelists))
df_param = pd.DataFrame(result,columns=['hidden_size','kernel_size','interaction_type'])
df_param['train_loss_hist'] = None
df_param['val_acc_hist'] = None
df_param['max_val_acc'] = None
df_param['max_val_acc_epoch'] = None

In [138]:
df_param

Unnamed: 0,hidden_size,kernel_size,interaction_type,train_loss_hist,val_acc_hist,max_val_acc,max_val_acc_epoch
0,256,3,concat,,,,
1,256,3,mul,,,,
2,256,5,concat,,,,
3,256,5,mul,,,,
4,512,3,concat,,,,
5,512,3,mul,,,,
6,512,5,concat,,,,
7,512,5,mul,,,,


In [103]:
CNN_HIDDEN_SIZES = [200,512]
INTERACT_TYPES = ['concat','mul']
KERNEL_SIZES = [3,5]
somelists = [CNN_HIDDEN_SIZES,KERNEL_SIZES,INTERACT_TYPES]
LIN_HIDDEN_SIZE = 256
result = list(itertools.product(*somelists))
df_param = pd.DataFrame(result,columns=['hidden_size','kernel_size','interaction_type'])
df_param['train_loss_hist'] = None
df_param['val_acc_hist'] = None
df_param['max_val_acc'] = None
df_param['max_val_acc_epoch'] = None

learning_rate = 3e-4
num_epochs = 10 # number epoch to train
for param_i in range(len(df_param)):
    print(df_param.iloc[param_i])
    CNN_HIDDEN_SIZE = df_param.iloc[param_i]['hidden_size']
    KERNEL_SIZE = df_param.iloc[param_i]['kernel_size']
    INTERACT_TYPE = df_param.iloc[param_i]['interaction_type']
    
    model = CNN(hidden_size=CNN_HIDDEN_SIZE, num_layers=1, vocab_size=len(token2id),weights=weights_mat,kernel_size=KERNEL_SIZE).to(DEVICE)
    classification_network = ClassificationNetwork(num_inputs=CNN_HIDDEN_SIZE, hidden_size=LIN_HIDDEN_SIZE, num_outputs=NUM_CLASSES,num_directions=1,interact_type=INTERACT_TYPE).to(DEVICE)
    
    # Criterion and Optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(model.parameters())+list(classification_network.parameters()), lr=learning_rate)

    # Train the model
    total_step = len(train_loader)
    train_loss_hist = []
    val_acc_hist = []

    for epoch in range(num_epochs):
        for i, (x1,x2,length_x1,length_x2,x1_mask,x2_mask,label) in enumerate(train_loader):
            x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
            model.train()
            optimizer.zero_grad()
            # Forward pass
            outputs_x1 = model(x1, length_x1,x1_mask)
            outputs_x2 = model(x2,length_x2,x2_mask)
            outputs = classification_network(outputs_x1,outputs_x2)
            loss = criterion(outputs, label)

            # Backward and optimize
            loss.backward()
            optimizer.step()
            train_loss_hist.append(loss.item())
            # validate every 100 iterations
            if i > 0 and i % 100 == 0:
                val_acc = test_model(val_loader, model,classification_network=classification_network)
                print('Epoch: [{}/{}], Step: [{}/{}], Training Loss: {}'.format(epoch+1,num_epochs,i+1,len(train_loader),loss.item()))
                val_acc_hist.append(val_acc)
                print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(epoch+1, num_epochs, i+1, len(train_loader), val_acc))
                # validate
        val_acc = test_model(val_loader, model,classification_network=classification_network)
        val_acc_hist.append(val_acc)
        print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                   epoch+1, num_epochs, i+1, len(train_loader), val_acc))


    val_acc_hist = np.array(val_acc_hist)
    max_val_acc = np.max(val_acc_hist)
    max_val_acc_epoch = np.argmax(val_acc_hist)
    #df_param.set_value(i,'train_loss_hist',train_loss_hist)
    df_param.at[param_i,'train_loss_hist'] = np.array(train_loss_hist)
    
    df_param.at[param_i,'val_acc_hist'] = val_acc_hist
    df_param.at[param_i,'max_val_acc'] = max_val_acc
    df_param.at[param_i,'max_val_acc_epoch'] = max_val_acc_epoch+1
    print(max_val_acc)
    print(max_val_acc_epoch)

hidden_size             200
kernel_size               3
interaction_type     concat
train_loss_hist        None
val_acc_hist           None
max_val_acc            None
max_val_acc_epoch      None
Name: 0, dtype: object
Epoch: [1/10], Step: [101/3125], Training Loss: 1.0734792947769165
Epoch: [1/10], Step: [101/3125], Validation Acc: 44.5
Epoch: [1/10], Step: [201/3125], Training Loss: 1.0744177103042603
Epoch: [1/10], Step: [201/3125], Validation Acc: 55.0
Epoch: [1/10], Step: [301/3125], Training Loss: 1.028780221939087
Epoch: [1/10], Step: [301/3125], Validation Acc: 54.5
Epoch: [1/10], Step: [401/3125], Training Loss: 0.6995874643325806
Epoch: [1/10], Step: [401/3125], Validation Acc: 56.7
Epoch: [1/10], Step: [501/3125], Training Loss: 0.8842617869377136
Epoch: [1/10], Step: [501/3125], Validation Acc: 58.5
Epoch: [1/10], Step: [601/3125], Training Loss: 1.0650224685668945
Epoch: [1/10], Step: [601/3125], Validation Acc: 56.8
Epoch: [1/10], Step: [701/3125], Training Loss: 0.718243

Epoch: [3/10], Step: [401/3125], Training Loss: 1.0571174621582031
Epoch: [3/10], Step: [401/3125], Validation Acc: 67.8
Epoch: [3/10], Step: [501/3125], Training Loss: 0.6541440486907959
Epoch: [3/10], Step: [501/3125], Validation Acc: 66.9
Epoch: [3/10], Step: [601/3125], Training Loss: 0.6137204766273499
Epoch: [3/10], Step: [601/3125], Validation Acc: 68.2
Epoch: [3/10], Step: [701/3125], Training Loss: 0.59590744972229
Epoch: [3/10], Step: [701/3125], Validation Acc: 66.9
Epoch: [3/10], Step: [801/3125], Training Loss: 0.6771609783172607
Epoch: [3/10], Step: [801/3125], Validation Acc: 66.6
Epoch: [3/10], Step: [901/3125], Training Loss: 0.9394620060920715
Epoch: [3/10], Step: [901/3125], Validation Acc: 67.6
Epoch: [3/10], Step: [1001/3125], Training Loss: 0.6120421290397644
Epoch: [3/10], Step: [1001/3125], Validation Acc: 66.8
Epoch: [3/10], Step: [1101/3125], Training Loss: 0.8885984420776367
Epoch: [3/10], Step: [1101/3125], Validation Acc: 65.9
Epoch: [3/10], Step: [1201/312

Epoch: [5/10], Step: [901/3125], Training Loss: 0.6541707515716553
Epoch: [5/10], Step: [901/3125], Validation Acc: 68.9
Epoch: [5/10], Step: [1001/3125], Training Loss: 0.46971943974494934
Epoch: [5/10], Step: [1001/3125], Validation Acc: 68.2
Epoch: [5/10], Step: [1101/3125], Training Loss: 0.795462429523468
Epoch: [5/10], Step: [1101/3125], Validation Acc: 68.5
Epoch: [5/10], Step: [1201/3125], Training Loss: 0.5756881833076477
Epoch: [5/10], Step: [1201/3125], Validation Acc: 65.2
Epoch: [5/10], Step: [1301/3125], Training Loss: 0.37578344345092773
Epoch: [5/10], Step: [1301/3125], Validation Acc: 68.2
Epoch: [5/10], Step: [1401/3125], Training Loss: 0.4424251914024353
Epoch: [5/10], Step: [1401/3125], Validation Acc: 69.5
Epoch: [5/10], Step: [1501/3125], Training Loss: 0.44197434186935425
Epoch: [5/10], Step: [1501/3125], Validation Acc: 68.3
Epoch: [5/10], Step: [1601/3125], Training Loss: 0.45086583495140076
Epoch: [5/10], Step: [1601/3125], Validation Acc: 66.2
Epoch: [5/10], 

Epoch: [7/10], Step: [1301/3125], Training Loss: 0.47891440987586975
Epoch: [7/10], Step: [1301/3125], Validation Acc: 67.5
Epoch: [7/10], Step: [1401/3125], Training Loss: 0.19726520776748657
Epoch: [7/10], Step: [1401/3125], Validation Acc: 68.1
Epoch: [7/10], Step: [1501/3125], Training Loss: 0.24004001915454865
Epoch: [7/10], Step: [1501/3125], Validation Acc: 69.2
Epoch: [7/10], Step: [1601/3125], Training Loss: 0.42772766947746277
Epoch: [7/10], Step: [1601/3125], Validation Acc: 67.5
Epoch: [7/10], Step: [1701/3125], Training Loss: 0.3925114870071411
Epoch: [7/10], Step: [1701/3125], Validation Acc: 66.6
Epoch: [7/10], Step: [1801/3125], Training Loss: 0.44506099820137024
Epoch: [7/10], Step: [1801/3125], Validation Acc: 65.7
Epoch: [7/10], Step: [1901/3125], Training Loss: 0.5493829846382141
Epoch: [7/10], Step: [1901/3125], Validation Acc: 68.0
Epoch: [7/10], Step: [2001/3125], Training Loss: 0.5653658509254456
Epoch: [7/10], Step: [2001/3125], Validation Acc: 67.0
Epoch: [7/1

Epoch: [9/10], Step: [1701/3125], Training Loss: 0.23722928762435913
Epoch: [9/10], Step: [1701/3125], Validation Acc: 66.9
Epoch: [9/10], Step: [1801/3125], Training Loss: 0.31583282351493835
Epoch: [9/10], Step: [1801/3125], Validation Acc: 66.8
Epoch: [9/10], Step: [1901/3125], Training Loss: 0.2600606679916382
Epoch: [9/10], Step: [1901/3125], Validation Acc: 67.8
Epoch: [9/10], Step: [2001/3125], Training Loss: 0.3725855350494385
Epoch: [9/10], Step: [2001/3125], Validation Acc: 68.4
Epoch: [9/10], Step: [2101/3125], Training Loss: 0.2850981056690216
Epoch: [9/10], Step: [2101/3125], Validation Acc: 69.1
Epoch: [9/10], Step: [2201/3125], Training Loss: 0.270473450422287
Epoch: [9/10], Step: [2201/3125], Validation Acc: 68.4
Epoch: [9/10], Step: [2301/3125], Training Loss: 0.2543026804924011
Epoch: [9/10], Step: [2301/3125], Validation Acc: 68.4
Epoch: [9/10], Step: [2401/3125], Training Loss: 0.22653305530548096
Epoch: [9/10], Step: [2401/3125], Validation Acc: 68.2
Epoch: [9/10],

Epoch: [1/10], Step: [1901/3125], Training Loss: 0.8402913212776184
Epoch: [1/10], Step: [1901/3125], Validation Acc: 56.4
Epoch: [1/10], Step: [2001/3125], Training Loss: 0.898898184299469
Epoch: [1/10], Step: [2001/3125], Validation Acc: 55.8
Epoch: [1/10], Step: [2101/3125], Training Loss: 0.8270833492279053
Epoch: [1/10], Step: [2101/3125], Validation Acc: 57.9
Epoch: [1/10], Step: [2201/3125], Training Loss: 0.8428290486335754
Epoch: [1/10], Step: [2201/3125], Validation Acc: 56.8
Epoch: [1/10], Step: [2301/3125], Training Loss: 0.9624852538108826
Epoch: [1/10], Step: [2301/3125], Validation Acc: 56.8
Epoch: [1/10], Step: [2401/3125], Training Loss: 0.8046665191650391
Epoch: [1/10], Step: [2401/3125], Validation Acc: 57.5
Epoch: [1/10], Step: [2501/3125], Training Loss: 0.935598611831665
Epoch: [1/10], Step: [2501/3125], Validation Acc: 59.6
Epoch: [1/10], Step: [2601/3125], Training Loss: 0.8730244636535645
Epoch: [1/10], Step: [2601/3125], Validation Acc: 58.4
Epoch: [1/10], Ste

Epoch: [3/10], Step: [2401/3125], Training Loss: 0.7729149460792542
Epoch: [3/10], Step: [2401/3125], Validation Acc: 64.0
Epoch: [3/10], Step: [2501/3125], Training Loss: 0.6847625970840454
Epoch: [3/10], Step: [2501/3125], Validation Acc: 62.7
Epoch: [3/10], Step: [2601/3125], Training Loss: 0.681370735168457
Epoch: [3/10], Step: [2601/3125], Validation Acc: 63.5
Epoch: [3/10], Step: [2701/3125], Training Loss: 0.5371304750442505
Epoch: [3/10], Step: [2701/3125], Validation Acc: 64.2
Epoch: [3/10], Step: [2801/3125], Training Loss: 0.8388192057609558
Epoch: [3/10], Step: [2801/3125], Validation Acc: 67.0
Epoch: [3/10], Step: [2901/3125], Training Loss: 0.6973751783370972
Epoch: [3/10], Step: [2901/3125], Validation Acc: 64.4
Epoch: [3/10], Step: [3001/3125], Training Loss: 0.8275224566459656
Epoch: [3/10], Step: [3001/3125], Validation Acc: 64.9
Epoch: [3/10], Step: [3101/3125], Training Loss: 0.7040328979492188
Epoch: [3/10], Step: [3101/3125], Validation Acc: 64.6
Epoch: [3/10], St

Epoch: [5/10], Step: [2901/3125], Training Loss: 0.5706024765968323
Epoch: [5/10], Step: [2901/3125], Validation Acc: 64.0
Epoch: [5/10], Step: [3001/3125], Training Loss: 0.8212899565696716
Epoch: [5/10], Step: [3001/3125], Validation Acc: 63.7
Epoch: [5/10], Step: [3101/3125], Training Loss: 0.7624106407165527
Epoch: [5/10], Step: [3101/3125], Validation Acc: 65.3
Epoch: [5/10], Step: [3125/3125], Validation Acc: 66.0
Epoch: [6/10], Step: [101/3125], Training Loss: 0.7318341135978699
Epoch: [6/10], Step: [101/3125], Validation Acc: 66.8
Epoch: [6/10], Step: [201/3125], Training Loss: 0.513002872467041
Epoch: [6/10], Step: [201/3125], Validation Acc: 66.1
Epoch: [6/10], Step: [301/3125], Training Loss: 0.9127899408340454
Epoch: [6/10], Step: [301/3125], Validation Acc: 65.0
Epoch: [6/10], Step: [401/3125], Training Loss: 0.3818405866622925
Epoch: [6/10], Step: [401/3125], Validation Acc: 62.8
Epoch: [6/10], Step: [501/3125], Training Loss: 0.7695192694664001
Epoch: [6/10], Step: [501/

Epoch: [8/10], Step: [201/3125], Training Loss: 0.44973284006118774
Epoch: [8/10], Step: [201/3125], Validation Acc: 66.5
Epoch: [8/10], Step: [301/3125], Training Loss: 0.5507070422172546
Epoch: [8/10], Step: [301/3125], Validation Acc: 65.3
Epoch: [8/10], Step: [401/3125], Training Loss: 0.25497984886169434
Epoch: [8/10], Step: [401/3125], Validation Acc: 65.7
Epoch: [8/10], Step: [501/3125], Training Loss: 0.4188198447227478
Epoch: [8/10], Step: [501/3125], Validation Acc: 65.4
Epoch: [8/10], Step: [601/3125], Training Loss: 0.34905344247817993
Epoch: [8/10], Step: [601/3125], Validation Acc: 66.7
Epoch: [8/10], Step: [701/3125], Training Loss: 0.4267171025276184
Epoch: [8/10], Step: [701/3125], Validation Acc: 65.2
Epoch: [8/10], Step: [801/3125], Training Loss: 0.36455899477005005
Epoch: [8/10], Step: [801/3125], Validation Acc: 67.4
Epoch: [8/10], Step: [901/3125], Training Loss: 0.3178556263446808
Epoch: [8/10], Step: [901/3125], Validation Acc: 65.8
Epoch: [8/10], Step: [1001/3

Epoch: [10/10], Step: [601/3125], Training Loss: 0.4906935393810272
Epoch: [10/10], Step: [601/3125], Validation Acc: 67.9
Epoch: [10/10], Step: [701/3125], Training Loss: 0.3007795214653015
Epoch: [10/10], Step: [701/3125], Validation Acc: 67.8
Epoch: [10/10], Step: [801/3125], Training Loss: 0.2787054777145386
Epoch: [10/10], Step: [801/3125], Validation Acc: 66.1
Epoch: [10/10], Step: [901/3125], Training Loss: 0.21237266063690186
Epoch: [10/10], Step: [901/3125], Validation Acc: 66.3
Epoch: [10/10], Step: [1001/3125], Training Loss: 0.36229151487350464
Epoch: [10/10], Step: [1001/3125], Validation Acc: 67.2
Epoch: [10/10], Step: [1101/3125], Training Loss: 0.30740463733673096
Epoch: [10/10], Step: [1101/3125], Validation Acc: 65.2
Epoch: [10/10], Step: [1201/3125], Training Loss: 0.22371764481067657
Epoch: [10/10], Step: [1201/3125], Validation Acc: 68.0
Epoch: [10/10], Step: [1301/3125], Training Loss: 0.4700711965560913
Epoch: [10/10], Step: [1301/3125], Validation Acc: 67.4
Epoc

Epoch: [2/10], Step: [801/3125], Training Loss: 0.8322593569755554
Epoch: [2/10], Step: [801/3125], Validation Acc: 63.0
Epoch: [2/10], Step: [901/3125], Training Loss: 1.12204909324646
Epoch: [2/10], Step: [901/3125], Validation Acc: 63.4
Epoch: [2/10], Step: [1001/3125], Training Loss: 0.6552528738975525
Epoch: [2/10], Step: [1001/3125], Validation Acc: 62.3
Epoch: [2/10], Step: [1101/3125], Training Loss: 0.8722684383392334
Epoch: [2/10], Step: [1101/3125], Validation Acc: 61.8
Epoch: [2/10], Step: [1201/3125], Training Loss: 0.7358697652816772
Epoch: [2/10], Step: [1201/3125], Validation Acc: 63.8
Epoch: [2/10], Step: [1301/3125], Training Loss: 0.8436897993087769
Epoch: [2/10], Step: [1301/3125], Validation Acc: 63.0
Epoch: [2/10], Step: [1401/3125], Training Loss: 0.8538573384284973
Epoch: [2/10], Step: [1401/3125], Validation Acc: 63.3
Epoch: [2/10], Step: [1501/3125], Training Loss: 0.6955330967903137
Epoch: [2/10], Step: [1501/3125], Validation Acc: 63.7
Epoch: [2/10], Step: [

Epoch: [4/10], Step: [1301/3125], Training Loss: 0.696312665939331
Epoch: [4/10], Step: [1301/3125], Validation Acc: 66.7
Epoch: [4/10], Step: [1401/3125], Training Loss: 0.43636414408683777
Epoch: [4/10], Step: [1401/3125], Validation Acc: 65.7
Epoch: [4/10], Step: [1501/3125], Training Loss: 0.7810540795326233
Epoch: [4/10], Step: [1501/3125], Validation Acc: 67.9
Epoch: [4/10], Step: [1601/3125], Training Loss: 0.48136189579963684
Epoch: [4/10], Step: [1601/3125], Validation Acc: 66.1
Epoch: [4/10], Step: [1701/3125], Training Loss: 0.549876868724823
Epoch: [4/10], Step: [1701/3125], Validation Acc: 65.5
Epoch: [4/10], Step: [1801/3125], Training Loss: 0.5363568067550659
Epoch: [4/10], Step: [1801/3125], Validation Acc: 65.1
Epoch: [4/10], Step: [1901/3125], Training Loss: 0.5599669814109802
Epoch: [4/10], Step: [1901/3125], Validation Acc: 65.1
Epoch: [4/10], Step: [2001/3125], Training Loss: 0.4807821810245514
Epoch: [4/10], Step: [2001/3125], Validation Acc: 65.8
Epoch: [4/10], S

Epoch: [6/10], Step: [1701/3125], Training Loss: 0.4456534683704376
Epoch: [6/10], Step: [1701/3125], Validation Acc: 64.5
Epoch: [6/10], Step: [1801/3125], Training Loss: 0.6322200894355774
Epoch: [6/10], Step: [1801/3125], Validation Acc: 64.5
Epoch: [6/10], Step: [1901/3125], Training Loss: 0.8216878771781921
Epoch: [6/10], Step: [1901/3125], Validation Acc: 63.7
Epoch: [6/10], Step: [2001/3125], Training Loss: 0.5433279871940613
Epoch: [6/10], Step: [2001/3125], Validation Acc: 65.0
Epoch: [6/10], Step: [2101/3125], Training Loss: 0.49276304244995117
Epoch: [6/10], Step: [2101/3125], Validation Acc: 64.0
Epoch: [6/10], Step: [2201/3125], Training Loss: 0.572292149066925
Epoch: [6/10], Step: [2201/3125], Validation Acc: 64.8
Epoch: [6/10], Step: [2301/3125], Training Loss: 0.8296078443527222
Epoch: [6/10], Step: [2301/3125], Validation Acc: 65.1
Epoch: [6/10], Step: [2401/3125], Training Loss: 0.5993982553482056
Epoch: [6/10], Step: [2401/3125], Validation Acc: 64.5
Epoch: [6/10], S

Epoch: [8/10], Step: [2101/3125], Training Loss: 0.2570982575416565
Epoch: [8/10], Step: [2101/3125], Validation Acc: 62.2
Epoch: [8/10], Step: [2201/3125], Training Loss: 0.3564380407333374
Epoch: [8/10], Step: [2201/3125], Validation Acc: 61.6
Epoch: [8/10], Step: [2301/3125], Training Loss: 0.6120457649230957
Epoch: [8/10], Step: [2301/3125], Validation Acc: 63.6
Epoch: [8/10], Step: [2401/3125], Training Loss: 0.24654044210910797
Epoch: [8/10], Step: [2401/3125], Validation Acc: 62.5
Epoch: [8/10], Step: [2501/3125], Training Loss: 0.11985526233911514
Epoch: [8/10], Step: [2501/3125], Validation Acc: 63.6
Epoch: [8/10], Step: [2601/3125], Training Loss: 0.4723634123802185
Epoch: [8/10], Step: [2601/3125], Validation Acc: 63.2
Epoch: [8/10], Step: [2701/3125], Training Loss: 0.28689414262771606
Epoch: [8/10], Step: [2701/3125], Validation Acc: 64.6
Epoch: [8/10], Step: [2801/3125], Training Loss: 0.21431097388267517
Epoch: [8/10], Step: [2801/3125], Validation Acc: 63.7
Epoch: [8/10

Epoch: [10/10], Step: [2501/3125], Training Loss: 0.04979032650589943
Epoch: [10/10], Step: [2501/3125], Validation Acc: 61.8
Epoch: [10/10], Step: [2601/3125], Training Loss: 0.13754446804523468
Epoch: [10/10], Step: [2601/3125], Validation Acc: 62.5
Epoch: [10/10], Step: [2701/3125], Training Loss: 0.315653920173645
Epoch: [10/10], Step: [2701/3125], Validation Acc: 62.7
Epoch: [10/10], Step: [2801/3125], Training Loss: 0.282657265663147
Epoch: [10/10], Step: [2801/3125], Validation Acc: 61.6
Epoch: [10/10], Step: [2901/3125], Training Loss: 0.0688466727733612
Epoch: [10/10], Step: [2901/3125], Validation Acc: 62.2
Epoch: [10/10], Step: [3001/3125], Training Loss: 0.08564211428165436
Epoch: [10/10], Step: [3001/3125], Validation Acc: 63.1
Epoch: [10/10], Step: [3101/3125], Training Loss: 0.2568601667881012
Epoch: [10/10], Step: [3101/3125], Validation Acc: 62.3
Epoch: [10/10], Step: [3125/3125], Validation Acc: 64.0
67.9
110
hidden_size           200
kernel_size             5
interac

Epoch: [2/10], Step: [2801/3125], Training Loss: 0.7381757497787476
Epoch: [2/10], Step: [2801/3125], Validation Acc: 63.7
Epoch: [2/10], Step: [2901/3125], Training Loss: 0.5293483734130859
Epoch: [2/10], Step: [2901/3125], Validation Acc: 63.5
Epoch: [2/10], Step: [3001/3125], Training Loss: 0.7318052053451538
Epoch: [2/10], Step: [3001/3125], Validation Acc: 63.0
Epoch: [2/10], Step: [3101/3125], Training Loss: 1.0729809999465942
Epoch: [2/10], Step: [3101/3125], Validation Acc: 64.2
Epoch: [2/10], Step: [3125/3125], Validation Acc: 62.6
Epoch: [3/10], Step: [101/3125], Training Loss: 0.6764594912528992
Epoch: [3/10], Step: [101/3125], Validation Acc: 63.6
Epoch: [3/10], Step: [201/3125], Training Loss: 0.6391422152519226
Epoch: [3/10], Step: [201/3125], Validation Acc: 62.6
Epoch: [3/10], Step: [301/3125], Training Loss: 0.8141990303993225
Epoch: [3/10], Step: [301/3125], Validation Acc: 60.0
Epoch: [3/10], Step: [401/3125], Training Loss: 0.5270096063613892
Epoch: [3/10], Step: [4

Epoch: [5/10], Step: [101/3125], Training Loss: 0.6356619000434875
Epoch: [5/10], Step: [101/3125], Validation Acc: 62.9
Epoch: [5/10], Step: [201/3125], Training Loss: 0.436411589384079
Epoch: [5/10], Step: [201/3125], Validation Acc: 63.8
Epoch: [5/10], Step: [301/3125], Training Loss: 0.5238315463066101
Epoch: [5/10], Step: [301/3125], Validation Acc: 62.8
Epoch: [5/10], Step: [401/3125], Training Loss: 0.6443713903427124
Epoch: [5/10], Step: [401/3125], Validation Acc: 62.7
Epoch: [5/10], Step: [501/3125], Training Loss: 0.6425521969795227
Epoch: [5/10], Step: [501/3125], Validation Acc: 63.2
Epoch: [5/10], Step: [601/3125], Training Loss: 0.6116721034049988
Epoch: [5/10], Step: [601/3125], Validation Acc: 63.7
Epoch: [5/10], Step: [701/3125], Training Loss: 0.5135117769241333
Epoch: [5/10], Step: [701/3125], Validation Acc: 63.3
Epoch: [5/10], Step: [801/3125], Training Loss: 0.42945393919944763
Epoch: [5/10], Step: [801/3125], Validation Acc: 66.1
Epoch: [5/10], Step: [901/3125],

Epoch: [7/10], Step: [501/3125], Training Loss: 0.4469790756702423
Epoch: [7/10], Step: [501/3125], Validation Acc: 65.1
Epoch: [7/10], Step: [601/3125], Training Loss: 0.29684916138648987
Epoch: [7/10], Step: [601/3125], Validation Acc: 65.1
Epoch: [7/10], Step: [701/3125], Training Loss: 0.5238204598426819
Epoch: [7/10], Step: [701/3125], Validation Acc: 64.6
Epoch: [7/10], Step: [801/3125], Training Loss: 0.3624670207500458
Epoch: [7/10], Step: [801/3125], Validation Acc: 64.1
Epoch: [7/10], Step: [901/3125], Training Loss: 0.38687601685523987
Epoch: [7/10], Step: [901/3125], Validation Acc: 61.8
Epoch: [7/10], Step: [1001/3125], Training Loss: 0.4987877309322357
Epoch: [7/10], Step: [1001/3125], Validation Acc: 63.1
Epoch: [7/10], Step: [1101/3125], Training Loss: 0.566204845905304
Epoch: [7/10], Step: [1101/3125], Validation Acc: 63.4
Epoch: [7/10], Step: [1201/3125], Training Loss: 0.31164389848709106
Epoch: [7/10], Step: [1201/3125], Validation Acc: 64.2
Epoch: [7/10], Step: [13

Epoch: [9/10], Step: [901/3125], Training Loss: 0.21244001388549805
Epoch: [9/10], Step: [901/3125], Validation Acc: 61.5
Epoch: [9/10], Step: [1001/3125], Training Loss: 0.1692037582397461
Epoch: [9/10], Step: [1001/3125], Validation Acc: 61.2
Epoch: [9/10], Step: [1101/3125], Training Loss: 0.3659396767616272
Epoch: [9/10], Step: [1101/3125], Validation Acc: 63.5
Epoch: [9/10], Step: [1201/3125], Training Loss: 0.3391084671020508
Epoch: [9/10], Step: [1201/3125], Validation Acc: 61.1
Epoch: [9/10], Step: [1301/3125], Training Loss: 0.23301169276237488
Epoch: [9/10], Step: [1301/3125], Validation Acc: 63.7
Epoch: [9/10], Step: [1401/3125], Training Loss: 0.26705634593963623
Epoch: [9/10], Step: [1401/3125], Validation Acc: 62.7
Epoch: [9/10], Step: [1501/3125], Training Loss: 0.22003912925720215
Epoch: [9/10], Step: [1501/3125], Validation Acc: 63.8
Epoch: [9/10], Step: [1601/3125], Training Loss: 0.31345558166503906
Epoch: [9/10], Step: [1601/3125], Validation Acc: 63.0
Epoch: [9/10]

Epoch: [1/10], Step: [1101/3125], Training Loss: 0.7445151209831238
Epoch: [1/10], Step: [1101/3125], Validation Acc: 62.0
Epoch: [1/10], Step: [1201/3125], Training Loss: 0.8877547979354858
Epoch: [1/10], Step: [1201/3125], Validation Acc: 62.2
Epoch: [1/10], Step: [1301/3125], Training Loss: 0.8553110361099243
Epoch: [1/10], Step: [1301/3125], Validation Acc: 62.3
Epoch: [1/10], Step: [1401/3125], Training Loss: 0.8366382122039795
Epoch: [1/10], Step: [1401/3125], Validation Acc: 64.6
Epoch: [1/10], Step: [1501/3125], Training Loss: 0.7447705268859863
Epoch: [1/10], Step: [1501/3125], Validation Acc: 63.2
Epoch: [1/10], Step: [1601/3125], Training Loss: 0.7982133626937866
Epoch: [1/10], Step: [1601/3125], Validation Acc: 63.7
Epoch: [1/10], Step: [1701/3125], Training Loss: 0.7374877333641052
Epoch: [1/10], Step: [1701/3125], Validation Acc: 62.7
Epoch: [1/10], Step: [1801/3125], Training Loss: 0.9397774338722229
Epoch: [1/10], Step: [1801/3125], Validation Acc: 62.4
Epoch: [1/10], S

Epoch: [3/10], Step: [1601/3125], Training Loss: 0.7329577803611755
Epoch: [3/10], Step: [1601/3125], Validation Acc: 68.4
Epoch: [3/10], Step: [1701/3125], Training Loss: 0.6278986930847168
Epoch: [3/10], Step: [1701/3125], Validation Acc: 67.6
Epoch: [3/10], Step: [1801/3125], Training Loss: 0.806607186794281
Epoch: [3/10], Step: [1801/3125], Validation Acc: 69.7
Epoch: [3/10], Step: [1901/3125], Training Loss: 0.7533426284790039
Epoch: [3/10], Step: [1901/3125], Validation Acc: 67.8
Epoch: [3/10], Step: [2001/3125], Training Loss: 0.6217280030250549
Epoch: [3/10], Step: [2001/3125], Validation Acc: 66.7
Epoch: [3/10], Step: [2101/3125], Training Loss: 0.6225863695144653
Epoch: [3/10], Step: [2101/3125], Validation Acc: 67.6
Epoch: [3/10], Step: [2201/3125], Training Loss: 0.5531108379364014
Epoch: [3/10], Step: [2201/3125], Validation Acc: 68.5
Epoch: [3/10], Step: [2301/3125], Training Loss: 0.8181281685829163
Epoch: [3/10], Step: [2301/3125], Validation Acc: 70.9
Epoch: [3/10], St

Epoch: [5/10], Step: [2001/3125], Training Loss: 0.516187310218811
Epoch: [5/10], Step: [2001/3125], Validation Acc: 67.8
Epoch: [5/10], Step: [2101/3125], Training Loss: 0.6696739792823792
Epoch: [5/10], Step: [2101/3125], Validation Acc: 68.1
Epoch: [5/10], Step: [2201/3125], Training Loss: 0.49908190965652466
Epoch: [5/10], Step: [2201/3125], Validation Acc: 69.5
Epoch: [5/10], Step: [2301/3125], Training Loss: 0.527518093585968
Epoch: [5/10], Step: [2301/3125], Validation Acc: 68.8
Epoch: [5/10], Step: [2401/3125], Training Loss: 0.4164397120475769
Epoch: [5/10], Step: [2401/3125], Validation Acc: 67.8
Epoch: [5/10], Step: [2501/3125], Training Loss: 0.45591774582862854
Epoch: [5/10], Step: [2501/3125], Validation Acc: 67.4
Epoch: [5/10], Step: [2601/3125], Training Loss: 0.5953698754310608
Epoch: [5/10], Step: [2601/3125], Validation Acc: 68.0
Epoch: [5/10], Step: [2701/3125], Training Loss: 0.37665075063705444
Epoch: [5/10], Step: [2701/3125], Validation Acc: 68.6
Epoch: [5/10], 

Epoch: [7/10], Step: [2401/3125], Training Loss: 0.21406078338623047
Epoch: [7/10], Step: [2401/3125], Validation Acc: 69.1
Epoch: [7/10], Step: [2501/3125], Training Loss: 0.3702787160873413
Epoch: [7/10], Step: [2501/3125], Validation Acc: 68.4
Epoch: [7/10], Step: [2601/3125], Training Loss: 0.22988061606884003
Epoch: [7/10], Step: [2601/3125], Validation Acc: 68.3
Epoch: [7/10], Step: [2701/3125], Training Loss: 0.6769160032272339
Epoch: [7/10], Step: [2701/3125], Validation Acc: 68.8
Epoch: [7/10], Step: [2801/3125], Training Loss: 0.25983926653862
Epoch: [7/10], Step: [2801/3125], Validation Acc: 66.7
Epoch: [7/10], Step: [2901/3125], Training Loss: 0.23668932914733887
Epoch: [7/10], Step: [2901/3125], Validation Acc: 67.7
Epoch: [7/10], Step: [3001/3125], Training Loss: 0.35714396834373474
Epoch: [7/10], Step: [3001/3125], Validation Acc: 69.3
Epoch: [7/10], Step: [3101/3125], Training Loss: 0.3968829810619354
Epoch: [7/10], Step: [3101/3125], Validation Acc: 69.5
Epoch: [7/10],

Epoch: [9/10], Step: [2801/3125], Training Loss: 0.3234059810638428
Epoch: [9/10], Step: [2801/3125], Validation Acc: 67.4
Epoch: [9/10], Step: [2901/3125], Training Loss: 0.1335100680589676
Epoch: [9/10], Step: [2901/3125], Validation Acc: 68.0
Epoch: [9/10], Step: [3001/3125], Training Loss: 0.15935415029525757
Epoch: [9/10], Step: [3001/3125], Validation Acc: 69.1
Epoch: [9/10], Step: [3101/3125], Training Loss: 0.1785789430141449
Epoch: [9/10], Step: [3101/3125], Validation Acc: 69.8
Epoch: [9/10], Step: [3125/3125], Validation Acc: 68.9
Epoch: [10/10], Step: [101/3125], Training Loss: 0.06321727484464645
Epoch: [10/10], Step: [101/3125], Validation Acc: 69.4
Epoch: [10/10], Step: [201/3125], Training Loss: 0.20675799250602722
Epoch: [10/10], Step: [201/3125], Validation Acc: 68.8
Epoch: [10/10], Step: [301/3125], Training Loss: 0.03291372209787369
Epoch: [10/10], Step: [301/3125], Validation Acc: 69.5
Epoch: [10/10], Step: [401/3125], Training Loss: 0.280093252658844
Epoch: [10/10

Epoch: [1/10], Step: [3001/3125], Training Loss: 0.8070165514945984
Epoch: [1/10], Step: [3001/3125], Validation Acc: 62.3
Epoch: [1/10], Step: [3101/3125], Training Loss: 0.8724780082702637
Epoch: [1/10], Step: [3101/3125], Validation Acc: 61.7
Epoch: [1/10], Step: [3125/3125], Validation Acc: 61.7
Epoch: [2/10], Step: [101/3125], Training Loss: 0.9243180155754089
Epoch: [2/10], Step: [101/3125], Validation Acc: 59.9
Epoch: [2/10], Step: [201/3125], Training Loss: 0.9085099697113037
Epoch: [2/10], Step: [201/3125], Validation Acc: 60.5
Epoch: [2/10], Step: [301/3125], Training Loss: 0.7184205055236816
Epoch: [2/10], Step: [301/3125], Validation Acc: 63.0
Epoch: [2/10], Step: [401/3125], Training Loss: 0.9313076734542847
Epoch: [2/10], Step: [401/3125], Validation Acc: 63.5
Epoch: [2/10], Step: [501/3125], Training Loss: 0.918367862701416
Epoch: [2/10], Step: [501/3125], Validation Acc: 62.4
Epoch: [2/10], Step: [601/3125], Training Loss: 0.7952067255973816
Epoch: [2/10], Step: [601/31

Epoch: [4/10], Step: [301/3125], Training Loss: 0.7213506698608398
Epoch: [4/10], Step: [301/3125], Validation Acc: 66.0
Epoch: [4/10], Step: [401/3125], Training Loss: 0.7795091271400452
Epoch: [4/10], Step: [401/3125], Validation Acc: 64.9
Epoch: [4/10], Step: [501/3125], Training Loss: 0.46231991052627563
Epoch: [4/10], Step: [501/3125], Validation Acc: 64.6
Epoch: [4/10], Step: [601/3125], Training Loss: 0.6382578015327454
Epoch: [4/10], Step: [601/3125], Validation Acc: 65.0
Epoch: [4/10], Step: [701/3125], Training Loss: 0.5422007441520691
Epoch: [4/10], Step: [701/3125], Validation Acc: 66.9
Epoch: [4/10], Step: [801/3125], Training Loss: 0.6628296375274658
Epoch: [4/10], Step: [801/3125], Validation Acc: 62.7
Epoch: [4/10], Step: [901/3125], Training Loss: 0.459025502204895
Epoch: [4/10], Step: [901/3125], Validation Acc: 66.6
Epoch: [4/10], Step: [1001/3125], Training Loss: 0.8276776671409607
Epoch: [4/10], Step: [1001/3125], Validation Acc: 65.6
Epoch: [4/10], Step: [1101/312

Epoch: [6/10], Step: [801/3125], Training Loss: 0.27287396788597107
Epoch: [6/10], Step: [801/3125], Validation Acc: 65.4
Epoch: [6/10], Step: [901/3125], Training Loss: 0.6721771955490112
Epoch: [6/10], Step: [901/3125], Validation Acc: 67.6
Epoch: [6/10], Step: [1001/3125], Training Loss: 0.3439188301563263
Epoch: [6/10], Step: [1001/3125], Validation Acc: 66.8
Epoch: [6/10], Step: [1101/3125], Training Loss: 0.27193683385849
Epoch: [6/10], Step: [1101/3125], Validation Acc: 66.5
Epoch: [6/10], Step: [1201/3125], Training Loss: 0.6010063886642456
Epoch: [6/10], Step: [1201/3125], Validation Acc: 66.3
Epoch: [6/10], Step: [1301/3125], Training Loss: 0.2655787467956543
Epoch: [6/10], Step: [1301/3125], Validation Acc: 67.9
Epoch: [6/10], Step: [1401/3125], Training Loss: 0.3553316593170166
Epoch: [6/10], Step: [1401/3125], Validation Acc: 68.6
Epoch: [6/10], Step: [1501/3125], Training Loss: 0.4718402028083801
Epoch: [6/10], Step: [1501/3125], Validation Acc: 68.7
Epoch: [6/10], Step: 

Epoch: [8/10], Step: [1201/3125], Training Loss: 0.3535829782485962
Epoch: [8/10], Step: [1201/3125], Validation Acc: 67.8
Epoch: [8/10], Step: [1301/3125], Training Loss: 0.37373459339141846
Epoch: [8/10], Step: [1301/3125], Validation Acc: 67.5
Epoch: [8/10], Step: [1401/3125], Training Loss: 0.1906854510307312
Epoch: [8/10], Step: [1401/3125], Validation Acc: 65.5
Epoch: [8/10], Step: [1501/3125], Training Loss: 0.23669490218162537
Epoch: [8/10], Step: [1501/3125], Validation Acc: 65.4
Epoch: [8/10], Step: [1601/3125], Training Loss: 0.27808135747909546
Epoch: [8/10], Step: [1601/3125], Validation Acc: 66.8
Epoch: [8/10], Step: [1701/3125], Training Loss: 0.17098470032215118
Epoch: [8/10], Step: [1701/3125], Validation Acc: 66.9
Epoch: [8/10], Step: [1801/3125], Training Loss: 0.17398062348365784
Epoch: [8/10], Step: [1801/3125], Validation Acc: 66.2
Epoch: [8/10], Step: [1901/3125], Training Loss: 0.24503032863140106
Epoch: [8/10], Step: [1901/3125], Validation Acc: 66.6
Epoch: [8/

Epoch: [10/10], Step: [1601/3125], Training Loss: 0.21051207184791565
Epoch: [10/10], Step: [1601/3125], Validation Acc: 66.4
Epoch: [10/10], Step: [1701/3125], Training Loss: 0.23742994666099548
Epoch: [10/10], Step: [1701/3125], Validation Acc: 66.4
Epoch: [10/10], Step: [1801/3125], Training Loss: 0.13118979334831238
Epoch: [10/10], Step: [1801/3125], Validation Acc: 67.0
Epoch: [10/10], Step: [1901/3125], Training Loss: 0.1503363847732544
Epoch: [10/10], Step: [1901/3125], Validation Acc: 66.4
Epoch: [10/10], Step: [2001/3125], Training Loss: 0.11843923479318619
Epoch: [10/10], Step: [2001/3125], Validation Acc: 66.6
Epoch: [10/10], Step: [2101/3125], Training Loss: 0.19352516531944275
Epoch: [10/10], Step: [2101/3125], Validation Acc: 67.2
Epoch: [10/10], Step: [2201/3125], Training Loss: 0.07209999114274979
Epoch: [10/10], Step: [2201/3125], Validation Acc: 66.7
Epoch: [10/10], Step: [2301/3125], Training Loss: 0.10297583788633347
Epoch: [10/10], Step: [2301/3125], Validation Acc

Epoch: [2/10], Step: [1801/3125], Training Loss: 0.7806349396705627
Epoch: [2/10], Step: [1801/3125], Validation Acc: 67.2
Epoch: [2/10], Step: [1901/3125], Training Loss: 0.9725139141082764
Epoch: [2/10], Step: [1901/3125], Validation Acc: 66.4
Epoch: [2/10], Step: [2001/3125], Training Loss: 0.801888108253479
Epoch: [2/10], Step: [2001/3125], Validation Acc: 66.3
Epoch: [2/10], Step: [2101/3125], Training Loss: 0.7756314873695374
Epoch: [2/10], Step: [2101/3125], Validation Acc: 65.7
Epoch: [2/10], Step: [2201/3125], Training Loss: 0.6322757601737976
Epoch: [2/10], Step: [2201/3125], Validation Acc: 65.2
Epoch: [2/10], Step: [2301/3125], Training Loss: 0.9735460877418518
Epoch: [2/10], Step: [2301/3125], Validation Acc: 66.8
Epoch: [2/10], Step: [2401/3125], Training Loss: 0.6180745363235474
Epoch: [2/10], Step: [2401/3125], Validation Acc: 67.2
Epoch: [2/10], Step: [2501/3125], Training Loss: 0.6397939920425415
Epoch: [2/10], Step: [2501/3125], Validation Acc: 66.7
Epoch: [2/10], St

Epoch: [4/10], Step: [2201/3125], Training Loss: 0.7943367958068848
Epoch: [4/10], Step: [2201/3125], Validation Acc: 66.1
Epoch: [4/10], Step: [2301/3125], Training Loss: 0.3917333781719208
Epoch: [4/10], Step: [2301/3125], Validation Acc: 66.0
Epoch: [4/10], Step: [2401/3125], Training Loss: 0.45872944593429565
Epoch: [4/10], Step: [2401/3125], Validation Acc: 65.4
Epoch: [4/10], Step: [2501/3125], Training Loss: 0.6579786539077759
Epoch: [4/10], Step: [2501/3125], Validation Acc: 66.9
Epoch: [4/10], Step: [2601/3125], Training Loss: 0.5094728469848633
Epoch: [4/10], Step: [2601/3125], Validation Acc: 66.1
Epoch: [4/10], Step: [2701/3125], Training Loss: 0.7638616561889648
Epoch: [4/10], Step: [2701/3125], Validation Acc: 65.8
Epoch: [4/10], Step: [2801/3125], Training Loss: 0.6398367881774902
Epoch: [4/10], Step: [2801/3125], Validation Acc: 64.1
Epoch: [4/10], Step: [2901/3125], Training Loss: 0.43728694319725037
Epoch: [4/10], Step: [2901/3125], Validation Acc: 65.2
Epoch: [4/10],

Epoch: [6/10], Step: [2601/3125], Training Loss: 0.403083860874176
Epoch: [6/10], Step: [2601/3125], Validation Acc: 64.0
Epoch: [6/10], Step: [2701/3125], Training Loss: 0.2835376560688019
Epoch: [6/10], Step: [2701/3125], Validation Acc: 64.7
Epoch: [6/10], Step: [2801/3125], Training Loss: 0.39948907494544983
Epoch: [6/10], Step: [2801/3125], Validation Acc: 62.1
Epoch: [6/10], Step: [2901/3125], Training Loss: 0.22600626945495605
Epoch: [6/10], Step: [2901/3125], Validation Acc: 64.5
Epoch: [6/10], Step: [3001/3125], Training Loss: 0.3750884234905243
Epoch: [6/10], Step: [3001/3125], Validation Acc: 62.6
Epoch: [6/10], Step: [3101/3125], Training Loss: 0.3482931852340698
Epoch: [6/10], Step: [3101/3125], Validation Acc: 63.9
Epoch: [6/10], Step: [3125/3125], Validation Acc: 62.6
Epoch: [7/10], Step: [101/3125], Training Loss: 0.20412448048591614
Epoch: [7/10], Step: [101/3125], Validation Acc: 65.1
Epoch: [7/10], Step: [201/3125], Training Loss: 0.08179338276386261
Epoch: [7/10], S

Epoch: [8/10], Step: [3001/3125], Training Loss: 0.19037380814552307
Epoch: [8/10], Step: [3001/3125], Validation Acc: 63.7
Epoch: [8/10], Step: [3101/3125], Training Loss: 0.06018519774079323
Epoch: [8/10], Step: [3101/3125], Validation Acc: 63.8
Epoch: [8/10], Step: [3125/3125], Validation Acc: 64.2
Epoch: [9/10], Step: [101/3125], Training Loss: 0.04769989103078842
Epoch: [9/10], Step: [101/3125], Validation Acc: 63.8
Epoch: [9/10], Step: [201/3125], Training Loss: 0.15103524923324585
Epoch: [9/10], Step: [201/3125], Validation Acc: 65.2
Epoch: [9/10], Step: [301/3125], Training Loss: 0.06637651473283768
Epoch: [9/10], Step: [301/3125], Validation Acc: 65.0
Epoch: [9/10], Step: [401/3125], Training Loss: 0.07334764301776886
Epoch: [9/10], Step: [401/3125], Validation Acc: 64.4
Epoch: [9/10], Step: [501/3125], Training Loss: 0.22991378605365753
Epoch: [9/10], Step: [501/3125], Validation Acc: 63.8
Epoch: [9/10], Step: [601/3125], Training Loss: 0.06271404027938843
Epoch: [9/10], Step

Epoch: [1/10], Step: [101/3125], Training Loss: 1.0696094036102295
Epoch: [1/10], Step: [101/3125], Validation Acc: 39.8
Epoch: [1/10], Step: [201/3125], Training Loss: 1.0499444007873535
Epoch: [1/10], Step: [201/3125], Validation Acc: 39.0
Epoch: [1/10], Step: [301/3125], Training Loss: 1.046463966369629
Epoch: [1/10], Step: [301/3125], Validation Acc: 44.6
Epoch: [1/10], Step: [401/3125], Training Loss: 0.8904687166213989
Epoch: [1/10], Step: [401/3125], Validation Acc: 45.0
Epoch: [1/10], Step: [501/3125], Training Loss: 1.1183472871780396
Epoch: [1/10], Step: [501/3125], Validation Acc: 49.6
Epoch: [1/10], Step: [601/3125], Training Loss: 1.0601534843444824
Epoch: [1/10], Step: [601/3125], Validation Acc: 52.1
Epoch: [1/10], Step: [701/3125], Training Loss: 1.0433036088943481
Epoch: [1/10], Step: [701/3125], Validation Acc: 53.6
Epoch: [1/10], Step: [801/3125], Training Loss: 1.1123273372650146
Epoch: [1/10], Step: [801/3125], Validation Acc: 55.3
Epoch: [1/10], Step: [901/3125], 

Epoch: [3/10], Step: [601/3125], Training Loss: 0.6212116479873657
Epoch: [3/10], Step: [601/3125], Validation Acc: 62.1
Epoch: [3/10], Step: [701/3125], Training Loss: 0.7828437685966492
Epoch: [3/10], Step: [701/3125], Validation Acc: 64.9
Epoch: [3/10], Step: [801/3125], Training Loss: 0.5846081972122192
Epoch: [3/10], Step: [801/3125], Validation Acc: 62.6
Epoch: [3/10], Step: [901/3125], Training Loss: 0.8454771637916565
Epoch: [3/10], Step: [901/3125], Validation Acc: 61.9
Epoch: [3/10], Step: [1001/3125], Training Loss: 0.7785950303077698
Epoch: [3/10], Step: [1001/3125], Validation Acc: 59.5
Epoch: [3/10], Step: [1101/3125], Training Loss: 0.7738522291183472
Epoch: [3/10], Step: [1101/3125], Validation Acc: 63.5
Epoch: [3/10], Step: [1201/3125], Training Loss: 0.6788663864135742
Epoch: [3/10], Step: [1201/3125], Validation Acc: 62.7
Epoch: [3/10], Step: [1301/3125], Training Loss: 0.7604480385780334
Epoch: [3/10], Step: [1301/3125], Validation Acc: 63.7
Epoch: [3/10], Step: [14

Epoch: [5/10], Step: [1101/3125], Training Loss: 0.48693278431892395
Epoch: [5/10], Step: [1101/3125], Validation Acc: 64.3
Epoch: [5/10], Step: [1201/3125], Training Loss: 0.5080539584159851
Epoch: [5/10], Step: [1201/3125], Validation Acc: 63.6
Epoch: [5/10], Step: [1301/3125], Training Loss: 0.5400339365005493
Epoch: [5/10], Step: [1301/3125], Validation Acc: 62.7
Epoch: [5/10], Step: [1401/3125], Training Loss: 0.4695316553115845
Epoch: [5/10], Step: [1401/3125], Validation Acc: 62.0
Epoch: [5/10], Step: [1501/3125], Training Loss: 0.705337643623352
Epoch: [5/10], Step: [1501/3125], Validation Acc: 63.9
Epoch: [5/10], Step: [1601/3125], Training Loss: 0.3434630334377289
Epoch: [5/10], Step: [1601/3125], Validation Acc: 64.1
Epoch: [5/10], Step: [1701/3125], Training Loss: 0.3929605185985565
Epoch: [5/10], Step: [1701/3125], Validation Acc: 63.1
Epoch: [5/10], Step: [1801/3125], Training Loss: 0.39134684205055237
Epoch: [5/10], Step: [1801/3125], Validation Acc: 64.0
Epoch: [5/10], 

Epoch: [7/10], Step: [1501/3125], Training Loss: 0.25175511837005615
Epoch: [7/10], Step: [1501/3125], Validation Acc: 63.4
Epoch: [7/10], Step: [1601/3125], Training Loss: 0.3762384057044983
Epoch: [7/10], Step: [1601/3125], Validation Acc: 64.1
Epoch: [7/10], Step: [1701/3125], Training Loss: 0.21051424741744995
Epoch: [7/10], Step: [1701/3125], Validation Acc: 63.4
Epoch: [7/10], Step: [1801/3125], Training Loss: 0.3169766664505005
Epoch: [7/10], Step: [1801/3125], Validation Acc: 64.5
Epoch: [7/10], Step: [1901/3125], Training Loss: 0.36879923939704895
Epoch: [7/10], Step: [1901/3125], Validation Acc: 63.2
Epoch: [7/10], Step: [2001/3125], Training Loss: 0.4000057876110077
Epoch: [7/10], Step: [2001/3125], Validation Acc: 63.3
Epoch: [7/10], Step: [2101/3125], Training Loss: 0.4481079578399658
Epoch: [7/10], Step: [2101/3125], Validation Acc: 64.7
Epoch: [7/10], Step: [2201/3125], Training Loss: 0.2826695442199707
Epoch: [7/10], Step: [2201/3125], Validation Acc: 63.4
Epoch: [7/10]

Epoch: [9/10], Step: [1901/3125], Training Loss: 0.14498183131217957
Epoch: [9/10], Step: [1901/3125], Validation Acc: 61.9
Epoch: [9/10], Step: [2001/3125], Training Loss: 0.21536755561828613
Epoch: [9/10], Step: [2001/3125], Validation Acc: 64.7
Epoch: [9/10], Step: [2101/3125], Training Loss: 0.1838529109954834
Epoch: [9/10], Step: [2101/3125], Validation Acc: 65.2
Epoch: [9/10], Step: [2201/3125], Training Loss: 0.09544853121042252
Epoch: [9/10], Step: [2201/3125], Validation Acc: 62.4
Epoch: [9/10], Step: [2301/3125], Training Loss: 0.218521386384964
Epoch: [9/10], Step: [2301/3125], Validation Acc: 63.9
Epoch: [9/10], Step: [2401/3125], Training Loss: 0.21979166567325592
Epoch: [9/10], Step: [2401/3125], Validation Acc: 63.1
Epoch: [9/10], Step: [2501/3125], Training Loss: 0.1318407654762268
Epoch: [9/10], Step: [2501/3125], Validation Acc: 62.1
Epoch: [9/10], Step: [2601/3125], Training Loss: 0.13349178433418274
Epoch: [9/10], Step: [2601/3125], Validation Acc: 64.3
Epoch: [9/10

In [None]:
pkl.dump(df_param,open('df_cnn_cv_correct.pkl','wb'))

In [None]:
RNN_HIDDEN_SIZES = [200,512]
INTERACT_TYPES = ['concat','mul']
KERNEL_SIZES = [3,5]
somelists = [RNN_HIDDEN_SIZES,INTERACT_TYPES]
LIN_HIDDEN_SIZE = 256
result = list(itertools.product(*somelists))
df_param_rnn = pd.DataFrame(result,columns=['hidden_size','interaction_type'])
df_param_rnn['train_loss_hist'] = None
df_param_rnn['val_acc_hist'] = None
df_param_rnn['max_val_acc'] = None
df_param_rnn['max_val_acc_epoch'] = None

learning_rate = 3e-4
num_epochs = 10 # number epoch to train
for param_i in range(len(df_param_rnn)):
    print(df_param_rnn.iloc[param_i])
    RNN_HIDDEN_SIZE = int(df_param_rnn.iloc[param_i]['hidden_size'])
    INTERACT_TYPE = df_param_rnn.iloc[param_i]['interaction_type']
    
    model = RNN(hidden_size=RNN_HIDDEN_SIZE, num_layers=1, vocab_size=len(token2id),weights=weights_mat, bidirectional = True).to(DEVICE)
    classification_network = ClassificationNetwork(num_inputs=RNN_HIDDEN_SIZE, hidden_size=LIN_HIDDEN_SIZE, num_outputs=NUM_CLASSES,num_directions=NUM_DIRECTIONS,interact_type=INTERACT_TYPE).to(DEVICE)
    
    # Criterion and Optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(model.parameters()) + list(classification_network.parameters()), lr=learning_rate)

    # Train the model
    total_step = len(train_loader)
    train_loss_hist = []
    val_acc_hist = []

    for epoch in range(num_epochs):
        for i, (x1,x2,length_x1,length_x2,x1_mask,x2_mask,label) in enumerate(train_loader):
            x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
            model.train()
            optimizer.zero_grad()
            # Forward pass
            outputs_x1 = model(x1, length_x1,x1_mask)
            outputs_x2 = model(x2,length_x2,x2_mask)
            outputs = classification_network(outputs_x1,outputs_x2)
            loss = criterion(outputs, label)

            # Backward and optimize
            loss.backward()
            optimizer.step()
            train_loss_hist.append(loss.item())
            # validate every 100 iterations
            if i > 0 and i % 100 == 0:
                val_acc = test_model(val_loader, model,classification_network=classification_network)
                print('Epoch: [{}/{}], Step: [{}/{}], Training Loss: {}'.format(epoch+1,num_epochs,i+1,len(train_loader),loss.item()))
                val_acc_hist.append(val_acc)
                print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(epoch+1, num_epochs, i+1, len(train_loader), val_acc))
                # validate
        val_acc = test_model(val_loader, model,classification_network=classification_network)
        val_acc_hist.append(val_acc)
        print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                   epoch+1, num_epochs, i+1, len(train_loader), val_acc))


    val_acc_hist = np.array(val_acc_hist)
    max_val_acc = np.max(val_acc_hist)
    max_val_acc_epoch = np.argmax(val_acc_hist)
    #df_param.set_value(i,'train_loss_hist',train_loss_hist)
    df_param_rnn.at[param_i,'train_loss_hist'] = np.array(train_loss_hist)
    
    df_param_rnn.at[param_i,'val_acc_hist'] = val_acc_hist
    df_param_rnn.at[param_i,'max_val_acc'] = max_val_acc
    df_param_rnn.at[param_i,'max_val_acc_epoch'] = max_val_acc_epoch+1
    print(max_val_acc)
    print(max_val_acc_epoch)

hidden_size             200
interaction_type     concat
train_loss_hist        None
val_acc_hist           None
max_val_acc            None
max_val_acc_epoch      None
Name: 0, dtype: object
Epoch: [1/10], Step: [101/3125], Training Loss: 1.4171068668365479
Epoch: [1/10], Step: [101/3125], Validation Acc: 38.3
Epoch: [1/10], Step: [201/3125], Training Loss: 1.0441713333129883
Epoch: [1/10], Step: [201/3125], Validation Acc: 35.6
Epoch: [1/10], Step: [301/3125], Training Loss: 1.2311501502990723
Epoch: [1/10], Step: [301/3125], Validation Acc: 36.1
Epoch: [1/10], Step: [401/3125], Training Loss: 1.143058180809021
Epoch: [1/10], Step: [401/3125], Validation Acc: 39.8
Epoch: [1/10], Step: [501/3125], Training Loss: 1.2052719593048096
Epoch: [1/10], Step: [501/3125], Validation Acc: 37.1
Epoch: [1/10], Step: [601/3125], Training Loss: 1.1479912996292114
Epoch: [1/10], Step: [601/3125], Validation Acc: 42.1
Epoch: [1/10], Step: [701/3125], Training Loss: 1.0620074272155762
Epoch: [1/10], St

Epoch: [3/10], Step: [401/3125], Training Loss: 1.1430144309997559
Epoch: [3/10], Step: [401/3125], Validation Acc: 58.8
Epoch: [3/10], Step: [501/3125], Training Loss: 0.8388009667396545
Epoch: [3/10], Step: [501/3125], Validation Acc: 59.6
Epoch: [3/10], Step: [601/3125], Training Loss: 0.9113153219223022
Epoch: [3/10], Step: [601/3125], Validation Acc: 58.7
Epoch: [3/10], Step: [701/3125], Training Loss: 0.9049214124679565
Epoch: [3/10], Step: [701/3125], Validation Acc: 58.7
Epoch: [3/10], Step: [801/3125], Training Loss: 0.9867337942123413
Epoch: [3/10], Step: [801/3125], Validation Acc: 58.7
Epoch: [3/10], Step: [901/3125], Training Loss: 0.8704606294631958
Epoch: [3/10], Step: [901/3125], Validation Acc: 59.0
Epoch: [3/10], Step: [1001/3125], Training Loss: 0.9325309991836548
Epoch: [3/10], Step: [1001/3125], Validation Acc: 59.2
Epoch: [3/10], Step: [1101/3125], Training Loss: 0.845345139503479
Epoch: [3/10], Step: [1101/3125], Validation Acc: 59.5
Epoch: [3/10], Step: [1201/31

Epoch: [5/10], Step: [901/3125], Training Loss: 0.6167958378791809
Epoch: [5/10], Step: [901/3125], Validation Acc: 63.2
Epoch: [5/10], Step: [1001/3125], Training Loss: 0.870089590549469
Epoch: [5/10], Step: [1001/3125], Validation Acc: 63.2
Epoch: [5/10], Step: [1101/3125], Training Loss: 0.7834392189979553
Epoch: [5/10], Step: [1101/3125], Validation Acc: 64.7
Epoch: [5/10], Step: [1201/3125], Training Loss: 0.982855498790741
Epoch: [5/10], Step: [1201/3125], Validation Acc: 64.7
Epoch: [5/10], Step: [1301/3125], Training Loss: 0.8549488186836243
Epoch: [5/10], Step: [1301/3125], Validation Acc: 65.5
Epoch: [5/10], Step: [1401/3125], Training Loss: 0.7748283743858337
Epoch: [5/10], Step: [1401/3125], Validation Acc: 65.8
Epoch: [5/10], Step: [1501/3125], Training Loss: 0.6787644028663635
Epoch: [5/10], Step: [1501/3125], Validation Acc: 66.1
Epoch: [5/10], Step: [1601/3125], Training Loss: 0.8384294509887695
Epoch: [5/10], Step: [1601/3125], Validation Acc: 66.0
Epoch: [5/10], Step:

Epoch: [7/10], Step: [1401/3125], Training Loss: 0.7283430099487305
Epoch: [7/10], Step: [1401/3125], Validation Acc: 69.2
Epoch: [7/10], Step: [1501/3125], Training Loss: 0.5376429557800293
Epoch: [7/10], Step: [1501/3125], Validation Acc: 66.2
Epoch: [7/10], Step: [1601/3125], Training Loss: 0.6461762189865112
Epoch: [7/10], Step: [1601/3125], Validation Acc: 68.6
Epoch: [7/10], Step: [1701/3125], Training Loss: 0.6749982833862305
Epoch: [7/10], Step: [1701/3125], Validation Acc: 67.1
Epoch: [7/10], Step: [1801/3125], Training Loss: 0.7948722243309021
Epoch: [7/10], Step: [1801/3125], Validation Acc: 67.2
Epoch: [7/10], Step: [1901/3125], Training Loss: 0.5664923787117004
Epoch: [7/10], Step: [1901/3125], Validation Acc: 69.1
Epoch: [7/10], Step: [2001/3125], Training Loss: 0.6454054713249207
Epoch: [7/10], Step: [2001/3125], Validation Acc: 67.8
Epoch: [7/10], Step: [2101/3125], Training Loss: 0.8307227492332458
Epoch: [7/10], Step: [2101/3125], Validation Acc: 69.4
Epoch: [7/10], S

In [None]:
pkl.dump(df_param_rnn,open('df_rnn_cv_correct.pkl','wb'))

## Refitting on the best model 

### CNN

In [59]:
def test_model(loader, model,classification_network):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    total = 0
    model.eval()
    for x1,x2,length_x1,length_x2,x1_mask,x2_mask,label in loader:
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        outputs_x1 = model(x1, length_x1,x1_mask)
        outputs_x2 = model(x2,length_x2,x2_mask)
        outputs = F.softmax(classification_network(outputs_x1,outputs_x2),dim=1)
        predicted = outputs.max(1, keepdim=True)[1]

        total += label.size(0)
        correct += predicted.eq(label.view_as(predicted)).sum().item()
    return (100 * correct / total)

CNN_HIDDEN_SIZE = 512

best_cnn_model = CNN(hidden_size=CNN_HIDDEN_SIZE, num_layers=1, vocab_size=len(token2id),weights=weights_mat).to(DEVICE)
classification_network = ClassificationNetwork(num_inputs=CNN_HIDDEN_SIZE, hidden_size=LIN_HIDDEN_SIZE, num_outputs=NUM_CLASSES,num_directions=1,interact_type='concat').to(DEVICE)
learning_rate = 3e-4
num_epochs = 10 # number epoch to train

# Criterion and Optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(best_cnn_model.parameters())+list(classification_network.parameters()), lr=learning_rate)

# Train the model
total_step = len(train_loader)
train_loss_hist = []
val_acc_hist = []

for epoch in range(num_epochs):
    for i, (x1,x2,length_x1,length_x2,x1_mask,x2_mask,label) in enumerate(train_loader):
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        best_cnn_model.train()
        optimizer.zero_grad()
        # Forward pass
        outputs_x1 = best_cnn_model(x1, length_x1,x1_mask)
        outputs_x2 = best_cnn_model(x2,length_x2,x2_mask)
        outputs = classification_network(outputs_x1,outputs_x2)
        loss = criterion(outputs, label)

        # Backward and optimize
        loss.backward()
        optimizer.step()
        train_loss_hist.append(loss.item())
        # validate every 100 iterations
        if i > 0 and i % 100 == 0:
            val_acc = test_model(val_loader, best_cnn_model,classification_network=classification_network)
            val_acc_hist.append(val_acc)
            print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                       epoch+1, num_epochs, i+1, len(train_loader), val_acc))
            # validate
    
            

val_acc_hist = np.array(val_acc_hist)
max_val_acc = np.max(val_acc_hist)
max_val_acc_epoch = np.argmax(val_acc_hist)
print(max_val_acc)
print(max_val_acc_epoch)

Epoch: [1/10], Step: [101/3125], Validation Acc: 50.6
Epoch: [1/10], Step: [201/3125], Validation Acc: 56.2
Epoch: [1/10], Step: [301/3125], Validation Acc: 57.9
Epoch: [1/10], Step: [401/3125], Validation Acc: 57.0
Epoch: [1/10], Step: [501/3125], Validation Acc: 59.1
Epoch: [1/10], Step: [601/3125], Validation Acc: 59.4
Epoch: [1/10], Step: [701/3125], Validation Acc: 60.9
Epoch: [1/10], Step: [801/3125], Validation Acc: 60.9
Epoch: [1/10], Step: [901/3125], Validation Acc: 62.0
Epoch: [1/10], Step: [1001/3125], Validation Acc: 62.0
Epoch: [1/10], Step: [1101/3125], Validation Acc: 61.2
Epoch: [1/10], Step: [1201/3125], Validation Acc: 60.8
Epoch: [1/10], Step: [1301/3125], Validation Acc: 62.1
Epoch: [1/10], Step: [1401/3125], Validation Acc: 60.8
Epoch: [1/10], Step: [1501/3125], Validation Acc: 59.0
Epoch: [1/10], Step: [1601/3125], Validation Acc: 61.6
Epoch: [1/10], Step: [1701/3125], Validation Acc: 61.7
Epoch: [1/10], Step: [1801/3125], Validation Acc: 62.1
Epoch: [1/10], Step

Epoch: [5/10], Step: [2701/3125], Validation Acc: 67.6
Epoch: [5/10], Step: [2801/3125], Validation Acc: 67.3
Epoch: [5/10], Step: [2901/3125], Validation Acc: 67.2
Epoch: [5/10], Step: [3001/3125], Validation Acc: 68.8
Epoch: [5/10], Step: [3101/3125], Validation Acc: 68.8
Epoch: [6/10], Step: [101/3125], Validation Acc: 70.1
Epoch: [6/10], Step: [201/3125], Validation Acc: 68.6
Epoch: [6/10], Step: [301/3125], Validation Acc: 67.6
Epoch: [6/10], Step: [401/3125], Validation Acc: 67.5
Epoch: [6/10], Step: [501/3125], Validation Acc: 68.5
Epoch: [6/10], Step: [601/3125], Validation Acc: 67.2
Epoch: [6/10], Step: [701/3125], Validation Acc: 67.8
Epoch: [6/10], Step: [801/3125], Validation Acc: 67.8
Epoch: [6/10], Step: [901/3125], Validation Acc: 66.8
Epoch: [6/10], Step: [1001/3125], Validation Acc: 66.4
Epoch: [6/10], Step: [1101/3125], Validation Acc: 67.0
Epoch: [6/10], Step: [1201/3125], Validation Acc: 69.4
Epoch: [6/10], Step: [1301/3125], Validation Acc: 69.2
Epoch: [6/10], Step

Epoch: [10/10], Step: [2201/3125], Validation Acc: 66.6
Epoch: [10/10], Step: [2301/3125], Validation Acc: 68.2
Epoch: [10/10], Step: [2401/3125], Validation Acc: 67.2
Epoch: [10/10], Step: [2501/3125], Validation Acc: 68.4
Epoch: [10/10], Step: [2601/3125], Validation Acc: 67.8
Epoch: [10/10], Step: [2701/3125], Validation Acc: 67.1
Epoch: [10/10], Step: [2801/3125], Validation Acc: 68.0
Epoch: [10/10], Step: [2901/3125], Validation Acc: 66.9
Epoch: [10/10], Step: [3001/3125], Validation Acc: 66.7
Epoch: [10/10], Step: [3101/3125], Validation Acc: 66.8
71.3
122


In [60]:
original_state_dicts_cnn = {
    'embedding_network': best_cnn_model.state_dict(),
    'classification_network': classification_network.state_dict()
}

### RNN

In [64]:
def test_model(loader, model,classification_network):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    total = 0
    model.eval()
    classification_network.eval()
    for x1,x2,length_x1,length_x2,x1_mask,x2_mask,label in loader:
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        outputs_x1 = model(x1, length_x1,x1_mask)
        outputs_x2 = model(x2,length_x2,x2_mask)
        outputs = F.softmax(classification_network(outputs_x1,outputs_x2),dim=1)
        predicted = outputs.max(1, keepdim=True)[1]

        total += label.size(0)
        correct += predicted.eq(label.view_as(predicted)).sum().item()
    return (100 * correct / total)

RNN_HIDDEN_SIZE = 512
INTERACT_TYPE = 'concat'

best_rnn_model = RNN(hidden_size=RNN_HIDDEN_SIZE, num_layers=1, vocab_size=len(token2id),weights=weights_mat, bidirectional = True).to(DEVICE)
classification_network_rnn = ClassificationNetwork(num_inputs=RNN_HIDDEN_SIZE, hidden_size=LIN_HIDDEN_SIZE, num_outputs=NUM_CLASSES,num_directions=NUM_DIRECTIONS,interact_type=INTERACT_TYPE).to(DEVICE)

# Criterion and Optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(best_rnn_model.parameters()) + list(classification_network_rnn.parameters()), lr=learning_rate)

# Train the model
total_step = len(train_loader)
train_loss_hist = []
val_acc_hist = []

for epoch in range(num_epochs):
    for i, (x1,x2,length_x1,length_x2,x1_mask,x2_mask,label) in enumerate(train_loader):
        x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
        best_rnn_model.train()
        classification_network_rnn.train()
        optimizer.zero_grad()
        # Forward pass
        outputs_x1 = best_rnn_model(x1, length_x1,x1_mask)
        outputs_x2 = best_rnn_model(x2,length_x2,x2_mask)
        outputs = classification_network_rnn(outputs_x1,outputs_x2)
        loss = criterion(outputs, label)

        # Backward and optimize
        loss.backward()
        optimizer.step()
        train_loss_hist.append(loss.item())
        # validate every 100 iterations
        if i > 0 and i % 100 == 0:
            val_acc = test_model(val_loader, best_rnn_model,classification_network=classification_network_rnn)
            print('Epoch: [{}/{}], Step: [{}/{}], Training Loss: {}'.format(epoch+1,num_epochs,i+1,len(train_loader),loss.item()))
            val_acc_hist.append(val_acc)
            print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(epoch+1, num_epochs, i+1, len(train_loader), val_acc))
            # validate
    val_acc = test_model(val_loader, best_rnn_model,classification_network=classification_network_rnn)
    val_acc_hist.append(val_acc)
    print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
               epoch+1, num_epochs, i+1, len(train_loader), val_acc))

Epoch: [1/10], Step: [101/3125], Training Loss: 1.1995611190795898
Epoch: [1/10], Step: [101/3125], Validation Acc: 37.6
Epoch: [1/10], Step: [201/3125], Training Loss: 1.1100894212722778
Epoch: [1/10], Step: [201/3125], Validation Acc: 37.4
Epoch: [1/10], Step: [301/3125], Training Loss: 1.0222305059432983
Epoch: [1/10], Step: [301/3125], Validation Acc: 38.4
Epoch: [1/10], Step: [401/3125], Training Loss: 1.0443081855773926
Epoch: [1/10], Step: [401/3125], Validation Acc: 38.9
Epoch: [1/10], Step: [501/3125], Training Loss: 1.1561411619186401
Epoch: [1/10], Step: [501/3125], Validation Acc: 37.7
Epoch: [1/10], Step: [601/3125], Training Loss: 0.9454556107521057
Epoch: [1/10], Step: [601/3125], Validation Acc: 38.7
Epoch: [1/10], Step: [701/3125], Training Loss: 1.0953350067138672
Epoch: [1/10], Step: [701/3125], Validation Acc: 41.1
Epoch: [1/10], Step: [801/3125], Training Loss: 1.1802767515182495
Epoch: [1/10], Step: [801/3125], Validation Acc: 42.0
Epoch: [1/10], Step: [901/3125],

Epoch: [3/10], Step: [601/3125], Training Loss: 0.8916212320327759
Epoch: [3/10], Step: [601/3125], Validation Acc: 57.1
Epoch: [3/10], Step: [701/3125], Training Loss: 0.8576177954673767
Epoch: [3/10], Step: [701/3125], Validation Acc: 58.3
Epoch: [3/10], Step: [801/3125], Training Loss: 0.6413313150405884
Epoch: [3/10], Step: [801/3125], Validation Acc: 58.8
Epoch: [3/10], Step: [901/3125], Training Loss: 0.8267204761505127
Epoch: [3/10], Step: [901/3125], Validation Acc: 57.5
Epoch: [3/10], Step: [1001/3125], Training Loss: 0.7518324255943298
Epoch: [3/10], Step: [1001/3125], Validation Acc: 59.5
Epoch: [3/10], Step: [1101/3125], Training Loss: 0.9723003506660461
Epoch: [3/10], Step: [1101/3125], Validation Acc: 56.9
Epoch: [3/10], Step: [1201/3125], Training Loss: 0.7372439503669739
Epoch: [3/10], Step: [1201/3125], Validation Acc: 57.8
Epoch: [3/10], Step: [1301/3125], Training Loss: 0.9425835013389587
Epoch: [3/10], Step: [1301/3125], Validation Acc: 58.6
Epoch: [3/10], Step: [14

Epoch: [5/10], Step: [1101/3125], Training Loss: 0.9302197694778442
Epoch: [5/10], Step: [1101/3125], Validation Acc: 64.9
Epoch: [5/10], Step: [1201/3125], Training Loss: 0.7532354593276978
Epoch: [5/10], Step: [1201/3125], Validation Acc: 65.0
Epoch: [5/10], Step: [1301/3125], Training Loss: 0.7656674981117249
Epoch: [5/10], Step: [1301/3125], Validation Acc: 63.9
Epoch: [5/10], Step: [1401/3125], Training Loss: 0.7302685379981995
Epoch: [5/10], Step: [1401/3125], Validation Acc: 64.9
Epoch: [5/10], Step: [1501/3125], Training Loss: 0.6648122668266296
Epoch: [5/10], Step: [1501/3125], Validation Acc: 63.9
Epoch: [5/10], Step: [1601/3125], Training Loss: 0.6083792448043823
Epoch: [5/10], Step: [1601/3125], Validation Acc: 64.0
Epoch: [5/10], Step: [1701/3125], Training Loss: 0.8468359708786011
Epoch: [5/10], Step: [1701/3125], Validation Acc: 64.2
Epoch: [5/10], Step: [1801/3125], Training Loss: 0.896939754486084
Epoch: [5/10], Step: [1801/3125], Validation Acc: 63.6
Epoch: [5/10], St

Epoch: [7/10], Step: [1601/3125], Training Loss: 0.6373110413551331
Epoch: [7/10], Step: [1601/3125], Validation Acc: 69.1
Epoch: [7/10], Step: [1701/3125], Training Loss: 0.7456009984016418
Epoch: [7/10], Step: [1701/3125], Validation Acc: 68.5
Epoch: [7/10], Step: [1801/3125], Training Loss: 0.6611131429672241
Epoch: [7/10], Step: [1801/3125], Validation Acc: 69.5
Epoch: [7/10], Step: [1901/3125], Training Loss: 0.7957515716552734
Epoch: [7/10], Step: [1901/3125], Validation Acc: 68.6
Epoch: [7/10], Step: [2001/3125], Training Loss: 0.7345466613769531
Epoch: [7/10], Step: [2001/3125], Validation Acc: 68.3
Epoch: [7/10], Step: [2101/3125], Training Loss: 0.6303751468658447
Epoch: [7/10], Step: [2101/3125], Validation Acc: 67.9
Epoch: [7/10], Step: [2201/3125], Training Loss: 0.6335746645927429
Epoch: [7/10], Step: [2201/3125], Validation Acc: 67.9
Epoch: [7/10], Step: [2301/3125], Training Loss: 0.7813224792480469
Epoch: [7/10], Step: [2301/3125], Validation Acc: 68.2
Epoch: [7/10], S

Epoch: [9/10], Step: [2101/3125], Training Loss: 0.5182740688323975
Epoch: [9/10], Step: [2101/3125], Validation Acc: 71.8
Epoch: [9/10], Step: [2201/3125], Training Loss: 0.5161539912223816
Epoch: [9/10], Step: [2201/3125], Validation Acc: 69.8
Epoch: [9/10], Step: [2301/3125], Training Loss: 0.6029080152511597
Epoch: [9/10], Step: [2301/3125], Validation Acc: 72.5
Epoch: [9/10], Step: [2401/3125], Training Loss: 0.7643371224403381
Epoch: [9/10], Step: [2401/3125], Validation Acc: 70.4
Epoch: [9/10], Step: [2501/3125], Training Loss: 0.5227996110916138
Epoch: [9/10], Step: [2501/3125], Validation Acc: 70.5
Epoch: [9/10], Step: [2601/3125], Training Loss: 0.6224384307861328
Epoch: [9/10], Step: [2601/3125], Validation Acc: 70.8
Epoch: [9/10], Step: [2701/3125], Training Loss: 0.6552159786224365
Epoch: [9/10], Step: [2701/3125], Validation Acc: 69.9
Epoch: [9/10], Step: [2801/3125], Training Loss: 0.7933003306388855
Epoch: [9/10], Step: [2801/3125], Validation Acc: 72.2
Epoch: [9/10], S

In [74]:
max(val_acc_hist)

72.5

In [65]:
original_state_dicts_rnn = {
    'embedding_network': best_rnn_model.state_dict(),
    'classification_network': classification_network_rnn.state_dict()
}

# MNLI

## MNLI Validation Accuracy

In [37]:
mnli_train_data = pd.read_csv('hw2_data/mnli_train.tsv',sep='\t')
mnli_val_data = pd.read_csv('hw2_data/mnli_val.tsv',sep='\t')


In [38]:
mnli_val_data.head()

Unnamed: 0,sentence1,sentence2,label,genre
0,"'Not entirely , ' I snapped , harsher than int...",I spoke more harshly than I wanted to .,entailment,fiction
1,cook and then the next time it would be my tur...,I would cook and then the next turn would be h...,contradiction,telephone
2,The disorder hardly seemed to exist before the...,The disorder did n't seem to be as common when...,entailment,slate
3,"The Report and Order , in large part , adopts ...",The Report and Order ignores recommendations f...,contradiction,government
4,"IDPA 's OIG 's mission is to prevent , detect ...",IDPA 's OIG 's mission is clear and cares abou...,entailment,government


In [40]:
label_map = {'contradiction':0, 'entailment':2, 'neutral':1}
mnli_train_data.replace({'label':label_map},inplace=True)
mnli_val_data.replace({'label':label_map},inplace=True)

In [41]:
mnli_train_data = prepare_data(mnli_train_data)
mnli_val_data = prepare_data(mnli_val_data)

In [73]:
mnli_val_data.head()

Unnamed: 0,sentence1,sentence2,label,genre
0,"['Not, entirely, ,, ', I, snapped, ,, harsher,...","[I, spoke, more, harshly, than, I, wanted, to, .]",2,fiction
1,"[cook, and, then, the, next, time, it, would, ...","[I, would, cook, and, then, the, next, turn, w...",0,telephone
2,"[The, disorder, hardly, seemed, to, exist, bef...","[The, disorder, did, n't, seem, to, be, as, co...",2,slate
3,"[The, Report, and, Order, ,, in, large, part, ...","[The, Report, and, Order, ignores, recommendat...",0,government
4,"[IDPA, 's, OIG, 's, mission, is, to, prevent, ...","[IDPA, 's, OIG, 's, mission, is, clear, and, c...",2,government


In [43]:
genres=mnli_val_data['genre'].unique()

In [66]:
val_acc_rnn = []
val_acc_cnn = []

In [None]:
#the data loader and the vocab remains same for the mnli dataset

train_dataset = SNLIDataset(train_data, token2id)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=snli_collate_func,
                                           shuffle=True)

In [67]:
for i in range(len(genres)):
    val_dataset_mnli = SNLIDataset(mnli_val_data[(mnli_val_data['genre'] == genres[i])][['sentence1','sentence2','label']], token2id)
    val_loader_mnli = torch.utils.data.DataLoader(dataset=val_dataset_mnli,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=snli_collate_func,
                                               shuffle=True)
    val_acc_cnn.append(test_model(val_loader_mnli, best_cnn_model,classification_network=classification_network))
    val_acc_rnn.append(test_model(val_loader_mnli, best_rnn_model,classification_network=classification_network_rnn))

In [68]:
val_acc_cnn

[44.92462311557789,
 41.791044776119406,
 41.616766467065865,
 41.732283464566926,
 41.24236252545825]

In [69]:
val_acc_rnn

[44.120603015075375,
 44.17910447761194,
 42.21556886227545,
 47.539370078740156,
 43.686354378818734]

In [70]:
df_mnli_acc = pd.DataFrame({'Genre':genres , 'accuracy_cnn':val_acc_cnn, 'accuracy_rnn':val_acc_rnn}) 

In [71]:
df_mnli_acc

Unnamed: 0,Genre,accuracy_cnn,accuracy_rnn
0,fiction,44.924623,44.120603
1,telephone,41.791045,44.179104
2,slate,41.616766,42.215569
3,government,41.732283,47.53937
4,travel,41.242363,43.686354


In [72]:
print(df_mnli_acc.to_latex(index=False))

\begin{tabular}{lrr}
\toprule
      Genre &  accuracy\_cnn &  accuracy\_rnn \\
\midrule
    fiction &     44.924623 &     44.120603 \\
  telephone &     41.791045 &     44.179104 \\
      slate &     41.616766 &     42.215569 \\
 government &     41.732283 &     47.539370 \\
     travel &     41.242363 &     43.686354 \\
\bottomrule
\end{tabular}



## MNLI Fine Tuning

In [97]:
df_mnli_ft_val_acc = pd.DataFrame({'Genre':genres})

In [98]:
for i in range(len(genres)):
    df_mnli_ft_val_acc[genres[i]] = None

In [99]:
df_mnli_ft_val_acc

Unnamed: 0,Genre,fiction,telephone,slate,government,travel
0,fiction,,,,,
1,telephone,,,,,
2,slate,,,,,
3,government,,,,,
4,travel,,,,,


In [79]:
import copy

In [84]:
classification_network_cnn = copy.deepcopy(classification_network)

In [100]:
num_epochs = 2
for genre_i in range(len(genres)):
    print('Genre: {}'.format(genres[genre_i]))
    train_dataset_mnli = SNLIDataset(mnli_train_data[(mnli_train_data['genre'] == genres[genre_i])][['sentence1','sentence2','label']], token2id)
    train_loader_mnli = torch.utils.data.DataLoader(dataset=train_dataset_mnli,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=snli_collate_func,
                                           shuffle=True)
    val_dataset_mnli = SNLIDataset(mnli_val_data[(mnli_val_data['genre'] == genres[genre_i])][['sentence1','sentence2','label']], token2id)
    val_loader_mnli = torch.utils.data.DataLoader(dataset=val_dataset_mnli,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=snli_collate_func,
                                               shuffle=True)
    
    best_rnn_model.load_state_dict(original_state_dicts_rnn['embedding_network'])
    classification_network_rnn.load_state_dict(original_state_dicts_rnn['classification_network'])
    # Criterion and Optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(best_rnn_model.parameters()) + list(classification_network_rnn.parameters()), lr=learning_rate)
    
    train_loss_hist = []
    val_acc_hist = []

    for epoch in range(num_epochs):
        for i, (x1,x2,length_x1,length_x2,x1_mask,x2_mask,label) in enumerate(train_loader_mnli):
            x1,x2,x1_mask,x2_mask,label = x1.to(DEVICE),x2.to(DEVICE),x1_mask.to(DEVICE),x2_mask.to(DEVICE),label.to(DEVICE)
            best_rnn_model.train()
            classification_network_rnn.train()
            optimizer.zero_grad()
            # Forward pass
            outputs_x1 = best_rnn_model(x1, length_x1,x1_mask)
            outputs_x2 = best_rnn_model(x2,length_x2,x2_mask)
            outputs = classification_network_rnn(outputs_x1,outputs_x2)
            loss = criterion(outputs, label)

            # Backward and optimize
            loss.backward()
            optimizer.step()
            train_loss_hist.append(loss.item())
            # validate every 100 iterations
            if i > 0 and i % 100 == 0:
                val_acc = test_model(val_loader_mnli, best_rnn_model,classification_network=classification_network_rnn)
                print('Epoch: [{}/{}], Step: [{}/{}], Training Loss: {}'.format(epoch+1,num_epochs,i+1,len(train_loader_mnli),loss.item()))
                val_acc_hist.append(val_acc)
                print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(epoch+1, num_epochs, i+1, len(train_loader_mnli), val_acc))
                # validate
        val_acc = test_model(val_loader_mnli, best_rnn_model,classification_network=classification_network_rnn)
        val_acc_hist.append(val_acc)
        print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                   epoch+1, num_epochs, i+1, len(train_loader_mnli), val_acc))
        
    for genre_j in range(len(genres)):
        val_dataset_mnli = SNLIDataset(mnli_val_data[(mnli_val_data['genre'] == genres[genre_j])][['sentence1','sentence2','label']], token2id)
        val_loader_mnli = torch.utils.data.DataLoader(dataset=val_dataset_mnli,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=snli_collate_func,
                                               shuffle=True)
        df_mnli_ft_val_acc.at[genre_i,genres[genre_j]] = test_model(val_loader_mnli, best_rnn_model,classification_network=classification_network_rnn)
    
    
    
    

Genre: fiction
Epoch: [1/2], Step: [101/120], Training Loss: 0.8693899512290955
Epoch: [1/2], Step: [101/120], Validation Acc: 52.76381909547739
Epoch: [1/2], Step: [120/120], Validation Acc: 52.96482412060301
Epoch: [2/2], Step: [101/120], Training Loss: 0.7448746562004089
Epoch: [2/2], Step: [101/120], Validation Acc: 55.778894472361806
Epoch: [2/2], Step: [120/120], Validation Acc: 53.969849246231156
Genre: telephone
Epoch: [1/2], Step: [101/134], Training Loss: 0.905422031879425
Epoch: [1/2], Step: [101/134], Validation Acc: 54.527363184079604
Epoch: [1/2], Step: [134/134], Validation Acc: 56.417910447761194
Epoch: [2/2], Step: [101/134], Training Loss: 0.7695083618164062
Epoch: [2/2], Step: [101/134], Validation Acc: 56.91542288557214
Epoch: [2/2], Step: [134/134], Validation Acc: 55.52238805970149
Genre: slate
Epoch: [1/2], Step: [101/126], Training Loss: 1.0679054260253906
Epoch: [1/2], Step: [101/126], Validation Acc: 50.99800399201597
Epoch: [1/2], Step: [126/126], Validation 

In [90]:
len(train_loader_mnli.dataset)

3836

In [101]:
df_mnli_ft_val_acc

Unnamed: 0,Genre,fiction,telephone,slate,government,travel
0,fiction,55.1759,54.9254,49.8004,54.0354,50.2037
1,telephone,55.3769,57.1144,50.7984,54.2323,51.6293
2,slate,54.9749,53.3333,51.2974,55.4134,51.3238
3,government,53.9698,54.2289,50.6986,56.6929,52.0367
4,travel,52.8643,55.1244,50.1996,56.7913,53.4623


In [102]:
print(df_mnli_ft_val_acc.to_latex(index=False))

\begin{tabular}{llllll}
\toprule
      Genre &  fiction & telephone &    slate & government &   travel \\
\midrule
    fiction &  55.1759 &   54.9254 &  49.8004 &    54.0354 &  50.2037 \\
  telephone &  55.3769 &   57.1144 &  50.7984 &    54.2323 &  51.6293 \\
      slate &  54.9749 &   53.3333 &  51.2974 &    55.4134 &  51.3238 \\
 government &  53.9698 &   54.2289 &  50.6986 &    56.6929 &  52.0367 \\
     travel &  52.8643 &   55.1244 &  50.1996 &    56.7913 &  53.4623 \\
\bottomrule
\end{tabular}

