# Interpretabble image classifier using deep learning

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
import numpy as np

In [77]:
class FeatureModel(nn.Module):
    def __init__(self):
        '''
        Define model that maps (3, Q, Q) to (V)        
        '''
        super(FeatureModel, self).__init__()
        self.l1 = self.get_conv_layer((3, 64, 3))
        self.l2 = self.get_conv_layer((64, 64, 3))
        self.l3 = self.get_conv_layer((64, 128, 3))
        self.l = nn.Sequential(*[self.l1, self.l2, self.l3])
        self.fc1 = nn.Linear(128*4*4, 100)

    def get_conv_layer(self, param):
        return nn.Sequential(*[nn.Conv2d(*param), nn.BatchNorm2d(param[1]), nn.ReLU()])
        
    def forward(self, x):
        '''
        Args
            patch - (3, Q, Q)
        Returns
            representation - (V)
        Convolutional neural network that maps patch to a vector
        '''
        x1 = self.l(x)
        return self.fc1(x1.view(-1, 128*4*4))

model = FeatureModel()
model(torch.randn(1, 3, 10, 10)).shape


torch.Size([1, 100])

In [84]:
class AdhocNet(nn.Module):
    def __init__(self, patch_size):
        super(AdhocNet, self).__init__()
        self.patch_size = patch_size
        self.feature_model = FeatureModel()
        self.predictor = nn.Linear(100, 10)
            
    def img_to_patch(self, img, patch_size):
        '''
        Args
            img - (3, W, H)
            patch_size - (Q)
        Returns
            Patches (W/Q, H/Q, 3, Q, Q)
        '''
        num_patches = int(img.shape[1] / patch_size)
        img1 = torch.stack(torch.split(img, num_patches, dim=2))
        img2 = torch.stack(torch.split(img1, num_patches, dim=2))
        return img2.permute(3, 4, 2, 0, 1)
    
   
    def patches_to_representations(self, patches):
        '''
        Args
            patches - (W/Q, H/Q, 3, Q, Q)
        Returns
            representations - (W/Q, H/Q, V)
        '''
        representations = []
        for i in range(patches.shape[0]):
            row = []
            for j in range(patches.shape[1]):
                features = self.feature_model(patches[i,j].unsqueeze(0))
                scores = self.predictor(features)[0]
                row.append(scores)
            representations.append(torch.stack(row))
        representations = torch.stack(representations)
        return representations

    def forward(self, x):
        batch_patches = torch.stack([self.img_to_patch(img, self.patch_size) for img in x])
        batch_representations = torch.stack([ self.patches_to_representations(patches) for patches in batch_patches])
        return batch_representations, batch_representations.sum(dim=[1, 2])
        

model = AdhocNet(10)
rep, scores = model(torch.randn(1, 3, 200, 200))
rep.shape, scores.shape

(torch.Size([1, 20, 20, 10]), torch.Size([1, 10]))