In [1]:
from pathlib import Path

import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

import searchnets
from searchnets.utils.dataset import VisSearchDataset

In [2]:
'''LeNet in PyTorch.'''

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*53*53, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 2)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

In [3]:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

normalize = transforms.Normalize(mean=MEAN,
                                 std=STD)

batch_size = 64
num_workers = 32

def get_train_and_val(csv_file):
    trainset = VisSearchDataset(csv_file=csv_file,
                            split='train',
                            transform=transforms.Compose(
                                [transforms.ToTensor(), normalize]
                            ))

    train_loader = DataLoader(trainset, batch_size=batch_size,
                              shuffle=True, num_workers=num_workers,
                              pin_memory=True)

    valset = VisSearchDataset(csv_file=csv_file,
                              split='val',
                              transform=transforms.Compose([transforms.ToTensor(), normalize]))
    val_loader = DataLoader(valset, batch_size=batch_size,
                            shuffle=False, num_workers=num_workers)
    
    return trainset, train_loader, valset, val_loader

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [5]:
def train_one_epoch(model, trainset, train_loader, criterion, optimizer):
    """train model for one epoch"""
    model.train()

    total_loss = 0.0

    batch_total = int(np.ceil(len(trainset) / batch_size))
    batch_pbar = tqdm(train_loader)
    for i, (batch_x, batch_y) in enumerate(batch_pbar):
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        output = model(batch_x)
        loss = criterion(output, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_pbar.set_description(f'batch {i} of {batch_total}, loss: {loss: 7.3f}')
        total_loss += loss

    avg_loss = total_loss / batch_total
    print(f'\tTraining Avg. Loss: {avg_loss:7.3f}')
    return avg_loss

In [6]:
def validate(model, valset, val_loader):
    model.eval()

    val_acc_this_epoch = []
    with torch.no_grad():
        total = int(np.ceil(len(valset) / batch_size))
        pbar = tqdm(val_loader)
        for i, (batch_x, batch_y) in enumerate(pbar):
            pbar.set_description(f'batch {i} of {total}')
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            output = model(batch_x)
            # below, _ because torch.max returns (values, indices)
            _, predicted = torch.max(output.data, 1)
            acc = (predicted == batch_y).sum().item() / batch_y.size(0)
            val_acc_this_epoch.append(acc)

    val_acc_this_epoch = np.asarray(val_acc_this_epoch).mean()
    print(' Validation Acc: %7.3f' % val_acc_this_epoch)

    return val_acc_this_epoch

In [7]:
epochs = 200
val_epoch = 1
patience = 10

In [8]:
def train(csv_file):
    trainset, train_loader, valset, val_loader = get_train_and_val(csv_file)

    model = LeNet()
    model.to(device);

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.001,
                                momentum=0.9)
    
    loss = []
    best_val_acc = 0
    epochs_without_improvement = 0

    for epoch in range(1, epochs + 1):

        print(f'\nEpoch {epoch}')
        avg_loss = train_one_epoch(model, trainset, train_loader, criterion, optimizer)
        loss.append(avg_loss)

        if epoch % val_epoch == 0:
            val_acc_this_epoch = validate(model, valset, val_loader)

            if patience is not None:
                if val_acc_this_epoch > best_val_acc:
                    best_val_acc = val_acc_this_epoch
                    epochs_without_improvement = 0
                    print(f'Validation accuracy improved')
                else:
                    epochs_without_improvement += 1
                    if epochs_without_improvement > patience:
                        print(
                            f'greater than {patience} epochs without improvement in validation '
                            'accuracy, stopping training')

                        break

    return model, loss

In [9]:
def get_test(csv_file):
        testset = VisSearchDataset(csv_file=csv_file,
                                    split='test',
                                    transform=transforms.Compose(
                                        [transforms.ToTensor(), normalize]
                                    ))
        test_loader = DataLoader(testset, batch_size=batch_size,
                                 shuffle=False, num_workers=num_workers,
                                 pin_memory=True)
        return testset, test_loader

In [10]:
def test(model, csv_file):
    testset, test_loader = get_test(csv_file)
    model.eval()

    total = int(np.ceil(len(testset) / batch_size))
    pbar = tqdm(test_loader)
    acc = []
    pred = []
    with torch.no_grad():
        for i, (x_batch, y_batch) in enumerate(pbar):
            pbar.set_description(f'batch {i} of {total}')
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            output = model(x_batch)
            # below, _ because torch.max returns (values, indices)
            _, pred_batch = torch.max(output.data, 1)
            acc_batch = (pred_batch == y_batch).sum().item() / y_batch.size(0)
            acc.append(acc_batch)

    acc = np.asarray(acc).mean()

    return acc

In [11]:
csv_files = [
    # '~/Documents/repos/L2M/visual-search-nets/data/visual_search_stimuli/alexnet_RVvGV/alexnet_RVvGV_finetune_split.csv',
    '~/Documents/repos/L2M/visual-search-nets/data/visual_search_stimuli/alexnet_RVvRHGV/alexnet_RVvRHGV_finetune_split.csv',
    '~//Documents/repos/L2M/visual-search-nets/data/visual_search_stimuli/alexnet_2_v_5/alexnet_2_v_5_finetune_split.csv',
]

In [12]:
for csv_file in csv_files:
    csv_file = Path(csv_file)
    csv_file = csv_file.expanduser()
    model, loss_history = train(csv_file)
    acc = test(model, csv_file)
    print(f'\taccuracy on test set: {acc}')

  0%|          | 0/100 [00:00<?, ?it/s]


Epoch 1


batch 99 of 100, loss:   0.684: 100%|██████████| 100/100 [00:11<00:00, 24.33it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.693


batch 3 of 4: 100%|██████████| 4/4 [00:02<00:00,  1.49it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.609
Validation accuracy improved

Epoch 2


batch 99 of 100, loss:   0.692: 100%|██████████| 100/100 [00:10<00:00,  9.95it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.693


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.11s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.391

Epoch 3


batch 99 of 100, loss:   0.692: 100%|██████████| 100/100 [00:10<00:00,  9.44it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.691


batch 3 of 4: 100%|██████████| 4/4 [00:03<00:00,  1.31it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.391

Epoch 4


batch 99 of 100, loss:   0.678: 100%|██████████| 100/100 [00:09<00:00, 24.88it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.687


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.04s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.754
Validation accuracy improved

Epoch 5


batch 99 of 100, loss:   0.620: 100%|██████████| 100/100 [00:09<00:00, 25.08it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.658


batch 3 of 4: 100%|██████████| 4/4 [00:03<00:00,  1.11it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.652

Epoch 6


batch 99 of 100, loss:   0.080: 100%|██████████| 100/100 [00:09<00:00, 10.39it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.325


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.05s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.965
Validation accuracy improved

Epoch 7


batch 99 of 100, loss:   0.019: 100%|██████████| 100/100 [00:09<00:00, 25.00it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.030


batch 3 of 4: 100%|██████████| 4/4 [00:03<00:00,  1.28it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.996
Validation accuracy improved

Epoch 8


batch 99 of 100, loss:   0.003: 100%|██████████| 100/100 [00:09<00:00, 10.96it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.006


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.03s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000
Validation accuracy improved

Epoch 9


batch 99 of 100, loss:   0.002: 100%|██████████| 100/100 [00:10<00:00,  9.91it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.002


batch 3 of 4: 100%|██████████| 4/4 [00:03<00:00,  1.29it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 10


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:10<00:00,  9.98it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.001


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.07s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 11


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:10<00:00,  9.81it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.001


batch 3 of 4: 100%|██████████| 4/4 [00:02<00:00,  1.57it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 12


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:09<00:00, 10.00it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.001


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.08s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 13


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:09<00:00, 24.73it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.001


batch 3 of 4: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 14


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:10<00:00,  9.69it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.001


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.05s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 15


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:09<00:00, 10.31it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.000


batch 3 of 4: 100%|██████████| 4/4 [00:02<00:00,  1.37it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 16


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:09<00:00, 10.47it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.000


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.07s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 17


batch 99 of 100, loss:   0.000: 100%|██████████| 100/100 [00:09<00:00, 10.15it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.000


batch 3 of 4: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 18


batch 99 of 100, loss:   0.000: 100%|██████████| 100/100 [00:09<00:00, 23.18it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.000


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.04s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   1.000

Epoch 19


batch 99 of 100, loss:   0.000: 100%|██████████| 100/100 [00:09<00:00, 10.12it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.000


batch 3 of 4: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s]
  0%|          | 0/13 [00:00<?, ?it/s]

 Validation Acc:   1.000
greater than 10 epochs without improvement in validation accuracy, stopping training


batch 12 of 13: 100%|██████████| 13/13 [00:04<00:00,  2.79it/s]


	accuracy on test set: 1.0


  0%|          | 0/100 [00:00<?, ?it/s]


Epoch 1


batch 99 of 100, loss:   0.689: 100%|██████████| 100/100 [00:09<00:00, 24.65it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.693


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.02s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.609
Validation accuracy improved

Epoch 2


batch 99 of 100, loss:   0.695: 100%|██████████| 100/100 [00:08<00:00, 11.42it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.692


batch 3 of 4: 100%|██████████| 4/4 [00:03<00:00,  2.65s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.621
Validation accuracy improved

Epoch 3


batch 99 of 100, loss:   0.694: 100%|██████████| 100/100 [00:09<00:00, 10.27it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.690


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.07s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.422

Epoch 4


batch 99 of 100, loss:   0.675: 100%|██████████| 100/100 [00:08<00:00, 11.50it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.687


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.09s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.508

Epoch 5


batch 99 of 100, loss:   0.670: 100%|██████████| 100/100 [00:09<00:00, 24.72it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.679


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.09s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.562

Epoch 6


batch 99 of 100, loss:   0.599: 100%|██████████| 100/100 [00:08<00:00, 11.53it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.656


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.10s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.738
Validation accuracy improved

Epoch 7


batch 99 of 100, loss:   0.320: 100%|██████████| 100/100 [00:09<00:00, 10.43it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.513


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.08s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.934
Validation accuracy improved

Epoch 8


batch 99 of 100, loss:   0.095: 100%|██████████| 100/100 [00:08<00:00, 11.43it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.139


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.05s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.988
Validation accuracy improved

Epoch 9


batch 99 of 100, loss:   0.010: 100%|██████████| 100/100 [00:09<00:00, 24.55it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.043


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.01s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.992
Validation accuracy improved

Epoch 10


batch 99 of 100, loss:   0.004: 100%|██████████| 100/100 [00:08<00:00, 11.20it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.012


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.10s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.996
Validation accuracy improved

Epoch 11


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:10<00:00,  9.62it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.008


batch 3 of 4: 100%|██████████| 4/4 [00:03<00:00,  1.10it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.992

Epoch 12


batch 99 of 100, loss:   0.004: 100%|██████████| 100/100 [00:08<00:00, 24.52it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.014


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.08s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.992

Epoch 13


batch 99 of 100, loss:   0.037: 100%|██████████| 100/100 [00:09<00:00, 24.14it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.041


batch 3 of 4: 100%|██████████| 4/4 [00:02<00:00,  1.35it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.980

Epoch 14


batch 99 of 100, loss:   0.008: 100%|██████████| 100/100 [00:10<00:00,  9.65it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.070


batch 3 of 4: 100%|██████████| 4/4 [00:03<00:00,  2.62s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.996

Epoch 15


batch 99 of 100, loss:   0.000: 100%|██████████| 100/100 [00:10<00:00,  9.80it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.004


batch 3 of 4: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.996

Epoch 16


batch 99 of 100, loss:   0.000: 100%|██████████| 100/100 [00:10<00:00, 10.00it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.001


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.12s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.996

Epoch 17


batch 99 of 100, loss:   0.002: 100%|██████████| 100/100 [00:10<00:00,  9.95it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.001


batch 3 of 4: 100%|██████████| 4/4 [00:03<00:00,  1.18it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.996

Epoch 18


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:10<00:00,  9.77it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.001


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.09s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.996

Epoch 19


batch 99 of 100, loss:   0.001: 100%|██████████| 100/100 [00:09<00:00, 10.79it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.000


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.03s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.996

Epoch 20


batch 99 of 100, loss:   0.000: 100%|██████████| 100/100 [00:09<00:00, 10.11it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.000


batch 3 of 4: 100%|██████████| 4/4 [00:04<00:00,  1.09s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

 Validation Acc:   0.996

Epoch 21


batch 99 of 100, loss:   0.000: 100%|██████████| 100/100 [00:08<00:00, 11.43it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

	Training Avg. Loss:   0.000


batch 3 of 4: 100%|██████████| 4/4 [00:03<00:00,  2.70s/it]
  0%|          | 0/13 [00:00<?, ?it/s]

 Validation Acc:   0.996
greater than 10 epochs without improvement in validation accuracy, stopping training


batch 12 of 13: 100%|██████████| 13/13 [00:04<00:00,  1.39s/it]

	accuracy on test set: 0.9975961538461539



