<a href="https://colab.research.google.com/github/aryanasadianuoit/Fitnets/blob/master/FitNets_Maxout_Student_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.autograd import Function
from dataloader import cifar10
from torch.nn import functional as F
from torch.nn.parameter import Parameter 
from torchsummary import summary
from general_utils import train
from globals import *

MAX_OUT_NUMBER = 2
W_B_LR_SCALE = 0.05

class Maxout(Function):



    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input):
        x = input
          # Maxout Parameter
        max_out = MAX_OUT_NUMBER
        #print("MAX-OUT NUMBER ===>", max_out_number)
        kernels = x.shape[1]  # to get how many kernels/output
        feature_maps = int(kernels / max_out)
        out_shape = (x.shape[0], feature_maps, max_out, x.shape[2], x.shape[3])
        #print("OUt shape ==> ", out_shape)
        x = x.view(out_shape)
        y, indices = torch.max(x[:, :, :], 2)
        ctx.save_for_backward(input)
        ctx.indices = indices
        ctx.max_out = max_out
        return y

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):

        input1, indices, max_out = ctx.saved_variables[0], Variable(ctx.indices), ctx.max_out
        input = input1.clone()
        for i in range(max_out):
            a0 = indices == i
            input[:, i:input.data.shape[1]:max_out] = a0.float() * grad_output * W_B_LR_SCALE

        return input


# This is an example for image reconstruction but you can modify it as you want.
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.h0 = nn.Conv2d(3, 32, kernel_size=9, padding=4, stride=1)
        self.mo0 = Maxout.apply
        #self.mo1 = Maxout.apply


        #Maxout(max_out_number=4)
        self.h1 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.mo1 = Maxout.apply
        #self.mo2 = Maxout.apply

        self.h2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.mo2 = Maxout.apply

        self.h3 = nn.Conv2d(16, 48, kernel_size=3, padding=1)
        self.mo3 = Maxout.apply
       

        self.h4 = nn.Conv2d(24, 48, kernel_size=3, padding=1)
        self.h4_pooling = self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.mo4 = Maxout.apply

        self.h5 = nn.Conv2d(24, 80, kernel_size=3, padding=1)
        self.mo5 = Maxout.apply

        self.h6 = nn.Conv2d(40, 80, kernel_size=3, padding=1)
        self.mo6 = Maxout.apply

        self.h7 = nn.Conv2d(40, 80, kernel_size=3, padding=1)
        self.mo7 = Maxout.apply

        self.h8 = nn.Conv2d(40, 80, kernel_size=3, padding=1)
        self.mo8 = Maxout.apply

        self.h9 = nn.Conv2d(40, 80, kernel_size=3, padding=1)
        self.mo9 = Maxout.apply

        self.h10 = nn.Conv2d(40, 80, kernel_size=3, padding=1)
        self.h10_pooling = self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.mo10 = Maxout.apply

        self.h11 = nn.Conv2d(40, 128, kernel_size=3, padding=1)
        self.mo11 = Maxout.apply

        self.h12 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.mo12 = Maxout.apply

        self.h13 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.mo13 = Maxout.apply

        self.h14 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.mo14 = Maxout.apply

        self.h15 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.mo15 = Maxout.apply

        self.h16 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.h16_pooling = self.pool = nn.MaxPool2d(kernel_size=8, stride=1)
        self.mo16 = Maxout.apply


        self.fc1 = nn.Linear(in_features= 64 * 1 * 1, out_features=10)
   

    def forward(self, x):
        out = self.mo0(self.h0(x))
        out = self.mo1(self.h1(out))
        out = self.mo2(self.h2(out))
        out = self.mo3(self.h3(out))
        out = self.mo4(self.h4_pooling(self.h4(out)))
        out = self.mo5(self.h5(out))
        out = self.mo6(self.h6(out))
        out = self.mo7(self.h7(out))
        out = self.mo8(self.h8(out))
        out = self.mo9(self.h9(out))
        out = self.mo10(self.h10_pooling(self.h10(out)))
        out = self.mo11(self.h11(out))
        out = self.mo12(self.h12(out))
        out = self.mo13(self.h13(out))
        out = self.mo14(self.h14(out))
        out = self.mo15(self.h15(out))
        out = self.mo16(self.h16_pooling(self.h16(out)))
        out = out.view(-1, 1 * 1 * 64 )
        out = F.relu(self.fc1(out))
        #out = F.relu(self.fc2(out))
       # out = self.fc3(out)
        out = F.softmax(out, dim=-1)
        return out

test = CNN()
print(test)


summary(model=test, input_size=(3,32,32), device="cpu")

train_loader, test_loader = cifar10()
optimizer = torch.optim.SGD(test.parameters(),lr= 0.005, momentum= 0.9,weight_decay= 0.9)
criterion = nn.CrossEntropyLoss()
train(test,optimizer= optimizer,criterion= criterion, train_loder= train_loader, test_loader= test_loader, train_on_gpu= True, multiple_gpu= True,
      path= SERVER_2_PREFIX_ADDRESS+"testmaxout.pth", epochs= 5)