In [1]:
from dataset import GratingDataset
from transforms import GaussianNoise
from alexnet_rnn import AlexNetRNN

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import models, transforms
from utils import *
import wandb

import math



In [2]:
data_transforms = transforms.Compose([
        transforms.Resize(227), # changed from 128
        transforms.ToTensor(),
        GaussianNoise(0, 0.01), # STANDARD DEVIATION OF GAUSSIAN NOISE
    ])

root_dir = './SG_train_double_sf/'
test_root_dir = './SG_test_double_sf/'

num_seqs = 1000
batch_size = 100
num_epochs = 10

num_workers = 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dir_list = [
            [root_dir+'sep_10.0',0,0.05,10.0],
            [root_dir+'sep_5.0',0,0.05,5.0],
            [root_dir+'sep_2.0',0,0.05,2.0],
            [root_dir+'sep_1.0',0,0.05,1.0],
            [root_dir+'sep_0.5',0,0.05,0.5]
            ]

test_dir_list = [
                [test_root_dir+'sep_10.0',0,0.1,10.0],
                [test_root_dir+'sep_5.0',0,0.1,5.0],
                [test_root_dir+'sep_2.0',0,0.1,2.0],
                [test_root_dir+'sep_1.0',0,0.1,1.0],
                [test_root_dir+'sep_0.5',0,0.1,0.5]
                ]

In [3]:
i = 0
train_dir = dir_list[i]
train_root_dir = train_dir[0]
train_ref_ori = train_dir[1]
train_sf = train_dir[2]
train_sep = train_dir[3]

train_ref_dir = './SG_refs/' + 'REFERENCE_ref_'+str(train_ref_ori)+'_sep_0.0_contr_1_ph_0.0_sf_'+str(train_sf)+'_NONE.png'

train_grating_dataset = GratingDataset(train_root_dir, train_ref_dir, transform=data_transforms, num_seqs=num_seqs)
train_dataloader = DataLoader(train_grating_dataset, batch_size=batch_size, shuffle=True, num_workers= num_workers)


test_dir = test_dir_list[i]

test_root_dir = test_dir[0]
test_ref_ori = test_dir[1]
test_sf = test_dir[2]
test_sep = test_dir[3]
test_ref_dir = './SG_refs/'+ 'REFERENCE_ref_' + str(test_ref_ori)+'_sep_0.0_contr_1_ph_0.0_sf_'+str(test_sf)+'_NONE.png'
test_grating_dataset = GratingDataset(test_root_dir, test_ref_dir, transform=data_transforms, num_seqs=num_seqs)
test_dataloader = DataLoader(test_grating_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)


alexnet = torchvision.models.alexnet(pretrained=True)
model= AlexNetRNN()
copy_weights(model, alexnet)
model.to(device)

loss_fn = nn.BCELoss()

optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.8)




In [4]:
for step, (images, labels) in enumerate(train_dataloader):
    images, labels = images.to(device), labels.to(device)
    break



In [5]:
out = model(images)

In [6]:
out.shape

torch.Size([100, 15])

In [14]:
last = out[:,-5:]

In [8]:
labels.shape

torch.Size([100, 5])

In [15]:
loss_fn(last, labels)

tensor(0.6937, grad_fn=<BinaryCrossEntropyBackward0>)

In [16]:
labels

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1