In [None]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.autograd import Function

In [None]:
class Maxout(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input):
        x = input
        max_out=4    #Maxout Parameter
        kernels = x.shape[1]  # to get how many kernels/output
        feature_maps = int(kernels / max_out)
        out_shape = (x.shape[0], feature_maps, max_out, x.shape[2], x.shape[3])
        x= x.view(out_shape)
        y, indices = torch.max(x[:, :, :], 2)
        ctx.save_for_backward(input)
        ctx.indices=indices
        ctx.max_out=max_out
        return y

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        input1,indices,max_out= ctx.saved_variables[0],Variable(ctx.indices),ctx.max_out
        input=input1.clone()
        for i in range(max_out):
            a0=indices==i
            input[:,i:input.data.shape[1]:max_out]=a0.float()*grad_output
      

        return input


In [None]:
#This is an example for image reconstruction but you can modify it as you want.
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.mo1=Maxout.apply
        self.layer2 = nn.Conv2d(16, 32, kernel_size=5, padding=2)
        self.mo2 = Maxout.apply
        self.layer3 = nn.Conv2d(8, 4, kernel_size=3, padding=1) 
        self.mo3 = Maxout.apply #max_out on line 8 if class Maxout is 4, it will output 1 feature map here
        
    def forward(self, x):
        out = self.mo1(self.layer1(x))
        out = self.mo2(self.layer2(out))
        out = self.mo3(self.layer3(out))
        return out