In [16]:
import torch
import math
import torch
from torch import empty , cat , arange
from torch . nn . functional import fold , unfold

#to delete
from torch import nn

In [17]:
# Put on GPU
device = torch.device ("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
# Extract images

noisy_imgs_1 , noisy_imgs_2 = torch.load('data/train_data.pkl')
noisy_imgs , clean_imgs = torch.load ('data/val_data.pkl')

noisy_imgs = noisy_imgs/255
clean_imgs = clean_imgs/255

# select a preset of images:

imgs_1 = noisy_imgs_1[:1]/255
imgs_2 = noisy_imgs_2[:10000]/255
print(imgs_1.shape)

torch.Size([1, 3, 32, 32])


In [19]:
class ConvolutionTransposed(object):
    def __init__(self, channels_input, channels_output, kernel_size, stride):
        self.weight = torch.empty(channels_output, channels_input, kernel_size, kernel_size).normal_()
        
        self.channel_input = channels_input
        #print('input channels', self.channel_input)
        self.kernel_size = kernel_size
        self.stride = stride
        
        print('weight',self.weight.shape)
        
        
    def forward(self, imgs):
        #print('forward')
        _,_,H,W = imgs.shape
        H_out = (H - 1)*self.stride + self.kernel_size 
        W_out = (W - 1)*self.stride + self.kernel_size 
        
        #print('Hout Wout', H_out, W_out)
        
        self.x = imgs.permute(1, 2, 3, 0).reshape(self.channel_input, -1)
        #print('x',self.x.shape)
        self.y = (self.weight.reshape(self.channel_input, -1)).t().matmul(self.x)
        #print('y',self.y.shape)
        self.y = self.y.reshape(self.y.shape[0], -1, imgs.shape[0])
        #print('y2',self.y.shape)
        self.y = self.y.permute(2, 0, 1)
        
        #print(self.y.shape)
        self.y = fold( self.y, (H_out, W_out), kernel_size=(self.kernel_size,self.kernel_size), stride=self.stride)
        
        return self.y

In [20]:
class Upsample(object):
    def __init__(self, factor_size):
        self.kernel = torch.ones(factor_size,factor_size)
        #print(self.kernel.shape)
        
    def forward(self,x):
        x0, x1, s1, s2 = x.shape
        s3, s4 = self.kernel.shape
        x = x.reshape(x0, x1, s1, 1, s2, 1)
        self.kernel = self.kernel.reshape(1, s3, 1, s4)
        return (x * self.kernel).reshape(x0, x1, s1 * s3, s2 * s4) 

In [27]:
class Convolution(object):
    def __init__(self, channels_input, channels_output, kernel_size, stride):
        
        self.weight = torch.empty(channels_output, channels_input, kernel_size, kernel_size).normal_()
        self.kernel_size = kernel_size
        self.stride = stride
        self.channels_output = channels_output
        self.channels_input = channels_input

        #print('weight',self.weight.shape)
         
        
    def forward(self, imgs):
        
        _,_,H,W = imgs.shape
        #print('h w ',H,W)
        self.H = H
        self.W = W
        self.Hout = (H - self.kernel_size)/self.stride + 1
        self.Wout = (H - self.kernel_size)/self.stride + 1
        print('Hout Wout', self.Hout, self.Wout)
        self.x = imgs
        print('x', self.x.shape)
        self.x_unfolded = unfold(self.x, kernel_size = (self.kernel_size, self.kernel_size), stride = self.stride)
        print('x unfold',self.x_unfolded.shape)
        print('w shape', self.weight.view(self.channels_output, -1).shape)
        self.y = self.x_unfolded.transpose(1, 2).matmul(self.weight.view(self.channels_output, -1).t()).transpose(1, 2)
        #print('y',self.y.shape)
        self.y = fold(self.y, (int(self.Hout), int(self.Wout)),(1,1), stride = 1)
        #self.y = self.y.view(1,10,15,15)
        return self.y  #, self_x, self_weight
    

    def backward(self,gradwrtoutput):
        dL_dS = gradwrtoutput # [B, O, SO, SO]
        dS_dX = self.weight   # weight.shape [O, I, K, K]
        
        #define the size IxKxK
        inKerKer_size = self.channels_input*self.kernel_size*self.kernel_size
        
        dL_dS_reshape = dL_dS.reshape(1,self.channels_output,self.Wout, self.Hout) # [B, O, (SOxSO)]
        dS_dX_reshape = dS_dX.reshape(self.channels_output, -1).transpose(0,1)   # [O, (IxKxK)]^T
        
        #backward input
        dL_dX_reshape = dL_dS_reshape @ dS_dX_reshape # [B, (IxKxK), (SOxSO)]
        dL_dX = fold(dL_dX_reshape, kernel_size = (self.kernel_size, self.kernel_size), stride = self.stride, output_size = (self.W, self.H)) # [B, I, SI, SI]
        
        #backward weight
        dL_dS_reshape2 = dL_dS.reshape(self.channels_output, -1) # [O, (BxSOxSO)]
        dS_dW = x_unfold # [B, (IxKxK), (SOxSO)]
        dS_dW_reshape = x_unfold.reshape(-1, inKerKer_size) # [(BxSOxSO), (IxKxK))] 
        dL_dW_reshape = dL_dS_reshape2 @ dS_dW_reshape # [O, (IxKxK)]
        dL_dW = dL_dW_reshape.view(self.channels_output,  self.channels_input, self.kernel_size, self.kernel_size) # [O, I, K, K] 
        
        """
        #backward bias
        """
        
        return dL_dX, dL_dW  

    def param(self):
        return None


In [28]:
class Net(object):
    def __init__(self):
        self.conv1 = Convolution(3, 10, kernel_size = 4, stride = 2) # tensor(1,10,15,15)
        self.conv2 = Convolution(10, 10, kernel_size = 5, stride = 2) # tensor (1,10,6,6)
        #self.convT1 = ConvolutionTransposed(10, 10, kernel_size = 5, stride = 2)
        #self.convT2 = ConvolutionTransposed(10, 32, kernel_size = 4, stride = 2)
        
        self.upsample = Upsample(3)
        self.upsample2 = Upsample(2)
        self.convSample1 = Convolution(10, 10, kernel_size = 2, stride = 1)
        self.convSample2 = Convolution(10, 3, kernel_size = 3, stride = 1)
        
    def forward(self,x):
        print(x.shape)
        y = self.conv1.forward(x)
        print(y.shape)
        y = self.conv2.forward(y)
        print(y.shape)
        """
        y = self.convT1.forward(y)
        print(y.shape)
        y = self.convT2.forward(y)
        print(y.shape)
        """
        
        
        y = self.upsample.forward(y)
        print('y',y.shape)
        y = self.convSample1.forward(y)
        print('y',y.shape)
        y = self.upsample2.forward(y)
        print('y',y.shape)
        y = self.convSample2.forward(y)
        
        return y

    def backward (self, *gradwrtoutput ) :
        return None

    

In [29]:
##main function

autoencoder = Net()
img = torch.randn(1,3,32,32)

y = autoencoder.forward(img)
print(y.shape)


torch.Size([1, 3, 32, 32])
Hout Wout 15.0 15.0
x torch.Size([1, 3, 32, 32])
x unfold torch.Size([1, 48, 225])
w shape torch.Size([10, 48])
torch.Size([1, 10, 15, 15])
Hout Wout 6.0 6.0
x torch.Size([1, 10, 15, 15])
x unfold torch.Size([1, 250, 36])
w shape torch.Size([10, 250])
torch.Size([1, 10, 6, 6])
y torch.Size([1, 10, 18, 18])
Hout Wout 17.0 17.0
x torch.Size([1, 10, 18, 18])
x unfold torch.Size([1, 40, 289])
w shape torch.Size([10, 40])
y torch.Size([1, 10, 17, 17])
y torch.Size([1, 10, 34, 34])
Hout Wout 32.0 32.0
x torch.Size([1, 10, 34, 34])
x unfold torch.Size([1, 90, 1024])
w shape torch.Size([3, 90])
torch.Size([1, 3, 32, 32])
