In [23]:
from torch.nn.functional import fold, unfold
from torch import Tensor

class Module(object):
    
     def forward(self,*input):
        raise NotImplementedError
    
     def backward(self,*gradwrtoutput):
        raise NotImplementedError
    
     def param(self):
        return []

In [24]:
class Sequential(Module):
    
    def __init__(self, modules):
        self.modules = modules
        self.input = None
    
    def forward(self, input):
        self.input = input
        output = self.input
        
        for module in self.modules:
            output = module.forward(output)
        return output
    
    def backward(self,*gradwrtouput):
        gradient = gradwrtouput
        for module in reversed(self.modules):
            gradient = module.backward(gradient)
        self.input = None
        return gradient
    
    def param(self):
        params = []
        for module in self.modules:
            params.append(module.param())
        return params
            

In [25]:
class Optimizer(object):
    def step(self):
        return NotImplementedError
    
    def zero_grad(self):
        return NotImplementedError

In [26]:
class SGD(Optimizer):
    
    def __init__(self,params,lr, mu = 0, tau = 0):
        self.params = params
        self.lr = lr
        
        # parameters in order to add momemtum 
        self.momemtum = mu
        self.dampening = tau
        self.state_momemtum = None
        
    def step(self):
        for x, grad in self.params:
            x.add_(-self.lr * grad) 
    
    def zero_grad(self):
        for x, grad in self.params:
            grad = grad.zero_()  

In [27]:
class MSE(Module):
    
    def forward(self,input,target):
        self.input = input
        self.target = target
        return (self.input - self.target).pow(2).mean() 
    
    def backward(self):
        return 2*(self.input - self.target).div(torch.tensor(self.input.size(0)))
        ## we divide by the batch size as in Pytorch

In [28]:
class Sigmoid(Module):
    
    def forward(self,input):
        self.input = input
        self.sigmoid = 1./(1+(-self.input).exp())
        return  self.sigmoid
    
    def backward(self,*gradwrtouput):
        return gradwrtouput*self.sigmoid*(1-self.sigmoid)
    

In [29]:
class ReLU(Module):
    def forward(self, input):
        self.input = input
        return (self.input>0.)*self.input
    
    def backward(self, *gradwrtouput):
        return gradwrtouput*(self.input>=0.)

In [34]:
def _pair(x):
    if isinstance(x, int):
        return (x, x)
    return x

def pad(input, pad=(1,1,1,1), mode="constant", value=0.0):
    # pad = (pad_left, pad_right, pad_up, pad_down)
    if len(pad) == 2:
        pad = (pad[0], pad[0], pad[1], pad[1])
    input_shape = [input.size(0), input.size(1), input.size(2), input.size(3)]
    
    input_shape[3] += (pad[0] + pad[1])
    input_shape[2] += (pad[2] + pad[3])
    i1 = pad[2]
    j1 = pad[0]
    i2 = input_shape[2] - pad[3]
    j2 = input_shape[3] - pad[1]
    
    result = empty(input_shape).fill_(value)
    result[:,:,i1:i2,j1:j2] = input
    return result

def zero_internal_pad(x, pad=(1,1)):
    x_shape = [x.size(0), x.size(1), x.size(2), x.size(3)]
    x_shape[2] += (pad[0]*(x_shape[2]-1)) 
    x_shape[3] += (pad[1]*(x_shape[3]-1)) 
    res = empty(x_shape).fill_(0.0)
    res[:,:,::pad[0]+1,::pad[1]+1] = x
    return res



In [115]:
def nearest_upsampling(input, scale_factor):
    
    scale = _pair(scale_factor)
    
    return input.repeat_interleave(scale[0], dim = 2).repeat_interleave(scale[1], dim = 3)

In [255]:
class NearestUpsampling(Module):
        
    def __init__(self, scale_factor: None ):
        self.scale_factor = _pair(scale_factor)
    
    def forward(self, input):
        self.input_size = input.size()
        
        if (len(input.size()) == 4) :
            return nearest_upsampling(input, self.scale_factor)

        
    # the error between the our gradient and the true value of the gradient is small but not enough ...
    #for scale 2 it's quite ok
    def backward(self,gradwrtouput):
        return conv2d(gradwrtouput, empty(self.input_size[1],1,self.scale_factor[0],self.scale_factor[1]).fill_(1.), bias=None, stride=self.scale_factor, groups = self.input_size[1])


In [266]:
input = torch.normal(mean = torch.zeros((110,3,280, 280)),std = 1).requires_grad_(requires_grad=True) 

m = torch.nn.Upsample(scale_factor=2, mode='nearest')
mm = m(input)
out = mm.pow(3).sum()
#print(out.size())
out.backward()
m_ = NearestUpsampling(2)
mm_ = m_.forward(input)
#print(mm.size())
#print((3*(mm.pow(3))).size())
#print(input)
#print(input.grad.size())
print((m_.backward(3*mm.pow(2))-input.grad).max())


#print(torch.allclose(mm,mm_))
#print((mm-mm_).norm())

tensor(0., grad_fn=<MaxBackward1>)


In [258]:
x = torch.arange(0,9).view(3,3).repeat_interleave(3,0).repeat_interleave(3,1)
x[::3,::3]

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

In [35]:
# TODO: groups not working
def conv_transpose2d(input: Tensor, weight: Tensor, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1, output_padding=0) -> Tensor:
    # input is 4d tensor
    stride = _pair(stride)
    padding = _pair(padding)
    output_padding = _pair(output_padding)
    dilation = _pair(dilation)

    N = input.size(0)
    H_in = input.size(-2)
    W_in = input.size(-1)
        
    kernel_size = (weight.size(-2), weight.size(-1))
    C_out = weight.size(1)
    H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + 1 + output_padding[0]
    W_out = (W_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + 1 + output_padding[1]
    
    pad0 = (dilation[0] * (kernel_size[0] - 1) - padding[0])
    pad1 = (dilation[1] * (kernel_size[1] - 1) - padding[1])

    if (pad0<0) or (pad1<0):
        raise ValueError("Invalid inputs, transposed convolution not possible")
    
    if (stride[0]>1) or (stride[1]>1):
        input = zero_internal_pad(input, pad=(stride[0]-1,stride[1]-1))
    if (output_padding[0]>0) or (output_padding[1]>0):
        input = pad(input, pad=(0,output_padding[1],0,output_padding[0]))
    unfolded = unfold(input, kernel_size=kernel_size, dilation=dilation, stride=1, padding=(pad0,pad1))
    
    w = weight.transpose(0,1).rot90(-2, [-2,-1])
    
    wxb = unfolded.transpose(1, 2).matmul(w.reshape(w.size(0), -1).t()).transpose(1, 2) + bias.view(1, -1, 1).repeat(N,1,1)
    
    res = wxb.view(N, C_out, H_out, W_out)
    return res



In [32]:
def grad_conv2d_weight(self,input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
        r"""
        Computes the gradient of conv2d with respect to the weight of the convolution.
        Args:
            input: input tensor of shape (minibatch x in_channels x iH x iW)
            weight_size : Shape of the weight gradient tensor
            grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
            stride (int or tuple, optional): Stride of the convolution. Default: 1
            padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
            dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
            groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        """
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        in_channels = input.shape[1]
        out_channels = grad_output.shape[1]
        N = input.shape[0]

        grad_output = grad_output.repeat(1, in_channels // groups, 1,
                                                      1)
        grad_output = grad_output.view(
            grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2],
            grad_output.shape[3])

        input = input.view(1, input.shape[0] * input.shape[1],
                                        input.shape[2], input.shape[3])

        grad_weight = conv2d(input, grad_output, None, dilation, padding,
                                   stride, in_channels * N)

        grad_weight = grad_weight.view(
            N, grad_weight.shape[1] // N, grad_weight.shape[2],
            grad_weight.shape[3])

        return grad_weight.sum(dim=0).view(
            in_channels // groups, out_channels,
            grad_weight.shape[2], grad_weight.shape[3]).transpose(0, 1).narrow(
                2, 0, weight_size[2]).narrow(3, 0, weight_size[3])

In [202]:
def conv2d(input: Tensor, weight: Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor:
        # input is 4d tensor
        
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        N = input.size(0)
        C_in = input.size(1)
        H_in = input.size(-2)
        W_in = input.size(-1)
        
        kernel_size = (weight.size(-2), weight.size(-1))
        C_out = weight.size(0)
        H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
        W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
       
        out_conv = torch.empty(N,C_out,H_out,W_out)
        
        inp_unf = unfold(input.view(N*groups,C_in//groups,H_in,W_in),kernel_size,dilation,padding,stride)#.view(N,H_out*W_out,-1,C_in//groups)
        inp_unf = inp_unf.view(N,groups,inp_unf.size(-2),inp_unf.size(-1)).transpose(0,1)
        
        
        weight = weight.view(groups,C_out//groups,C_in//groups,kernel_size[0],kernel_size[1])
        
        for i in range(groups):
            out_unf = inp_unf[i,...].transpose(1,2).matmul(weight[i,...].view(C_out//groups,-1).t()).transpose(1,2)
            out_conv[:,i*(C_out//groups):(1+i)*(C_out//groups),...] = out_unf.view(N,C_out//groups,H_out,W_out)
        if bias!=None:
            return out_conv + bias.view(1, -1, 1,1).repeat(N,1,H_out,W_out)
        return out_conv

In [31]:
class Conv2d(Module):
    
    def __init__(self,in_channel, out_channel, kernel_size = (2,2),stride=1, padding=0, dilation=1, groups=1,weight = None, bias = None):
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.kernel = _pair(kernel_size)
        
        self.groups = groups
        self.in_channel = in_channel
        self.out_channel = out_channel
        
        k = self.groups/(self.in_channel*self.kernel.prod())
        if weight == None:
            self.weight = torch.empty((self.out_channel,self.in_channel// self.groups,
                                       self.kernel[0],self.kernel[1])).uniform_(-k,k)
        else:
            self.weight = weight
        if bias == None:
            self.bias = torch.empty(self.out_channel).uniform_(-k,k)
        else:
            self.bias = bias
        
        self.weight_grad = torch.empty((self.kernel[0],self.kernel[1]))
        self.bias_grad = torch.empty(self.out_channel)
    
    
    def forward(self, input):
        self.input = input
        
        if (len(input.size()) == 4) :
            return self.conv2d(self.input,weight = self.weight, bias = self.bias,stride =self.stride,
                                  padding = self.padding, dilation = self.dilation,groups = self.groups)
    

    def grad_conv_bias(self,grad_output,bias_size):
        
        return grad_output.transpose(0,1).reshape(grad_output.size(1),-1).sum(dim = 1)
    
    def backward(self,gradwrtouput):
        bias = gradwrtouput.sum((0, 2, 3)).squeeze(0)
        a = self.grad_conv2d_weight(self.input, self.weight.shape, gradwrtouput, stride = self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
        #b = self.conv_transpose2d(self.input, self.weight, gradwrtouput,stride = self.stride,padding = self.padding,dilation = self.dilation)
        return bias
    
    