In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cPickle

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

In [None]:
%matplotlib inline
from itertools import product
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
from torchvision.datasets import ImageFolder, DatasetFolder
import torch.utils.data

## Generate Stimuli For Imagenet Experiment

In [None]:
## Generate Stimuli

#savedir = '/Users/akshay/proj/psych209/attention/imagenet_stimuli'
imagenet_path = '/scratch/groups/jlg/imagenet'
savedir = '/scratch/users/akshayj/att_net_stimuli'
if not os.path.exists(savedir):
    os.makedirs(savedir)

labels = ['elephant', 'leopard', 'truck', 'plane']

def generate_stimuli(nObjectsPerImg=4, savedir=savedir, nStims=100, imsz=256):
    all_locs = np.array(list(product(np.arange(imsz/128)*128, np.arange(imsz/128)*128)))

    meta5 = dict()
    for j in range(nStims):
        
        if j % 500 == 0:
            print('Stimulus {}'.format(j))
        
        background = np.zeros((imsz,imsz,3),dtype=np.uint8) + 128 # grayscale background
        loc_ind = np.random.choice(len(all_locs), size=(nObjectsPerImg,), replace=False)

        meta5[j+1] = dict()

        # Randomly select N objects per image, and extract them from the pickle.
        img_ind = []
        lbls = np.random.permutation(labels)
        for i in range(nObjectsPerImg):
            
            cat1 = os.listdir(imagenet_path +'/'+lbls[i])
            img_ind.append(lbls[i]+'/'+np.random.choice(cat1))
            img = np.array(Image.open(imagenet_path+'/'+img_ind[-1]).resize((128,128)))
            
            if len(img.shape) < 3:
                img = np.stack((img,)*3, axis=-1)

            loc = all_locs[loc_ind[i],:]
            background[loc[0]:loc[0]+128, loc[1]:loc[1]+128, :] = img

            # Save the location of each object as a key-value pair in the meta dictionary.
            meta5[j+1][lbls[i]] = i

        # Save image as png.
        savename = 'stim{0}_{1:03d}'.format(nObjectsPerImg, j+1)
        savepath = savedir + '/' + savename + '.png'
        
        for i in range(nObjectsPerImg):
            catg = lbls[i]
            loca = np.where(lbls==labels[i])[0][0]
            loca = loc_ind[i]
            new_savedir = '{}/{}_{}'.format(savedir, catg, loca)
            if not os.path.exists(new_savedir):
                os.makedirs(new_savedir)

            # Copy the file into the appropriate folder.
            plt.imsave('{}/{}.png'.format(new_savedir, savename), background)
        
    np.save(savedir + '/meta{0}'.format(nObjectsPerImg), meta5)
    return meta5

#meta5 = generate_stimuli(nStims=25000)

## Real Code Starts Here

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

with open('imagenet1000.txt') as f:
    num2label = eval(f.read())
    
imsize = 240
loader = transforms.Compose([transforms.Scale(imsize), transforms.CenterCrop(imsize), transforms.ToTensor(), normalize])
def image_loader(image_name):
    """load image, returns cuda tensor"""
    image = Image.open(image_name).convert('RGB')

    image = loader(image).float()
    image = Variable(image, requires_grad=False)
    image = image.unsqueeze(0)  #this is for VGG, may not be needed for ResNet
    return image#.to_device(device)  #assumes that you're using GPU


In [None]:
def print_predictions(output, num2label=num2label, numlines=3):
    for idx in output[0].sort(descending=True)[1][:numlines].numpy():
        print num2label[idx], output[0][idx].detach().item() 

## Training

In [None]:
from run_att_net import *

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

stimdir = '/home/users/akshayj/att_net/imagenet_stimuli/'
train_dataset = ImageFolder(stimdir, loader=image_loader)
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50, shuffle=True)

vgg19 = models.vgg19(pretrained=True).to(device)
conv1 = nn.Sequential(*list(vgg19.children())[0][:1])
convRest = nn.Sequential(*list(vgg19.children())[0][1:])

model = AttMLP(128, 576).to(device)
m = nn.Upsample(scale_factor=10, mode='nearest')
    
for step, (x,y) in enumerate(train_data_loader):
    b_x = Variable(torch.squeeze(x)).to(device)
    
    # Extract hidden unit initialization and turn into one-hot
    hid = torch.div(y,4).unsqueeze(1).to(device);
    hid_1hot = torch.FloatTensor(hid.shape[0], 4).zero_().to(device)    
    hid_1hot.scatter_(1, hid, 1)
    hid_1hot = hid_1hot.unsqueeze(0)

    # conv1 output
    conv1_out = conv1(b_x)
    
    # Output of MLP
    output = model(b_x, hid_1hot)
    out_reshape = output.view(1,-1,24,24)
    
    a = torch.squeeze(m(out_reshape))
    gain_map = torch.stack((a,)*64, dim=0).transpose(0,1)
    
    product = torch.mul(conv1_out, gain_map)
    
    out = convRest(product)
    print product.shape, out.shape
    break
