In [22]:
class Module(object):
    
     def forward(self,*input):
        raise NotImplementedError
    
     def backward(self,*gradwrtoutput):
        raise NotImplementedError
    
     def param(self):
        return []

In [23]:
class Sequential(Module):
    
    def __init__(self, modules):
        self.modules = modules
        self.input = None
    
    def forward(self, input):
        self.input = input
        output = self.input
        
        for module in self.modules:
            output = module.forward(output)
        return output
    
    def backward(self,*gradwrtouput):
        gradient = gradwrtouput
        for module in reversed(self.modules):
            gradient = module.backward(gradient)
        self.input = None
        return gradient
    
    def param(self):
        params = []
        for module in self.modules:
            params.append(module.param())
        return params
            

In [24]:
class Optimizer(object):
    def step(self):
        return NotImplementedError
    
    def zero_grad(self):
        return NotImplementedError

In [25]:
class SGD(Optimizer):
    
    def __init__(self,params,lr, mu = 0, tau = 0):
        self.params = params
        self.lr = lr
        
        # parameters in order to add momemtum 
        self.momemtum = mu
        self.dampening = tau
        self.state_momemtum = None
        
    def step(self):
        for x, grad in self.params:
            x.add_(-self.lr * grad) 
    
    def zero_grad(self):
        for x, grad in self.params:
            grad = grad.zero_()  

In [26]:
class MSE(Module):
    
    def forward(self,input,target):
        self.input = input
        self.target = target
        return (self.input - self.target).pow(2).mean() 
    
    def backward(self):
        return 2*(self.input - self.target).div(torch.tensor(self.input.size(0)))
        ## we divide by the batch size as in Pytorch

In [27]:
#Check MSE loss function
mse = MSE()
input = torch.normal(mean=torch.zeros(100,3,10),std = 1)
target = torch.normal(mean=torch.zeros(100,3,10),std = 1)
loss = mse.forward(input,target)
print(loss)
mse_ = torch.nn.MSELoss()
loss_ = mse_.forward(input,target)
print(loss_)

print(mse.backward().size())

tensor(1.9983)
tensor(1.9983)
torch.Size([100, 3, 10])


In [28]:
class Sigmoid(Module):
    
    def forward(self,input):
        self.input = input
        self.sigmoid = 1./(1+(-self.input).exp())
        return  self.sigmoid
    
    def backward(self,*gradwrtouput):
        return gradwrtouput*self.sigmoid*(1-self.sigmoid)
    

In [29]:
input = torch.normal(mean=torch.zeros(100,3,38,38),std = 1)

sig = Sigmoid()
sig_ = torch.nn.Sigmoid()

print(torch.allclose(sig.forward(input),sig_.forward(input)))

True


In [30]:
class ReLU(Module):
    def forward(self, input):
        self.input = input
        return (self.input>0.)*self.input
    
    def backward(self, *gradwrtouput):
        return gradwrtouput*(self.input>=0.)

In [31]:
input = torch.normal(mean=torch.zeros(100,3,380,380),std = 1)

re = ReLU()
re_ = torch.nn.ReLU()

print(torch.allclose(sig.forward(input),sig_.forward(input)))

True


In [32]:
import torch 

In [33]:
def nearest_upsampling(input, scale_factor):
    if isinstance(scale_factor,int):
        scale1 ,scale2 = scale_factor,scale_factor
    if isinstance(scale_factor,tuple):
        scale1 ,scale2 = scale_factor[0],scale_factor[1]
    
    N, C, H, W = tuple(input.size())
    output = torch.empty(N,C,scale1*H,scale2*W)
    
    for i in range(N):
        output[i] = torch.nn.functional.fold(input[i].view(C,1,-1).repeat(1,scale1*scale2,1), 
                                             output_size=(H*scale1,W*scale2), kernel_size = (scale1,scale2), 
                                             stride = (scale1,scale2) ).view(C,H*scale1,W*scale2)
    return output

In [34]:
class NearestUpsampling(Module):
        
    def __init__(self, scale_factor: None ):
        self.scale_factor = scale_factor
        
    
    def forward(self, input):
        self.input = input
        
        if (len(input.size()) == 4) :
            return nearest_upsampling(self.input, self.scale_factor)

        
    
    def backward(self,*gradwrtouput):
        return 3
    
    

In [35]:
input = torch.normal(mean = torch.zeros((100,3,280, 280)),std = 1)

m = torch.nn.Upsample(scale_factor=2, mode='nearest')
mm = m(input)

m_ = NearestUpsampling(2)
mm_ = m_.forward(input)

print(torch.allclose(mm,mm_))


True


In [37]:
from torch.nn.functional import fold, unfold
from torch import Tensor, empty

def conv2d_(input: Tensor, weight: Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor:
    # input is 4d tensor
    
    if isinstance(stride, int):
        stride = (stride, stride)
    if isinstance(padding, int):
        padding = (padding, padding)
    if isinstance(dilation, int):
        dilation = (dilation, dilation)
    
    N = input.size(0)
    H_in = input.size(-2)
    W_in = input.size(-1)
    
    kernel_size = (weight.size(-2), weight.size(-1))
    C_out = weight.size(0)
    H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
    W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
        
    unfolded = unfold(input, kernel_size=kernel_size, dilation=dilation, stride=stride, padding=padding)
    
    #wxb  = empty(N, C_out, unfolded.size(2))
    #for ind, unfdd in enumerate(unfolded):
    #    wxb[ind] =  weight.view(C_out, -1) @ unfdd + bias.view(-1,1)
    
    #wxb = einsum('nij,njk->nik', weight.view(1, C_out, -1).repeat(N, 1, 1), unfolded) + bias.view(1, -1, 1).repeat(N,1,1)
    
    wxb = weight.view(1, C_out, -1).repeat(N, 1, 1).matmul(unfolded) + bias.view(1, -1, 1).repeat(N,1,1)
    
    return wxb.view(N, C_out, H_out, W_out)

#FALSE METHOD
def grad_conv2d_weight(self,gradwrtouput):

        N = self.input.shape[0]

        grad_ = gradwrtouput.contiguous().repeat(1, self.in_channel // self.groups, 1,
                                                      1)
        grad_ = grad_.contiguous().view(
            grad_.shape[0] * grad_.shape[1], 1, grad_.shape[2],
            grad_.shape[3])

        input = self.input.contiguous().view(1, self.input.shape[0] * self.input.shape[1],
                                        self.input.shape[2], self.input.shape[3])

        grad_weight = self.conv2d(input, grad_, None, self.dilation, self.padding,
                                   self.stride, self.in_channel * N)

        grad_weight = grad_weight.contiguous().view(
            min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2],
            grad_weight.shape[3])

        return grad_weight.sum(dim=0).view(
            self.in_channel // self.groups, self.out_channel,
            grad_weight.shape[2], grad_weight.shape[3]).transpose(0, 1).narrow(
                2, 0, self.kernel[2]).narrow(3, 0, self.kernel[3])
    

    
        @staticmethod
    def conv2d(input: Tensor, weight: Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor:
        # input is 4d tensor

        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        N = input.size(0)
        H_in = input.size(-2)
        W_in = input.size(-1)

        kernel_size = (weight.size(-2), weight.size(-1))
        C_out = weight.size(0)
        H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
        W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)

        unfolded = unfold(input, kernel_size=kernel_size, dilation=dilation, stride=stride, padding=padding)
        
        print(unfolded.size(),weight.view(1, C_out, -1).repeat(N, 1, 1).size())
        
        if bias != None:
            wxb = weight.view(1, C_out, -1).repeat(N, 1, 1).matmul(unfolded.repeat_interleave(groups,1)) #+ ...
            #bias.view(1, -1, 1).repeat(N,1,1)
        else: 
            wxb = weight.view(1, C_out, -1).repeat(N, 1, 1).matmul(unfolded)
        #print(wxb.size())
        return wxb.view(N, C_out, H_out, W_out)



    def conv2d(input: Tensor, weight: Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor:
        # input is 4d tensor
        

        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        N = input.size(0)
        C_in = input.size(1)
        H_in = input.size(-2)
        W_in = input.size(-1)
        
        
        kernel_size = (weight.size(-2), weight.size(-1))
        C_out = weight.size(0)
        H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
        W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
        
        #weight = weight.repeat_interleave(groups,1)
        
        inp_unf = unfold(input,kernel,stride,padding,dilation)
        print("weight", weight.view(C_out, -1).t().size())
        print("inp_", inp_unf.transpose(1, 2).size())
        if bias != None:
            out_unf = inp_unf.transpose(1, 2).matmul(weight.view(C_out, -1).t()).transpose(1, 2).add(bias.view(1, -1, 1).repeat(N,1,H_out*W_out))
        else:
            out_unf = inp_unf.transpose(1, 2).matmul(weight.view(C_out, -1).t()).transpose(1, 2)
            
        out = fold(out_unf, (H_out,W_out), (1,1), dilation, padding, stride)
        
        return out

    def conv_backward(x, w, b, conv_param, dout):
        HF, WF, DF, NF = w.shape
        x_col = im2col(x, HF, WF, conv_param['pad'], conv_param['stride'])
        w_col = w.transpose(3, 0, 1, 2).reshape((NF, -1))
        db = np.sum(dout, axis=(0, 1, 3))
        dout = dout.transpose(2, 0, 1, 3)
        dout = dout.reshape((w_col.shape[0], x_col.shape[-1]))
        dx_col = w_col.T.dot(dout)
        dw_col = dout.dot(x_col.T)

        dx = col2im(dx_col, x.shape, HF, WF, conv_param['pad'], conv_param['stride'])
        dw = dw_col.reshape((dw_col.shape[0], HF, WF, DF))
        dw = dw.transpose(1, 2, 3, 0)

        return [dx, dw, db]

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 65)

In [None]:
    def grad_conv2d_weight(self,input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):

        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        N = input.size(0)
        C_in = input.size(1)
        H_in = input.size(-2)
        W_in = input.size(-1)
        
        kernel_size = (weight_size[2], weight_size[3])
        C_out = weight_size[0]
        H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
        W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)

        grad_w = torch.empty(weight_size)

        for i in range(grad_w.size(0)): 
            for j in range(input.size(1)): 
                grad_w[i,j,:,:] = self.conv2d(self.input[:,j,:,:].view(N,1,H_in,W_in),
                                              weight = grad_output[i,:,:].view(1,1,H_out,W_out), 
                                              bias = None,stride =self.dilation,padding = self.padding, 
                                              dilation = self.stride,groups = self.groups).narrow(2, 0, 
                                              weight_size[-2]).narrow(3, 0, weight_size[-1]).sum(dim=0)
        
        return grad_w
    
    def conv2d(input: Tensor, weight: Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor:
        # input is 4d tensor
        

        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        N = input.size(0)
        C_in = input.size(1)
        H_in = input.size(-2)
        W_in = input.size(-1)
        
        
        kernel_size = (weight.size(-2), weight.size(-1))
        C_out = weight.size(0)
        H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
        W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
        
        inp_unf = unfold(input,kernel_size,dilation,padding,stride)
        if bias != None:
            out_unf = inp_unf.transpose(1, 2).matmul(weight.view(C_out, -1).t()).transpose(1, 2).add(bias.view(1, -1, 1).repeat(N,1,H_out*W_out))
        else:
            out_unf = inp_unf.transpose(1, 2).matmul(weight.view(C_out, -1).t()).transpose(1, 2)
        return out_unf.view(N,C_out,H_out,W_out)
    
    def _conv_transpose2d(input: Tensor, weight: Tensor, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1, output_padding=0) -> Tensor:
    # input is 4d tensor
    if isinstance(stride, int):
        stride = (stride, stride)    
    if isinstance(padding, int):
        padding = (padding, padding)
    if isinstance(dilation, int):
        dilation = (dilation, dilation)
    if (output_padding>=stride[0] or output_padding>=stride[1]) & (output_padding>=dilation[0] or output_padding>=dilation[1]):
        raise ValueError("Invalid output_padding,output padding must be smaller than either stride or dilation")
    
    N = input.size(0)
    H_in = input.size(-2)
    W_in = input.size(-1)
    
    kernel_size = (weight.size(-2), weight.size(-1))
    C_out = weight.size(1)
    H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + 1
    W_out = (W_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + 1
    
    H_out_ = H_out + output_padding
    W_out_ = W_out + output_padding
    #print(H_out_)
    pad0 = (dilation[0] * (kernel_size[0] - 1) - padding[0]) 
    pad1 = (dilation[1] * (kernel_size[1] - 1) - padding[1])

    if (pad0<0) or (pad1<0):
        raise ValueError("Invalid inputs, transposed convolution not possible")
    
    if (stride[0] != 1) or (stride[1]!=1):
        inputs = torch.empty(N,input.size(1),stride[0]*H_in-1,stride[1]*W_in-1).fill_(0.)
        for i in range(H_in):
            for j in range(W_in):
                inputs[...,i*stride[0],j*stride[1]] = input[...,i,j]
    else:
        inputs = input
    
    unfolded = unfold(inputs, kernel_size=kernel_size, dilation=dilation,stride = 1, padding=(pad0,pad1))
    
    w = weight.transpose(0,1).rot90(2, [-2,-1])
    
    wxb = unfolded.transpose(1, 2).matmul(w.reshape(w.size(0), -1).t()).transpose(1, 2)
    print(wxb.size(-1))
    if (wxb.size(-1) == H_out_*W_out_):
        print(1)
        return wxb.view(N,C_out, H_out_, W_out_) + bias.view(1, -1, 1,1).repeat(N,1,1,1)
    else:
        find_size = True
        H_out_effective = H_out_ - 1
        W_out_effective = W_out_ - 1
        
        while (find_size):
            if(wxb.size(-1) == H_out_effective*W_out_effective):
                find_size = False
        wxb = wxb.view(N,C_out,H_out_effective,W_out_effective)
        print("sizeee: ", wxb.size())
        if (output_padding!=0):
            wxb_ = torch.empty(N,C_out,H_out_,W_out_).fill_(0.)
            wxb_[...,0:wxb.size(-2),0:wxb.size(-1)] = wxb
            return wxb_ + bias.view(1, -1, 1,1).repeat(N,1,1,1)
        else:
            return wxb[...,0:H_out,0:W_out] + bias.view(1, -1, 1,1).repeat(N,1,1,1)

In [992]:
from torch.nn.functional import fold, unfold
from torch import Tensor


class Conv2d(Module):
    
    def __init__(self,in_channel, out_channel, kernel_size = (2,2),stride=1, padding=0, dilation=1, groups=1,weight = None, bias = None):
        if isinstance(stride, int):
            self.stride = (stride, stride)
        else:
            self.stride = stride
        if isinstance(padding, int):
            self.padding = (padding, padding)
        else: 
            self.padding = padding
        if isinstance(dilation, int):
            self.dilation = (dilation, dilation)
        else:
            self.dilation = dilation
        if isinstance(kernel_size, int):
            self.kernel = (kernel_size, kernel_size)
        else:
            self.kernel = kernel_size
        
        self.groups = groups
        self.in_channel = in_channel
        self.out_channel = out_channel
        
        k = self.groups/(self.in_channel*self.kernel.prod())
        if weight == None:
            self.weight = torch.empty((self.out_channel,self.in_channel// self.groups,
                                       self.kernel[0],self.kernel[1])).uniform_(-k,k)
        else:
            self.weight = weight
        if bias == None:
            self.bias = torch.empty(self.out_channel).uniform_(-k,k)
        else:
            self.bias = bias
        
        self.weight_grad = torch.empty((self.kernel[0],self.kernel[1]))
        self.bias_grad = torch.empty(self.out_channel)
        
    @staticmethod
    def conv2d(input: Tensor, weight: Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor:
        # input is 4d tensor
        
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        N = input.size(0)
        C_in = input.size(1)
        H_in = input.size(-2)
        W_in = input.size(-1)
        
        kernel_size = (weight.size(-2), weight.size(-1))
        C_out = weight.size(0)
        H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
        W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
       
        out_conv = torch.empty(N,C_out,H_out,W_out)
        
        inp_unf = unfold(input.view(N*groups,C_in//groups,H_in,W_in),kernel_size,dilation,padding,stride)#.view(N,H_out*W_out,-1,C_in//groups)
        inp_unf = inp_unf.view(N,groups,inp_unf.size(-2),inp_unf.size(-1)).transpose(0,1)
        
        
        weight = weight.view(groups,C_out//groups,C_in//groups,kernel_size[0],kernel_size[1])
        
        for i in range(groups):
            out_unf = inp_unf[i,...].transpose(1,2).matmul(weight[i,...].view(C_out//groups,-1).t()).transpose(1,2)
            out_conv[:,i*(C_out//groups):(1+i)*(C_out//groups),...] = out_unf.view(N,C_out//groups,H_out,W_out)
        if bias!=None:
            return out_conv + bias.view(1, -1, 1,1).repeat(N,1,H_out,W_out)
        return out_conv
    
    
    @staticmethod
    def conv_transpose2d(input: Tensor, weight: Tensor, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1, output_padding=0) -> Tensor:
        # input is 4d tensor
        if isinstance(stride, int):
            stride = (stride, stride)    
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(output_padding, int):
            output_padding = (output_padding, output_padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        N = input.size(0)
        H_in = input.size(-2)
        W_in = input.size(-1)

        kernel_size = (weight.size(-2), weight.size(-1))
        C_out = weight.size(1)
        H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
        W_out = (W_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1

        pad0 = (dilation[0] * (kernel_size[0] - 1) - padding[0]) * stride[0]
        pad1 = (dilation[1] * (kernel_size[1] - 1) - padding[1]) * stride[1]

        if (pad0<0) or (pad1<0):
            raise ValueError("Invalid inputs, transposed convolution not possible")
        inputs = torch.empty(N,input.size(1),2*H_in-1,2*W_in-1)
        for i in range(7):
            inputs[:,:,]

        unfolded = unfold(input, kernel_size=kernel_size, dilation=dilation,stride = stride, padding=(pad0,pad1))

        w = weight.transpose(0,1).rot90(2, [-2,-1])
        if bias != None:
            wxb = unfolded.transpose(1, 2).matmul(w.reshape(w.size(0), -1).t()).transpose(1, 2) + bias.view(1, -1, 1).repeat(N,1,1)
        else: 
            wxb = unfolded.transpose(1, 2).matmul(w.reshape(w.size(0), -1).t()).transpose(1, 2)
        return fold(wxb,(H_out, W_out), (1,1),stride = stride,dilation = dilation)


    def forward(self, input):
        self.input = input
        
        if (len(input.size()) == 4) :
            return self.conv2d(self.input,weight = self.weight, bias = self.bias,stride =self.stride,
                                  padding = self.padding, dilation = self.dilation,groups = self.groups)
        
    
    
    def grad_con2d_input(self,input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):

        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        N = input.size(0)
        C_in = input.size(1)
        H_in = input.size(-2)
        W_in = input.size(-1)
        
        kernel_size = (weight_size[2], weight_size[3])
        C_out = weight_size[0]
        H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
        W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
        
        grad_input = torch.empty(weight_size)

        for i in range(grad_input.size(0)): 
            for j in range(input.size(1)): 
                grad_input[i,j,:,:] = self.conv_transpose2d(self.input[:,j,:,:].view(N,1,H_in,W_in),
                                              weight = grad_output[:,i,:,:].view(N,1,H_out,W_out), 
                                              bias = None,stride =self.stride,padding = self.padding, 
                                              dilation = self.dilation).narrow(2, 0, 
                                              H_in).narrow(3, 0, W_in)
        
        return grad_input
    
    def grad_conv2d_weight(self,input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
        r"""
        Computes the gradient of conv2d with respect to the weight of the convolution.
        Args:
            input: input tensor of shape (minibatch x in_channels x iH x iW)
            weight_size : Shape of the weight gradient tensor
            grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
            stride (int or tuple, optional): Stride of the convolution. Default: 1
            padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
            dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
            groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        """
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)
        in_channels = input.shape[1]
        out_channels = grad_output.shape[1]
        N = input.shape[0]

        grad_output = grad_output.repeat(1, in_channels // groups, 1,
                                                      1)
        grad_output = grad_output.view(
            grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2],
            grad_output.shape[3])

        input = input.view(1, input.shape[0] * input.shape[1],
                                        input.shape[2], input.shape[3])

        grad_weight = conv2d(input, grad_output, None, dilation, padding,
                                   stride, in_channels * N)

        grad_weight = grad_weight.view(
            N, grad_weight.shape[1] // N, grad_weight.shape[2],
            grad_weight.shape[3])

        return grad_weight.sum(dim=0).view(
            in_channels // groups, out_channels,
            grad_weight.shape[2], grad_weight.shape[3]).transpose(0, 1).narrow(
                2, 0, weight_size[2]).narrow(3, 0, weight_size[3])

    def grad_conv_bias(self,grad_output,bias_size):
        
        return grad_output.transpose(0,1).reshape(grad_output.size(1),-1).sum(dim = 1)
    
    def backward(self,gradwrtouput):
        bias = self.grad_conv_bias(gradwrtouput,self.bias.shape)
        a = self.grad_conv2d_weight(self.input, self.weight.shape, gradwrtouput, stride = self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
        #b = self.conv_transpose2d(self.input, self.weight, gradwrtouput,stride = self.stride,padding = self.padding,dilation = self.dilation)
        return bias
    
    

In [997]:
with torch.no_grad():
    input = torch.normal(mean = torch.zeros(110,4,280,280),std = 1)

    kernel = torch.tensor([2,2])
    stride = 2
    padding = 2
    dilation= 1
    groups = 2

    conv = torch.nn.Conv2d(4, 2, kernel, stride=stride, padding=padding, dilation=dilation, groups=groups)
    weight = conv.state_dict()["weight"]
    bias = conv.state_dict()["bias"]
    #print(bias.size())
    conv_ = Conv2d(4,2,kernel,stride, padding, dilation, groups,weight, bias)

    m = conv.forward(input) 
    #print("size",m.size())
    m_ = conv_.forward(input)
    #m = torch.conv2d(input, weight, None, stride, padding,dilation, groups)
    #print(torch.allclose(m,m_))
    print(torch.norm(m-m_))


grad_output = torch.normal(mean=torch.zeros(m.size()),std= 1)
with torch.no_grad():
    grad = conv_.backward(grad_output)
print(grad.size())

#input.requires_grad_(True)
#weight.requires_grad_(True)
    
#output = torch.nn.functional.conv2d(input, weight)
#grad_weight = torch.autograd.grad(output, filter, grad_output)
#print(input.size(), weight.shape, grad_output.size())
#torch.nn.functional.grad.conv2d_weight(input, weight.shape, grad_output,stride = stride,dilation=dilation)

torch.Size([110, 2, 142, 142])
torch.Size([110, 2, 142, 142])
tensor(0.0001)
torch.Size([2])


352

In [947]:
import math
def _conv_transpose2d(input: Tensor, weight: Tensor, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1, output_padding=0) -> Tensor:
    # input is 4d tensor
    if isinstance(stride, int):
        stride = (stride, stride)    
    if isinstance(padding, int):
        padding = (padding, padding)
    if isinstance(dilation, int):
        dilation = (dilation, dilation)
    if (output_padding>=stride[0] or output_padding>=stride[1]) & (output_padding>=dilation[0] or output_padding>=dilation[1]):
        raise ValueError("Invalid output_padding,output padding must be smaller than either stride or dilation")
    
    N = input.size(0)
    H_in = input.size(-2)
    W_in = input.size(-1)
    
    kernel_size = (weight.size(-2), weight.size(-1))
    C_out = weight.size(1)
    H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + 1
    W_out = (W_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + 1
    
    H_out_ = H_out + output_padding
    W_out_ = W_out + output_padding
    #print(H_out_)
    pad0 = (dilation[0] * (kernel_size[0] - 1) - padding[0]) 
    pad1 = (dilation[1] * (kernel_size[1] - 1) - padding[1])

    if (pad0<0) or (pad1<0):
        raise ValueError("Invalid inputs, transposed convolution not possible")
    
    if (stride[0] != 1) or (stride[1]!=1):
        inputs = torch.empty(N,input.size(1),stride[0]*H_in-1,stride[1]*W_in-1).fill_(0.)
        for i in range(H_in):
            for j in range(W_in):
                inputs[...,i*stride[0],j*stride[1]] = input[...,i,j]
    else:
        inputs = input
    
    unfolded = unfold(inputs, kernel_size=kernel_size, dilation=dilation,stride = 1, padding=(pad0,pad1))
    
    w = weight.transpose(0,1).rot90(2, [-2,-1])
    
    wxb = unfolded.transpose(1, 2).matmul(w.reshape(w.size(0), -1).t()).transpose(1, 2)
    print(wxb.size(-1))
    if (wxb.size(-1) == H_out_*W_out_):
        return wxb.view(N,C_out, H_out_, W_out_) + bias.view(1, -1, 1,1).repeat(N,1,1,1)
    else:
        find_size = True
        H_out_effective = H_out_ - 1
        W_out_effective = W_out_ - 1
        
        while (find_size):
            if(wxb.size(-1) == H_out_effective*W_out_effective):
                find_size = False
        wxb = wxb.view(N,C_out,H_out_effective,W_out_effective)
        print("sizeee: ", wxb.size())
        if (output_padding!=0):
            wxb_ = torch.empty(N,C_out,H_out_,W_out_).fill_(0.)
            wxb_[...,0:wxb.size(-2),0:wxb.size(-1)] = wxb
            return wxb_ + bias.view(1, -1, 1,1).repeat(N,1,1,1)
        else:
            return wxb[...,0:H_out,0:W_out] + bias.view(1, -1, 1,1).repeat(N,1,1,1)

In [950]:
inputs = torch.randn(10, 4,6,6)

    
stride = 2
padding = 1
dilation= 1
output_padding = 1
groups = 1

tt = torch.nn.ConvTranspose2d(4,2,(5,5),stride= stride,dilation =dilation,padding = padding
                              ,output_padding = output_padding)
weights = tt.weight
bias = tt.bias

trans_conv = tt(inputs)
#print(trans_conv[1,1,...])
trans_conv_ = _conv_transpose2d(inputs,weights,stride= stride,dilation =dilation,padding = padding,
                                bias = bias,output_padding=output_padding)
print(trans_conv_[1,1,...])
print((trans_conv-trans_conv_).norm())

#print(torch.nn.functional.conv_transpose2d(inputs, weights,output_p))

169
sizeee:  torch.Size([10, 2, 13, 13])
tensor([[ 0.0984,  0.0478, -0.0699, -0.3200,  0.2897,  0.0773,  0.2755,  0.3862,
         -0.0544, -0.0865, -0.0211,  0.0853,  0.0550,  0.0468],
        [ 0.2963,  0.4745,  0.1291, -0.4850, -0.2147, -0.3885,  0.4238,  0.6988,
          0.1219, -0.2336, -0.1225, -0.0623, -0.2228,  0.0468],
        [-0.1624,  0.0535,  0.2761,  0.1379,  0.1224, -0.1374, -0.0238,  0.0137,
          0.3729,  0.0319, -0.2822, -0.1254, -0.0122,  0.0468],
        [-0.0507,  0.5516, -0.3158,  0.2655,  0.7902,  0.0331,  0.5681,  0.7517,
         -0.2997,  0.4871,  0.2907, -0.0139,  0.0170,  0.0468],
        [-0.0163,  0.0363,  0.0076,  0.5989,  0.2335, -0.2862,  0.1563,  0.0633,
          0.0099,  0.8350, -0.1404,  0.1876,  0.0928,  0.0468],
        [ 0.1899,  0.0122,  0.0342,  1.0551,  0.3656, -0.1435, -0.0442,  0.5517,
         -0.0759,  0.4776, -0.5041, -0.0044, -0.4174,  0.0468],
        [ 0.1217,  0.0268, -0.0597,  0.1811,  0.1272,  0.3913,  0.0080,  0.3386,
        

In [931]:
13*13


169

In [944]:
inputs = torch.tensor([1.,2.,3.,4.,1.,2.,3.,4.]).view(1,2,2,2)
stride = 2
padding = 1
dilation= 1
output_padding = 1
#weight = torch.arange(0.,18).view(1,2,3,3)
conv = torch.nn.ConvTranspose2d(2,1,(3,3),stride= stride,dilation = dilation,output_padding=output_padding,bias = False)

state_dict_ = conv.state_dict()

weight = state_dict_['weight'] 
#state_dict_["bias"] = torch.zeros(1)
conv.load_state_dict(state_dict_)
ccc =conv(inputs)

trans_conv_ = _conv_transpose2d(inputs,weight,stride= stride,dilation = dilation,bias = bias,output_padding=output_padding)
#print(trans_conv_)
#print(cc.size(),trans_conv_.size())
print((ccc-trans_conv_).norm())

25
sizeee:  torch.Size([1, 1, 5, 5])
tensor(0.2083, grad_fn=<CopyBackwards>)


In [790]:
trans_conv_

tensor([[[[ 0.,  0.,  0.,  1.,  0.,  2.,  2.,  0.,  4.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  3.,  0.,  4.,  6.,  0.,  8.],
          [ 3.,  0.,  6.,  4.,  0.,  8.,  5.,  0., 10.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 9.,  0., 12., 12.,  0., 16., 15.,  0., 20.],
          [ 6.,  0., 12.,  7.,  0., 14.,  8.,  0., 16.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [18.,  0., 24., 21.,  0., 28., 24.,  0., 32.]]]])

In [951]:
def _pair(x):
    if isinstance(x, int):
        return (x, x)
    return x

def pad(input, pad=(1,1,1,1), mode="constant", value=0.0):
    # pad = (pad_left, pad_right, pad_up, pad_down)
    if len(pad) == 2:
        pad = (pad[0], pad[0], pad[1], pad[1])
    input_shape = [input.size(0), input.size(1), input.size(2), input.size(3)]
    
    input_shape[3] += (pad[0] + pad[1])
    input_shape[2] += (pad[2] + pad[3])
    i1 = pad[2]
    j1 = pad[0]
    i2 = input_shape[2] - pad[3]
    j2 = input_shape[3] - pad[1]
    
    result = empty(input_shape).fill_(value)
    result[:,:,i1:i2,j1:j2] = input
    return result

def zero_internal_pad(x, pad=(1,1)):
    x_shape = [x.size(0), x.size(1), x.size(2), x.size(3)]
    x_shape[2] += (pad[0]*(x_shape[2]-1)) 
    x_shape[3] += (pad[1]*(x_shape[3]-1)) 
    res = empty(x_shape).fill_(0.0)
    res[:,:,::pad[0]+1,::pad[1]+1] = x
    return res

# TODO: groups not working
def conv_transpose2d(input: Tensor, weight: Tensor, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1, output_padding=0) -> Tensor:
    # input is 4d tensor
    stride = _pair(stride)
    padding = _pair(padding)
    output_padding = _pair(output_padding)
    dilation = _pair(dilation)

    N = input.size(0)
    H_in = input.size(-2)
    W_in = input.size(-1)
        
    kernel_size = (weight.size(-2), weight.size(-1))
    C_out = weight.size(1)
    H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + 1 + output_padding[0]
    W_out = (W_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + 1 + output_padding[1]
    
    pad0 = (dilation[0] * (kernel_size[0] - 1) - padding[0])
    pad1 = (dilation[1] * (kernel_size[1] - 1) - padding[1])

    if (pad0<0) or (pad1<0):
        raise ValueError("Invalid inputs, transposed convolution not possible")
    
    if (stride[0]>1) or (stride[1]>1):
        input = zero_internal_pad(input, pad=(stride[0]-1,stride[1]-1))
    if (output_padding[0]>0) or (output_padding[1]>0):
        input = pad(input, pad=(0,output_padding[1],0,output_padding[0]))
    unfolded = unfold(input, kernel_size=kernel_size, dilation=dilation, stride=1, padding=(pad0,pad1))
    
    w = weight.transpose(0,1).rot90(-2, [-2,-1])
    
    wxb = unfolded.transpose(1, 2).matmul(w.reshape(w.size(0), -1).t()).transpose(1, 2) + bias.view(1, -1, 1).repeat(N,1,1)
    
    res = wxb.view(N, C_out, H_out, W_out)
    return res


In [960]:
from torch import empty
inputs = torch.randn(10, 4,6,6)

    
stride = 1
padding = 3
dilation= 3
output_padding = 2
groups = 1

tt = torch.nn.ConvTranspose2d(4,2,(5,5),stride= stride,dilation =dilation,padding = padding
                              ,output_padding = output_padding)
weights = tt.weight
bias = tt.bias

trans_conv = tt(inputs)
#print(trans_conv[1,1,...])
trans_conv_ = conv_transpose2d(inputs,weights,stride= stride,dilation =dilation,padding = padding,
                                bias = bias,output_padding=output_padding)
#print(trans_conv_[1,1,...])
print((trans_conv-trans_conv_).norm())



tensor(1.5515e-06, grad_fn=<CopyBackwards>)
