In [None]:
from IPython.core.display import HTML
display(HTML('<style>.p-Widget.jp-OutputPrompt.jp-OutputArea-prompt:'
             + 'empty {padding: 0; border: 0;}</style>'));

In [None]:
import os, sys, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm as tqdm
from copy import deepcopy
from glob import glob
from PIL import Image

import torch, torchvision
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
sys.path.append('../mouseland/model_opts')
from feature_extraction import *
from model_options import *
from image_ops import *

In [None]:
torch.cuda.set_device(3)

In [None]:
root = 'vessel_assets/'
assets = glob(root + '*.jpg')
dictlist = []
for asset in assets:
    imgstr = asset.split('/')[1]
    row = {'ImageName': imgstr}
    dictlist.append(row)
image_df = pd.DataFrame(dictlist)

In [None]:
class StimulusSet(Dataset):
    def __init__(self, csv, root_dir, image_transforms=None):
        
        self.root = os.path.expanduser(root_dir)
        self.transforms = image_transforms
        
        if isinstance(csv, pd.DataFrame):
            self.df = csv
        if isinstance(csv, str):
            self.df = pd.read_csv(csv)
        
        self.images = self.df.ImageName

    def __getitem__(self, index):
        filename = os.path.join(self.root, self.images.iloc[index])
        img = Image.open(filename).convert('RGB')
        
        if self.transforms:
            img = self.transforms(img)
            
        return img
    
    def __len__(self):
        return len(self.images)

In [None]:
class Array2DataLoader(Dataset):
    def __init__(self, img_array, image_transforms=None):
        self.transforms = image_transforms
        if isinstance(img_array, np.ndarray):
            self.images = img_array
        if isinstance(img_array, str):
            self.images = np.load(img_array)

    def __getitem__(self, index):
        img = Image.fromarray(self.images[index]).convert('RGB')
        if self.transforms:
            img = self.transforms(img)
        return img
    
    def __len__(self):
        return self.images.shape[0]

In [None]:
model_string = 'alexnet_imagenet'

model_options = get_model_options()
image_transforms = get_recommended_transforms(model_string)
model_name = model_options[model_string]['model_name']
train_type = model_options[model_string]['train_type']
model_call = model_options[model_string]['call']

model = eval(model_call)
model = model.eval()
model = model.cuda()

In [None]:
stimulus_loader = DataLoader(dataset=StimulusSet(image_df, root, image_transforms), batch_size=64)

In [None]:
stimulus_features = get_all_feature_maps(model, stimulus_loader, numpy=False)

In [None]:
eval_images = np.load('../datasets/samples/imagenet_eval_sample.npy')
train_images = np.load('../datasets/samples/imagenet_train_sample.npy')
imagenet_images = np.concatenate((train_images, eval_images), axis = 0)

In [None]:
imagenet_loader = DataLoader(dataset=Array2DataLoader(imagenet_images, image_transforms), batch_size=64)

In [None]:
imagenet_features = get_all_feature_maps(model, imagenet_loader, numpy=False)

In [None]:
imagenet_features['Conv2d-1'].shape, stimulus_features['Conv2d-1'].shape

In [None]:
sample_feature_map = stimulus_features['Conv2d-1'][0]
sample_feature_map.shape

In [None]:
def treves_rolls(x):
    if isinstance(x, np.ndarray):
        return ((np.sum(x / x.shape[0]))**2 / np.sum(x**2 / x.shape[0]))
    if isinstance(x, torch.Tensor):
        return ((torch.sum(x / x.shape[0]))**2 / torch.sum(x**2 / x.shape[0]))

In [None]:
activity_dictlist = []
for map_key_i, map_key in enumerate(tqdm(stimulus_features)):
    target_map = stimulus_features[map_key]
    for target_i, target_activity in enumerate(target_map):
        image_name = image_df.ImageName.iloc[target_i]
        
        activity_dictlist.append({
            'image': image_name, 
            'model': model_name,
            'train_type': train_type,
            'model_layer': map_key, 
            'model_layer_index': map_key.split('-')[1],
            'model_layer_depth': map_key_i,
            'max_activity': target_activity.abs().max().item(),
            'mean_activity': target_activity.abs().mean().item(),
            'sparseness': treves_rolls(target_activity).item()
        })
        
activity_df = pd.DataFrame(activity_dictlist)
stim_info = pd.DataFrame(activity_dictlist)

In [None]:
from scipy.stats import pearsonr
pearsonr(activity_df.mean_activity, activity_df.sparseness)

In [None]:
activity_dictlist = []
for map_key_i, map_key in enumerate(tqdm(imagenet_features)):
    target_map = imagenet_features[map_key]
    for target_i, target_activity in enumerate(target_map):
        if target_i < 1000:
            image_source = 'imagenet_train'
        if target_i > 1000:
            image_source = 'imagenet_val'
        
        activity_dictlist.append({
            'image': target_i, 
            'model': model_name,
            'train_type': train_type,
            'model_layer': map_key, 
            'model_layer_index': map_key.split('-')[1],
            'model_layer_depth': map_key_i,
            'max_activity': target_activity.abs().max().item(),
            'mean_activity': target_activity.abs().mean().item(),
            'sparseness': treves_rolls(target_activity).item(),
            'image_source': image_source,
        })
        
activity_df = pd.DataFrame(activity_dictlist)
imgnet_info = pd.DataFrame(activity_dictlist)

In [None]:
stim_df = deepcopy(stim_info)
stim_df['image_source'] = 'vessel'
imgnet_df = deepcopy(imgnet_info)
imgnet_df['image'] = 'imagenet_' + imgnet_df['image'].astype('str')

In [None]:
combo_df = pd.concat([stim_df,imgnet_df])

In [None]:
combo_df.to_csv('results/alexnet_special.csv', index = None)

In [None]:
sns.regplot(x='model_layer_depth', y='sparseness', data = imgnet_info);

In [None]:
sns.regplot(x='model_layer_depth', y='mean_activity', data = imgnet_info);

In [None]:
imgnet_info.groupby('model_layer_depth')['sparseness'].mean()