In [1]:
import os
import shutil

import torch
import torch.utils.data

import argparse
import re

In [2]:
## utility functions
def makedir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def create_logger(log_filename, display=True):
    f = open(log_filename, 'a')
    counter = [0]
    # this function will still have access to f after create_logger terminates
    def logger(text):
        if display:
            print(text)
        f.write(text + '\n')
        counter[0] += 1
        if counter[0] % 10 == 0:
            f.flush()
            os.fsync(f.fileno())
        # Question: do we need to flush()
    return logger, f.close

In [3]:
## setting
prototype_shape = (128, 8, 1, 1)
num_classes = 2
prototype_activation_function = 'log'
add_on_layers_type = 'regular'

experiment_run = '003'

In [4]:
## settings
base_architecture = 'vgg19'
base_architecture_type = re.match('^[a-z]*', base_architecture).group(0)

model_dir = './saved_models/' + base_architecture + '/' + experiment_run + '/'
makedir(model_dir)

my_filepath = os.path.join(os.getcwd(), 'main.ipynb') 
shutil.copy(src=my_filepath, dst=model_dir)

log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))
img_dir = os.path.join(model_dir, 'img')
makedir(img_dir)
weight_matrix_filename = 'outputL_weights'
prototype_img_filename_prefix = 'prototype-img'
prototype_self_act_filename_prefix = 'prototype-self-act'
proto_bound_boxes_filename_prefix = 'bb'

log('base_architecture: {0}'.format(base_architecture))
log('base_architecture_type: {0}'.format(base_architecture_type))
log('mode_dir: {0}'.format(model_dir))
log('saved main.ipynb to {0}'.format(my_filepath))

data_path = './datasets/cub200_cropped/'
train_dir = data_path + 'train_cropped_augmented/'
test_dir = data_path + 'test_cropped/'
train_push_dir = data_path + 'train_cropped/'

log('train dir: {0}'.format(train_dir))
log('test_dir: {0}'.format(test_dir))
log('train_push_dir: {0}'.format(train_push_dir))

base_architecture: vgg19
base_architecture_type: vgg
mode_dir: ./saved_models/vgg19/003/
saved main.ipynb to /home/jovyan/codes/ProtoPNet/main.ipynb
train dir: ./datasets/cub200_cropped/train_cropped_augmented/
test_dir: ./datasets/cub200_cropped/test_cropped/
train_push_dir: ./datasets/cub200_cropped/train_cropped/


In [5]:
## load the dataset

import torchvision.datasets as datasets
import torchvision.transforms as transforms

img_size = 56
train_batch_size = 5
test_batch_size = 5
train_push_batch_size = 5

preprocess_mean = (0.485, 0.456, 0.406)
preprocess_std = (0.229, 0.224, 0.225)

normalize = transforms.Normalize(mean=preprocess_mean, std=preprocess_std)

train_dataset = datasets.ImageFolder(
    train_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
    ]))
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True,
    num_workers=0, pin_memory=False)
# push set
train_push_dataset = datasets.ImageFolder(
    train_push_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
    ]))
train_push_loader = torch.utils.data.DataLoader(
    train_push_dataset, batch_size=train_push_batch_size, shuffle=False,
    num_workers=0, pin_memory=False)
# test set
test_dataset = datasets.ImageFolder(
    test_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
    ]))
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False,
    num_workers=0, pin_memory=False)

log('INFO: training set size: {0}'.format(len(train_loader.dataset)))
log('INFO: push set size: {0}'.format(len(train_push_loader.dataset)))
log('INFO: test set size: {0}'.format(len(test_loader.dataset)))
log('INFO: batch size: {0}'.format(train_batch_size))

INFO: training set size: 60
INFO: push set size: 60
INFO: test set size: 60
INFO: batch size: 5


In [6]:
## vgg_features
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

class VGG_features(nn.Module):
    def __init__(self, cfg):
        super(VGG_features, self).__init__()
        self.kernel_sizes = []
        self.strides = []
        self.paddings = []
        self.features = self._make_layers(cfg)

    def forward(self, x):
        x = self.features(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def _make_layers(self, cfg):
        self.n_layers = 0
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                self.kernel_sizes.append(2)
                self.strides.append(2)
                self.paddings.append(0)
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.ReLU(inplace=True)]
                self.n_layers += 1
                self.kernel_sizes.append(3)
                self.strides.append(1)
                self.paddings.append(1)
                in_channels = v

        return nn.Sequential(*layers)

    def conv_info(self):
        return self.kernel_sizes, self.strides, self.paddings

    def num_layers(self):
        '''
        the number of conv layers in the network
        '''
        return self.n_layers

    def __repr__(self):
        template = 'VGG{}'
        return template.format(self.num_layers() + 3)

def vgg19_features():
    """VGG 19-layer model
    """
    cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
    model = VGG_features(cfg)
    my_dict = model_zoo.load_url('https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', model_dir='./pretrained_models')
    keys_to_remove = set()
    
    for key in my_dict:
        if key.startswith('classifier'):
            keys_to_remove.add(key)
    for key in keys_to_remove:
        del my_dict[key]
    model.load_state_dict(my_dict, strict=False)
    return model

In [7]:
import math

def compute_proto_layer_rf_info_v2(img_size, layer_filter_sizes, layer_strides, layer_paddings, prototype_kernel_size):

    assert(len(layer_filter_sizes) == len(layer_strides))
    assert(len(layer_filter_sizes) == len(layer_paddings))

    rf_info = [img_size, 1, 1, 0.5]

    for i in range(len(layer_filter_sizes)):
        filter_size = layer_filter_sizes[i]
        stride_size = layer_strides[i]
        padding_size = layer_paddings[i]

        rf_info = compute_layer_rf_info(layer_filter_size=filter_size,
                                layer_stride=stride_size,
                                layer_padding=padding_size,
                                previous_layer_rf_info=rf_info)

    proto_layer_rf_info = compute_layer_rf_info(layer_filter_size=prototype_kernel_size,
                                                layer_stride=1,
                                                layer_padding='VALID',
                                                previous_layer_rf_info=rf_info)

    return proto_layer_rf_info

def compute_layer_rf_info(layer_filter_size, layer_stride, layer_padding,
                          previous_layer_rf_info):
    n_in = previous_layer_rf_info[0] # input size
    j_in = previous_layer_rf_info[1] # receptive field jump of input layer
    r_in = previous_layer_rf_info[2] # receptive field size of input layer
    start_in = previous_layer_rf_info[3] # center of receptive field of input layer

    if layer_padding == 'SAME':
        n_out = math.ceil(float(n_in) / float(layer_stride))
        if (n_in % layer_stride == 0):
            pad = max(layer_filter_size - layer_stride, 0)
        else:
            pad = max(layer_filter_size - (n_in % layer_stride), 0)
        assert(n_out == math.floor((n_in - layer_filter_size + pad)/layer_stride) + 1) # sanity check
        assert(pad == (n_out-1)*layer_stride - n_in + layer_filter_size) # sanity check
    elif layer_padding == 'VALID':
        n_out = math.ceil(float(n_in - layer_filter_size + 1) / float(layer_stride))
        pad = 0
        assert(n_out == math.floor((n_in - layer_filter_size + pad)/layer_stride) + 1) # sanity check
        assert(pad == (n_out-1)*layer_stride - n_in + layer_filter_size) # sanity check
    else:
        # layer_padding is an int that is the amount of padding on one side
        pad = layer_padding * 2
        n_out = math.floor((n_in - layer_filter_size + pad)/layer_stride) + 1

    pL = math.floor(pad/2)

    j_out = j_in * layer_stride
    r_out = r_in + (layer_filter_size - 1)*j_in
    start_out = start_in + ((layer_filter_size - 1)/2 - pL)*j_in
    return [n_out, j_out, r_out, start_out]

In [8]:
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
# from receptive_field import compute_proto_layer_rf_info_v2

class PPNet(nn.Module):

    def __init__(self, features, img_size, prototype_shape,
                 proto_layer_rf_info, num_classes,
                 prototype_activation_function='log',
                 add_on_layers_type='bottleneck'):

        super(PPNet, self).__init__()
        self.img_size = img_size
        self.prototype_shape = prototype_shape
        self.num_prototypes = prototype_shape[0]
        self.num_classes = num_classes
        self.epsilon = 1e-4
        
        # prototype_activation_function could be 'log', 'linear',
        # or a generic function that converts distance to similarity score
        self.prototype_activation_function = prototype_activation_function

        '''
        Here we are initializing the class identities of the prototypes
        Without domain specific knowledge we allocate the same number of
        prototypes for each class
        '''
        assert(self.num_prototypes % self.num_classes == 0)
        # a onehot indication matrix for each prototype's class identity
        self.prototype_class_identity = torch.zeros(self.num_prototypes,
                                                    self.num_classes)

        num_prototypes_per_class = self.num_prototypes // self.num_classes
        for j in range(self.num_prototypes):
            self.prototype_class_identity[j, j // num_prototypes_per_class] = 1

        self.proto_layer_rf_info = proto_layer_rf_info

        # this has to be named features to allow the precise loading
        self.features = features
        first_add_on_layer_in_channels = [i for i in features.modules() if isinstance(i, nn.Conv2d)][-1].out_channels

        self.add_on_layers = nn.Sequential(
                nn.Conv2d(in_channels=first_add_on_layer_in_channels, out_channels=self.prototype_shape[1], kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1),
                nn.Sigmoid()
                )
        
        self.prototype_vectors = nn.Parameter(torch.rand(self.prototype_shape),
                                              requires_grad=True)

        self.ones = nn.Parameter(torch.ones(self.prototype_shape),
                                 requires_grad=False)

        self.h_last_layer = nn.Linear(self.num_prototypes, self.num_classes,
                                    bias=False) # do not use bias

        self._initialize_weights()

    def forward(self, x):
        f_out = self.f_conv_layer(x)
        g_p_out, min_distances, _ = self.g_p_layer(f_out)
        h_out = self.h_last_layer(g_p_out)
        return h_out, min_distances    

    def f_conv_layer(self, x):
        x = self.features(x)
        x = self.add_on_layers(x)
        return x

    def g_p_layer(self, z):
        z2 = z ** 2
        z2_patch_sum = F.conv2d(input=z2, weight=self.ones)

        p2 = self.prototype_vectors ** 2
        p2 = torch.sum(p2, dim=(1, 2, 3))
        p2_reshape = p2.view(-1, 1, 1)

        zp = F.conv2d(input=z, weight=self.prototype_vectors)
        intermediate_result = - 2 * zp + p2_reshape
        zp_distances = F.relu(z2_patch_sum + intermediate_result)

        min_zp_distance = -F.max_pool2d(-zp_distances,
                                      kernel_size=(zp_distances.size()[2],
                                                   zp_distances.size()[3]))
        min_zp_distance = min_zp_distance.view(-1, self.num_prototypes)
        g_p_out = self.distance_2_similarity(min_zp_distance)        

        return g_p_out, min_zp_distance, zp_distances

    def distance_2_similarity(self, distances):
        return torch.log((distances + 1) / (distances + self.epsilon))

    def push_forward(self, x):
        '''this method is needed for the pushing operation'''
        f_out = self.f_conv_layer(x)
        g_p_out, min_distances, distances = self.g_p_layer(f_out)
        #return g_p_out, distances
        return f_out, distances

    def prune_prototypes(self, prototypes_to_prune):
        '''
        prototypes_to_prune: a list of indices each in
        [0, current number of prototypes - 1] that indicates the prototypes to
        be removed
        '''
        prototypes_to_keep = list(set(range(self.num_prototypes)) - set(prototypes_to_prune))

        self.prototype_vectors = nn.Parameter(self.prototype_vectors.data[prototypes_to_keep, ...],
                                              requires_grad=True)

        self.prototype_shape = list(self.prototype_vectors.size())
        self.num_prototypes = self.prototype_shape[0]

        # changing self.h_last_layer in place
        # changing in_features and out_features make sure the numbers are consistent
        self.h_last_layer.in_features = self.num_prototypes
        self.h_last_layer.out_features = self.num_classes
        self.h_last_layer.weight.data = self.h_last_layer.weight.data[:, prototypes_to_keep]

        # self.ones is nn.Parameter
        self.ones = nn.Parameter(self.ones.data[prototypes_to_keep, ...],
                                 requires_grad=False)
        # self.prototype_class_identity is torch tensor
        # so it does not need .data access for value update
        self.prototype_class_identity = self.prototype_class_identity[prototypes_to_keep, :]

    def __repr__(self):
        # PPNet(self, features, img_size, prototype_shape,
        # proto_layer_rf_info, num_classes, init_weights=True):
        rep = (
            'PPNet(\n'
            '\tfeatures: {},\n'
            '\timg_size: {},\n'
            '\tprototype_shape: {},\n'
            '\tproto_layer_rf_info: {},\n'
            '\tnum_classes: {},\n'
            '\tepsilon: {}\n'
            ')'
        )

        return rep.format(self.features,
                          self.img_size,
                          self.prototype_shape,
                          self.proto_layer_rf_info,
                          self.num_classes,
                          self.epsilon)

    def set_h_last_layer_incorrect_connection(self, incorrect_strength):
        '''
        the incorrect strength will be actual strength if -0.5 then input -0.5
        '''
        positive_one_weights_locations = torch.t(self.prototype_class_identity)
        negative_one_weights_locations = 1 - positive_one_weights_locations

        correct_class_connection = 1
        incorrect_class_connection = incorrect_strength
        self.h_last_layer.weight.data.copy_(
            correct_class_connection * positive_one_weights_locations
            + incorrect_class_connection * negative_one_weights_locations)

    def _initialize_weights(self):
        for m in self.add_on_layers.modules():
            if isinstance(m, nn.Conv2d):
                # every init technique has an underscore _ in the name
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        self.set_h_last_layer_incorrect_connection(incorrect_strength=-0.5)



def construct_PPNet(base_architecture, pretrained, img_size,
                    prototype_shape, num_classes,
                    prototype_activation_function='log',
                    add_on_layers_type='bottleneck'):
    features = vgg19_features()
    layer_filter_sizes, layer_strides, layer_paddings = features.conv_info()
    proto_layer_rf_info = compute_proto_layer_rf_info_v2(img_size=img_size,
                                                         layer_filter_sizes=layer_filter_sizes,
                                                         layer_strides=layer_strides,
                                                         layer_paddings=layer_paddings,
                                                         prototype_kernel_size=prototype_shape[2])
    return PPNet(features=features,
                 img_size=img_size,
                 prototype_shape=prototype_shape,
                 proto_layer_rf_info=proto_layer_rf_info,
                 num_classes=num_classes,
                 prototype_activation_function=prototype_activation_function,
                 add_on_layers_type=add_on_layers_type)


In [9]:
# construct the model
ppnet = construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)

warm_optimizer_lrs = {'add_on_layers': 3e-3, 
                      'prototype_vectors': 3e-3}

warm_optimizer_specs = [{
    'params': ppnet.add_on_layers.parameters(), 
    'lr': warm_optimizer_lrs['add_on_layers'], 
    'weight_decay': 1e-3},                     
    {'params': ppnet.prototype_vectors, 
    'lr': warm_optimizer_lrs['prototype_vectors']
    }]

warm_optimizer = torch.optim.Adam(warm_optimizer_specs)

joint_optimizer_lrs = {'features': 1e-4,
                       'add_on_layers': 3e-3,
                       'prototype_vectors': 3e-3}
joint_lr_step_size = 5

joint_optimizer_specs = [{
    'params': ppnet.features.parameters(), 
    'lr': joint_optimizer_lrs['features'], 
    'weight_decay': 1e-3
    }, 
    {
    'params': ppnet.add_on_layers.parameters(), 
    'lr': joint_optimizer_lrs['add_on_layers'], 
    'weight_decay': 1e-3 },
    {
    'params': ppnet.prototype_vectors,
    'lr': joint_optimizer_lrs['prototype_vectors']
    }]

joint_optimizer = torch.optim.Adam(joint_optimizer_specs)

joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.1)

h_last_layer_optimizer_lr = 1e-4

h_last_layer_optimizer_specs = [{
    'params': ppnet.h_last_layer.parameters(), 
    'lr': h_last_layer_optimizer_lr
    }]

h_last_layer_optimizer = torch.optim.Adam(h_last_layer_optimizer_specs)

In [10]:
import time
import torch

def _train_or_test(model, dataloader, optimizer, class_specific, coefs, log):
    is_train = optimizer is not None
    start = time.time()
    n_examples = 0
    n_correct = 0
    n_batches = 0
    total_cross_entropy = 0
    total_cluster_cost = 0
    total_separation_cost = 0
    total_avg_separation_cost = 0

    for i, (image, label) in enumerate(dataloader):
        input_imag = image 
        target_class = label
        log('INFO: training with the images batch#{0}'.format(i))
        grad_req = torch.enable_grad() if is_train else torch.no_grad()
        with grad_req:
            output, min_distances = model(input_imag)
            # compute loss
            cross_entropy = torch.nn.functional.cross_entropy(output, target_class)

            max_dist = (ppnet.prototype_shape[1] * ppnet.prototype_shape[2] * ppnet.prototype_shape[3])

            correct_class_indicators = torch.t(ppnet.prototype_class_identity[:, target_class]) # batch_size * num_prototypes
            inverted_min_distances2correct_class, _ = torch.max((max_dist - min_distances) * correct_class_indicators, dim=1)  # max over all correct prototypes
            cluster_cost = torch.mean(max_dist - inverted_min_distances2correct_class)

            # calculate separation cost
            wrong_class_indicators = 1 - correct_class_indicators
            inverted_min_distances2wrong_class, _ = torch.max((max_dist - min_distances) * wrong_class_indicators, dim=1)
            separation_cost = torch.mean(max_dist - inverted_min_distances2wrong_class)

            # calculate avg cluster cost
            avg_separation_cost = torch.sum(min_distances * wrong_class_indicators, dim=1) / torch.sum(wrong_class_indicators, dim=1)
            avg_separation_cost = torch.mean(avg_separation_cost)
                
            l1_mask = 1 - torch.t(ppnet.prototype_class_identity) 
            l1 = (ppnet.h_last_layer.weight * l1_mask).norm(p=1)

            # evaluation statistics
            _, predicted = torch.max(output.data, dim=1)
            n_examples += target_class.size(0)
            n_correct += (predicted == target_class).sum().item()

            n_batches += 1
            total_cross_entropy += cross_entropy.item()
            total_cluster_cost += cluster_cost.item()
            total_separation_cost += separation_cost.item()
            total_avg_separation_cost += avg_separation_cost.item()

        # compute gradient and do SGD step
        if is_train:
            loss = (coefs['crs_ent'] * cross_entropy + coefs['clst'] * cluster_cost + coefs['sep'] * separation_cost + coefs['l1'] * l1)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        del input_imag
        del target_class
        del output
        del predicted
        del min_distances

    end = time.time()
    training_or_testing = 'training' if is_train else 'testing'
    log('INFO: {} time: \t{}'.format(training_or_testing, end -  start))
    log('INFO: average cross entropy per batch: \t{0}'.format(total_cross_entropy / n_batches))
    log('INFO: average cluster loss per batch: \t{0}'.format(total_cluster_cost / n_batches))
    log('INFO: separation loss:\t{0}'.format(total_separation_cost / n_batches))
    log('INFO: avged separation loss:\t{0}'.format(total_avg_separation_cost / n_batches))
    log('INFO: accu: \t{0}%'.format(n_correct / n_examples * 100))
    
    return n_correct / n_examples


def tnt_train(model, dataloader, optimizer, class_specific, coefs, log):
    assert(optimizer is not None)    
    model.train()
    return _train_or_test(model=model, dataloader=dataloader, optimizer=optimizer,
                          class_specific=class_specific, coefs=coefs, log=log)


def tnt_test(model, dataloader, class_specific, log):
    model.eval()
    return _train_or_test(model=model, dataloader=dataloader, optimizer=None,
                          class_specific=class_specific, coefs=None, log=log)


def tnt_last_only(model, log=print):
    for p in ppnet.features.parameters():
        p.requires_grad = False
    for p in ppnet.add_on_layers.parameters():
        p.requires_grad = False
    ppnet.prototype_vectors.requires_grad = False
    for p in ppnet.h_last_layer.parameters():
        p.requires_grad = True

def tnt_warm_only(model, log=print):
    for p in ppnet.features.parameters():
        p.requires_grad = False
    for p in ppnet.add_on_layers.parameters():
        p.requires_grad = True
    ppnet.prototype_vectors.requires_grad = True
    for p in ppnet.h_last_layer.parameters():
        p.requires_grad = True
    

def tnt_joint(model, log=print):
    for p in ppnet.features.parameters():
        p.requires_grad = True
    for p in ppnet.add_on_layers.parameters():
        p.requires_grad = True
    ppnet.prototype_vectors.requires_grad = True
    for p in ppnet.h_last_layer.parameters():
        p.requires_grad = True
    

In [11]:
def compute_rf_prototype(img_size, prototype_patch_index, protoL_rf_info):
    img_index = prototype_patch_index[0]
    height_index = prototype_patch_index[1]
    width_index = prototype_patch_index[2]
    rf_indices = compute_rf_protoL_at_spatial_location(img_size,
                                                       height_index,
                                                       width_index,
                                                       protoL_rf_info)
    return [img_index, rf_indices[0], rf_indices[1],
            rf_indices[2], rf_indices[3]]

def compute_rf_protoL_at_spatial_location(img_size, height_index, width_index, protoL_rf_info):
    n = protoL_rf_info[0]
    j = protoL_rf_info[1]
    r = protoL_rf_info[2]
    start = protoL_rf_info[3]
    assert(height_index < n)
    assert(width_index < n)

    center_h = start + (height_index*j)
    center_w = start + (width_index*j)

    rf_start_height_index = max(int(center_h - (r/2)), 0)
    rf_end_height_index = min(int(center_h + (r/2)), img_size)

    rf_start_width_index = max(int(center_w - (r/2)), 0)
    rf_end_width_index = min(int(center_w + (r/2)), img_size)

    return [rf_start_height_index, rf_end_height_index,
            rf_start_width_index, rf_end_width_index]


In [12]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import copy
import time

# from receptive_field import compute_rf_prototype
# from helpers import makedir, find_high_activation_crop

def preprocess(x, mean, std):
    assert x.size(1) == 3
    y = torch.zeros_like(x)
    for i in range(3):
        y[:, i, :, :] = (x[:, i, :, :] - mean[i]) / std[i]
    return y


def preprocess_input_function(x):
    '''
    allocate new tensor like x and apply the normalization used in the
    pretrained model
    '''
    return preprocess(x, mean=preprocess_mean, std=preprocess_std)

def find_high_activation_crop(activation_map, percentile=95):
    threshold = np.percentile(activation_map, percentile)
    mask = np.ones(activation_map.shape)
    mask[activation_map < threshold] = 0
    lower_y, upper_y, lower_x, upper_x = 0, 0, 0, 0
    for i in range(mask.shape[0]):
        if np.amax(mask[i]) > 0.5:
            lower_y = i
            break
    for i in reversed(range(mask.shape[0])):
        if np.amax(mask[i]) > 0.5:
            upper_y = i
            break
    for j in range(mask.shape[1]):
        if np.amax(mask[:,j]) > 0.5:
            lower_x = j
            break
    for j in reversed(range(mask.shape[1])):
        if np.amax(mask[:,j]) > 0.5:
            upper_x = j
            break
    return lower_y, upper_y+1, lower_x, upper_x+1

# push each prototype to the nearest patch in the training set
def push_prototypes(dataloader, # pytorch dataloader (must be unnormalized in [0,1])
                    prototype_network, # pytorch network with prototype_vectors
                    class_specific=True,
                    preprocess_input_function=None, # normalize if needed
                    prototype_layer_stride=1,
                    root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                    epoch_number=None, # if not provided, prototypes saved previously will be overwritten
                    prototype_img_filename_prefix=None,
                    prototype_self_act_filename_prefix=None,
                    proto_bound_boxes_filename_prefix=None,
                    save_prototype_class_identity=True, # which class the prototype image comes from
                    log=print,
                    prototype_activation_function_in_numpy=None):

    prototype_network.eval()

    prototype_shape = prototype_network.prototype_shape
    n_prototypes = prototype_network.num_prototypes
    # saves the closest distance seen so far
    global_min_proto_dist = np.full(n_prototypes, np.inf)
    # saves the patch representation that gives the current smallest distance
    global_min_fmap_patches = np.zeros(
        [n_prototypes,
         prototype_shape[1],
         prototype_shape[2],
         prototype_shape[3]])

    '''
    proto_rf_boxes and proto_bound_boxes column:
    0: image index in the entire dataset
    1: height start index
    2: height end index
    3: width start index
    4: width end index
    5: (optional) class identity
    '''

    proto_rf_boxes = np.full(shape=[n_prototypes, 6], fill_value=-1)
    proto_bound_boxes = np.full(shape=[n_prototypes, 6], fill_value=-1)

    proto_epoch_dir = os.path.join(root_dir_for_saving_prototypes, 'epoch-'+str(epoch_number))
    makedir(proto_epoch_dir)

    search_batch_size = dataloader.batch_size
    num_classes = prototype_network.num_classes

    for push_iter, (search_batch_input, search_y) in enumerate(dataloader):
        start_index_of_search_batch = push_iter * search_batch_size
        update_prototypes_on_batch(search_batch_input,
                                   start_index_of_search_batch,
                                   prototype_network,
                                   global_min_proto_dist,
                                   global_min_fmap_patches,
                                   proto_rf_boxes,
                                   proto_bound_boxes,
                                   class_specific=class_specific,
                                   search_y=search_y,
                                   num_classes=num_classes,
                                   preprocess_input_function=preprocess_input_function,
                                   prototype_layer_stride=prototype_layer_stride,
                                   dir_for_saving_prototypes=proto_epoch_dir,
                                   prototype_img_filename_prefix=prototype_img_filename_prefix,
                                   prototype_self_act_filename_prefix=prototype_self_act_filename_prefix,
                                   prototype_activation_function_in_numpy=prototype_activation_function_in_numpy)

    np.save(os.path.join(proto_epoch_dir, proto_bound_boxes_filename_prefix + '-receptive_field' + str(epoch_number) + '.npy'), proto_rf_boxes)
    np.save(os.path.join(proto_epoch_dir, proto_bound_boxes_filename_prefix + str(epoch_number) + '.npy'), proto_bound_boxes)

    prototype_update = np.reshape(global_min_fmap_patches, tuple(prototype_shape))
    prototype_network.prototype_vectors.data.copy_(torch.tensor(prototype_update, dtype=torch.float32))

# update each prototype for current search batch
def update_prototypes_on_batch(search_batch_input,
                               start_index_of_search_batch,
                               prototype_network,
                               global_min_proto_dist, # this will be updated
                               global_min_fmap_patches, # this will be updated
                               proto_rf_boxes, # this will be updated
                               proto_bound_boxes, # this will be updated
                               class_specific=True,
                               search_y=None, # required if class_specific == True
                               num_classes=None, # required if class_specific == True
                               preprocess_input_function=None,
                               prototype_layer_stride=1,
                               dir_for_saving_prototypes=None,
                               prototype_img_filename_prefix=None,
                               prototype_self_act_filename_prefix=None,
                               prototype_activation_function_in_numpy=None):

    prototype_network.eval()
    search_batch = preprocess_input_function(search_batch_input)

    with torch.no_grad():
        protoL_input_tmp, proto_dist_tmp = prototype_network.push_forward(search_batch)

    protoL_input_ = protoL_input_tmp.numpy()
    proto_dist_ = proto_dist_tmp.numpy()
    
    class_to_img_index_dict = {key: [] for key in range(num_classes)}
    # img_y is the image's integer label
    for img_index, img_y in enumerate(search_y):
        img_label = img_y.item()
        class_to_img_index_dict[img_label].append(img_index)

    prototype_shape = prototype_network.prototype_shape
    n_prototypes = prototype_shape[0]
    proto_h = prototype_shape[2]
    proto_w = prototype_shape[3]
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]

    for j in range(n_prototypes):
        target_class = torch.argmax(prototype_network.prototype_class_identity[j]).item()
        if len(class_to_img_index_dict[target_class]) == 0:
            continue
        proto_dist_j = proto_dist_[class_to_img_index_dict[target_class]][:,j,:,:]

        batch_min_proto_dist_j = np.amin(proto_dist_j)
        if batch_min_proto_dist_j < global_min_proto_dist[j]:
            batch_argmin_proto_dist_j = list(np.unravel_index(np.argmin(proto_dist_j, axis=None), proto_dist_j.shape))

            ''' change the argmin index from the index among images of the target class to the index in the entire search batch '''
            batch_argmin_proto_dist_j[0] = class_to_img_index_dict[target_class][batch_argmin_proto_dist_j[0]]

            # retrieve the corresponding feature map patch
            img_index_in_batch = batch_argmin_proto_dist_j[0]
            fmap_height_start_index = batch_argmin_proto_dist_j[1] * prototype_layer_stride
            fmap_height_end_index = fmap_height_start_index + proto_h
            fmap_width_start_index = batch_argmin_proto_dist_j[2] * prototype_layer_stride
            fmap_width_end_index = fmap_width_start_index + proto_w

            batch_min_fmap_patch_j = protoL_input_[img_index_in_batch, :,
                                                   fmap_height_start_index:fmap_height_end_index,
                                                   fmap_width_start_index:fmap_width_end_index]

            global_min_proto_dist[j] = batch_min_proto_dist_j
            global_min_fmap_patches[j] = batch_min_fmap_patch_j
            
            # get the receptive field boundary of the image patch that generates the representation
            protoL_rf_info = prototype_network.proto_layer_rf_info
            rf_prototype_j = compute_rf_prototype(search_batch.size(2), batch_argmin_proto_dist_j, protoL_rf_info)
            
            # get the whole image
            original_img_j = search_batch_input[rf_prototype_j[0]]
            original_img_j = original_img_j.numpy()
            original_img_j = np.transpose(original_img_j, (1, 2, 0))
            original_img_size = original_img_j.shape[0]
            
            # crop out the receptive field
            rf_img_j = original_img_j[rf_prototype_j[1]:rf_prototype_j[2],
                                      rf_prototype_j[3]:rf_prototype_j[4], :]
            
            # save the prototype receptive field information
            proto_rf_boxes[j, 0] = rf_prototype_j[0] + start_index_of_search_batch
            proto_rf_boxes[j, 1] = rf_prototype_j[1]
            proto_rf_boxes[j, 2] = rf_prototype_j[2]
            proto_rf_boxes[j, 3] = rf_prototype_j[3]
            proto_rf_boxes[j, 4] = rf_prototype_j[4]
            if proto_rf_boxes.shape[1] == 6 and search_y is not None:
                proto_rf_boxes[j, 5] = search_y[rf_prototype_j[0]].item()

            # find the highly activated region of the original image
            proto_dist_img_j = proto_dist_[img_index_in_batch, j, :, :]
            if prototype_network.prototype_activation_function == 'log':
                proto_act_img_j = np.log((proto_dist_img_j + 1) / (proto_dist_img_j + prototype_network.epsilon))
            elif prototype_network.prototype_activation_function == 'linear':
                proto_act_img_j = max_dist - proto_dist_img_j
            else:
                proto_act_img_j = prototype_activation_function_in_numpy(proto_dist_img_j)
            upsampled_act_img_j = cv2.resize(proto_act_img_j, dsize=(original_img_size, original_img_size),
                                             interpolation=cv2.INTER_CUBIC)
            proto_bound_j = find_high_activation_crop(upsampled_act_img_j)
            # crop out the image patch with high activation as prototype image
            proto_img_j = original_img_j[proto_bound_j[0]:proto_bound_j[1],
                                         proto_bound_j[2]:proto_bound_j[3], :]

            # save the prototype boundary (rectangular boundary of highly activated region)
            proto_bound_boxes[j, 0] = proto_rf_boxes[j, 0]
            proto_bound_boxes[j, 1] = proto_bound_j[0]
            proto_bound_boxes[j, 2] = proto_bound_j[1]
            proto_bound_boxes[j, 3] = proto_bound_j[2]
            proto_bound_boxes[j, 4] = proto_bound_j[3]
            if proto_bound_boxes.shape[1] == 6 and search_y is not None:
                proto_bound_boxes[j, 5] = search_y[rf_prototype_j[0]].item()

    del class_to_img_index_dict


In [13]:
# train the model
log('INFO: start training the model ......')
num_train_epochs = 1000
num_warm_epochs = 5

push_start = 10
push_epochs = [i for i in range(num_train_epochs) if i % 10 == 0]

coefs = {'crs_ent': 1, 'clst': 0.8, 'sep': -0.08, 'l1': 1e-4}
class_specific = True

for epoch in range(num_train_epochs):
    log('INFO:epoch: \t{0}'.format(epoch))

    if epoch < num_warm_epochs:
        tnt_warm_only(model=ppnet, log=log)
        _ = tnt_train(model=ppnet, dataloader=train_loader, optimizer=warm_optimizer,
                      class_specific=class_specific, coefs=coefs, log=log)
    else:
        tnt_joint(model=ppnet, log=log)
        joint_lr_scheduler.step()
        _ = tnt_train(model=ppnet, dataloader=train_loader, optimizer=joint_optimizer,
                      class_specific=class_specific, coefs=coefs, log=log)

    accu = tnt_test(model=ppnet, dataloader=test_loader,
                    class_specific=class_specific, log=log)

    if epoch >= push_start and epoch in push_epochs:
        push_prototypes(
            train_push_loader, # pytorch dataloader (must be unnormalized in [0,1])
            prototype_network=ppnet, # pytorch network with prototype_vectors
            class_specific=class_specific,
            preprocess_input_function=preprocess_input_function, # normalize if needed
            prototype_layer_stride=1,
            root_dir_for_saving_prototypes=img_dir, # if not None, prototypes will be saved here
            epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten
            prototype_img_filename_prefix=prototype_img_filename_prefix,
            prototype_self_act_filename_prefix=prototype_self_act_filename_prefix,
            proto_bound_boxes_filename_prefix=proto_bound_boxes_filename_prefix,
            save_prototype_class_identity=True,
            log=log)
        accu = tnt_test(model=ppnet, dataloader=test_loader,
                        class_specific=class_specific, log=log)

        if prototype_activation_function != 'linear':
            tnt_last_only(model=ppnet, log=log)
            for i in range(20):
                _ = tnt_train(model=ppnet, dataloader=train_loader, optimizer=h_last_layer_optimizer,
                              class_specific=class_specific, coefs=coefs, log=log)
                accu = tnt_test(model=ppnet, dataloader=test_loader,
                                class_specific=class_specific, log=log)

logclose()

INFO: start training the model ......
INFO:epoch: 	0
INFO: training with the images batch#0
INFO: training with the images batch#1
INFO: training with the images batch#2
INFO: training with the images batch#3
INFO: training with the images batch#4
INFO: training with the images batch#5
INFO: training with the images batch#6
INFO: training with the images batch#7
INFO: training with the images batch#8
INFO: training with the images batch#9
INFO: training with the images batch#10
INFO: training with the images batch#11
INFO: training time: 	2.9157440662384033
INFO: average cross entropy per batch: 	0.8079073280096054
INFO: average cluster loss per batch: 	0.8237993468840917
INFO: separation loss:	0.8235770414272944
INFO: avged separation loss:	2.5115922689437866
INFO: accu: 	55.00000000000001%
INFO: training with the images batch#0
INFO: training with the images batch#1
INFO: training with the images batch#2
INFO: training with the images batch#3
INFO: training with the images batch#4
IN



INFO: training with the images batch#1
INFO: training with the images batch#2
INFO: training with the images batch#3
INFO: training with the images batch#4
INFO: training with the images batch#5
INFO: training with the images batch#6
INFO: training with the images batch#7
INFO: training with the images batch#8
INFO: training with the images batch#9
INFO: training with the images batch#10
INFO: training with the images batch#11
INFO: training time: 	7.823610782623291
INFO: average cross entropy per batch: 	0.9311062668760618
INFO: average cluster loss per batch: 	0.652057888607184
INFO: separation loss:	0.7447778185208639
INFO: avged separation loss:	2.4589456717173257
INFO: accu: 	53.333333333333336%
INFO: training with the images batch#0
INFO: training with the images batch#1
INFO: training with the images batch#2
INFO: training with the images batch#3
INFO: training with the images batch#4
INFO: training with the images batch#5
INFO: training with the images batch#6
INFO: training wi

KeyboardInterrupt: 