# Set up the environment

In [1]:
import torch
import codecs, json
import numpy as np
import torchvision.models as tmodels
from torchvision.models import *
import os

# Pick the models of interest, load the weights and create a json file for each model

In [13]:
model_names = np.array(['alexnet','resnet50','resnet18'])
for arch in model_names:
    model_file = '%s_places365.pth.tar' % arch
    if not os.access(model_file, os.W_OK):
        weight_url = 'http://places2.csail.mit.edu/models_places365/' + model_file
        os.system('wget ' + weight_url)

    model = tmodels.__dict__[arch](num_classes=365) #get the model's name + the name of all functions
    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage) #load the weights
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()} #remove the word 'modules'
    model.load_state_dict(state_dict) #load the pretrained model?
    
    #create a dictionary of all layers
    model_layers = []
    for layer_name, m in model.named_modules():  
        model_dict = {'model': arch,'layer': layer_name}
        model_layers.append(model_dict)
    
    #save as .json 
    file_path = '/mnt/raid/ni/agnessa/RSA/'
    path_name = os.path.join(file_path + '/layer_names/' + arch + '.json')   
    json.dump(model_layers, codecs.open(path_name, 'w', encoding='utf-8'), separators=(',', ':'), sort_keys=False, indent=4) ### this saves the array in .json format 
    
    #for resnets, select only a few layers
    if 'res' in arch:
        selected_layers = []
        for layer_name,m in model.named_modules():
            dots = 0
            for char in layer_name:              
                if char == '.':
                    dots = dots+1
            if dots == 1:     
                model_dict_selected = {'model': arch,'layer': layer_name}
                selected_layers.append(model_dict_selected)
    path_name = os.path.join(file_path + '/layer_names/' + arch + '_selected_layers.json') 
    json.dump(selected_layers, codecs.open(path_name, 'w', encoding='utf-8'), separators=(',', ':'), sort_keys=False, indent=4) 