In [None]:
import torch
import torch.nn as nn
import torchvision
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
#import torch.utils.data
from utils import *
import numpy as np
import math
from numpy.random import default_rng
import pandas as pd
from scipy.stats import ks_2samp
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import os
import sys
from collections import namedtuple
from torch.utils.data import Dataset, DataLoader
from spacetorch.datasets import DatasetRegistry
import spacetorch.analyses.core as core
from spacetorch.utils import (
    figure_utils,
    plot_utils,
    spatial_utils,
    array_utils,
    seed_str,
)
from spacetorch.analyses.floc import get_floc_tissue
from spacetorch.datasets import floc
from spacetorch.maps import nsd_floc
from spacetorch.maps.it_map import ITMap
from spacetorch.paths import PROJ_DIR, RESULTS_DIR
#from spacetorch.constants import RNG_SEED
from spacetorch.feature_extractor import FeatureExtractor
from einops import reduce, rearrange
from spacetorch.datasets import floc
from spacetorch.maps import it_map

from utils import get_model
import argparse

from load_model import *

# Best stimulus plot

In [None]:

def compute(layers, dataloader, model):

    vtc_tissues = []

    for i in range(0, len(layers), 2):
        # 1.Extract the features and labels
        features, _, labels = FeatureExtractor(dataloader, 32).extract_features(model, [layers[i], layers[i+1]], return_inputs_and_labels = True)
        features_1, features_2 = features

        # 2. Average across spatial dimension
        avg_features_1 = reduce(features_1, 'b c h w -> b c', 'mean')
        avg_features_2 = reduce(features_2, 'b c h w -> b c', 'mean')

        avg_features = np.concatenate((avg_features_1, avg_features_2), axis=1)

        # 3. Unit position
        kw, kh = get_closest_factors(avg_features_1.shape[1])
        coord = get_coord(kw, kh*2)

        break


    return avg_features, kw





def plot_IT(vtc_tissues, layers_names, sel_range):

    figure_utils.set_text_sizes()
    contrasts = floc.DOMAIN_CONTRASTS
    contrast_dict = {c.name: c for c in contrasts}
    contrast_order = ["Faces", "Bodies", "Characters", "Places", "Objects"]
    ordered_contrasts = [contrast_dict[curr] for curr in contrast_order]
    contrast_colors = {c.name: c.color for c in contrasts}
    marker_fill = "#ccc"
    bar_fill = "#ccc"
    rng = default_rng(seed=0)

    for vtc_tissue in vtc_tissues:
        fig, ic_row = plt.subplots(figsize=(16, 5.6), ncols=5)
        plot_utils.remove_spines([ax for ax in ic_row], to_remove="all")

        for contrast_name, ax in zip(contrast_order, ic_row):

            contrast = contrast_dict[contrast_name]
            handle = vtc_tissue.make_single_contrast_map(ax, contrast, final_psm=0.5e-2, rasterized=True, vmin=-sel_range, vmax = sel_range, linewidths=0) #20)

            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_ylim([0, 50])# size of the square
            ax.set_xlim([0, 25])
            plot_utils.add_scale_bar(ax, 10)

        # add a colorbar to the last axis
        cax = fig.add_axes([0.92, 0.35, 0.01, 0.35])
        cax.set_title(layers_names[vtc_tissues.index(vtc_tissue)])
        cb = plt.colorbar(handle, cax=cax, ticks=[-sel_range, 0, sel_range])#20
        fig.savefig("V4.pdf", format='pdf')
        plt.close()


In [None]:
def main():

    model = load_model(pool_type='gaussian', kap_kernelsize=0.23, continuous=True, local_conv=False, expname='gaussian_0.23_continuous_prog_t', epoch=110, sel_range=10)

    layers, layers_names = load_layers_names_forcontinuous(model)

    print("Model loading completed！==============================")

    dataloader = DataLoader(DatasetRegistry.get("ImageNet"), batch_size=159, shuffle=True, num_workers=1, pin_memory=True)

    features= FeatureExtractor(dataloader, 32).extract_features(model, [model.layer4[0].conv1, model.layer4[0].conv2, model.layer4[1].conv1, model.layer4[1].conv2], return_inputs_and_labels=False)
    #features, _, labels = FeatureExtractor(dataloader, 32).extract_features(model, [layers[14], layers[15]], return_inputs_and_labels = True)
    features_1, features_2, features_3, features_4 = features

    # 2. Average across spatial dimension
    avg_features_1 = reduce(features_1, 'b c h w -> b c', 'mean')
    avg_features_2 = reduce(features_2, 'b c h w -> b c', 'mean')
    avg_features_3 = reduce(features_3, 'b c h w -> b c', 'mean')
    avg_features_4 = reduce(features_4, 'b c h w -> b c', 'mean')


    avg_features = np.concatenate((avg_features_1, avg_features_2, avg_features_3, avg_features_4), axis=1)

    kw, kh = get_closest_factors(avg_features.shape[1])
    coord = get_coord(int(kw/2), int(kh*2))



    images, labels_list = [], []
    for item in iter(DatasetRegistry.get("ImageNet")):
        images.append(item[0])
        labels_list.append(item[1])

    max_idx=np.argmax(avg_features, axis = 0)

    max_images = []

    for i in max_idx:
        max_images.append(images[i])

    img=torchvision.utils.make_grid(max_images,nrow=int(kh/2), ncol=int(kw*2),padding=2)

    plt.figure(figsize=(10,32))
    plt.imshow(img.permute(1,2,0))
    plt.axis('off')
    plt.show()
