In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torch.optim import Adam
from torchvision.transforms import ToTensor
import statistics as stats
from pathlib import Path

from DIM.classification_stats import precision
from DIM.models import DeepInfoAsLatent, Encoder, Classifier

In [2]:
#!mkdir data data/tv models models/run1
cur_dir = '/notebooks/DockerShared/MINE'
data_dir = '/notebooks/DockerShared/MINE/data/tv'
root = Path(cur_dir)
model_path = root / Path('models/run1')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128
num_classes = 10
fully_supervised = False
reload = 0
epochs = 100
# image size 3, 32, 32; batch size must be an even number; shuffle must be True
ds = CIFAR10(data_dir, download=True, transform=ToTensor())
len_train = len(ds) // 10 * 9
len_test = len(ds) - len_train
train, test = random_split(ds, [len_train, len_test])
train_l = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=True)
test_l = DataLoader(test, batch_size=batch_size, shuffle=True, drop_last=True)

Files already downloaded and verified


In [3]:
if fully_supervised:
    classifier = nn.Sequential(Encoder(), Classifier()).to(device)
else:
    classifier = DeepInfoAsLatent('run1', '990').to(device)
    if reload > 0:
        classifier = torch.load(model_path / Path(f'w_dim{reload}.mdl'))
optim = Adam(classifier.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [4]:
for epoch in range(reload + 1, reload + epochs):
    ll = []
    batch = tqdm(train_l, total=len_train // batch_size)
    for x, target in batch:
        optim.zero_grad()
        x, target = x.to(device), target.to(device)
        y = classifier(x)
        loss = criterion(y, target)
        ll.append(loss.detach().item())
        batch.set_description(f'{epoch} Train Loss: {stats.mean(ll)}')
        loss.backward()
        optim.step()

    confusion = torch.zeros(num_classes, num_classes)
    batch = tqdm(test_l, total=len_test // batch_size)
    ll = []
    for x, target in batch:
        x, target = x.to(device), target.to(device)
        y = classifier(x)
        loss = criterion(y, target)
        ll.append(loss.detach().item())
        batch.set_description(f'{epoch} Test Loss: {stats.mean(ll)}')
        _, predicted = y.detach().max(1)
        for item in zip(predicted, target):
            confusion[item[0], item[1]] += 1

    precis = precision(confusion)
    print(precis)
    classifier_path = model_path / Path('w_dim' + str(epoch) + '.mdl')
    classifier_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(classifier, str(classifier_path))

1 Train Loss: 2.280734795790452: 100%|██████████| 351/351 [00:21<00:00, 18.38it/s] 
1 Test Loss: 2.2570445354168234: 100%|██████████| 39/39 [00:02<00:00, 15.79it/s]


(tensor([0.6699, 0.6224, 0.0408, 0.0000, 0.0082, 0.3645, 0.6968, 0.0020, 0.0000,
        0.0000]), 0.24419070512820512)


2 Train Loss: 2.2400303403196853: 100%|██████████| 351/351 [00:20<00:00, 17.48it/s]
2 Test Loss: 2.225627581278483: 100%|██████████| 39/39 [00:02<00:00, 14.18it/s] 


(tensor([0.7033, 0.7245, 0.0204, 0.0000, 0.0348, 0.6248, 0.8028, 0.0138, 0.0000,
        0.0000]), 0.296474358974359)


3 Train Loss: 2.2147052654853234: 100%|██████████| 351/351 [00:20<00:00, 17.55it/s]
3 Test Loss: 2.2046985565087733: 100%|██████████| 39/39 [00:02<00:00, 15.28it/s]


(tensor([0.7221, 0.7606, 0.0082, 0.0000, 0.0784, 0.6853, 0.8193, 0.0614, 0.0000,
        0.0000]), 0.31810897435897434)


4 Train Loss: 2.197145465432409: 100%|██████████| 351/351 [00:20<00:00, 17.02it/s] 
4 Test Loss: 2.18870681982774: 100%|██████████| 39/39 [00:02<00:00, 13.89it/s]  


(tensor([0.7397, 0.7761, 0.0000, 0.0000, 0.1499, 0.7191, 0.8266, 0.1381, 0.0020,
        0.0020]), 0.33974358974358976)


5 Train Loss: 2.1826729672586818: 100%|██████████| 351/351 [00:20<00:00, 17.07it/s]
5 Test Loss: 2.176142007876665: 100%|██████████| 39/39 [00:02<00:00, 15.94it/s] 


(tensor([0.7510, 0.7919, 0.0000, 0.0000, 0.2074, 0.7191, 0.8310, 0.2411, 0.0040,
        0.0102]), 0.36017628205128205)


6 Train Loss: 2.170157556180601: 100%|██████████| 351/351 [00:20<00:00, 16.78it/s] 
6 Test Loss: 2.1637723996089053: 100%|██████████| 39/39 [00:02<00:00, 15.01it/s]


(tensor([0.7623, 0.7823, 0.0000, 0.0000, 0.2500, 0.7265, 0.8206, 0.3294, 0.0139,
        0.0343]), 0.3762019230769231)


7 Train Loss: 2.1594058720134943: 100%|██████████| 351/351 [00:21<00:00, 16.65it/s]
7 Test Loss: 2.153832753499349: 100%|██████████| 39/39 [00:02<00:00, 15.11it/s] 


(tensor([0.7686, 0.7765, 0.0000, 0.0000, 0.2731, 0.7505, 0.8246, 0.4150, 0.0378,
        0.0707]), 0.39603365384615385)


8 Train Loss: 2.149028492109728: 100%|██████████| 351/351 [00:20<00:00, 17.06it/s] 
8 Test Loss: 2.1445351258302345: 100%|██████████| 39/39 [00:02<00:00, 14.15it/s]


(tensor([0.7623, 0.7669, 0.0000, 0.0000, 0.3340, 0.7331, 0.8008, 0.4655, 0.0580,
        0.0972]), 0.40625)


9 Train Loss: 2.1390339268578424: 100%|██████████| 351/351 [00:19<00:00, 17.75it/s]
9 Test Loss: 2.133167951534956: 100%|██████████| 39/39 [00:02<00:00, 14.75it/s] 


(tensor([0.7525, 0.7630, 0.0020, 0.0000, 0.3696, 0.7285, 0.7912, 0.5325, 0.1178,
        0.1232]), 0.42247596153846156)


10 Train Loss: 2.1275463070285285: 100%|██████████| 351/351 [00:20<00:00, 17.09it/s]
10 Test Loss: 2.1212504704793296: 100%|██████████| 39/39 [00:03<00:00, 12.98it/s]


(tensor([0.7569, 0.7934, 0.0000, 0.0000, 0.3984, 0.7465, 0.7791, 0.5415, 0.1849,
        0.1417]), 0.4387019230769231)


11 Train Loss: 2.1158440038349555: 100%|██████████| 351/351 [00:20<00:00, 17.13it/s]
11 Test Loss: 2.109738826751709: 100%|██████████| 39/39 [00:02<00:00, 15.40it/s] 


(tensor([0.7476, 0.7722, 0.0000, 0.0000, 0.4242, 0.7490, 0.7807, 0.5754, 0.2286,
        0.1542]), 0.44751602564102566)


12 Train Loss: 2.1043467046188833: 100%|██████████| 351/351 [00:20<00:00, 17.32it/s]
12 Test Loss: 2.0995103395902195: 100%|██████████| 39/39 [00:02<00:00, 15.91it/s]


(tensor([0.7490, 0.7592, 0.0000, 0.0000, 0.4590, 0.7311, 0.7651, 0.5870, 0.2749,
        0.1805]), 0.4551282051282051)


13 Train Loss: 2.0940132596214274: 100%|██████████| 351/351 [00:20<00:00, 16.98it/s]
13 Test Loss: 2.0904679115001974: 100%|██████████| 39/39 [00:02<00:00, 14.47it/s]


(tensor([0.7250, 0.7582, 0.0000, 0.0000, 0.4795, 0.7470, 0.7646, 0.5968, 0.3406,
        0.2004]), 0.4651442307692308)


14 Train Loss: 2.0836738860844886: 100%|██████████| 351/351 [00:20<00:00, 17.23it/s]
14 Test Loss: 2.0816823030129457: 100%|██████████| 39/39 [00:02<00:00, 15.06it/s]


(tensor([0.7235, 0.7394, 0.0000, 0.0000, 0.4795, 0.7289, 0.7631, 0.5968, 0.4084,
        0.2242]), 0.4701522435897436)


15 Train Loss: 2.074351202049147: 100%|██████████| 351/351 [00:19<00:00, 17.80it/s] 
15 Test Loss: 2.071556488672892: 100%|██████████| 39/39 [00:02<00:00, 14.02it/s] 


(tensor([0.6988, 0.7457, 0.0000, 0.0000, 0.5041, 0.7365, 0.7711, 0.5889, 0.4394,
        0.2551]), 0.4779647435897436)


16 Train Loss: 2.0658713728614004: 100%|██████████| 351/351 [00:20<00:00, 17.24it/s]
16 Test Loss: 2.0641940465340247: 100%|██████████| 39/39 [00:02<00:00, 14.56it/s]


(tensor([0.6961, 0.7447, 0.0000, 0.0000, 0.5195, 0.7430, 0.7731, 0.6134, 0.4542,
        0.2949]), 0.4879807692307692)


17 Train Loss: 2.056550872631562: 100%|██████████| 351/351 [00:20<00:00, 17.41it/s] 
17 Test Loss: 2.05771533648173: 100%|██████████| 39/39 [00:02<00:00, 15.45it/s]  


(tensor([0.6745, 0.7201, 0.0000, 0.0000, 0.5113, 0.7371, 0.7782, 0.6095, 0.4761,
        0.3279]), 0.48717948717948717)


18 Train Loss: 2.0481400843019837: 100%|██████████| 351/351 [00:20<00:00, 17.00it/s]
18 Test Loss: 2.047761776508429: 100%|██████████| 39/39 [00:02<00:00, 15.04it/s] 


(tensor([0.6751, 0.7201, 0.0000, 0.0000, 0.5205, 0.7565, 0.7560, 0.6482, 0.5150,
        0.3859]), 0.5014022435897436)


19 Train Loss: 2.040683647845885: 100%|██████████| 351/351 [00:20<00:00, 17.42it/s] 
19 Test Loss: 2.039140453705421: 100%|██████████| 39/39 [00:02<00:00, 14.77it/s] 


(tensor([0.6686, 0.6898, 0.0000, 0.0000, 0.5216, 0.7325, 0.7646, 0.6304, 0.5200,
        0.4263]), 0.4987980769230769)


20 Train Loss: 2.031848164365502: 100%|██████████| 351/351 [00:20<00:00, 17.26it/s] 
20 Test Loss: 2.0327006792410827: 100%|██████████| 39/39 [00:02<00:00, 15.24it/s]


(tensor([0.6804, 0.6815, 0.0000, 0.0000, 0.5471, 0.7131, 0.7626, 0.6554, 0.5677,
        0.4323]), 0.5076121794871795)


21 Train Loss: 2.024742018802893: 100%|██████████| 351/351 [00:20<00:00, 17.20it/s] 
21 Test Loss: 2.0265922393554296: 100%|██████████| 39/39 [00:02<00:00, 15.88it/s]


(tensor([0.6556, 0.6724, 0.0000, 0.0000, 0.5339, 0.7385, 0.7726, 0.6391, 0.5785,
        0.4462]), 0.5074118589743589)


22 Train Loss: 2.0173857113574645: 100%|██████████| 351/351 [00:20<00:00, 17.08it/s]
22 Test Loss: 2.018518704634446: 100%|██████████| 39/39 [00:02<00:00, 14.94it/s] 


(tensor([0.6784, 0.6911, 0.0000, 0.0000, 0.5359, 0.7365, 0.7621, 0.6548, 0.5726,
        0.4757]), 0.5142227564102564)


23 Train Loss: 2.0109482011903723: 100%|██████████| 351/351 [00:20<00:00, 17.37it/s]
23 Test Loss: 2.013261669721359: 100%|██████████| 39/39 [00:02<00:00, 15.38it/s] 


(tensor([0.6804, 0.6892, 0.0000, 0.0000, 0.5307, 0.7405, 0.7490, 0.6469, 0.5936,
        0.4970]), 0.5166266025641025)


24 Train Loss: 2.0041914838671344: 100%|██████████| 351/351 [00:20<00:00, 17.12it/s]
24 Test Loss: 2.006326051858755: 100%|██████████| 39/39 [00:02<00:00, 14.47it/s] 


(tensor([0.6751, 0.6776, 0.0000, 0.0000, 0.5430, 0.7445, 0.7490, 0.6601, 0.5976,
        0.5121]), 0.5196314102564102)


25 Train Loss: 1.9992206697790031: 100%|██████████| 351/351 [00:20<00:00, 17.09it/s]
25 Test Loss: 2.00230509806902: 100%|██████████| 39/39 [00:02<00:00, 15.14it/s]  


(tensor([0.6725, 0.6602, 0.0000, 0.0000, 0.5430, 0.7390, 0.7586, 0.6416, 0.6044,
        0.5142]), 0.5168269230769231)


26 Train Loss: 1.9935292328185166: 100%|██████████| 351/351 [00:20<00:00, 17.08it/s]
26 Test Loss: 1.9983911361449804: 100%|██████████| 39/39 [00:02<00:00, 15.88it/s]


(tensor([0.6660, 0.6429, 0.0000, 0.0000, 0.5729, 0.7226, 0.7430, 0.6535, 0.6083,
        0.5434]), 0.5184294871794872)


27 Train Loss: 1.9891207340436103: 100%|██████████| 351/351 [00:20<00:00, 16.84it/s]
27 Test Loss: 1.992127504104223: 100%|██████████| 39/39 [00:02<00:00, 15.87it/s] 


(tensor([0.6686, 0.6609, 0.0020, 0.0000, 0.5629, 0.7285, 0.7505, 0.6568, 0.6315,
        0.5283]), 0.5222355769230769)


28 Train Loss: 1.9844760453259502: 100%|██████████| 351/351 [00:19<00:00, 17.93it/s]
28 Test Loss: 1.9863940416238246: 100%|██████████| 39/39 [00:02<00:00, 15.41it/s]


(tensor([0.6497, 0.6724, 0.0041, 0.0000, 0.5533, 0.7305, 0.7565, 0.6561, 0.6400,
        0.5598]), 0.5254407051282052)


29 Train Loss: 1.9798322477911272: 100%|██████████| 351/351 [00:21<00:00, 16.56it/s]
29 Test Loss: 1.9854371975629757: 100%|██████████| 39/39 [00:02<00:00, 14.60it/s]


(tensor([0.6667, 0.6718, 0.0000, 0.0000, 0.5697, 0.7249, 0.7545, 0.6410, 0.6243,
        0.5688]), 0.5252403846153846)


30 Train Loss: 1.975229626707202: 100%|██████████| 351/351 [00:20<00:00, 17.11it/s] 
30 Test Loss: 1.980916976928711: 100%|██████████| 39/39 [00:02<00:00, 14.61it/s] 


(tensor([0.6608, 0.6564, 0.0061, 0.0000, 0.5881, 0.7435, 0.7590, 0.6627, 0.6355,
        0.5619]), 0.530448717948718)


31 Train Loss: 1.9724979509315599: 100%|██████████| 351/351 [00:20<00:00, 17.38it/s]
31 Test Loss: 1.9764012709642067: 100%|██████████| 39/39 [00:02<00:00, 15.09it/s]


(tensor([0.6575, 0.6641, 0.0102, 0.0000, 0.5717, 0.7351, 0.7525, 0.6614, 0.6660,
        0.5587]), 0.5308493589743589)


32 Train Loss: 1.968817655517165: 100%|██████████| 351/351 [00:20<00:00, 16.77it/s] 
32 Test Loss: 1.9759615415181868: 100%|██████████| 39/39 [00:02<00:00, 15.89it/s]


(tensor([0.6634, 0.6255, 0.0204, 0.0000, 0.5779, 0.7206, 0.7339, 0.6627, 0.6500,
        0.5657]), 0.5250400641025641)


33 Train Loss: 1.9657165110960306: 100%|██████████| 351/351 [00:20<00:00, 17.02it/s]
33 Test Loss: 1.9690268131402822: 100%|██████████| 39/39 [00:02<00:00, 15.28it/s]


(tensor([0.6569, 0.6602, 0.0225, 0.0000, 0.5615, 0.7345, 0.7550, 0.6522, 0.6494,
        0.5810]), 0.5306490384615384)


34 Train Loss: 1.9624992381473552: 100%|██████████| 351/351 [00:20<00:00, 17.37it/s]
34 Test Loss: 1.9673862335009453: 100%|██████████| 39/39 [00:02<00:00, 14.36it/s]


(tensor([0.6419, 0.6609, 0.0449, 0.0000, 0.5893, 0.7291, 0.7586, 0.6548, 0.6534,
        0.5723]), 0.5336538461538461)


35 Train Loss: 1.9597769768489393: 100%|██████████| 351/351 [00:21<00:00, 16.42it/s]
35 Test Loss: 1.964676557443081: 100%|██████████| 39/39 [00:02<00:00, 14.26it/s] 


(tensor([0.6490, 0.6486, 0.0449, 0.0000, 0.5749, 0.7331, 0.7530, 0.6607, 0.6733,
        0.5842]), 0.5354567307692307)


36 Train Loss: 1.9558342138246934: 100%|██████████| 351/351 [00:19<00:00, 17.59it/s]
36 Test Loss: 1.9636712624476507: 100%|██████████| 39/39 [00:02<00:00, 15.96it/s]


(tensor([0.6419, 0.6551, 0.0675, 0.0000, 0.5811, 0.7186, 0.7349, 0.6587, 0.6594,
        0.5677]), 0.5316506410256411)


37 Train Loss: 1.9526973697874281: 100%|██████████| 351/351 [00:21<00:00, 16.51it/s]
37 Test Loss: 1.9586224128038456: 100%|██████████| 39/39 [00:02<00:00, 14.91it/s]


(tensor([0.6575, 0.6402, 0.0796, 0.0000, 0.5758, 0.7200, 0.7475, 0.6607, 0.6720,
        0.5931]), 0.5374599358974359)


38 Train Loss: 1.9500603207156189: 100%|██████████| 351/351 [00:20<00:00, 17.03it/s]
38 Test Loss: 1.9566610195697882: 100%|██████████| 39/39 [00:02<00:00, 15.77it/s]


(tensor([0.6569, 0.6344, 0.0879, 0.0000, 0.5881, 0.7211, 0.7309, 0.6581, 0.6793,
        0.6101]), 0.539863782051282)


39 Train Loss: 1.9474750719858371: 100%|██████████| 351/351 [00:19<00:00, 17.59it/s]
39 Test Loss: 1.9543585716149745: 100%|██████████| 39/39 [00:02<00:00, 15.02it/s]


(tensor([0.6438, 0.6378, 0.0920, 0.0000, 0.5676, 0.7311, 0.7390, 0.6581, 0.6673,
        0.5793]), 0.5348557692307693)


40 Train Loss: 1.9445877333312294: 100%|██████████| 351/351 [00:20<00:00, 16.83it/s]
40 Test Loss: 1.9519456594418256: 100%|██████████| 39/39 [00:02<00:00, 15.70it/s]


(tensor([0.6451, 0.6506, 0.1268, 0.0000, 0.5779, 0.7331, 0.7390, 0.6713, 0.6707,
        0.5879]), 0.5432692307692307)


41 Train Loss: 1.9427098956203188: 100%|██████████| 351/351 [00:21<00:00, 16.10it/s]
41 Test Loss: 1.9503550560046465: 100%|██████████| 39/39 [00:02<00:00, 15.23it/s]


(tensor([0.6451, 0.6243, 0.1490, 0.0000, 0.5893, 0.7246, 0.7465, 0.6529, 0.6733,
        0.5963]), 0.5428685897435898)


42 Train Loss: 1.940182353356625: 100%|██████████| 351/351 [00:20<00:00, 17.23it/s] 
42 Test Loss: 1.9470506998208852: 100%|██████████| 39/39 [00:02<00:00, 15.25it/s]


(tensor([0.6549, 0.6371, 0.1575, 0.0000, 0.5594, 0.7295, 0.7243, 0.6607, 0.6713,
        0.5859]), 0.5408653846153846)


43 Train Loss: 1.9377496612717284: 100%|██████████| 351/351 [00:21<00:00, 16.57it/s]
43 Test Loss: 1.9435142064705873: 100%|██████████| 39/39 [00:02<00:00, 15.25it/s]


(tensor([0.6471, 0.6313, 0.1816, 0.0000, 0.5638, 0.7146, 0.7364, 0.6640, 0.6600,
        0.6081]), 0.5434695512820513)


44 Train Loss: 1.9362266104445498: 100%|██████████| 351/351 [00:20<00:00, 16.82it/s]
44 Test Loss: 1.9415667454401653: 100%|██████████| 39/39 [00:02<00:00, 15.50it/s]


(tensor([0.6385, 0.6255, 0.1898, 0.0000, 0.5615, 0.7405, 0.7149, 0.6686, 0.7058,
        0.6247]), 0.5498798076923077)


45 Train Loss: 1.9329082867358824: 100%|██████████| 351/351 [00:21<00:00, 16.41it/s]
45 Test Loss: 1.941489149362613: 100%|██████████| 39/39 [00:02<00:00, 15.85it/s] 


(tensor([0.6444, 0.6166, 0.1984, 0.0000, 0.5585, 0.7226, 0.7329, 0.6680, 0.6687,
        0.6040]), 0.5440705128205128)


46 Train Loss: 1.9298661183088253: 100%|██████████| 351/351 [00:20<00:00, 17.18it/s]
46 Test Loss: 1.9376804186747625: 100%|██████████| 39/39 [00:02<00:00, 15.85it/s]


(tensor([0.6399, 0.6448, 0.2316, 0.0000, 0.5533, 0.7092, 0.7329, 0.6686, 0.6727,
        0.6057]), 0.5486778846153846)


47 Train Loss: 1.9284715193968553: 100%|██████████| 351/351 [00:20<00:00, 17.43it/s]
47 Test Loss: 1.9373362889656653: 100%|██████████| 39/39 [00:02<00:00, 15.65it/s]


(tensor([0.6380, 0.6332, 0.2413, 0.0000, 0.5738, 0.7151, 0.7264, 0.6627, 0.6600,
        0.6032]), 0.5480769230769231)


48 Train Loss: 1.9271244646137595: 100%|██████████| 351/351 [00:20<00:00, 16.84it/s]
48 Test Loss: 1.9327324139766204: 100%|██████████| 39/39 [00:02<00:00, 15.21it/s]


(tensor([0.6431, 0.6409, 0.2449, 0.0000, 0.5257, 0.7226, 0.7364, 0.6581, 0.6899,
        0.6134]), 0.5502804487179487)


49 Train Loss: 1.925110537442047: 100%|██████████| 351/351 [00:20<00:00, 16.74it/s] 
49 Test Loss: 1.9349311743027124: 100%|██████████| 39/39 [00:02<00:00, 15.82it/s]


(tensor([0.6419, 0.6286, 0.2531, 0.0000, 0.5553, 0.7066, 0.7264, 0.6660, 0.6880,
        0.6101]), 0.5500801282051282)


50 Train Loss: 1.9224067783763266: 100%|██████████| 351/351 [00:19<00:00, 17.90it/s]
50 Test Loss: 1.9322983148770454: 100%|██████████| 39/39 [00:02<00:00, 15.89it/s]


(tensor([0.6248, 0.6344, 0.2761, 0.0000, 0.5626, 0.7026, 0.7209, 0.6627, 0.6892,
        0.6162]), 0.5514823717948718)


51 Train Loss: 1.9209982051468983: 100%|██████████| 351/351 [00:21<00:00, 16.67it/s]
51 Test Loss: 1.9266253006763947: 100%|██████████| 39/39 [00:02<00:00, 15.26it/s]


(tensor([0.6204, 0.6364, 0.2951, 0.0000, 0.5647, 0.6980, 0.7410, 0.6607, 0.6839,
        0.6073]), 0.5532852564102564)


52 Train Loss: 1.9191703694498436: 100%|██████████| 351/351 [00:20<00:00, 17.39it/s]
52 Test Loss: 1.9274686941733727: 100%|██████████| 39/39 [00:02<00:00, 14.73it/s]


(tensor([0.6299, 0.6262, 0.2889, 0.0000, 0.5524, 0.7046, 0.7309, 0.6607, 0.6899,
        0.6162]), 0.5526842948717948)


53 Train Loss: 1.9178429661992609: 100%|██████████| 351/351 [00:20<00:00, 17.10it/s]
53 Test Loss: 1.9264726516528008: 100%|██████████| 39/39 [00:02<00:00, 15.89it/s]


(tensor([0.6262, 0.6332, 0.3094, 0.0000, 0.5635, 0.6972, 0.7304, 0.6529, 0.6740,
        0.6174]), 0.5528846153846154)


54 Train Loss: 1.9159952364755832: 100%|██████████| 351/351 [00:20<00:00, 17.14it/s]
54 Test Loss: 1.9255483884077806: 100%|██████████| 39/39 [00:02<00:00, 15.21it/s]


(tensor([0.6346, 0.6185, 0.3102, 0.0000, 0.5606, 0.7080, 0.7129, 0.6509, 0.6873,
        0.6263]), 0.5534855769230769)


55 Train Loss: 1.914863045059378: 100%|██████████| 351/351 [00:20<00:00, 16.87it/s] 
55 Test Loss: 1.9222984222265391: 100%|██████████| 39/39 [00:02<00:00, 15.20it/s]


(tensor([0.6280, 0.6204, 0.3197, 0.0000, 0.5585, 0.7032, 0.7209, 0.6746, 0.6972,
        0.6174]), 0.5564903846153846)


56 Train Loss: 1.9133777241421561: 100%|██████████| 351/351 [00:20<00:00, 17.30it/s]
56 Test Loss: 1.9217374661029913: 100%|██████████| 39/39 [00:02<00:00, 14.95it/s]


(tensor([0.6235, 0.6383, 0.3429, 0.0000, 0.5656, 0.7092, 0.7298, 0.6647, 0.7038,
        0.6113]), 0.5614983974358975)


57 Train Loss: 1.9126073556747871: 100%|██████████| 351/351 [00:20<00:00, 17.14it/s]
57 Test Loss: 1.92137165252979: 100%|██████████| 39/39 [00:02<00:00, 15.18it/s]  


(tensor([0.6142, 0.6339, 0.3497, 0.0000, 0.5615, 0.7006, 0.7022, 0.6561, 0.7097,
        0.6162]), 0.5568910256410257)


58 Train Loss: 1.9109175436177486: 100%|██████████| 351/351 [00:19<00:00, 17.56it/s]
58 Test Loss: 1.9191611057672746: 100%|██████████| 39/39 [00:02<00:00, 15.21it/s]


(tensor([0.6373, 0.6185, 0.3354, 0.0000, 0.5700, 0.7092, 0.7209, 0.6561, 0.6958,
        0.6138]), 0.5580929487179487)


59 Train Loss: 1.9087564551252925: 100%|██████████| 351/351 [00:19<00:00, 17.61it/s]
59 Test Loss: 1.918330348454989: 100%|██████████| 39/39 [00:02<00:00, 15.78it/s] 


(tensor([0.6125, 0.6062, 0.3374, 0.0000, 0.5741, 0.6926, 0.7309, 0.6706, 0.7052,
        0.6356]), 0.5588942307692307)


60 Train Loss: 1.9080140247643842: 100%|██████████| 351/351 [00:20<00:00, 17.00it/s]
60 Test Loss: 1.9187564697021093: 100%|██████████| 39/39 [00:02<00:00, 14.55it/s]


(tensor([0.6216, 0.6204, 0.3497, 0.0000, 0.5606, 0.6892, 0.7149, 0.6541, 0.6799,
        0.6061]), 0.5520833333333334)


61 Train Loss: 1.9075146322576408: 100%|██████████| 351/351 [00:19<00:00, 17.60it/s]
61 Test Loss: 1.9159942284608498: 100%|██████████| 39/39 [00:02<00:00, 15.92it/s]


(tensor([0.6294, 0.6332, 0.3402, 0.0000, 0.5462, 0.7166, 0.7209, 0.6529, 0.6879,
        0.6093]), 0.5562900641025641)


62 Train Loss: 1.9065711406561046: 100%|██████████| 351/351 [00:19<00:00, 18.14it/s]
62 Test Loss: 1.915583115357619: 100%|██████████| 39/39 [00:02<00:00, 16.19it/s] 


(tensor([0.6255, 0.6390, 0.3558, 0.0000, 0.5697, 0.6886, 0.7319, 0.6502, 0.6873,
        0.6222]), 0.5592948717948718)


63 Train Loss: 1.9046325792274583: 100%|██████████| 351/351 [00:19<00:00, 17.70it/s]
63 Test Loss: 1.9135097968272674: 100%|██████████| 39/39 [00:02<00:00, 14.64it/s]


(tensor([0.6294, 0.6236, 0.3566, 0.0000, 0.5749, 0.6972, 0.7209, 0.6627, 0.6992,
        0.6397]), 0.5629006410256411)


64 Train Loss: 1.9042118132284225: 100%|██████████| 351/351 [00:20<00:00, 17.36it/s]
64 Test Loss: 1.9154520095923009: 100%|██████████| 39/39 [00:02<00:00, 15.16it/s]


(tensor([0.6301, 0.6042, 0.3517, 0.0000, 0.5635, 0.7066, 0.7264, 0.6653, 0.6886,
        0.6121]), 0.5570913461538461)


65 Train Loss: 1.9027736546986462: 100%|██████████| 351/351 [00:20<00:00, 17.35it/s]
65 Test Loss: 1.9127882627340465: 100%|██████████| 39/39 [00:02<00:00, 15.63it/s]


(tensor([0.6471, 0.6313, 0.3633, 0.0000, 0.5656, 0.7032, 0.7157, 0.6634, 0.7078,
        0.6032]), 0.5625)


66 Train Loss: 1.9020608351101562: 100%|██████████| 351/351 [00:21<00:00, 16.66it/s]
66 Test Loss: 1.9148524296589386: 100%|██████████| 39/39 [00:02<00:00, 15.90it/s]


(tensor([0.6137, 0.6050, 0.3653, 0.0000, 0.5667, 0.6833, 0.7189, 0.6587, 0.6873,
        0.6303]), 0.5552884615384616)


67 Train Loss: 1.9006231381342962: 100%|██████████| 351/351 [00:19<00:00, 17.59it/s]
67 Test Loss: 1.9148769745459924: 100%|██████████| 39/39 [00:02<00:00, 15.34it/s]


(tensor([0.6223, 0.6139, 0.3640, 0.0041, 0.5779, 0.6766, 0.7062, 0.6647, 0.6912,
        0.6126]), 0.5556891025641025)


68 Train Loss: 1.9002099624726168: 100%|██████████| 351/351 [00:20<00:00, 17.24it/s]
68 Test Loss: 1.911634277074765: 100%|██████████| 39/39 [00:02<00:00, 15.11it/s] 


(tensor([0.6412, 0.6339, 0.3571, 0.0062, 0.5782, 0.6952, 0.7183, 0.6561, 0.6833,
        0.6113]), 0.5604967948717948)


69 Train Loss: 1.901027567026622: 100%|██████████| 351/351 [00:20<00:00, 17.30it/s] 
69 Test Loss: 1.9098418248005402: 100%|██████████| 39/39 [00:02<00:00, 15.22it/s]


(tensor([0.6353, 0.6320, 0.3599, 0.0082, 0.5720, 0.6926, 0.7189, 0.6601, 0.6932,
        0.6424]), 0.5639022435897436)


70 Train Loss: 1.8984634044163586: 100%|██████████| 351/351 [00:20<00:00, 17.18it/s]
70 Test Loss: 1.9123490345783722: 100%|██████████| 39/39 [00:02<00:00, 15.99it/s]


(tensor([0.6137, 0.6146, 0.3857, 0.0062, 0.5492, 0.6952, 0.7203, 0.6495, 0.7038,
        0.6288]), 0.5590945512820513)


71 Train Loss: 1.8978470798231597: 100%|██████████| 351/351 [00:20<00:00, 17.41it/s]
71 Test Loss: 1.9059751308881319: 100%|██████████| 39/39 [00:02<00:00, 15.96it/s]


(tensor([0.6419, 0.6281, 0.3893, 0.0246, 0.5626, 0.6980, 0.7123, 0.6686, 0.7032,
        0.6215]), 0.5673076923076923)


72 Train Loss: 1.8968307836103304: 100%|██████████| 351/351 [00:20<00:00, 16.87it/s]
72 Test Loss: 1.9084189243805714: 100%|██████████| 39/39 [00:02<00:00, 15.83it/s]


(tensor([0.6181, 0.6120, 0.3906, 0.0205, 0.5708, 0.6826, 0.7324, 0.6509, 0.7078,
        0.6202]), 0.5627003205128205)


73 Train Loss: 1.8968403920149193: 100%|██████████| 351/351 [00:20<00:00, 17.46it/s]
73 Test Loss: 1.9071032939813075: 100%|██████████| 39/39 [00:02<00:00, 15.89it/s]


(tensor([0.6380, 0.6429, 0.3714, 0.0370, 0.5779, 0.6760, 0.7036, 0.6627, 0.7106,
        0.6081]), 0.5651041666666666)


74 Train Loss: 1.895102205099883: 100%|██████████| 351/351 [00:19<00:00, 17.65it/s] 
74 Test Loss: 1.9074919773982122: 100%|██████████| 39/39 [00:02<00:00, 15.34it/s]


(tensor([0.6196, 0.6100, 0.3796, 0.0617, 0.5852, 0.6873, 0.7129, 0.6509, 0.6892,
        0.6362]), 0.5653044871794872)


75 Train Loss: 1.8941765262870027: 100%|██████████| 351/351 [00:20<00:00, 17.02it/s]
75 Test Loss: 1.9052350673920069: 100%|██████████| 39/39 [00:02<00:00, 15.81it/s]


(tensor([0.6184, 0.6236, 0.3898, 0.0700, 0.5565, 0.6786, 0.7278, 0.6535, 0.6899,
        0.6283]), 0.5657051282051282)


76 Train Loss: 1.8929390126144106: 100%|██████████| 351/351 [00:20<00:00, 17.23it/s]
76 Test Loss: 1.9046272314511812: 100%|██████████| 39/39 [00:02<00:00, 15.27it/s]


(tensor([0.6169, 0.6204, 0.3551, 0.0947, 0.5799, 0.6932, 0.7062, 0.6653, 0.6972,
        0.6397]), 0.5689102564102564)


77 Train Loss: 1.892220036256687: 100%|██████████| 351/351 [00:20<00:00, 17.50it/s] 
77 Test Loss: 1.9036682477364173: 100%|██████████| 39/39 [00:02<00:00, 15.20it/s]


(tensor([0.6331, 0.6224, 0.3796, 0.1150, 0.5556, 0.6713, 0.7163, 0.6746, 0.6938,
        0.6215]), 0.5703125)


78 Train Loss: 1.891512556633039: 100%|██████████| 351/351 [00:21<00:00, 16.02it/s] 
78 Test Loss: 1.9041076195545685: 100%|██████████| 39/39 [00:02<00:00, 15.25it/s]


(tensor([0.6275, 0.6390, 0.3627, 0.1379, 0.5565, 0.6653, 0.6982, 0.6680, 0.6879,
        0.6323]), 0.5697115384615384)


79 Train Loss: 1.8898481874384432: 100%|██████████| 351/351 [00:19<00:00, 17.78it/s]
79 Test Loss: 1.9040290178396764: 100%|██████████| 39/39 [00:02<00:00, 15.58it/s]


(tensor([0.6145, 0.6262, 0.3525, 0.1520, 0.5647, 0.6687, 0.7169, 0.6522, 0.7052,
        0.6085]), 0.5681089743589743)


80 Train Loss: 1.8894090081891444: 100%|██████████| 351/351 [00:20<00:00, 17.45it/s]
80 Test Loss: 1.9034776015159411: 100%|██████████| 39/39 [00:02<00:00, 14.84it/s]


(tensor([0.6326, 0.6023, 0.3627, 0.1745, 0.5574, 0.6554, 0.6976, 0.6292, 0.6918,
        0.6316]), 0.5653044871794872)


81 Train Loss: 1.887560783967673: 100%|██████████| 351/351 [00:21<00:00, 16.06it/s] 
81 Test Loss: 1.8980306081282787: 100%|██████████| 39/39 [00:02<00:00, 15.54it/s]


(tensor([0.6341, 0.6236, 0.3620, 0.1910, 0.5585, 0.6700, 0.6962, 0.6509, 0.7078,
        0.6247]), 0.5737179487179487)


82 Train Loss: 1.8869432047900991: 100%|██████████| 351/351 [00:21<00:00, 16.71it/s]
82 Test Loss: 1.899757492236602: 100%|██████████| 39/39 [00:02<00:00, 15.99it/s] 


(tensor([0.6314, 0.6143, 0.3640, 0.2259, 0.5556, 0.6414, 0.6908, 0.6627, 0.6899,
        0.6518]), 0.5745192307692307)


83 Train Loss: 1.8862284163803795: 100%|██████████| 351/351 [00:20<00:00, 17.02it/s]
83 Test Loss: 1.8974626003167567: 100%|██████████| 39/39 [00:02<00:00, 14.17it/s]


(tensor([0.6360, 0.6390, 0.3730, 0.2202, 0.5451, 0.6607, 0.6982, 0.6522, 0.6859,
        0.6235]), 0.5753205128205128)


84 Train Loss: 1.885115478792761: 100%|██████████| 351/351 [00:20<00:00, 16.77it/s] 
84 Test Loss: 1.8988447739527776: 100%|██████████| 39/39 [00:02<00:00, 15.20it/s]


(tensor([0.6008, 0.6146, 0.3551, 0.2361, 0.5732, 0.6467, 0.6888, 0.6482, 0.6793,
        0.6288]), 0.5687099358974359)


85 Train Loss: 1.8844862476712958: 100%|██████████| 351/351 [00:20<00:00, 16.86it/s]
85 Test Loss: 1.8983141672916901: 100%|██████████| 39/39 [00:02<00:00, 16.05it/s]


(tensor([0.6118, 0.6062, 0.3571, 0.2474, 0.5861, 0.6235, 0.7163, 0.6601, 0.6873,
        0.6275]), 0.5739182692307693)


86 Train Loss: 1.8836870767452099: 100%|██████████| 351/351 [00:20<00:00, 17.52it/s]
86 Test Loss: 1.8979357419869838: 100%|██████████| 39/39 [00:02<00:00, 15.85it/s]


(tensor([0.6243, 0.6337, 0.3476, 0.2587, 0.5594, 0.6295, 0.6841, 0.6548, 0.6912,
        0.6227]), 0.5723157051282052)


87 Train Loss: 1.8825788277166862: 100%|██████████| 351/351 [00:20<00:00, 16.96it/s]
87 Test Loss: 1.8959903625341563: 100%|██████████| 39/39 [00:02<00:00, 13.86it/s]


(tensor([0.6204, 0.6267, 0.3694, 0.2769, 0.5717, 0.6228, 0.6908, 0.6588, 0.7092,
        0.6296]), 0.5793269230769231)


88 Train Loss: 1.882547658393186: 100%|██████████| 351/351 [00:21<00:00, 16.35it/s] 
88 Test Loss: 1.8943189810483883: 100%|██████████| 39/39 [00:02<00:00, 15.04it/s]


(tensor([0.6176, 0.6262, 0.3620, 0.2669, 0.5512, 0.6267, 0.6861, 0.6469, 0.7200,
        0.6235]), 0.5743189102564102)


89 Train Loss: 1.8815158007830977: 100%|██████████| 351/351 [00:21<00:00, 16.56it/s]
89 Test Loss: 1.8941174837259145: 100%|██████████| 39/39 [00:02<00:00, 14.78it/s]


(tensor([0.6189, 0.6332, 0.3673, 0.2916, 0.5453, 0.6355, 0.6807, 0.6522, 0.7106,
        0.6263]), 0.5777243589743589)


90 Train Loss: 1.8817325460265504: 100%|██████████| 351/351 [00:21<00:00, 16.71it/s]
90 Test Loss: 1.8948353559542925: 100%|██████████| 39/39 [00:02<00:00, 15.33it/s]


(tensor([0.5988, 0.6320, 0.3551, 0.2996, 0.5691, 0.5868, 0.7042, 0.6647, 0.6899,
        0.6404]), 0.5757211538461539)


91 Train Loss: 1.8813749239315674: 100%|██████████| 351/351 [00:19<00:00, 17.66it/s]
91 Test Loss: 1.8936342398325603: 100%|██████████| 39/39 [00:02<00:00, 15.28it/s]


(tensor([0.6294, 0.6262, 0.3770, 0.2957, 0.5667, 0.6387, 0.6667, 0.6482, 0.7097,
        0.6207]), 0.5795272435897436)


92 Train Loss: 1.8791835820912635: 100%|██████████| 351/351 [00:20<00:00, 16.91it/s]
92 Test Loss: 1.892681020956773: 100%|██████████| 39/39 [00:02<00:00, 14.62it/s] 


(tensor([0.6228, 0.6236, 0.3620, 0.3148, 0.5738, 0.6076, 0.6881, 0.6462, 0.6879,
        0.6316]), 0.577323717948718)


93 Train Loss: 1.8797083722899781: 100%|██████████| 351/351 [00:24<00:00, 14.37it/s]
93 Test Loss: 1.8942782145280104: 100%|██████████| 39/39 [00:03<00:00, 10.59it/s]


(tensor([0.6078, 0.6332, 0.3429, 0.3189, 0.5380, 0.6135, 0.6768, 0.6568, 0.6992,
        0.6141]), 0.5717147435897436)


94 Train Loss: 1.8785623332374117: 100%|██████████| 351/351 [00:21<00:00, 16.25it/s]
94 Test Loss: 1.894310501905588: 100%|██████████| 39/39 [00:02<00:00, 15.62it/s] 


(tensor([0.6240, 0.6429, 0.3552, 0.3080, 0.5594, 0.6056, 0.6908, 0.6542, 0.6918,
        0.6202]), 0.5769230769230769)


95 Train Loss: 1.8794634671632382: 100%|██████████| 351/351 [00:21<00:00, 16.43it/s]
95 Test Loss: 1.8945157008293347: 100%|██████████| 39/39 [00:02<00:00, 13.54it/s]


(tensor([0.6067, 0.6274, 0.3620, 0.3340, 0.5667, 0.6008, 0.6807, 0.6660, 0.6759,
        0.6316]), 0.5767227564102564)


96 Train Loss: 1.8779796321167905: 100%|██████████| 351/351 [00:23<00:00, 15.22it/s]
96 Test Loss: 1.8913808999917445: 100%|██████████| 39/39 [00:02<00:00, 13.92it/s]


(tensor([0.6164, 0.6228, 0.3742, 0.3285, 0.5576, 0.6168, 0.6787, 0.6391, 0.7026,
        0.6121]), 0.5763221153846154)


97 Train Loss: 1.8789408274865218: 100%|██████████| 351/351 [00:21<00:00, 16.08it/s]
97 Test Loss: 1.8897766027695093: 100%|██████████| 39/39 [00:02<00:00, 15.14it/s]


(tensor([0.6196, 0.6204, 0.3653, 0.3368, 0.5676, 0.6000, 0.6835, 0.6759, 0.7012,
        0.6275]), 0.5811298076923077)


98 Train Loss: 1.8773038044274701: 100%|██████████| 351/351 [00:20<00:00, 16.77it/s]
98 Test Loss: 1.8924976740127954: 100%|██████████| 39/39 [00:02<00:00, 14.11it/s]


(tensor([0.6130, 0.6158, 0.3714, 0.3450, 0.5553, 0.6040, 0.6707, 0.6297, 0.6998,
        0.6518]), 0.5769230769230769)


99 Train Loss: 1.8768601590751584: 100%|██████████| 351/351 [00:21<00:00, 16.30it/s]
99 Test Loss: 1.889027164532588: 100%|██████████| 39/39 [00:02<00:00, 15.99it/s] 


(tensor([0.6125, 0.6378, 0.3730, 0.3374, 0.5656, 0.6056, 0.6821, 0.6601, 0.6980,
        0.6242]), 0.5811298076923077)
