In [1]:
from siamese_dataloader import Siamese_dataset, Siamese_dataloader
from siamese_network import SiameseNetwork, SiameseNetworkToy, SiameseNetworkToyConv
from torch import nn as nn
import torch
from torch import device, cuda, no_grad, cat
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
from torch.nn import functional as F

In [2]:
device = 'cuda:0' if cuda.is_available() else 'cpu'
print(f"Using device: {device}")

SEED = 2024

root="../train_xml"
# Raw data
subset = np.logical_not(np.load("data/missing_file_names_mask.npy")) # Makes losses adapt to the selected mouse_records
losses = np.load("data/resnet50_losses_final_weights.npy")[subset]
mouse_records = np.load("data/mouse_record_interpolated.npy")
indices_broken = np.load("data/mouse_records_broken.npy")


# Clean data aka reasonable estimated time
estimate_times=np.load('data/sample_estimate_times.npy')
clean_subset = estimate_times[subset] < 3000
clean_subset = np.logical_and(clean_subset, estimate_times[subset] > 0)
clean_subset = np.logical_and(clean_subset, np.logical_not(indices_broken))

losses = losses[clean_subset]
mouse_records = mouse_records[clean_subset]#[:,:5,:]

# Train and test split
assert len(losses) == len(mouse_records)




Using device: cuda:0


In [11]:
indices = list(np.arange(len(losses)))
indices_randomized = np.random.choice(indices, int(len(indices)), replace=False)
train_indices = indices_randomized[:int(0.75*len(indices))]
test_indices = indices_randomized[int(0.8*len(indices)):]
val_indices = indices_randomized[int(0.75*len(indices)):int(0.8*len(indices))]

BATCH_SIZE = 4
mouse_record_shape= (30,2)

train_indices = train_indices[:4] # Lets overfit
losses_train = losses[train_indices]
mouse_records_train = mouse_records[train_indices]
dataset_train = Siamese_dataset(losses_train, mouse_records_train, seed=SEED)
dataloader_train = Siamese_dataloader(dataset_train, batch_size=BATCH_SIZE, num_workers=8, shuffle=True, mouse_record_shape=mouse_record_shape).run()

test_indices=test_indices[:1000]
losses_test = losses[test_indices]
mouse_records_test = mouse_records[test_indices]
dataset_test = Siamese_dataset(losses_test, mouse_records_test, seed=SEED)
dataloader_test = Siamese_dataloader(dataset_test, batch_size=BATCH_SIZE, num_workers=8, shuffle=True, mouse_record_shape=mouse_record_shape).run()

val_indices=val_indices[:256]
losses_val = losses[val_indices]
mouse_records_val = mouse_records[val_indices]
dataset_val = Siamese_dataset(losses_val, mouse_records_val, seed=SEED)
dataloader_val = Siamese_dataloader(dataset_val, batch_size=BATCH_SIZE, num_workers=8, shuffle=True, mouse_record_shape=mouse_record_shape).run()

len(dataset_train), len(dataset_test), len(dataset_val)

(2, 500, 128)

In [22]:
all_weights = []
for _, ( inputs_0, inputs_1, labels, weights) in enumerate(dataloader_train, 0):
    inputs_0, inputs_1, labels, weights = inputs_0.to(device), inputs_1.to(device), labels.to(device), weights.to(device)
    all_weights.append(weights)
all_weights = torch.cat(all_weights)
all_weights.sum(), inputs_0[0][:3], inputs_1[0][:3], inputs_0[1][:3], inputs_1[1][:3], all_weights, labels

(tensor(1., device='cuda:0'),
 tensor([[0.3456, 0.0029],
         [0.4778, 0.3784],
         [0.4702, 0.5510]], device='cuda:0'),
 tensor([[0.6852, 0.0377],
         [0.7407, 0.6075],
         [0.0000, 0.0000]], device='cuda:0'),
 tensor([[0.8889, 0.0545],
         [0.8245, 0.1996],
         [0.8836, 0.0800]], device='cuda:0'),
 tensor([[0.7593, 0.9772],
         [0.7066, 0.8459],
         [0.6759, 0.7534]], device='cuda:0'),
 tensor([[0.],
         [1.]], device='cuda:0'),
 tensor([[0.],
         [1.]], device='cuda:0'))

In [21]:
inputs_0[0][:3], inputs_1[0][:3], '\n',inputs_0[1][:3], inputs_1[1][:3], inputs_0[2][:3], inputs_1[2][:3], inputs_0[3][:3], inputs_1[3][:3], all_weights, labels


(tensor([[0.4383, 0.0573],
         [0.4748, 0.3499],
         [0.4822, 0.3672]], device='cuda:0'),
 tensor([[0.7193, 0.9922],
         [0.7018, 0.9062],
         [0.7018, 0.9062]], device='cuda:0'),
 '\n',
 tensor([[0.8960, 0.6480],
         [0.8247, 0.6480],
         [0.7214, 0.6446]], device='cuda:0'),
 tensor([[0.8580, 0.9561],
         [0.2331, 0.4893],
         [0.1950, 0.4354]], device='cuda:0'),
 tensor([[0.7193, 0.9922],
         [0.7018, 0.9062],
         [0.7018, 0.9062]], device='cuda:0'),
 tensor([[0.4383, 0.0573],
         [0.4748, 0.3499],
         [0.4822, 0.3672]], device='cuda:0'),
 tensor([[0.8580, 0.9561],
         [0.2331, 0.4893],
         [0.1950, 0.4354]], device='cuda:0'),
 tensor([[0.8960, 0.6480],
         [0.8247, 0.6480],
         [0.7214, 0.6446]], device='cuda:0'),
 tensor([[1.],
         [0.],
         [1.],
         [0.]], device='cuda:0'),
 tensor([[1.],
         [0.],
         [0.],
         [0.]], device='cuda:0'))

In [9]:
labels

tensor([[1.],
        [0.]], device='cuda:0')

In [9]:
inputs_0[0][:3], inputs_1[0][:3], inputs_0[-1][:3], inputs_1[-1][:3], labels

(tensor([[0.1728, 0.0087],
         [0.2161, 0.0462],
         [0.2348, 0.0529]], device='cuda:0'),
 tensor([[0.7932, 0.0139],
         [0.0000, 0.0000],
         [0.5159, 0.9860]], device='cuda:0'),
 tensor([[0.7315, 1.0000],
         [0.6862, 0.8160],
         [0.6473, 0.6975]], device='cuda:0'),
 tensor([[0.4474, 0.0206],
         [0.5335, 0.2290],
         [0.5653, 0.3286]], device='cuda:0'),
 tensor([[0.],
         [1.]], device='cuda:0'))

In [10]:
i = np.random.randint(0,10000)
i, mouse_records[i]

(8216,
 array([[0.9444444 , 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
        [0.91049385, 0.6064815 ],
       

In [79]:
def train_one_epoch(train_loader, net, criterion, optimizer, epoch):
    running_loss = 0.0
    predictions = []
    all_labels = []
    for _, (inputs_0, inputs_1, labels, weights) in enumerate(train_loader, 0):
        inputs_0, inputs_1, labels, weights = inputs_0.to(device), inputs_1.to(device), labels.to(device), weights.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs_0, inputs_1)
        loss = criterion(outputs, labels, weight=weights, size_average=False)
        loss.backward()
        optimizer.step()
        weights = weights.flatten()

        predictions.append(outputs[weights==1])
        all_labels.append(labels[weights==1])
        running_loss += loss.item()
    preds = cat(predictions)>0.5
    accuracy = (preds == cat(all_labels)).sum().item() / len(cat(predictions))
    return accuracy, running_loss

def test_one_epoch(loader, net, criterion):
    with no_grad():
        loss = 0
        predictions = []
        all_labels = []
        for _, (inputs_0, inputs_1, labels, weights) in enumerate(loader, 0):
            inputs_0, inputs_1, labels = inputs_0.to(device), inputs_1.to(device), labels.to(device)
            weights = weights.flatten()
            outputs = net(inputs_0, inputs_1)
            predictions.append(outputs[weights==1])
            all_labels.append(labels[weights==1])
            loss += criterion(outputs, labels).item()

        predictions = cat(predictions)
        all_labels = cat(all_labels)
    preds = predictions > 0.5
    accuracy = (preds == all_labels).sum().item() / len(predictions)
    return accuracy, loss

In [82]:
net = SiameseNetworkToyConv()
print(f"The number of parameters is {sum(p.numel() for p in net.parameters())}")
net.to(device)

criterion = F.binary_cross_entropy
# criterion = F.hinge_embedding_loss
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.8)

accuracy, loss = test_one_epoch(dataloader_test, net, criterion)
print(f"Initial test accuracy is {accuracy} and loss is {loss}")

The number of parameters is 103445
Initial test accuracy is 0.46551724137931033 and loss is 93.52345123887062


In [33]:
inputs_0,inputs_1,labels,weights = next(iter(dataloader_train))

In [15]:
all_weights = []
for _, ( inputs_0, inputs_1, labels, weights) in enumerate(dataloader_train, 0):
    inputs_0, inputs_1, labels, weights = inputs_0.to(device), inputs_1.to(device), labels.to(device), weights.to(device)
    all_weights.append(weights)
all_weights = torch.cat(all_weights)
all_weights.sum()

tensor(1., device='cuda:0')

In [59]:
len(dataset_train)

2

In [86]:
inputs_0,inputs_1,labels,weights = next(iter(dataloader_train))
inputs_0, inputs_1, labels, weights = inputs_0.to(device), inputs_1.to(device), labels.to(device), weights.to(device)
# zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
outputs = net(inputs_0, inputs_1)
preds = outputs > 0.5
loss = criterion(outputs, labels, weight=weights)
loss.backward()
optimizer.step()
outputs, loss, labels, preds, weights,( preds[weights==1]==labels[weights==1]).sum()/weights.sum(),  #inputs_0, inputs_1

(tensor([[0.4957],
         [0.4941]], device='cuda:0', grad_fn=<SigmoidBackward0>),
 tensor(0.3508, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward0>),
 tensor([[1.],
         [0.]], device='cuda:0'),
 tensor([[False],
         [False]], device='cuda:0'),
 tensor([[1.],
         [0.]], device='cuda:0'),
 tensor(0., device='cuda:0'))

In [87]:
EPOCHS = 100
training_accuracies = []
training_losses = []
val_accuracies = []
val_losses = []
for _, epoch in enumerate(tqdm(range(EPOCHS))):
    train_accuracy, loss = train_one_epoch(dataloader_train, net, criterion, optimizer, epoch)
    val_accuracy, v_loss = test_one_epoch(dataloader_val, net, criterion)
    training_accuracies.append(train_accuracy)
    training_losses.append(loss)
    val_accuracies.append(val_accuracy)
    val_losses.append(v_loss)

    print(f"Epoch: {epoch}/{EPOCHS} | Loss : {loss} | Train accuracy: {train_accuracy} | Val accuracy: {val_accuracy} | Val loss: {v_loss}")

accuracy, loss = test_one_epoch(dataloader_test, net, criterion)
print(f"Finished Training, final test accuracy is {accuracy} and loss is {loss}, final training accuracy is {train_accuracy}")

  1%|          | 1/100 [00:00<00:56,  1.74it/s]

Epoch: 0/100 | Loss : 0.29695820808410645 | Train accuracy: 1.0 | Val accuracy: 0.5416666666666666 | Val loss: 23.69816419482231


  2%|▏         | 2/100 [00:00<00:44,  2.20it/s]

Epoch: 1/100 | Loss : 0.7588624954223633 | Train accuracy: 0.0 | Val accuracy: 0.5416666666666666 | Val loss: 24.010351479053497


  3%|▎         | 3/100 [00:01<00:41,  2.36it/s]

Epoch: 2/100 | Loss : 1.9824068546295166 | Train accuracy: 0.0 | Val accuracy: 0.5 | Val loss: 24.64696303009987


  4%|▍         | 4/100 [00:01<00:39,  2.44it/s]

Epoch: 3/100 | Loss : 0.6845635175704956 | Train accuracy: 1.0 | Val accuracy: 0.5208333333333334 | Val loss: 22.838263392448425


  5%|▌         | 5/100 [00:02<00:37,  2.52it/s]

Epoch: 4/100 | Loss : 0.5094490051269531 | Train accuracy: 1.0 | Val accuracy: 0.5 | Val loss: 23.891261130571365


  6%|▌         | 6/100 [00:02<00:36,  2.56it/s]

Epoch: 5/100 | Loss : 0.5212288498878479 | Train accuracy: 1.0 | Val accuracy: 0.5833333333333334 | Val loss: 23.884972989559174


  7%|▋         | 7/100 [00:02<00:35,  2.58it/s]

Epoch: 6/100 | Loss : 0.4164869785308838 | Train accuracy: 1.0 | Val accuracy: 0.5416666666666666 | Val loss: 24.295672237873077


  8%|▊         | 8/100 [00:03<00:35,  2.59it/s]

Epoch: 7/100 | Loss : 0.730104386806488 | Train accuracy: 0.0 | Val accuracy: 0.5208333333333334 | Val loss: 23.979073524475098


  9%|▉         | 9/100 [00:03<00:34,  2.63it/s]

Epoch: 8/100 | Loss : 0.46350061893463135 | Train accuracy: 1.0 | Val accuracy: 0.5208333333333334 | Val loss: 23.934358596801758


 10%|█         | 10/100 [00:03<00:34,  2.64it/s]

Epoch: 9/100 | Loss : 1.0585498809814453 | Train accuracy: 0.0 | Val accuracy: 0.5625 | Val loss: 24.207410871982574


 11%|█         | 11/100 [00:04<00:33,  2.63it/s]

Epoch: 10/100 | Loss : 1.3617262840270996 | Train accuracy: 0.0 | Val accuracy: 0.6041666666666666 | Val loss: 22.639227211475372


 12%|█▏        | 12/100 [00:04<00:33,  2.61it/s]

Epoch: 11/100 | Loss : 0.5956093668937683 | Train accuracy: 1.0 | Val accuracy: 0.6458333333333334 | Val loss: 23.1717249751091


 13%|█▎        | 13/100 [00:05<00:33,  2.61it/s]

Epoch: 12/100 | Loss : 0.7211794853210449 | Train accuracy: 0.0 | Val accuracy: 0.5208333333333334 | Val loss: 23.121311277151108


 14%|█▍        | 14/100 [00:05<00:32,  2.63it/s]

Epoch: 13/100 | Loss : 0.35932910442352295 | Train accuracy: 1.0 | Val accuracy: 0.6041666666666666 | Val loss: 21.992282688617706


 15%|█▌        | 15/100 [00:05<00:32,  2.63it/s]

Epoch: 14/100 | Loss : 1.078942060470581 | Train accuracy: 0.0 | Val accuracy: 0.4583333333333333 | Val loss: 24.7890242934227


 16%|█▌        | 16/100 [00:06<00:32,  2.61it/s]

Epoch: 15/100 | Loss : 0.8145522475242615 | Train accuracy: 0.0 | Val accuracy: 0.5416666666666666 | Val loss: 23.280127584934235


 17%|█▋        | 17/100 [00:06<00:32,  2.59it/s]

Epoch: 16/100 | Loss : 0.5691114068031311 | Train accuracy: 1.0 | Val accuracy: 0.4375 | Val loss: 25.033662170171738


 18%|█▊        | 18/100 [00:07<00:31,  2.57it/s]

Epoch: 17/100 | Loss : 0.32650232315063477 | Train accuracy: 1.0 | Val accuracy: 0.4166666666666667 | Val loss: 23.611854285001755


 19%|█▉        | 19/100 [00:07<00:31,  2.58it/s]

Epoch: 18/100 | Loss : 0.7714357972145081 | Train accuracy: 0.0 | Val accuracy: 0.4375 | Val loss: 23.82621544599533


 20%|██        | 20/100 [00:07<00:31,  2.57it/s]

Epoch: 19/100 | Loss : 1.1355469226837158 | Train accuracy: 0.0 | Val accuracy: 0.5208333333333334 | Val loss: 25.105731964111328


 21%|██        | 21/100 [00:08<00:30,  2.56it/s]

Epoch: 20/100 | Loss : 0.7093707323074341 | Train accuracy: 0.0 | Val accuracy: 0.625 | Val loss: 23.48364034295082


 22%|██▏       | 22/100 [00:08<00:30,  2.58it/s]

Epoch: 21/100 | Loss : 0.7196922302246094 | Train accuracy: 0.0 | Val accuracy: 0.6041666666666666 | Val loss: 23.99232789874077


 23%|██▎       | 23/100 [00:09<00:29,  2.58it/s]

Epoch: 22/100 | Loss : 1.0226099491119385 | Train accuracy: 0.0 | Val accuracy: 0.5416666666666666 | Val loss: 24.880690574645996


 24%|██▍       | 24/100 [00:09<00:29,  2.58it/s]

Epoch: 23/100 | Loss : 1.5963788032531738 | Train accuracy: 0.0 | Val accuracy: 0.5208333333333334 | Val loss: 22.745480179786682


 25%|██▌       | 25/100 [00:09<00:29,  2.57it/s]

Epoch: 24/100 | Loss : 0.4576451778411865 | Train accuracy: 1.0 | Val accuracy: 0.4791666666666667 | Val loss: 24.81593358516693


 26%|██▌       | 26/100 [00:10<00:28,  2.60it/s]

Epoch: 25/100 | Loss : 0.7445958256721497 | Train accuracy: 0.0 | Val accuracy: 0.4791666666666667 | Val loss: 24.865352362394333


 27%|██▋       | 27/100 [00:10<00:28,  2.61it/s]

Epoch: 26/100 | Loss : 0.4832838773727417 | Train accuracy: 1.0 | Val accuracy: 0.6458333333333334 | Val loss: 22.65028503537178


 28%|██▊       | 28/100 [00:10<00:27,  2.60it/s]

Epoch: 27/100 | Loss : 0.7429443001747131 | Train accuracy: 0.0 | Val accuracy: 0.4791666666666667 | Val loss: 25.837442845106125


 29%|██▉       | 29/100 [00:11<00:27,  2.59it/s]

Epoch: 28/100 | Loss : 0.8060797452926636 | Train accuracy: 0.0 | Val accuracy: 0.4791666666666667 | Val loss: 23.536776304244995


 30%|███       | 30/100 [00:11<00:27,  2.58it/s]

Epoch: 29/100 | Loss : 0.33840152621269226 | Train accuracy: 1.0 | Val accuracy: 0.7083333333333334 | Val loss: 22.881423115730286


 31%|███       | 31/100 [00:12<00:26,  2.57it/s]

Epoch: 30/100 | Loss : 0.6358605027198792 | Train accuracy: 1.0 | Val accuracy: 0.5833333333333334 | Val loss: 23.115379363298416


 32%|███▏      | 32/100 [00:12<00:26,  2.59it/s]

Epoch: 31/100 | Loss : 0.9185164570808411 | Train accuracy: 0.0 | Val accuracy: 0.5416666666666666 | Val loss: 23.55798989534378


 33%|███▎      | 33/100 [00:12<00:25,  2.61it/s]

Epoch: 32/100 | Loss : 1.1312490701675415 | Train accuracy: 0.0 | Val accuracy: 0.5208333333333334 | Val loss: 24.74942782521248


 34%|███▍      | 34/100 [00:13<00:25,  2.62it/s]

Epoch: 33/100 | Loss : 0.756710410118103 | Train accuracy: 0.0 | Val accuracy: 0.5 | Val loss: 24.155605167150497


 35%|███▌      | 35/100 [00:13<00:24,  2.64it/s]

Epoch: 34/100 | Loss : 0.5568103790283203 | Train accuracy: 1.0 | Val accuracy: 0.4791666666666667 | Val loss: 22.741821259260178


 36%|███▌      | 36/100 [00:14<00:24,  2.60it/s]

Epoch: 35/100 | Loss : 0.5308065414428711 | Train accuracy: 1.0 | Val accuracy: 0.5416666666666666 | Val loss: 22.76176518201828


 37%|███▋      | 37/100 [00:14<00:24,  2.60it/s]

Epoch: 36/100 | Loss : 0.43164631724357605 | Train accuracy: 1.0 | Val accuracy: 0.5208333333333334 | Val loss: 24.550996780395508


 38%|███▊      | 38/100 [00:14<00:23,  2.60it/s]

Epoch: 37/100 | Loss : 0.3023819923400879 | Train accuracy: 1.0 | Val accuracy: 0.3958333333333333 | Val loss: 24.709106266498566


 39%|███▉      | 39/100 [00:15<00:23,  2.58it/s]

Epoch: 38/100 | Loss : 0.5478684902191162 | Train accuracy: 1.0 | Val accuracy: 0.3958333333333333 | Val loss: 24.328100562095642


 40%|████      | 40/100 [00:15<00:23,  2.59it/s]

Epoch: 39/100 | Loss : 1.3628612756729126 | Train accuracy: 0.0 | Val accuracy: 0.4375 | Val loss: 23.708253800868988


 41%|████      | 41/100 [00:15<00:22,  2.59it/s]

Epoch: 40/100 | Loss : 1.108116865158081 | Train accuracy: 0.0 | Val accuracy: 0.4166666666666667 | Val loss: 24.072037905454636


 42%|████▏     | 42/100 [00:16<00:22,  2.59it/s]

Epoch: 41/100 | Loss : 1.5090386867523193 | Train accuracy: 0.0 | Val accuracy: 0.4166666666666667 | Val loss: 25.21490377187729


 43%|████▎     | 43/100 [00:16<00:21,  2.60it/s]

Epoch: 42/100 | Loss : 1.2215244770050049 | Train accuracy: 0.0 | Val accuracy: 0.3541666666666667 | Val loss: 24.72838455438614


 44%|████▍     | 44/100 [00:17<00:21,  2.63it/s]

Epoch: 43/100 | Loss : 1.3009473085403442 | Train accuracy: 0.0 | Val accuracy: 0.4791666666666667 | Val loss: 24.332900911569595


 45%|████▌     | 45/100 [00:17<00:20,  2.63it/s]

Epoch: 44/100 | Loss : 0.9708683490753174 | Train accuracy: 0.0 | Val accuracy: 0.5 | Val loss: 24.080877780914307


 46%|████▌     | 46/100 [00:17<00:20,  2.64it/s]

Epoch: 45/100 | Loss : 0.6084373593330383 | Train accuracy: 1.0 | Val accuracy: 0.5208333333333334 | Val loss: 23.972584903240204


 47%|████▋     | 47/100 [00:18<00:20,  2.63it/s]

Epoch: 46/100 | Loss : 0.9414840340614319 | Train accuracy: 0.0 | Val accuracy: 0.625 | Val loss: 23.8705715239048


 48%|████▊     | 48/100 [00:18<00:20,  2.59it/s]

Epoch: 47/100 | Loss : 0.763165295124054 | Train accuracy: 0.0 | Val accuracy: 0.4583333333333333 | Val loss: 24.426209062337875


 49%|████▉     | 49/100 [00:19<00:19,  2.59it/s]

Epoch: 48/100 | Loss : 0.1727498471736908 | Train accuracy: 1.0 | Val accuracy: 0.4791666666666667 | Val loss: 22.395038962364197


 50%|█████     | 50/100 [00:19<00:19,  2.60it/s]

Epoch: 49/100 | Loss : 1.2776392698287964 | Train accuracy: 0.0 | Val accuracy: 0.375 | Val loss: 25.528541147708893


 51%|█████     | 51/100 [00:19<00:18,  2.60it/s]

Epoch: 50/100 | Loss : 1.429066777229309 | Train accuracy: 0.0 | Val accuracy: 0.5416666666666666 | Val loss: 23.209844410419464


 52%|█████▏    | 52/100 [00:20<00:18,  2.58it/s]

Epoch: 51/100 | Loss : 0.7322350740432739 | Train accuracy: 0.0 | Val accuracy: 0.4583333333333333 | Val loss: 22.94877764582634


 53%|█████▎    | 53/100 [00:20<00:18,  2.61it/s]

Epoch: 52/100 | Loss : 0.610336184501648 | Train accuracy: 1.0 | Val accuracy: 0.5 | Val loss: 23.464515328407288


 54%|█████▍    | 54/100 [00:20<00:17,  2.63it/s]

Epoch: 53/100 | Loss : 0.9478814601898193 | Train accuracy: 0.0 | Val accuracy: 0.6041666666666666 | Val loss: 24.02861052751541


 55%|█████▌    | 55/100 [00:21<00:17,  2.60it/s]

Epoch: 54/100 | Loss : 1.3196452856063843 | Train accuracy: 0.0 | Val accuracy: 0.3958333333333333 | Val loss: 24.456019192934036


 56%|█████▌    | 56/100 [00:21<00:16,  2.60it/s]

Epoch: 55/100 | Loss : 0.655147910118103 | Train accuracy: 1.0 | Val accuracy: 0.625 | Val loss: 22.74320113658905


 57%|█████▋    | 57/100 [00:22<00:16,  2.59it/s]

Epoch: 56/100 | Loss : 0.7279300689697266 | Train accuracy: 0.0 | Val accuracy: 0.5 | Val loss: 23.708566814661026


 58%|█████▊    | 58/100 [00:22<00:16,  2.59it/s]

Epoch: 57/100 | Loss : 0.9236322641372681 | Train accuracy: 0.0 | Val accuracy: 0.6875 | Val loss: 23.11312273144722


 59%|█████▉    | 59/100 [00:22<00:15,  2.60it/s]

Epoch: 58/100 | Loss : 0.6786673665046692 | Train accuracy: 1.0 | Val accuracy: 0.5416666666666666 | Val loss: 24.97857066988945


 60%|██████    | 60/100 [00:23<00:15,  2.60it/s]

Epoch: 59/100 | Loss : 1.084267258644104 | Train accuracy: 0.0 | Val accuracy: 0.4583333333333333 | Val loss: 23.95868620276451


 61%|██████    | 61/100 [00:23<00:15,  2.48it/s]

Epoch: 60/100 | Loss : 0.6166558861732483 | Train accuracy: 1.0 | Val accuracy: 0.5416666666666666 | Val loss: 22.357433915138245


 62%|██████▏   | 62/100 [00:24<00:15,  2.49it/s]

Epoch: 61/100 | Loss : 0.3564213216304779 | Train accuracy: 1.0 | Val accuracy: 0.5625 | Val loss: 23.378176152706146


 63%|██████▎   | 63/100 [00:24<00:14,  2.57it/s]

Epoch: 62/100 | Loss : 0.39919641613960266 | Train accuracy: 1.0 | Val accuracy: 0.5416666666666666 | Val loss: 23.89143082499504


 64%|██████▍   | 64/100 [00:24<00:13,  2.57it/s]

Epoch: 63/100 | Loss : 0.4153136610984802 | Train accuracy: 1.0 | Val accuracy: 0.4583333333333333 | Val loss: 24.233053654432297


 65%|██████▌   | 65/100 [00:25<00:13,  2.57it/s]

Epoch: 64/100 | Loss : 0.5829087495803833 | Train accuracy: 1.0 | Val accuracy: 0.4791666666666667 | Val loss: 25.530151784420013


 66%|██████▌   | 66/100 [00:25<00:13,  2.57it/s]

Epoch: 65/100 | Loss : 0.4449027478694916 | Train accuracy: 1.0 | Val accuracy: 0.3958333333333333 | Val loss: 23.926257073879242


 67%|██████▋   | 67/100 [00:25<00:12,  2.57it/s]

Epoch: 66/100 | Loss : 0.6710909008979797 | Train accuracy: 1.0 | Val accuracy: 0.4791666666666667 | Val loss: 23.86031973361969


 68%|██████▊   | 68/100 [00:26<00:12,  2.57it/s]

Epoch: 67/100 | Loss : 0.3674764037132263 | Train accuracy: 1.0 | Val accuracy: 0.4375 | Val loss: 24.41035658121109


 69%|██████▉   | 69/100 [00:26<00:12,  2.57it/s]

Epoch: 68/100 | Loss : 1.1939575672149658 | Train accuracy: 0.0 | Val accuracy: 0.4791666666666667 | Val loss: 24.901640325784683


 70%|███████   | 70/100 [00:27<00:11,  2.57it/s]

Epoch: 69/100 | Loss : 0.594192385673523 | Train accuracy: 1.0 | Val accuracy: 0.5833333333333334 | Val loss: 22.894330084323883


 71%|███████   | 71/100 [00:27<00:11,  2.56it/s]

Epoch: 70/100 | Loss : 0.7443143129348755 | Train accuracy: 0.0 | Val accuracy: 0.5 | Val loss: 23.10138291120529


 72%|███████▏  | 72/100 [00:27<00:10,  2.57it/s]

Epoch: 71/100 | Loss : 0.43840426206588745 | Train accuracy: 1.0 | Val accuracy: 0.5 | Val loss: 23.585577368736267


 73%|███████▎  | 73/100 [00:28<00:10,  2.56it/s]

Epoch: 72/100 | Loss : 0.3477882742881775 | Train accuracy: 1.0 | Val accuracy: 0.4791666666666667 | Val loss: 22.71367996931076


 74%|███████▍  | 74/100 [00:28<00:10,  2.59it/s]

Epoch: 73/100 | Loss : 1.0225509405136108 | Train accuracy: 0.0 | Val accuracy: 0.4166666666666667 | Val loss: 25.68341302871704


 75%|███████▌  | 75/100 [00:29<00:09,  2.64it/s]

Epoch: 74/100 | Loss : 1.326697587966919 | Train accuracy: 0.0 | Val accuracy: 0.4583333333333333 | Val loss: 23.05307701230049


 76%|███████▌  | 76/100 [00:29<00:09,  2.66it/s]

Epoch: 75/100 | Loss : 1.4091891050338745 | Train accuracy: 0.0 | Val accuracy: 0.625 | Val loss: 23.70604172348976


 77%|███████▋  | 77/100 [00:29<00:08,  2.61it/s]

Epoch: 76/100 | Loss : 0.8349019885063171 | Train accuracy: 0.0 | Val accuracy: 0.5833333333333334 | Val loss: 24.01616421341896


 78%|███████▊  | 78/100 [00:30<00:08,  2.60it/s]

Epoch: 77/100 | Loss : 0.719139575958252 | Train accuracy: 0.0 | Val accuracy: 0.4375 | Val loss: 24.94396686553955


 79%|███████▉  | 79/100 [00:30<00:08,  2.59it/s]

Epoch: 78/100 | Loss : 0.664797306060791 | Train accuracy: 1.0 | Val accuracy: 0.5833333333333334 | Val loss: 22.92532452940941


 80%|████████  | 80/100 [00:31<00:07,  2.56it/s]

Epoch: 79/100 | Loss : 0.49110424518585205 | Train accuracy: 1.0 | Val accuracy: 0.4583333333333333 | Val loss: 25.190729022026062


 81%|████████  | 81/100 [00:31<00:07,  2.60it/s]

Epoch: 80/100 | Loss : 0.9648751020431519 | Train accuracy: 0.0 | Val accuracy: 0.4583333333333333 | Val loss: 23.976810306310654


 82%|████████▏ | 82/100 [00:31<00:06,  2.59it/s]

Epoch: 81/100 | Loss : 0.4847845137119293 | Train accuracy: 1.0 | Val accuracy: 0.5416666666666666 | Val loss: 24.44339257478714


 83%|████████▎ | 83/100 [00:32<00:06,  2.58it/s]

Epoch: 82/100 | Loss : 0.8053840398788452 | Train accuracy: 0.0 | Val accuracy: 0.5208333333333334 | Val loss: 25.021875321865082


 84%|████████▍ | 84/100 [00:32<00:06,  2.57it/s]

Epoch: 83/100 | Loss : 0.8502234220504761 | Train accuracy: 0.0 | Val accuracy: 0.3333333333333333 | Val loss: 25.854250073432922


 85%|████████▌ | 85/100 [00:32<00:05,  2.56it/s]

Epoch: 84/100 | Loss : 0.5121937394142151 | Train accuracy: 1.0 | Val accuracy: 0.5208333333333334 | Val loss: 22.369323253631592


 86%|████████▌ | 86/100 [00:33<00:05,  2.56it/s]

Epoch: 85/100 | Loss : 0.4642184376716614 | Train accuracy: 1.0 | Val accuracy: 0.4375 | Val loss: 25.031331211328506


 87%|████████▋ | 87/100 [00:33<00:05,  2.55it/s]

Epoch: 86/100 | Loss : 0.23505692183971405 | Train accuracy: 1.0 | Val accuracy: 0.5 | Val loss: 22.768192529678345


 88%|████████▊ | 88/100 [00:34<00:04,  2.55it/s]

Epoch: 87/100 | Loss : 1.2013658285140991 | Train accuracy: 0.0 | Val accuracy: 0.4166666666666667 | Val loss: 23.96453756093979


 89%|████████▉ | 89/100 [00:34<00:04,  2.55it/s]

Epoch: 88/100 | Loss : 0.5063012838363647 | Train accuracy: 1.0 | Val accuracy: 0.6041666666666666 | Val loss: 22.684883892536163


 90%|█████████ | 90/100 [00:34<00:03,  2.56it/s]

Epoch: 89/100 | Loss : 0.7176575064659119 | Train accuracy: 0.0 | Val accuracy: 0.3541666666666667 | Val loss: 25.50743854045868


 91%|█████████ | 91/100 [00:35<00:03,  2.55it/s]

Epoch: 90/100 | Loss : 0.9860779047012329 | Train accuracy: 0.0 | Val accuracy: 0.5 | Val loss: 24.58515852689743


 92%|█████████▏| 92/100 [00:35<00:03,  2.56it/s]

Epoch: 91/100 | Loss : 0.6149012446403503 | Train accuracy: 1.0 | Val accuracy: 0.5 | Val loss: 22.810930132865906


 93%|█████████▎| 93/100 [00:36<00:02,  2.56it/s]

Epoch: 92/100 | Loss : 1.0905163288116455 | Train accuracy: 0.0 | Val accuracy: 0.4791666666666667 | Val loss: 25.101469844579697


 94%|█████████▍| 94/100 [00:36<00:02,  2.55it/s]

Epoch: 93/100 | Loss : 0.5663942098617554 | Train accuracy: 1.0 | Val accuracy: 0.6041666666666666 | Val loss: 23.70565155148506


 95%|█████████▌| 95/100 [00:36<00:01,  2.57it/s]

Epoch: 94/100 | Loss : 0.9242264628410339 | Train accuracy: 0.0 | Val accuracy: 0.5625 | Val loss: 22.90398469567299


 96%|█████████▌| 96/100 [00:37<00:01,  2.58it/s]

Epoch: 95/100 | Loss : 0.8082968592643738 | Train accuracy: 0.0 | Val accuracy: 0.4583333333333333 | Val loss: 24.205366671085358


 97%|█████████▋| 97/100 [00:37<00:01,  2.63it/s]

Epoch: 96/100 | Loss : 1.1786061525344849 | Train accuracy: 0.0 | Val accuracy: 0.5416666666666666 | Val loss: 24.699366748332977


 98%|█████████▊| 98/100 [00:37<00:00,  2.63it/s]

Epoch: 97/100 | Loss : 0.4083280563354492 | Train accuracy: 1.0 | Val accuracy: 0.4583333333333333 | Val loss: 26.836972951889038


 99%|█████████▉| 99/100 [00:38<00:00,  2.60it/s]

Epoch: 98/100 | Loss : 0.7054397463798523 | Train accuracy: 0.0 | Val accuracy: 0.4791666666666667 | Val loss: 22.847467362880707


100%|██████████| 100/100 [00:38<00:00,  2.58it/s]

Epoch: 99/100 | Loss : 0.2383473664522171 | Train accuracy: 1.0 | Val accuracy: 0.6041666666666666 | Val loss: 25.218609541654587





Finished Training, final test accuracy is 0.5172413793103449 and loss is 95.3354980647564, final training accuracy is 1.0
