In [54]:
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

import searchnets
from searchnets.utils.dataset import VisSearchDataset

In [2]:
model = searchnets.nets.alexnet.build(pretrained=True)

In [30]:
activations = {}

def hook_fn(module, inp, out):
    activations[module] = out

def register_forward_hooks(net):
    modules = [module for module in net.modules() if type(module) != nn.Sequential and type(module) != type(net)]
    for module in modules:
        module.register_forward_hook(hook_fn)

In [32]:
register_forward_hooks(model)

In [35]:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

In [37]:
normalize = transforms.Normalize(mean=MEAN,
                                 std=STD)

In [38]:
batch_size = 64
num_workers = 32

In [25]:
def get_vis_search_activations(csv_file, model, layer, batch_size, num_workers):
    trainset = VisSearchDataset(csv_file=csv_file,
                            split='train',
                            transform=transforms.Compose(
                                [transforms.ToTensor(), normalize]
                            ))

    train_loader = DataLoader(trainset, batch_size=batch_size,
                              shuffle=True, num_workers=num_workers,
                              pin_memory=True)

    model.eval()
    for batch_x, batch_y in train_loader:

searchnets.nets.alexnet.AlexNet

In [49]:
csv_file = Path('~/Documents/repos/L2M/visual-search-nets/data/visual_search_stimuli/alexnet_RVvGV/alexnet_RVvGV_finetune_split.csv')
csv_file = csv_file.expanduser()

In [52]:
trainset = VisSearchDataset(csv_file=csv_file,
                        split='test',
                        transform=transforms.Compose(
                            [transforms.ToTensor(), normalize]
                        ))

train_loader = DataLoader(trainset, batch_size=batch_size,
                          shuffle=True, num_workers=num_workers,
                          pin_memory=True)

In [55]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [57]:
model.to(device)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [59]:
batch_x, batch_y = next(iter(train_loader))
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
output = model(batch_x)

In [69]:
for k, v in activations.items():
    print("mean: ", v.detach().cpu().numpy().mean(), "std: ", v.detach().cpu().numpy().std())

mean:  0.29947537 std:  1.3874729
mean:  0.29947537 std:  1.3874729
mean:  0.5531618 std:  2.3343318
mean:  0.8639849 std:  3.3647711
mean:  0.8639849 std:  3.3647711
mean:  2.4694996 std:  6.4624257
mean:  0.666239 std:  3.2622116
mean:  0.666239 std:  3.2622116
mean:  0.5553713 std:  2.2769845
mean:  0.5553713 std:  2.2769845
mean:  0.153945 std:  1.0107583
mean:  0.153945 std:  1.0107583
mean:  0.4688521 std:  1.8375064
mean:  0.4688521 std:  1.8375064
mean:  0.46984673 std:  2.6531706
mean:  1.2925751 std:  2.5875597
mean:  1.2925751 std:  2.5875597
mean:  1.2871842 std:  3.859484
mean:  0.32277316 std:  1.4871278
mean:  0.32277316 std:  1.4871278
mean:  -0.00092787744 std:  4.0739326
