In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from skimage.color import rgb2gray
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
from torchvision import transforms, utils
import os
import numpy as np
from os import walk
import torch.optim as optim


In [39]:
class KPN(nn.Module):
    def __init__(self, device,burst_len = 1, K = 5):
        super(KPN, self).__init__()
        
        self.device = device
        self.K = K
        self.burst_len = burst_len
        self.conv_start = nn.Conv2d(burst_len,64,3,padding=1)
        self.conv_end = nn.Conv2d(64,burst_len,3,padding=1)
        
        self.batchnorm64 = nn.BatchNorm2d(64)
        self.batchnorm128 = nn.BatchNorm2d(128)
        self.batchnorm256 = nn.BatchNorm2d(256)
        self.batchnorm512 = nn.BatchNorm2d(512)
        self.batchnormK2N = nn.BatchNorm2d(K**2 * burst_len)
        
        
        self.conv_64to128 = nn.Conv2d(64,128,3,padding=1)
        self.conv_128to256 = nn.Conv2d(128,256,3,padding=1)
        self.conv_256to512 = nn.Conv2d(256,512,3,padding=1)
        self.conv_512to256 = nn.Conv2d(512,256,3,padding=1)
        self.conv_256to128 = nn.Conv2d(256,128,3,padding=1)
        self.conv_128toK2N = nn.Conv2d(128,K**2 * burst_len,3,padding=1)
        
        
        self.conv_64 = nn.Conv2d(64,64,3,padding=1)
        self.conv_128 = nn.Conv2d(128,128,3,padding=1)
        self.conv_256 = nn.Conv2d(256,256,3,padding=1)
        self.conv_512 = nn.Conv2d(512,512,3,padding=1)
        self.conv_K2N = nn.Conv2d(K**2 * burst_len,K**2 * burst_len,3,padding=1)
        
        self.avg_pool_downsample = nn.AvgPool2d(2,2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        
        
    def custom_block(self,conv1, conv2, x,bn):
        x = conv1(x)
        x = F.relu(x)
        x = bn(x)
        x = conv2(x)
        x = F.relu(x)
        x = bn(x)
        x = conv2(x)
        x = F.relu(x)
        x = bn(x)
        
        return x
        
    def custom_block_res(self,conv1, conv2, x, y,bn):
        x = conv1(x)
        x = F.relu(x)
        x = bn(x)
        x = x + y
        x = conv2(x)
        x = F.relu(x)
        x = bn(x)
        x = conv2(x)
        x = F.relu(x)
        x = bn(x)
        
        return x
    
    def custom_block_res2(self,conv1, conv2, conv3, x, y,bn1,bn2):
        x = conv1(x)
        x = F.relu(x)
        x = bn1(x)
        x = x + y
        x = conv2(x)
        x = F.relu(x)
        x = bn2(x)
        x = conv3(x)
        x = bn2(x)
        
        return x
    
    def convolve(self,filt,x,K):
        pad_len = K//2
        
        ss = filt.size()
        
        
        pad_vector = (pad_len,pad_len,pad_len,pad_len)
        padded_x = F.pad(x, pad_vector)
        


        count = 0
        for i in range(0,K):
            for j in range(0,K):
                if i == 0 and j == 0 :
                    denoised = torch.mul(filt[:,count:count+1,:,:], padded_x[:,:,i:i+64,j:j+64])
                else:
                    denoised = denoised + torch.mul(filt[:,count:count+1,:,:], padded_x[:,:,i:i+64,j:j+64])
        
        
        return denoised
        
        
        
        
    
    #def _init_weights(m):
    
    def forward(self,x):
        enc1 = self.custom_block(self.conv_start, self.conv_64, x,self.batchnorm64)
        
        enc2 = self.avg_pool_downsample(enc1)
        enc2 = self.custom_block(self.conv_64to128, self.conv_128, enc2,self.batchnorm128)
        
        enc3 = self.avg_pool_downsample(enc2)
        enc3 = self.custom_block(self.conv_128to256, self.conv_256, enc3,self.batchnorm256)
        
        enc4 = self.avg_pool_downsample(enc3)
        enc4 = self.custom_block(self.conv_256to512, self.conv_512, enc4,self.batchnorm512)
        
        enc5 = self.avg_pool_downsample(enc4)
        enc5 = self.custom_block(self.conv_512, self.conv_512, enc5,self.batchnorm512)
        
    
        dec4 = self.upsample(enc5)
        dec4 = self.custom_block_res(self.conv_512, self.conv_512, dec4, enc4,self.batchnorm512)
        
        dec3 = self.upsample(dec4)
        dec3 = self.custom_block_res(self.conv_512to256, self.conv_256, dec3, enc3,self.batchnorm256)
        
        dec2 = self.upsample(dec3)
        dec2 = self.custom_block_res2(self.conv_256to128, self.conv_128toK2N, self.conv_K2N, dec2, enc2,self.batchnorm128,self.batchnormK2N)
        
        dec1 = self.upsample(dec2)
        
        output = self.convolve(dec1,x,self.K)

        return output
        
        
        
        
        
        
        

In [40]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    raise Exception("You requested GPU support, but there's no GPU on this machine")
    
net = KPN(device = device)
net = net.to(device)
summary(net,(1,64,64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]             640
       BatchNorm2d-2           [-1, 64, 64, 64]             128
            Conv2d-3           [-1, 64, 64, 64]          36,928
       BatchNorm2d-4           [-1, 64, 64, 64]             128
            Conv2d-5           [-1, 64, 64, 64]          36,928
       BatchNorm2d-6           [-1, 64, 64, 64]             128
         AvgPool2d-7           [-1, 64, 32, 32]               0
            Conv2d-8          [-1, 128, 32, 32]          73,856
       BatchNorm2d-9          [-1, 128, 32, 32]             256
           Conv2d-10          [-1, 128, 32, 32]         147,584
      BatchNorm2d-11          [-1, 128, 32, 32]             256
           Conv2d-12          [-1, 128, 32, 32]         147,584
      BatchNorm2d-13          [-1, 128, 32, 32]             256
        AvgPool2d-14          [-1, 128,

In [67]:
class Noisy_burst(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, root_dir, burst_len, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        f = []
        for (dirpath, dirnames, filenames) in walk(root_dir):
            f.extend(filenames)
            break
        
        #print(f)
        self.image_names = f
        self.burst_len = burst_len
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.image_names[idx])
        image = io.imread(img_name)
        image = rgb2gray(image)
        
        image = transform.resize(image, (80,80),anti_aliasing=True)
        
        sample_in = np.zeros((self.burst_len,64,64)).astype('float32')
        sample_out = np.zeros((self.burst_len,64,64)).astype('float32')
        
        for i in range(0,self.burst_len):
            x = np.random.randint(0,17)
            y = np.random.randint(0,17)
            x = 0
            y = 0
            sample_in[i,:,:] = image[x:x+64,y:y+64].astype('float32') + np.random.normal(0,0.05,(64,64)).astype('float32')
            sample_out[i,:,:] = image[x:x+64,y:y+64].astype('float32') + np.random.normal(0,0.05,(64,64)).astype('float32')
        
        
        
        sample_in[sample_in <0] = 0
        sample_in[sample_in >1] = 1
        
        sample_out[sample_out <0] = 0
        sample_out[sample_out >1] = 1
        sample_out = torch.from_numpy(sample_out)
        sample_in = torch.from_numpy(sample_in)
        #if self.transform:
        #    sample_in = self.transform(sample_in)
        #    sample_out = self.transform(sample_out)
        
        #print(sample_in.size())
        return sample_in,sample_out


In [68]:
data_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

face_dataset = Noisy_burst(root_dir='../data/celebA/img_align_celeba', burst_len = 1, transform = data_transform)
dataloader = DataLoader(face_dataset, batch_size=32,
                        shuffle=True, num_workers=3)

In [43]:
a,b = face_dataset[0]
print(a.type())

torch.FloatTensor


In [74]:
num_epochs = 200
optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = nn.MSELoss()

In [75]:
running_loss = 0
count = 0
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        net.zero_grad()
        inp,outp = data
        
        inp = inp.to(device)
        outp = outp.to(device)
        output = net(inp)
        
        error_mse = criterion(output, outp)
        optimizer.zero_grad()
        
        error_mse.backward()
        optimizer.step()
        
        running_loss += error_mse.item()
        count = count + 1
        
        if(count % 100 == 0):
            print('epoch : [%d,%d],loss: %.3f' %
              (epoch + 1, i,running_loss/100 ))
            running_loss = 0.0
            

    running_loss = 0.0
    if(epoch % 20 == 0):
        torch.save(net.state_dict(), 'Checkpoints/epoch_' + str(epoch) + '.pth')

epoch : [1,99],loss: 0.010
epoch : [1,199],loss: 0.007
epoch : [1,299],loss: 0.006
epoch : [1,399],loss: 0.007
epoch : [1,499],loss: 0.006
epoch : [1,599],loss: 0.006
epoch : [1,699],loss: 0.005
epoch : [1,799],loss: 0.005
epoch : [1,899],loss: 0.005
epoch : [1,999],loss: 0.005
epoch : [1,1099],loss: 0.005
epoch : [1,1199],loss: 0.005
epoch : [1,1299],loss: 0.005
epoch : [1,1399],loss: 0.004
epoch : [1,1499],loss: 0.004
epoch : [1,1599],loss: 0.004
epoch : [1,1699],loss: 0.004
epoch : [1,1799],loss: 0.004
epoch : [1,1899],loss: 0.004
epoch : [1,1999],loss: 0.004
epoch : [1,2099],loss: 0.004
epoch : [1,2199],loss: 0.004
epoch : [1,2299],loss: 0.004
epoch : [1,2399],loss: 0.004
epoch : [1,2499],loss: 0.004
epoch : [1,2599],loss: 0.004
epoch : [1,2699],loss: 0.004
epoch : [1,2799],loss: 0.004
epoch : [1,2899],loss: 0.004
epoch : [1,2999],loss: 0.004
epoch : [1,3099],loss: 0.004
epoch : [1,3199],loss: 0.004
epoch : [1,3299],loss: 0.004
epoch : [1,3399],loss: 0.004
epoch : [1,3499],loss: 0.

Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

In [81]:
a,b = face_dataset[10000]
a = torch.unsqueeze(a, 0)
a = a.to(device)
c = net(a)

In [82]:
print(a.size())

torch.Size([1, 1, 64, 64])


In [83]:
a = a.cpu().detach().numpy()
c = c.cpu().detach().numpy()

In [84]:
np.shape(a)

(1, 1, 64, 64)

In [85]:
import cv2 as cv
cv.imwrite('noisy.png',a[0,0,:,:]*255)
cv.imwrite('denoised.png',c[0,0,:,:]*255)

True