In [1]:
from siamese_dataloader import Siamese_dataset, Siamese_dataloader
from torch import nn as nn
from torch import device, cuda, no_grad, manual_seed, cat
import torch.optim as optim
from torch.utils.data import Subset
from tqdm import tqdm
import numpy as np
import torch

RuntimeError: operator torchvision::nms does not exist

In [2]:
SEED = 2024

root="../train_xml"
# Raw data
subset = np.logical_not(np.load("data/missing_file_names_mask.npy")) # Makes losses adapt to the selected mouse_records
losses = np.load("data/resnet50_losses_final_weights.npy")[subset]
mouse_records = np.load("data/mouse_record_interpolated.npy")

# Clean data aka reasonable estimated time
estimate_times=np.load('data/sample_estimate_times.npy')
clean_subset = estimate_times[subset] < 3000
clean_subset = np.logical_and(clean_subset, estimate_times[subset] > 0)

losses = losses[clean_subset]
mouse_records = mouse_records[clean_subset]

# Train and test split
assert len(losses) == len(mouse_records)
indices = list(np.arange(len(losses)))
indices_randomized = np.random.choice(indices, int(len(indices)), replace=False)
train_indices = indices_randomized[:int(0.8*len(indices))]
test_indices = indices_randomized[int(0.8*len(indices)):]

losses_train = losses[train_indices]
mouse_records_train = mouse_records[train_indices]
dataset_train = Siamese_dataset(losses_train, mouse_records_train, seed=SEED)
dataloader_train = Siamese_dataloader(dataset_train, batch_size=4096, num_workers=8, shuffle=True).run()

losses_test = losses[test_indices]
mouse_records_test = mouse_records[test_indices]
dataset_test = Siamese_dataset(losses_test, mouse_records_test, seed=SEED)
dataloader_test = Siamese_dataloader(dataset_test, batch_size=4096, num_workers=8, shuffle=True).run()

len(dataset_train), len(dataset_test)

(872249, 218063)

In [3]:
def train_one_epoch(train_loader, net, criterion, optimizer, epoch):
    running_loss = 0.0
    for _, (inputs_0, inputs_1, labels) in enumerate(train_loader, 0):
        inputs_0, inputs_1, labels = inputs_0.to(device), inputs_1.to(device), labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs_0, inputs_1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss

In [4]:
class Siamese_network(nn.Module):
    # It should predict 0, 1.
    # 0 if 1st sample has higher loss
    # 1 if 2nd sample has higher loss
    def __init__(self):
        super().__init__()

        # input is 30x3 dim and output is 1 dim
        self.shared_features = nn.Sequential(
            nn.Linear(60, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
        )

        self.classification_head = nn.Sequential(
            nn.Linear(32, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.BatchNorm1d(16),
            nn.Linear(16, 1),
        )

    def forward(self, x, y):
        x = self.shared_features(x)
        y = self.shared_features(y)

        return self.classification_head(x * y)


In [5]:
cuda.is_available()

False

In [26]:
device = 'cuda:0' if cuda.is_available() else 'cpu'
print(f"Using device: {device}")

net = Siamese_network()
print(f"The number of parameters is {sum(p.numel() for p in net.parameters())}")
net.to(device)

# criterion = nn.BCEWithLogitsLoss()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.8)

print("Training")
history = []

with no_grad():
    loss = 0
    accuracy = 0
    predictions = []
    all_labels = []
    for _, (inputs_0, inputs_1, labels) in enumerate(dataloader_test, 0):
        inputs_0, inputs_1, labels = inputs_0.to(device), inputs_1.to(device), labels.to(device)

        outputs = net(inputs_0, inputs_1)
        predictions.append(outputs)
        all_labels.append(labels)
        loss += criterion(outputs, labels).item()

    predictions = cat(predictions)
    all_labels = cat(all_labels)
preds = predictions<0
accuracy = (preds == all_labels).sum().item() / len(predictions)
print(f"Initial test accuracy is {accuracy} and loss is {loss}")

Using device: cpu
The number of parameters is 165921
Training
Initial test accuracy is 0.5002269986196649 and loss is 38.64784795045853


In [27]:
device

'cpu'

In [29]:
EPOCHS = 50
for _, epoch in enumerate(tqdm(range(EPOCHS))):
    e_loss = train_one_epoch(dataloader_train, net, criterion, optimizer, epoch)
    print(f"Epoch: {epoch}/{EPOCHS} | Loss : {e_loss}")
    history.append(e_loss)

with no_grad():
    loss = 0
    accuracy = 0
    predictions = []
    all_labels = []
    for _, (inputs_0, inputs_1, labels) in enumerate(dataloader_train, 0):
        inputs_0, inputs_1, labels = inputs_0.to(device), inputs_1.to(device), labels.to(device)
        outputs = net(inputs_0, inputs_1)
        predictions.append(outputs)
        all_labels.append(labels)
        loss += criterion(outputs, labels).item()

    predictions = cat(predictions)
    all_labels = cat(all_labels)
    preds = predictions<0
    train_accuracy = (preds == all_labels).sum().item() / len(predictions)

with no_grad():
    loss = 0
    accuracy = 0
    predictions = []
    all_labels = []
    for _, (inputs_0, inputs_1, labels) in enumerate(dataloader_test, 0):
        inputs_0, inputs_1, labels = inputs_0.to(device), inputs_1.to(device), labels.to(device)
        outputs = net(inputs_0, inputs_1)
        predictions.append(outputs)
        all_labels.append(labels)
        loss += criterion(outputs, labels).item()

    predictions = cat(predictions)
    all_labels = cat(all_labels)
    preds = predictions<0
    accuracy = (preds == all_labels).sum().item() / len(predictions)
print(f"Finished Training, final test accuracy is {accuracy} and loss is {loss}, final training accuracy is {train_accuracy}")


  2%|▏         | 1/50 [00:37<30:30, 37.35s/it]

Epoch: 0/50 | Loss : 147.65708720684052


  4%|▍         | 2/50 [01:12<28:37, 35.78s/it]

Epoch: 1/50 | Loss : 147.6504823565483


  6%|▌         | 3/50 [01:45<27:11, 34.71s/it]

Epoch: 2/50 | Loss : 147.64918982982635


  8%|▊         | 4/50 [02:18<26:13, 34.21s/it]

Epoch: 3/50 | Loss : 147.6452980041504


 10%|█         | 5/50 [02:55<26:12, 34.94s/it]

Epoch: 4/50 | Loss : 147.646508872509


 12%|█▏        | 6/50 [03:27<25:05, 34.21s/it]

Epoch: 5/50 | Loss : 147.6442995071411


 14%|█▍        | 7/50 [04:01<24:25, 34.08s/it]

Epoch: 6/50 | Loss : 147.643668115139


 16%|█▌        | 8/50 [04:35<23:41, 33.84s/it]

Epoch: 7/50 | Loss : 147.64353972673416


 18%|█▊        | 9/50 [05:06<22:32, 33.00s/it]

Epoch: 8/50 | Loss : 147.64343398809433


 20%|██        | 10/50 [05:39<22:00, 33.01s/it]

Epoch: 9/50 | Loss : 147.64296925067902


 22%|██▏       | 11/50 [06:14<21:48, 33.54s/it]

Epoch: 10/50 | Loss : 147.64259040355682


 24%|██▍       | 12/50 [06:51<22:01, 34.79s/it]

Epoch: 11/50 | Loss : 147.64292407035828


 26%|██▌       | 13/50 [07:24<21:10, 34.34s/it]

Epoch: 12/50 | Loss : 147.64332020282745


 28%|██▊       | 14/50 [08:02<21:16, 35.45s/it]

Epoch: 13/50 | Loss : 147.64412266016006


 30%|███       | 15/50 [08:38<20:38, 35.39s/it]

Epoch: 14/50 | Loss : 147.64390981197357


 32%|███▏      | 16/50 [09:15<20:21, 35.92s/it]

Epoch: 15/50 | Loss : 147.6434475183487


 34%|███▍      | 17/50 [09:53<20:04, 36.49s/it]

Epoch: 16/50 | Loss : 147.64313513040543


 36%|███▌      | 18/50 [10:22<18:13, 34.18s/it]

Epoch: 17/50 | Loss : 147.64440041780472


 38%|███▊      | 19/50 [10:58<18:05, 35.02s/it]

Epoch: 18/50 | Loss : 147.64312040805817


 40%|████      | 20/50 [11:28<16:41, 33.38s/it]

Epoch: 19/50 | Loss : 147.6428923010826


 42%|████▏     | 21/50 [12:05<16:39, 34.46s/it]

Epoch: 20/50 | Loss : 147.64375030994415


 44%|████▍     | 22/50 [12:39<15:58, 34.25s/it]

Epoch: 21/50 | Loss : 147.64381474256516


 46%|████▌     | 23/50 [13:11<15:07, 33.60s/it]

Epoch: 22/50 | Loss : 147.6432700753212


 48%|████▊     | 24/50 [13:47<14:52, 34.34s/it]

Epoch: 23/50 | Loss : 147.6428900361061


 50%|█████     | 25/50 [14:20<14:07, 33.91s/it]

Epoch: 24/50 | Loss : 147.6437286734581


 52%|█████▏    | 26/50 [14:55<13:45, 34.40s/it]

Epoch: 25/50 | Loss : 147.64362370967865


 54%|█████▍    | 27/50 [15:32<13:24, 34.96s/it]

Epoch: 26/50 | Loss : 147.64353775978088


 56%|█████▌    | 28/50 [16:07<12:52, 35.13s/it]

Epoch: 27/50 | Loss : 147.64364844560623


 58%|█████▊    | 29/50 [16:41<12:11, 34.81s/it]

Epoch: 28/50 | Loss : 147.64370465278625


 60%|██████    | 30/50 [17:16<11:37, 34.88s/it]

Epoch: 29/50 | Loss : 147.6423093676567


 62%|██████▏   | 31/50 [17:56<11:32, 36.45s/it]

Epoch: 30/50 | Loss : 147.6440269947052


 64%|██████▍   | 32/50 [18:32<10:48, 36.05s/it]

Epoch: 31/50 | Loss : 147.64213049411774


 66%|██████▌   | 33/50 [19:04<09:52, 34.85s/it]

Epoch: 32/50 | Loss : 147.6438748240471


 68%|██████▊   | 34/50 [19:32<08:46, 32.88s/it]

Epoch: 33/50 | Loss : 147.64345955848694


 70%|███████   | 35/50 [20:00<07:53, 31.59s/it]

Epoch: 34/50 | Loss : 147.6427036523819


 72%|███████▏  | 36/50 [20:30<07:12, 30.91s/it]

Epoch: 35/50 | Loss : 147.64375537633896


 74%|███████▍  | 37/50 [21:00<06:40, 30.83s/it]

Epoch: 36/50 | Loss : 147.64360857009888


 76%|███████▌  | 38/50 [21:28<05:59, 29.99s/it]

Epoch: 37/50 | Loss : 147.6430356502533


 78%|███████▊  | 39/50 [21:59<05:30, 30.05s/it]

Epoch: 38/50 | Loss : 147.64384818077087


 80%|████████  | 40/50 [22:30<05:05, 30.51s/it]

Epoch: 39/50 | Loss : 147.64365124702454


 82%|████████▏ | 41/50 [23:04<04:43, 31.48s/it]

Epoch: 40/50 | Loss : 147.64308607578278


 84%|████████▍ | 42/50 [23:36<04:12, 31.52s/it]

Epoch: 41/50 | Loss : 147.64316189289093


 86%|████████▌ | 43/50 [24:10<03:47, 32.51s/it]

Epoch: 42/50 | Loss : 147.64308667182922


 88%|████████▊ | 44/50 [24:47<03:22, 33.83s/it]

Epoch: 43/50 | Loss : 147.64346301555634


 90%|█████████ | 45/50 [25:18<02:44, 32.87s/it]

Epoch: 44/50 | Loss : 147.6427525281906


 92%|█████████▏| 46/50 [25:48<02:08, 32.09s/it]

Epoch: 45/50 | Loss : 147.64364182949066


 94%|█████████▍| 47/50 [26:18<01:34, 31.35s/it]

Epoch: 46/50 | Loss : 147.64279693365097


 96%|█████████▌| 48/50 [26:52<01:04, 32.20s/it]

Epoch: 47/50 | Loss : 147.64331084489822


 98%|█████████▊| 49/50 [27:23<00:31, 31.70s/it]

Epoch: 48/50 | Loss : 147.64277404546738


100%|██████████| 50/50 [27:52<00:00, 33.46s/it]

Epoch: 49/50 | Loss : 147.6435136795044





Finished Training, final test accuracy is 0.5009744890238143 and loss is 37.4307045340538, final training accuracy is 0.49955230673809886
