In [None]:
import gc
from PIL import Image
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
plt.rcParams["figure.figsize"]= 15,10
#from layer import KernelConv2d, GaussianKernel, PolynomialKernel
from functools import partial # To invoke Kernel objects with input parameters when creating KernelConv2d object (e.g. partial(GaussianKernel, 0.05) for Gaussian OR partial(PolynomialKernel,2,3) for Polynomial)
%matplotlib inline
def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)
        

In [None]:
seed = 17
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



In [None]:
class lr_find():
#     min_lr, max_lr, steps_per_epoch, epochs
    def __init__(self, optimizer,model ,criterion,train_loader, val_loader=None ,min_lr= 1e-6, max_lr= 1e-1, num_iter=100, steps_per_epoch=None, epochs=None,device=None):
        super().__init__()
        
        self.min_lr= min_lr
        self.max_lr=max_lr
        self.total_iterations = num_iter
        self.history={}
        self.iteration=0
        self.optimizer= optimizer
        self.model = model
        self.iter_wrapper = DataLoaderIterWrapper(train_loader)
        self.val_loader=val_loader
        self.criterion=criterion
        
        self.model_device = next(self.model.parameters()).device
        
        
        if device:
            self.device = device
        else:
            self.device = self.model_device
            
            
        if min_lr:
            self.set_min_lr()
        
                
        
        
    def train_batch(self, iter_wrapper, acc_steps = 1):
        self.model.train()
        
        total_loss = None
        self.optimizer.zero_grad()
    
        for  i in range(acc_steps):   
            inputs, labels = next(iter_wrapper)
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss/= acc_steps
            
            loss.backward()
            if total_loss is None:
                total_loss = loss
            else:
                total_loss+=loss
        self.optimizer.step()
        
        return total_loss.item()
    
   
    def _validate(self, data_loader):
        running_loss=0
        self.model.eval()
        with torch.no_grad():
            for inputs, labels in data_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                batch_size = inputs.size(0)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                running_loss+= loss.item() * batch_size
                
        return running_loss / len(data_loader.dataset) # Avg. loss per observation   
        
        
        
    def current_lr(self):
        x = self.iteration / self.total_iterations
        lr = self.min_lr + (self.max_lr - self.min_lr) * x
        return lr
    
    def set_min_lr(self, logs=None):
        self.optimizer.param_groups[0]["lr"] = self.min_lr
    
    
    def run_test(self ):
        
        for iter_ in  tqdm(range(self.total_iterations)):
            
            self.iteration+=1


            self.history.setdefault("lr", []).append(self.optimizer.param_groups[0]["lr"])
            self.history.setdefault("iterations", []).append(self.iteration)
            
            if self.val_loader:
                self.history.setdefault("loss", []).append(self._validate(self.val_loader))
            else:
                self.history.setdefault("loss", []).append(self.train_batch(self.iter_wrapper))


            self.optimizer.param_groups[0]["lr"] = self.current_lr()

        
        
        
        
    def plot_lr(self):
        '''Helper function to quickly inspect the learning rate schedule.'''
        plt.plot(self.history['iterations'], self.history['lr'])
        plt.yscale('log')
        plt.xlabel('Iteration')
        plt.ylabel('Learning rate')
        plt.show()
        
    def plot_loss(self):
        '''Helper function to quickly observe the learning rate experiment results.'''
        plt.plot(self.history['lr'], self.history['loss'])
        plt.xscale('log')
        plt.xlabel('Learning rate')
        plt.ylabel('Loss')
        plt.show()
        
        
   
class DataLoaderIterWrapper(object):
    """A wrapper for iterating `torch.utils.data.DataLoader` with the ability to reset
    itself while `StopIteration` is raised."""

    def __init__(self, data_loader, auto_reset=True):
        self.data_loader = data_loader
        self.auto_reset = auto_reset
        self._iterator = iter(data_loader)

    def __next__(self):
        # Get a new set of inputs and labels
        try:
            inputs, labels, *_ = next(self._iterator)
        except StopIteration:
            if not self.auto_reset:
                raise
            self._iterator = iter(self.data_loader)
            inputs, labels, *_ = next(self._iterator)

        return inputs, labels
        

In [None]:
class Kernel(nn.Module):
    def __init__(self,in_channel,out_channel,kernelsize,kernel_fn, c=1.0,degree=5,gamma = 0.5,rhok=0.02, padding=0):
        super(Kernel, self).__init__()
        self.conv1 = nn.Conv2d(in_channel,out_channel,kernelsize, padding=padding)
        
        
        
        if kernel_fn == 0:            
            self.c = torch.nn.parameter.Parameter(torch.tensor(c), requires_grad=False)
            self.degree = torch.nn.parameter.Parameter(torch.tensor(degree), requires_grad=False)
        
        if kernel_fn == 1:
            self.gamma = torch.nn.parameter.Parameter(torch.tensor(gamma), requires_grad=False)
        
        if kernel_fn >= 3:
            self.rho = torch.nn.parameter.Parameter(torch.tensor(rhok, requires_grad=True))   
            self.c = torch.nn.parameter.Parameter(torch.tensor(c), requires_grad=False)
            self.degree = torch.nn.parameter.Parameter(torch.tensor(degree), requires_grad=False)
            self.gamma = torch.nn.parameter.Parameter(torch.tensor(gamma), requires_grad=False)
        
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernelsize = kernelsize
        self.kernel_fn= kernel_fn
 
    
    def __compute_shape(self, x):
        h = (x.shape[2] - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0]) // self.conv1.stride[0] + 1
        w = (x.shape[3] - self.conv1.kernel_size[1] + 2 * self.conv1.padding[1]) // self.conv1.stride[1] + 1
        return h, w
        
    def convolution(self,x):
        return self.conv1(x)
    def sigmoidkerv(self,x):
        return torch.tanh(self.convolution(x))
    
    def polynomial(self,x):
#         print(torch.max(self.convolution(x)))
#         print(torch.min(self.convolution(x)))
#         print(torch.max((self.convolution(x) + self.c) ** self.degree))
#         print(torch.min((self.convolution(x) + self.c) ** self.degree))
        return (self.convolution(x) + self.c) ** self.degree
    
    def gaussian(self,x):        
        x_unf = F.unfold(x, self.conv1.kernel_size, self.conv1.dilation, self.conv1.padding, self.conv1.stride).transpose(1, 2)
        h, w = self.__compute_shape(x)
        l2 = x_unf.unsqueeze(3) - self.conv1.weight.view(1, 1, -1, self.conv1.weight.size(0))
        l2 = torch.sum(l2 ** 2, 2)
        out =  torch.exp(-self.gamma * l2)
        if self.conv1.bias is not None:
            out = out + self.conv1.bias
        return out.view(x.shape[0], self.conv1.out_channels, w, h)

    def polyconv(self,x):
        conv = self.convolution(x)
        poly = torch.clamp(((conv + self.c) ** self.degree), min = 1e-10, max= 1e10)
        return torch.sigmoid(self.rho)*(poly) + (1-torch.sigmoid(self.rho))*conv
    
    def polysigm(self,x):
        conv = self.convolution(x)
        return torch.sigmoid(self.rho)*((conv + self.c) ** self.degree) + (1-torch.sigmoid(self.rho))*torch.tanh(conv)
        
    def kernel_fn_a(self, x):
        #Polynomial
        return F.relu(self.polynomial(x))
    
    def kernel_fn_b(self,x):
        #Gaussian
        return F.relu(self.gaussian(x))
    
    def kernel_fn_c(self, x):
        #Sigmoid
        return F.relu(self.sigmoidkerv(x))
    
    def kernel_fn_d(self, x):
        #Polynomial + Convolution
        return F.relu(self.polyconv(x))
    
    def kernel_fn_e(self, x):
        #Sigmoid + Convolution
        conv = self.convolution(x)
        return F.relu(torch.sigmoid(self.rho)*torch.tanh(conv) + (1- torch.sigmoid(self.rho)) * conv)
        
    def kernel_fn_f(self, x):
        #Gaussian + Convolution
        return F.relu(torch.sigmoid(self.rho)*self.gaussian(x) + (1- torch.sigmoid(self.rho)) * self.convolution(x))
        
    def kernel_fn_g(self, x):
        #Gaussian + Polynomial
        return F.relu(torch.sigmoid(self.rho)*self.gaussian(x) + (1- torch.sigmoid(self.rho)) * self.polynomial(x))
      
        
    def kernel_fn_h(self, x):
        #Gaussian + Sigmoid
        return F.relu(torch.sigmoid(self.rho)*self.gaussian(x) + (1- torch.sigmoid(self.rho)) * self.sigmoidkerv(x))
        
    def kernel_fn_i(self, x):
        #Polynomial + Sigmoid##
        return F.relu(self.polysigm(x))
        
    
    def forward(self, x):

        json = {0: self.kernel_fn_a ,1: self.kernel_fn_b ,2: self.kernel_fn_c ,
                3: self.kernel_fn_d ,4: self.kernel_fn_e , 5: self.kernel_fn_f,
                6: self.kernel_fn_g ,7: self.kernel_fn_h, 8: self.kernel_fn_i }
    
        function =  self.kernel_fn
        return json[function](x)
    
 


In [None]:
class ParallelKernelConv2d(torch.nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=False, padding_mode='zeros',c=1.0,degree=3,gamma = 0.5,rhok=0.02, a=1.0):
        super(ParallelKernelConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                           padding, dilation, groups, bias, padding_mode)
#         self.kernel_fn = kernel_fn
        self.conv1=torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding = padding)
#         self.convK1.weight = None#self.weight

        self.conv1.bias = self.bias
        self.conv1.weight = self.weight
        
        
        self.a = torch.nn.parameter.Parameter(torch.tensor(a, requires_grad=True))
        self.rho = torch.nn.parameter.Parameter(torch.tensor(rhok, requires_grad=True))   
        self.c = torch.nn.parameter.Parameter(torch.tensor(c), requires_grad=False)
        self.degree = torch.nn.parameter.Parameter(torch.tensor(degree), requires_grad=False)
        
#         self.rhoc1 = torch.nn.parameter.Parameter(torch.tensor(rhoc1, requires_grad=True))
#         self.rhoc2 = torch.nn.parameter.Parameter(torch.tensor(rhoc2, requires_grad=True))
        
        self.conv2=torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding = padding)
#         self.convK1.weight = None#self.weight
        if self.bias:
            self.conv2.bias = a * self.bias
        self.conv2.weight = a *  self.weight
        
    def forward(self, x):
        
        conv = self.conv1(x)
        poly = (self.conv2(x) + self.c) ** self.degree
        return torch.sigmoid(self.rho)*(poly) + (1-torch.sigmoid(self.rho))*conv
        
        

In [None]:
class LeNet5CIFAR10_Conv(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_Conv,self).__init__()
        self.conv1=nn.Conv2d(3,6,5 ) 
        self.conv2=nn.Conv2d(6,16,5 )  
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)
    
class LeNet5CIFAR10_Polynomial(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_Polynomial,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=0) 
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=0) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)  
    
class LeNet5CIFAR10_Gaussian(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_Gaussian,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=1) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=1) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)  
    
    
class LeNet5CIFAR10_Sigmoid(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_Sigmoid,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=2) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=2) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)  
    
class LeNet5CIFAR10_ConvPoly(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_ConvPoly,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=3) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=3) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)
    
class LeNet5CIFAR10_ConvGauss(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_ConvGauss,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=5) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=5) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)  
    

class LeNet5CIFAR10_ConvSigmoid(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_ConvSigmoid,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=4) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=4) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1) 
    
class LeNet5CIFAR10_SigmoidGauss(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_SigmoidGauss,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=7) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=7) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)  
    
class LeNet5CIFAR10_SigmoidPoly(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_SigmoidPoly,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=8) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=8) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)    
  

class LeNet5CIFAR10_GaussPoly(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_GaussPoly,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=6) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=6) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)  
    
    
class LeNet5CIFAR10_Poly_Conv(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_Poly_Conv,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=3,kernel_fn=0) 
        self.conv2=nn.Conv2d(6,16,5) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)  
    
 
class LeNet5CIFAR10_Conv_Poly(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_Conv_Poly,self).__init__()
        self.conv1=nn.Conv2d(3,6,5) 
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=3,kernel_fn=0)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1) 
    
class LeNet5CIFAR10_ConvPoly_Conv(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_ConvPoly_Conv,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=3,kernel_fn=3) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=nn.Conv2d(6,16,5) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)
    
    
class LeNet5CIFAR10_Conv_ConvPoly(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_Conv_ConvPoly,self).__init__()
        self.conv1=nn.Conv2d(3,6,5) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=3,kernel_fn=3) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)

    
class LeNet5CIFAR10_Conv_Sigmoid(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_Conv_Sigmoid,self).__init__()
        self.conv1=nn.Conv2d(3,6,5) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=Kernel(in_channel=6,out_channel= 16,kernelsize= 5,c=1.0,degree=5,kernel_fn=2) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)
    
class LeNet5CIFAR10_Sigmoid_Conv(nn.Module):
    def __init__(self):
        super(LeNet5CIFAR10_Sigmoid_Conv,self).__init__()
        self.conv1=Kernel(in_channel=3,out_channel= 6,kernelsize= 5,c=1.0,degree=5,kernel_fn=2) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters
        self.conv2=nn.Conv2d(6,16,5) 
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2,stride=2))
        x=F.relu(F.max_pool2d(self.conv2(x),2,stride=2))
        x=x.reshape(-1,16*5*5)#x.view(-1,320)#320
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return F.log_softmax(x,dim=1)
    
    

In [None]:
 F.relu(0.01 * torch.tensor([1e15, 234]) **3)

In [None]:
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        out = torch.clamp(F.log_softmax(out,dim=1), 1e-10,1)
        out = torch.log(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [ParallelKernelConv2d(in_channels= in_channels,out_channels= x,kernel_size= 3,padding=1, c=1.0,degree=3),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG("VGG19")
    x = torch.randn(3,3,32,32)
    y = net(x)
#     make_dot(y)
    print(y.size())
    
#     Kernel(in_channel= in_channels,out_channel= x,kernelsize= 3,padding=1, c=1.0,degree=5,kernel_fn=3)

In [None]:
# torch.log(torch.tensor([0.23,0.111, 0.00001]))

In [None]:

test()

In [None]:
from torchvision import models
model = VGG("VGG19")

In [None]:
# print(model.conv1.kernel_fn._parameters)

In [None]:
with torch.no_grad():
    train_loader=torch.utils.data.DataLoader(
    datasets.CIFAR10("data",train=True,download=True,transform=transforms.Compose([
                transforms.ToTensor(),torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])),batch_size=128,shuffle=True)
    test_loader=torch.utils.data.DataLoader(
    datasets.CIFAR10("data",train=False,download=True,transform=transforms.Compose([
                transforms.ToTensor(), torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                             ])),batch_size=128,shuffle=False)
print(len(train_loader))
print(len(test_loader))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available () else "cpu")

In [None]:
criterion = nn.NLLLoss() 
model.to(device)

In [None]:
optimizer=optim.SGD(model.parameters(), lr=0.03)

In [None]:
# model = torch.load("VGG19/VGG19_.pth")

In [None]:


# from tqdm import tqdm
# lr_finder= lr_find(optimizer, model, criterion,train_loader= train_loader, device=device )
# lr_finder.run_test()

In [None]:
# lr_finder.plot_loss()

In [None]:
optimizer=optim.SGD(model.parameters(),lr=0.1, momentum=0.9, weight_decay=5e-4)

In [None]:
optimizer.param_groups[0]["weight_decay"]

In [None]:
! pip install barbar
import time
from barbar import Bar

nb_epochs = 50
# torch.manual_seed(42)


def compute_accuray(pred,true):
    pred_idx=pred.argmax(dim=1).detach().cpu().numpy()
    tmp=pred_idx==true.cpu().numpy()
    return (sum(tmp)/len(pred_idx))*100

def plot_loss_epoch(train_loss, test_loss,epochs, model):
    train, =plt.plot(range(1,epochs+1),train_loss, marker='o')
    test, = plt.plot(range(1, epochs+1),test_loss,  marker="o")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend([train, test],["train_loss","test_loss"])
    plt.title("Loss Vs Epoch for: " + str(model))
    plt.show()
    
def plot_accuracy_epoch(train_accuracy, test_accuracy,epochs, model):
    train, =plt.plot(range(1,epochs+1),train_accuracy, marker="o")
#     print(train_accuracy)
    test, = plt.plot(range(1, epochs+1),test_accuracy,  marker="o")
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend([train, test],["train_accuracy","test_accuracy"])
    plt.title("Accuracy Vs Epoch for: " + str(model))
    plt.show() 

def model_comparison(test_acc, times, epochs):
    for model, acc in test_acc.items():
        plt.plot(times[model],acc,  marker="o")
        plt.xlabel('Time')
        plt.ylabel('Validation_Accuracy')
        plt.title("Validation Accuracy Vs Time")
        plt.legend(test_acc.keys())
        plt.show()
    
    for model, acc in test_acc.items():
        plt.plot(range(1,epochs+1),acc,  marker="o")
        plt.xlabel('Epochs')
        plt.ylabel('Validation_Accuracy')
        plt.title("Validation Accuracy Vs Epoch")
        plt.legend(test_acc.keys())
        plt.show()
    

    


def train(m,out_dir):
    iter_loss=[]
    train_losses=[]
    test_losses=[]
    train_accuracy=[]
    test_accuracy=[]
    times = []
    nan_output=[]
    iter_loss_path=os.path.join(out_dir,"iter_loss.csv")
    epoch_loss_path=os.path.join(out_dir,"epoch_loss.csv")
    last_loss=99999
    mkdirs(os.path.join(out_dir,"models"))
    optimizer=optim.SGD(m.parameters(),lr=0.1, momentum=0.9, weight_decay=5e-4)
    start_time = time.time()
    
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr= 1e-6, max_lr= 1e-1, step_size_up=2000, step_size_down=None, 
                                                  mode='triangular2', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True,
                                                  base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
    
    for epoch in range(nb_epochs):
        train_loss=0.
        train_acc=0.
        m.train(mode=True)
        for data,target in Bar(train_loader):
            data,target=data.to(device),target.to(device)
            optimizer.zero_grad()
            output=m(data)
#             if (i==100):
#                 print("train_output: ", output)
            loss=criterion(output,target)
            loss_value=loss.item()
            iter_loss.append(loss_value)
            train_loss+=loss_value
            loss.backward()
            torch.nn.utils.clip_grad_norm_(m.parameters(), 1)
            optimizer.step()
            acc=compute_accuray(torch.exp(output),target)
            train_acc+=acc
            
            
            scheduler.step()

        train_losses.append(train_loss/len(train_loader))
        train_accuracy.append(round(train_acc/len(train_loader),2))
        
        test_loss=0.
        test_acc=0.
        m.train(mode=False)
        with torch.no_grad():
            for data,target in test_loader:
                data,target=data.to(device),target.to(device)
                output=m(data)
#                 print(i)
#                 if i==50:
#                     print("test output: ", output)
                loss=criterion(output,target)
                loss_value=loss.item()
                if np.isnan(loss_value):
                    print("loss: {}, output {} : ".format( loss_value, output))
                iter_loss.append(loss_value)
                test_loss+=loss_value
                acc=compute_accuray(torch.exp(output),target)
                test_acc+=acc
            
        time_elapsed = np.round(time.time() - start_time,2)
        test_losses.append(test_loss/(len(test_loader)))
        test_accuracy.append(round(test_acc/len(test_loader),2))
        times.append(time_elapsed)
        
        print("Epoch {}: train loss is {}, train accuracy is {}; test loss is {}, test accuracy is {}, lr is: {},Weight_decay is {}".
              format(epoch,round(train_loss/len(train_loader),2),
                     round(train_acc/len(train_loader),2),
                     round(test_loss/len(test_loader),2),
                     round(test_acc/len(test_loader),2),
                     optimizer.param_groups[0]['lr'],
                     optimizer.param_groups[0]["weight_decay"]))        
        if test_loss/len(test_loader)<last_loss:      
            name = str(out_dir) + '_' + ".pth"
            save_model_path=os.path.join(out_dir,name)
            torch.save(m, save_model_path)
            last_loss=test_loss/len(test_loader)
        
#     df=pd.DataFrame()
#     df["iteration"]=np.arange(0,len(iter_loss))
#     df["loss"]=iter_loss
#     df.to_csv(iter_loss_path,index=False)
    
#     df=pd.DataFrame()
#     df["epoch"]=np.arange(0,nb_epochs)
#     df["train_loss"]=train_losses
#     df["test_loss"]=test_losses
#     df.to_csv(epoch_loss_path,index=False)
    
    
#     plot_accuracy_epoch(train_accuracy, test_accuracy, nb_epochs)
#     plot_loss_epoch(train_losses, test_losses, nb_epochs)
    
    return train_accuracy, test_accuracy, train_losses, test_losses, times, nan_output
    

In [None]:
train_accuracy, test_accuracy, train_losses, test_losses, times , nan_output= train(model, "VGG19_ConvPoly")

In [None]:
loss= [2.3, np.nan, 2.45]
for l in loss:
    if l==np.nan:
        print(l)

In [None]:
import torch
import numpy as np
torch.clamp(torch.tensor([np.nan, 23]), 0,1)


In [None]:
for i in nan_output[0]:
    print(i)

In [None]:
report = pd.DataFrame({"Epochs":range(1,nb_epochs+1),"Train_Accuracy":train_accuracy, "Test_Accuracy":test_accuracy,"Train_Loss":train_losses,"Test_Loss":test_losses, "Time":times})
report.to_csv("VGG19.csv", index=False)

In [None]:
plot_loss_epoch(train_losses, test_losses,epochs=80, model=VGG)

In [None]:
plot_accuracy_epoch(train_accuracy, test_accuracy,epochs=80, model="VGG19")

In [None]:
# print(model.conv1.convK1.weight[0:1,0:1])
# print(model.conv1.weight[0:1,0:1])

In [None]:
# print(model.conv1.kernel_fn._parameters)

In [None]:
# print(model.conv1._parameters)

In [None]:

# import torch
model = torch.load("/kaggle/input/New folder/LeNet5CIFAR10_Polynomial(1,3).pth", map_location = torch.device("cpu"))
model.eval()

In [None]:
# model.conv2._parameters

In [None]:
# torch.sigmoid(torch.tensor(1.2526))

In [None]:
def imshowCIFAR10(img, label):
  
    img = img.numpy()
    img = img.transpose(1,2,0)
    print(label)
    plt.axis("off")
    fig = plt.figure
    plt.imshow(img, cmap=plt.cm.hot)
    plt.show()

In [None]:
def imshowMNIST(img, label):
  
    img = img.numpy()
    img = img.transpose(2,0,1)
    print(label)
    plt.axis("off")
    fig = plt.figure
    plt.imshow(img, cmap=plt.cm.hot)
    plt.show()

In [None]:
layers= {"conv1": model.conv1.conv1, "conv2": model.conv2.conv1}

In [None]:
images= []
labels=[]
for img, label in train_loader:
    images.append(img)
    labels.append(label)
    

In [None]:
print(images[0][1].shape)
# print(labels[0][1])
image = images[0][7]
label = labels[0][21]
imshowCIFAR10(image,label)
# images[0][1].numpy().transpose(1,2,0)

## Visualization

#### Filter Maps

In [None]:
def plot_filters_single_channel_big(t):
    
    #setting the rows and columns
    nrows = t.shape[0]*t.shape[2]
    ncols = t.shape[1]*t.shape[3]
    
    
    npimg = np.array(t.numpy(), np.float32)
    npimg = npimg.transpose((0, 2, 1, 3))
    npimg = npimg.ravel().reshape(nrows, ncols)
    
    npimg = npimg.T
    
    fig, ax = plt.subplots()    
    imgplot = sns.heatmap(npimg, xticklabels=False, yticklabels=False, cmap='gray', ax=ax, cbar=False)
    
    
def plot_filters_single_channel(t):
    
    #kernels depth * number of kernels
    nplots = t.shape[0]*t.shape[1]
    ncols = 12
    
    nrows = 1 + nplots//ncols
    #convert tensor to numpy image
    npimg = np.array(t.numpy(), np.float32)
    
    count = 0
    fig = plt.figure(figsize=(ncols, nrows))
    
    #looping through all the kernels in each channel
    for i in range(t.shape[0]):
        for j in range(t.shape[1]):
            count += 1
            ax1 = fig.add_subplot(nrows, ncols, count)
            npimg = np.array(t[i, j].numpy(), np.float32)
            npimg = (npimg - np.mean(npimg)) / np.std(npimg)
            npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
            ax1.imshow(npimg)
            ax1.set_title(str(i) + ',' + str(j))
            ax1.axis('off')
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
   
    plt.tight_layout()
    plt.show()
    
    
    
    
def plot_filters_multi_channel(t):
    
    #get the number of kernals
    num_kernels = t.shape[0]    
    
    #define number of columns for subplots
    num_cols = 12
    #rows = num of kernels
    num_rows = num_kernels
    
    #set the figure size
    fig = plt.figure(figsize=(num_cols,num_rows))
    
    #looping through all the kernels
    for i in range(t.shape[0]):
        ax1 = fig.add_subplot(num_rows,num_cols,i+1)
        
        #for each kernel, we convert the tensor to numpy 
        npimg = np.array(t[i].numpy(), np.float32)
        #standardize the numpy image
        npimg = (npimg - np.mean(npimg)) / np.std(npimg)
        npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
        npimg = npimg.transpose((1, 2, 0))
        ax1.imshow(npimg)
        ax1.axis('off')
        ax1.set_title(str(i))
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        
    plt.savefig('myimage.png', dpi=100)    
    plt.tight_layout()
    plt.show()
    
    
    
    
def plot_weights(model, layer_name, single_channel = True, collated = False):
  
  #extracting the model features at the particular layer number
    layer = layers[layer_name]

      #checking whether the layer is convolution layer or not 
    if isinstance(layer, nn.Conv2d):
        #getting the weight tensor data
        weight_tensor = layer.weight.data

        if single_channel:
            if collated:
                plot_filters_single_channel_big(weight_tensor)
            else:
                plot_filters_single_channel(weight_tensor)

        else:
            print(weight_tensor.shape)
            if weight_tensor.shape[1] == 3:
                plot_filters_multi_channel(weight_tensor)
            else:
                print("Can only plot weights with three channels with single channel = False")

    else:
        print("Can only visualize layers which are convolutional")
        
#visualize weights for alexnet - first conv layer
# plot_weights(alexnet, 0, single_channel = False)

In [None]:

import seaborn as sns
plot_weights(model = model, layer_name= "conv1", single_channel = True, collated = False)

In [None]:
plot_weights(model = model, layer_name= "conv1", single_channel = False, collated = False)

In [None]:
plot_weights(model = model, layer_name= "conv1", single_channel = True, collated = True)

In [None]:
plot_weights(model = model, layer_name= "conv2", single_channel = True, collated = True)

#### Occlusion Experiments







In [None]:
#custom function to conduct occlusion experiments

def occlusion(model, image, label, occ_size = 15, occ_stride = 5, occ_pixel = 0.5):
  
    #get the width and height of the image
    image_copy = image.reshape(28,28)
    width, height = image_copy.shape[0], image_copy.shape[1]
    print(width)
    print(height)
  
    #setting the output image width and height
    output_height = int(np.ceil((height-occ_size)/occ_stride))
    print("output_height: ", output_height)
    output_width = int(np.ceil((width-occ_size)/occ_stride))
    print("output_width: ", output_width)
    #create a white image of sizes we defined
    heatmap = torch.zeros((output_height, output_width))
    
    #iterate all the pixels in each column
    for h in range(0, height):
        for w in range(0, width):
            
            h_start = h*occ_stride
            w_start = w*occ_stride
            h_end = min(height, h_start + occ_size)
            w_end = min(width, w_start + occ_size)
            
            if (w_end) >= width or (h_end) >= height:
                continue
            
            input_image = image.clone().detach()
            input_image = input_image.reshape(1,1,28,28)
            
            #replacing all the pixel information in the image with occ_pixel(grey) in the specified location
            input_image[:,:, w_start:w_end, h_start:h_end] = occ_pixel
            
            #run inference on modified image
            
            output = torch.exp(model(input_image))
#             print(output.tolist())
            prob = output.tolist()[0][label]
            print(prob)
            
            #setting the heatmap location to probability value
            heatmap[h, w] = prob 

    return heatmap

In [None]:
imshowMNIST(images[0][1],labels[0][1])

In [None]:
heatmap =  occlusion(model, image, label.item())
imgplot = sns.heatmap(heatmap, xticklabels=False, yticklabels=False)
figure = imgplot.get_figure()

> ####  Activation Maps

In [None]:
 list(model.children())

In [None]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook



model.conv1.register_forward_hook(get_activation('conv1'))
model.conv2.register_forward_hook(get_activation('conv2'))
model.fc1.register_forward_hook(get_activation('fc1'))
model.fc2.register_forward_hook(get_activation('fc2'))
model.fc3.register_forward_hook(get_activation('fc3'))
x = image.reshape(1,3,32,32)
output = model(x)
print(activation['conv1'].shape)

In [None]:
# activation

In [None]:
act = activation['conv1'].squeeze()
print(act.shape)
fig, axarr = plt.subplots(act.size(0), figsize=(10,10))
for idx in range(act.size(0)):
    axarr[idx].imshow(act[idx])

In [None]:
act = activation['conv2'].squeeze()
print(act.shape)
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(4,4))
for ax, feature in zip(ax.flatten(), act):
    ax.imshow(feature)



### Activation Maximization

In [None]:
list(model.children())[0]

In [None]:
class SaveFeatures():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = torch.tensor(output,requires_grad=True).cuda()
    def close(self):
        self.hook.remove()

In [None]:
activations = SaveFeatures(list(model.children())[0])

In [None]:
list(model.children())[0].register_forward_hook()

In [None]:
activations.f[0, 1]

In [None]:
class FilterVisualizer():
    def __init__(self,  model,size=56, upscaling_steps=12, upscaling_factor=1.2):
        self.size, self.upscaling_steps, self.upscaling_factor = size, upscaling_steps, upscaling_factor
        self.model = model
        self.model.eval()

    def visualize(self, layer, filter, lr=0.1, opt_steps=20, blur=None):
        sz = self.size
        img = np.uint8(np.random.uniform(150, 180, (1,sz, sz)))/255  # generate random image
        activations = SaveFeatures(list(self.model.children())[layer])  # register hook

        for _ in range(self.upscaling_steps):  # scale the image up upscaling_steps times
#             train_tfms, val_tfms = tfms_from_model(self.model, sz)
            img_var = torch.tensor(img[None], dtype=torch.double,requires_grad=True)  # convert image to Variable that requires grad
            optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)
            for n in range(opt_steps):  # optimize pixel values for opt_steps times
                optimizer.zero_grad()
                self.model(img_var)
                loss = -activations.features[0, filter].mean()
                loss.backward()
                optimizer.step()
            img = img_var.data.cpu().numpy()[0].transpose(1,2,0)
            self.output = img
            sz = int(self.upscaling_factor * sz)  # calculate new image size
            img = cv2.resize(img, (sz, sz), interpolation = cv2.INTER_CUBIC)  # scale image up
            if blur is not None: img = cv2.blur(img,(blur,blur))  # blur image to reduce high frequency patterns
        self.save(layer, filter)
        activations.close()
        
    def save(self, layer, filter):
        plt.imsave("layer_"+str(layer)+"_filter_"+str(filter)+".jpg", np.clip(self.output, 0, 1))

In [None]:
layer = 1
filter = 2
FV = FilterVisualizer(model, size=56, upscaling_steps=12, upscaling_factor=1.2)
FV.visualize(layer, filter, blur=5)

#### Saliency Maps

In [None]:

X=image.reshape(1,1,28,28)
X.requires_grad_()
scores = torch.exp(model(X))
score_max_index = scores.argmax()
score_max = scores[0,score_max_index]



In [None]:
score_max.backward()


In [None]:
saliency, _ = torch.max(X.grad.data.abs(),dim=1)


# code to plot the saliency map as a heatmap
plt.imshow(saliency[0], cmap=plt.cm.hot)
# plt.imshow(images[0][1])
plt.axis('off')
plt.show()

In [None]:
imshowMNIST(images[0][1],labels[0][1])