# Example: extraction of VGG-16 representations

In [1]:
# general imports

from __future__ import print_function, division
import numpy as np
import matplotlib.pyplot as plt
import time
import sys
import os
import os.path as path
from os import listdir 
from os.path import isfile, join
import copy
import pickle
from tqdm import tqdm
import argparse
from collections import namedtuple

In [2]:
# imports from torch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torchsummary import summary

import torchvision
from torchvision import datasets, models, transforms
from torchsummary import summary

from torchvision.models import vgg16

In [3]:
# change your paths here
ROOT = '/home/ansuini/repos/IntrinsicDimDeep'

os.chdir(ROOT)
import sys
sys.path.insert(0, ROOT)

# and here
results_folder = '.'

### Classes selected 

In [4]:
category_tags = ['n01882714','n02086240','n02087394','n02094433','n02100583','n02100735','n02279972', 'mix']
n_objects = len(category_tags) - 1
print('N.of classes : {}'.format(n_objects))

N.of classes : 7


In [5]:
# random generator init
torch.backends.cudnn.deterministic = True
torch.manual_seed(999)

<torch._C.Generator at 0x7f6d75fbcad0>

In [6]:
# parameters
arch = 'vgg16'
nsamples = 500
bs = 16

In [7]:
# functions to select checkpoint layers and to determine their depths
# (this works for AlexNet and VGG-like architectures)
def getDepths(model):    
    count = 0    
    modules = []
    names = []
    depths = []    
    modules.append('input')
    names.append('input')
    depths.append(0)    
    
    for i,module in enumerate(model.features):       
        name = module.__class__.__name__
        if 'Conv2d' in name or 'Linear' in name:
            count += 1
        if 'MaxPool2d' in name:
            modules.append(module)
            depths.append(count)
            names.append('MaxPool2d')            
    for i,module in enumerate(model.classifier):
        name = module.__class__.__name__
        if 'Linear' in name:
            modules.append(module)    
            count += 1
            depths.append(count + 1)
            names.append('Linear')                       
    depths = np.array(depths)   
    return modules, names, depths

def getLayerDepth(layer):
    count = 0
    for m in layer:
        for c in m.children():
            name = c.__class__.__name__
            if 'Conv' in name:
                count += 1
    return count

In [8]:
# functions to select checkpoint layers and to determine their depths
# (this works for ResNets architectures)
def getResNetsDepths(model):    
    modules = []
    names = []
    depths = []  
    
    # input
    count = 0
    modules.append('input')
    names.append('input')
    depths.append(count)           
    # maxpooling
    count += 1
    modules.append(model.maxpool)
    names.append('maxpool')
    depths.append(count)     
    # 1 
    count += getLayerDepth(model.layer1)
    modules.append(model.layer1)
    names.append('layer1')
    depths.append(count)         
    # 2
    count += getLayerDepth(model.layer2)
    modules.append(model.layer2)
    names.append('layer2')
    depths.append(count)      
    # 3
    count += getLayerDepth(model.layer3)
    modules.append(model.layer3)
    names.append('layer3')
    depths.append(count)     
    # 4 
    count += getLayerDepth(model.layer4)
    modules.append(model.layer4)
    names.append('layer4')
    depths.append(count)      
    # average pooling
    count += 1
    modules.append(model.avgpool)
    names.append('avgpool')
    depths.append(count)     
    # output
    count += 1
    modules.append(model.fc)
    names.append('fc')
    depths.append(count)                      
    depths = np.array(depths)    
    return modules, names, depths

# Instantiate model and define checkpoints

In [9]:
print('Instantiating pre-trained model')
model = vgg16(pretrained=True)

Instantiating pre-trained model


In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [24]:
# this switch to evaluation mode your network: in this way dropout and batchnorm 
# no more active and you can use the network as a 'passive' feedforward device; 
# forgetting this produces catastrophically wrong results (I know because I did it)
model.eval()
print('Training mode : {}'.format(model.training))

Training mode : False


In [None]:
modules, names, depths = getDepths(model)
print('List of layers from which to extract representations: {}.format(names) )

In [13]:
# images preprocessing methods

mean_imgs = [0.485, 0.456, 0.406]
std_imgs = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

# Data transformations (same as suggested by Soumith Chintala's script)
data_transforms =  transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])  

# Extraction of representations


This will extract for each class the representations of $\sim 500$ images 
from all checkpoints, including input and output.

It will save these representations as matrices of shape (n.images,embedding dimension).

A typical filename will be n02086240_5.npy which means that this file contains the representations of class n02086240 extracted at the sixt (5+1) checkpoint layer, which is the max pooling after the last convolutional layer, as you can easily check by printing the list of names. 

In [22]:
for i,name in enumerate(names):
    print(i,name)

0 input
1 MaxPool2d
2 MaxPool2d
3 MaxPool2d
4 MaxPool2d
5 MaxPool2d
6 Linear
7 Linear
8 Linear


In [14]:
n_layers = len(modules)
embdims=[] # store the embedding dimension of the checkpoint layers (n.of units)

In [16]:
for i,tag in enumerate(category_tags):
    
    print('Processing class: {}'.format(tag))
    data_folder = join(ROOT, 'data', 'imagenet_training_single_objs', tag)
    image_dataset = datasets.ImageFolder(join(data_folder), data_transforms)           
    dataloader = torch.utils.data.DataLoader(image_dataset, 
                                             batch_size=bs, 
                                             shuffle=True, 
                                             num_workers=1)  

    for l,module in enumerate(modules):    
        for k, data in enumerate(dataloader, 0):
            if k*bs > nsamples:
                break
            else:  
                inputs, _ = data                          
                if module == 'input':                
                    hout = inputs                      
                else:            
                    hout = []
                    def hook(module, input, output):
                        hout.append(output)                
                    handle = module.register_forward_hook(hook)                            
                    out = model(inputs.to(device))
                    del out   
                    
                    hout = hout[0] 
                            
                    handle.remove()

                if k == 0:
                    Out = hout.view(inputs.shape[0], -1).cpu().data    
                else :               
                    Out = torch.cat((Out, hout.view(inputs.shape[0], -1).cpu().data),0) 
                hout = hout.detach().cpu()
                del hout

        Out = Out.detach().cpu()  
        embdims.append(Out.shape[1])
        
        
        np.save(join(results_folder, tag + '_' + str(l) + '.npy' ), Out)

Processing class: n01882714
Processing class: n02086240


Traceback (most recent call last):


KeyboardInterrupt: 

  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
