In [213]:
import torch
from torch.utils.data import TensorDataset, DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from siamese_network import SiameseNetworkConv
from torch import nn
from training_utils import train_one_epoch, validate

In [227]:
data = np.load('data/siamese_data_cleaned.npy')
data = torch.from_numpy(data)

dataset = TensorDataset(data)

SUBSELECT_TO_DEBUG = 1000
batch_size = 4096

indices = np.arange(len(data))[:SUBSELECT_TO_DEBUG]
np.random.shuffle(indices)

train_indices = indices[:int(0.8*len(indices))]
val_indices = indices[int(0.8*len(indices)):int(0.85*len(indices))]
test_indices = indices[int(0.85*len(indices)):]

dataset_train = Subset(dataset, train_indices)
dataset_val = Subset(dataset, val_indices)
dataset_test = Subset(dataset, test_indices)

dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

len(dataset_train), len(dataset_val), len(dataset_test)

(800, 50, 150)

In [228]:
EPOCHS = 100
losses = []
train_accuracies = []
val_accuracies = []

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device}")
net = SiameseNetworkConv()
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
criterion = nn.BCELoss()
# number of params
print("Number of params:", sum(p.numel() for p in net.parameters() if p.requires_grad))
initial_train_accuracy = validate(net, dataloader_train, device)
initial_train_accuracy

Using device cuda
Number of params: 1561333


array([0.48875])

In [243]:
for i in range(20):
    for epoch in range(EPOCHS):
        loss, train_accuracy = train_one_epoch(net, dataloader_train, optimizer, criterion, device)
        val_accuracy = validate(net, dataloader_val, device)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        losses.append(loss)
        if epoch % 50 == 0:
            print(f"Epoch {epoch+EPOCHS*i}, loss: {loss}, train accuracy: {train_accuracy}, val accuracy: {val_accuracy}")

plt.plot(losses, label='loss')
plt.title('Loss')
plt.show()
plt.plot(val_accuracies, label='val accuracy')
plt.plot(train_accuracies, label='train accuracy')
plt.title("Accuracy")
plt.legend()
plt.show()

Epoch 0, loss: 0.29327115416526794, train accuracy: [0.8875], val accuracy: [0.58]
Epoch 50, loss: 0.26142022013664246, train accuracy: [0.9], val accuracy: [0.52]
Epoch 100, loss: 0.270355224609375, train accuracy: [0.88625], val accuracy: [0.46]
Epoch 150, loss: 0.27338260412216187, train accuracy: [0.8875], val accuracy: [0.46]
Epoch 200, loss: 0.24251176416873932, train accuracy: [0.90125], val accuracy: [0.6]
Epoch 250, loss: 0.2538853883743286, train accuracy: [0.89875], val accuracy: [0.42]
Epoch 300, loss: 0.27622395753860474, train accuracy: [0.88], val accuracy: [0.68]
Epoch 350, loss: 0.26879820227622986, train accuracy: [0.8925], val accuracy: [0.62]
Epoch 400, loss: 0.26942726969718933, train accuracy: [0.87875], val accuracy: [0.48]
Epoch 450, loss: 0.28986090421676636, train accuracy: [0.87375], val accuracy: [0.5]
Epoch 500, loss: 0.27953872084617615, train accuracy: [0.8875], val accuracy: [0.46]
Epoch 550, loss: 0.26273345947265625, train accuracy: [0.8925], val accur