In [17]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torch.multiprocessing
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
from PIL import Image
from tqdm import tqdm
import os
import natsort 
import pickle
from pathlib import Path
from tqdm import tqdm

In [18]:
opt_zscore = True
sub_background = True
networks = ['alexnet', 'vgg16']
untrained_nets = [True, False]
input_imsize = 500 # 227, 500

for network in tqdm(networks):
    for untrained_net in untrained_nets:
# network = 'vgg16'

        if network == 'alexnet':
            if untrained_net:
                weight = None
            else:
                weight = models.AlexNet_Weights.IMAGENET1K_V1
            layer_idx = [2, 5, 8, 10, 12, 18, 21]

        elif network == 'vgg16':
            if untrained_net:
                weight = None
            else:
                weight = models.VGG16_Weights.IMAGENET1K_V1
            layer_idx = [4, 9, 16, 23, 30, 35, 38]

        model = eval(f'models.{network}(weights=weight)')
        model.eval()

    
        data_dir = f'../data/{input_imsize}_sq'
        result_dir = Path('../results/')
        filenames = natsort.natsorted([name for name in os.listdir(data_dir) if name.endswith('.png')])
        fname_wo_ext = [fname.split('.')[0] for fname in filenames]

        transform = transforms.Compose([            
            transforms.Resize(227),                   
            transforms.CenterCrop(227),                
            transforms.ToTensor(),                     
            transforms.Normalize(                      
            mean=[0.485, 0.456, 0.406],                
            std=[0.229, 0.224, 0.225]                  
            )])

        # transform = weight.transforms()
            
        batch_size = len(filenames)
        batch =torch.zeros(batch_size, 3, 227, 227)
        for i, filename in enumerate(filenames):
            batch[i] = transform(Image.open(os.path.join(data_dir, filename)).convert('RGB'))

        _, nodes = get_graph_node_names(model)

        layers = ['Layer1', 'Layer2', 'Layer3', 'Layer4', 'Layer5', 'Layer6', 'Layer7']
        nodes_of_interest = [nodes[idx] for idx in layer_idx]
        return_nodes = dict(zip(nodes_of_interest, layers))
        return_nodes

        # from torchvision.models.feature_extraction import create_feature_extractor
        feature_extractor = create_feature_extractor(model, return_nodes=return_nodes)
        # `out` will be a dict of Tensors, each representing a feature map
        out = feature_extractor(batch) # dictionaly containing feture maps for each node
        resp_name = f'{network}_trained_{not(untrained_net)}_inputsize_{input_imsize}.pkl' 
        resp_dir = result_dir / 'data' / 'net_resp' / resp_name 
        with open(resp_dir, 'wb') as f:
            pickle.dump(out, f)

100%|██████████| 2/2 [00:26<00:00, 13.43s/it]
