In [25]:
import torch
from torch.nn import Conv2d, ReLU, MaxPool2d
from torch.autograd import Variable
from torchvision import models
import cv2
import sys
import numpy as np
import time

In [20]:
def replace_layers(model, i, indexes, layers):
    if i in indexes:
        return layers[indexes.index(i)]
    return model[i]

In [103]:
def prune_conv_layer(model, layer_index, filter_index):
    conv = model.features._modules[str(layer_index)]
    next_conv = None
    offset = 1
    
    while layer_index + offset <  len(model.features._modules.items()):
        res =  model.features._modules[str(layer_index+offset)]
        if isinstance(res, torch.nn.modules.conv.Conv2d):
            next_name = str(layer_index+offset) 
            next_conv = res
            break
        offset = offset + 1
    
    new_conv = \
        torch.nn.Conv2d(in_channels = conv.in_channels, \
            out_channels = conv.out_channels - 1,
            kernel_size = conv.kernel_size, \
            stride = conv.stride,
            padding = conv.padding,
            dilation = conv.dilation,
            groups = conv.groups,
            bias = True)

    old_weights = conv.weight.data.cpu().numpy()
    new_weights = new_conv.weight.data.cpu().numpy()

    new_weights[: filter_index, :, :, :] = old_weights[: filter_index, :, :, :]
    new_weights[filter_index : , :, :, :] = old_weights[filter_index + 1 :, :, :, :]
    if torch.cuda.is_available():
        new_conv.weight.data = torch.from_numpy(new_weights).cuda()
    else:
        new_conv.weight.data = torch.from_numpy(new_weights).cpu()

    bias_numpy = conv.bias.data.cpu().numpy()

    bias = np.zeros(shape = (bias_numpy.shape[0] - 1), dtype = np.float32)
    bias[:filter_index] = bias_numpy[:filter_index]
    bias[filter_index : ] = bias_numpy[filter_index + 1 :]
    if torch.cuda.is_available():
        new_conv.bias.data = torch.from_numpy(bias).cuda()
    else:
        new_conv.bias.data = torch.from_numpy(bias).cpu()

    if not next_conv is None:
        next_new_conv = \
            torch.nn.Conv2d(in_channels = next_conv.in_channels - 1,\
                out_channels =  next_conv.out_channels, \
                kernel_size = next_conv.kernel_size, \
                stride = next_conv.stride,
                padding = next_conv.padding,
                dilation = next_conv.dilation,
                groups = next_conv.groups,
                bias = next_conv.bias)

        old_weights = next_conv.weight.data.cpu().numpy()
        new_weights = next_new_conv.weight.data.cpu().numpy()

        new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :]
        new_weights[:, filter_index : , :, :] = old_weights[:, filter_index + 1 :, :, :]
        if torch.cuda.is_available():
            next_new_conv.weight.data = torch.from_numpy(new_weights).cuda()
        else:
            next_new_conv.weight.data = torch.from_numpy(new_weights).cpu()

        next_new_conv.bias.data = next_conv.bias.data

    if not next_conv is None:
        features = torch.nn.Sequential(
                *(replace_layers(model.features, i, [layer_index, layer_index+offset], \
                    [new_conv, next_new_conv]) for i, _ in enumerate(model.features)))
        del model.features
        del conv

        model.features = features

    else:
        #Prunning the last conv layer. This affects the first linear layer of the classifier.
        model.features = torch.nn.Sequential(
                *(replace_layers(model.features, i, [layer_index], \
                    [new_conv]) for i, _ in enumerate(model.features)))
        layer_index = 0
        old_linear_layer = None
        for _, module in model.classifier._modules.items():
            if isinstance(module, torch.nn.Linear):
                old_linear_layer = module
                break
            layer_index = layer_index  + 1

        if old_linear_layer is None:
            raise BaseException("No linear layer found in classifier")
        params_per_input_channel = int(old_linear_layer.in_features / conv.out_channels)
        
        new_linear_layer = \
            torch.nn.Linear(old_linear_layer.in_features - params_per_input_channel, 
                old_linear_layer.out_features)

        old_weights = old_linear_layer.weight.data.cpu().numpy()
        new_weights = new_linear_layer.weight.data.cpu().numpy()

        new_weights[:, : filter_index * params_per_input_channel] = \
            old_weights[:, : filter_index * params_per_input_channel]
        new_weights[:, filter_index * params_per_input_channel :] = \
            old_weights[:, (filter_index + 1) * params_per_input_channel :]

        new_linear_layer.bias.data = old_linear_layer.bias.data
        
        if torch.cuda.is_available():
            new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda()
        else:
            new_linear_layer.weight.data = torch.from_numpy(new_weights).cpu()

        classifier = torch.nn.Sequential(
            *(replace_layers(model.classifier, i, [layer_index], \
                [new_linear_layer]) for i, _ in enumerate(model.classifier)))

        del model.classifier
        del next_conv
        del conv
        model.classifier = classifier

    return model

In [107]:
#on other data, not ours

model = models.vgg16(pretrained=True)
model.train()
t0 = time.time()
model = prune_conv_layer(model, 28, 10)
print ("The prunning took " + str(time.time() - t0) + " seconds")

The prunning took 0.9545621871948242 seconds


In [108]:
#running on our data
our_vgg_model = torch.load('test.pt')
summary(our_vgg_model_model)
t0_vgg = time.time()
our_vgg_model = prune_conv_layer(model, 27, 10) # change to reflect what conv layer we want to prune
print ("The prunning took " + str(time.time() - t0) + " seconds")

FileNotFoundError: [Errno 2] No such file or directory: 'VGG16_v2-OCT_Retina_half_dataset.pt'