In [4]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch.nn as nn
from torchvision import models
from os.path import join
from PIL import Image
from tqdm import tqdm
import torch.optim as optim

In [5]:
class Noise(Dataset):
    def __init__(self, root, gt_root, transform=None):
        """ Intialize the dataset """
        self.filenames = []
        self.root = root
        self.gt_root = gt_root
        self.transform = transform

        # read filenames
        for i in range(1,1152):
            blur, gau, sp, poi = join(gt_root,str(i)+'.png'), join(root,str(i)+'_gaussian.png'), join(root,str(i)+'_sp.png'), join(root,str(i)+'_poisson.png')
            blur, gau, sp, poi = Image.open(blur), Image.open(gau), Image.open(sp), Image.open(poi)
            if self.transform is not None:
                blur = self.transform(blur)
                gau = self.transform(gau)
                sp = self.transform(sp)
                poi = self.transform(poi)
            self.filenames.append([blur,gau,sp,poi])

        self.len = len(self.filenames)

    def __getitem__(self, index):
        """ Get a sample from the dataset """
        blur, gau, sp, poi = self.filenames[index]
        
        return blur, gau, sp, poi

    def __len__(self):
        """ Total number of samples in the dataset """
        return self.len

In [8]:
def save_checkpoint(checkpoint_path, model, optimizer):
    state = {'state_dict': model.state_dict(),
             'optimizer' : optimizer.state_dict()}
    torch.save(state, checkpoint_path)
    print('model saved to %s' % checkpoint_path)
    
def load_checkpoint(checkpoint_path, model, optimizer):
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    print('model loaded from %s' % checkpoint_path)

In [6]:
from torchvision.transforms.transforms import RandomCrop
# Create the dataset. 
# transforms.ToTensor() automatically converts PIL images to
trans = transforms.Compose([
    #transforms.Resize(128),
    transforms.CenterCrop(256),
    #transforms.RandomHorizontalFlip(),
    #transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

NoiseSet = Noise(root='/content/blurred_noise',gt_root='/content/blurred',transform=trans)
train_set, val_set = torch.utils.data.random_split(NoiseSet, [900, 251])
train_loader = DataLoader(train_set, batch_size=8, shuffle=True, num_workers=0)
val_loader = DataLoader(val_set, batch_size=8, num_workers=0)

# UNet

In [145]:
"""
The UNet model credit to jakeoung
https://github.com/jakeoung/Unet_pytorch/blob/master/model.py
"""

import torch
from torch import nn
import torch.nn.functional as F

def add_conv_stage(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=True, useBN=False):
  if useBN:
    return nn.Sequential(
      nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
      nn.BatchNorm2d(dim_out),
      nn.LeakyReLU(0.1),
      nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
      nn.BatchNorm2d(dim_out),
      nn.LeakyReLU(0.1)
    )
  else:
    return nn.Sequential(
      nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
      nn.ReLU(),
      nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
      nn.ReLU()
    )

def add_merge_stage(ch_coarse, ch_fine, in_coarse, in_fine, upsample):
  conv = nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False)
  torch.cat(conv, in_fine)

  return nn.Sequential(
    nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False)
  )
  upsample(in_coarse)

def upsample(ch_coarse, ch_fine):
  return nn.Sequential(
    nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False),
    nn.ReLU()
  )

class Net(nn.Module):
  def __init__(self, useBN=False):
    super(Net, self).__init__()

    self.conv1   = add_conv_stage(1, 32, useBN=useBN)
    self.conv2   = add_conv_stage(32, 64, useBN=useBN)
    self.conv3   = add_conv_stage(64, 128, useBN=useBN)
    self.conv4   = add_conv_stage(128, 256, useBN=useBN)
    self.conv5   = add_conv_stage(256, 512, useBN=useBN)

    self.conv4m = add_conv_stage(512, 256, useBN=useBN)
    self.conv3m = add_conv_stage(256, 128, useBN=useBN)
    self.conv2m = add_conv_stage(128,  64, useBN=useBN)
    self.conv1m = add_conv_stage( 64,  32, useBN=useBN)

    self.conv0  = nn.Sequential(
        nn.Conv2d(32, 1, 3, 1, 1),
        nn.Sigmoid()
    )

    self.max_pool = nn.MaxPool2d(2)

    self.upsample54 = upsample(512, 256)
    self.upsample43 = upsample(256, 128)
    self.upsample32 = upsample(128,  64)
    self.upsample21 = upsample(64 ,  32)

    ## weight initialization
    for m in self.modules():
      if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        if m.bias is not None:
          m.bias.data.zero_()


  def forward(self, x):
    conv1_out = self.conv1(x)
    #return self.upsample21(conv1_out)
    conv2_out = self.conv2(self.max_pool(conv1_out))
    conv3_out = self.conv3(self.max_pool(conv2_out))
    conv4_out = self.conv4(self.max_pool(conv3_out))
    conv5_out = self.conv5(self.max_pool(conv4_out))

    conv5m_out = torch.cat((self.upsample54(conv5_out), conv4_out), 1)
    conv4m_out = self.conv4m(conv5m_out)

    conv4m_out_ = torch.cat((self.upsample43(conv4m_out), conv3_out), 1)
    conv3m_out = self.conv3m(conv4m_out_)

    conv3m_out_ = torch.cat((self.upsample32(conv3m_out), conv2_out), 1)
    conv2m_out = self.conv2m(conv3m_out_)

    conv2m_out_ = torch.cat((self.upsample21(conv2m_out), conv1_out), 1)
    conv1m_out = self.conv1m(conv2m_out_)

    conv0_out = self.conv0(conv1m_out)

    return conv0_out

In [146]:
device = 'cuda'
model = Net().to(device) # try decreasing the depth value if there is a memory error


In [None]:
from torchsummary import summary
summary(model,(1,512,512))

In [None]:
model.train()  # set training mode

In [152]:
epoch = 100
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
#lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
criterion = nn.MSELoss()
# load_checkpoint('/content/drive/MyDrive/torch/DIP/DenoiseAE_0103.pth', model, optimizer)

In [153]:
def model_RGB(img_in):
    img_R = model(img_in[:,0,:,:].unsqueeze(1))
    img_G = model(img_in[:,1,:,:].unsqueeze(1))
    img_B = model(img_in[:,2,:,:].unsqueeze(1))
    img_out = torch.cat((img_R,img_G,img_B),dim=1)
    return img_out

In [158]:
def test(model):
    criterion = nn.MSELoss()
    model.eval()  # Important: set evaluation mode
    gau_loss, sp_loss, poi_loss = 0,0,0
    with torch.no_grad(): # This will free the GPU memory used for back-prop
        for origin, gau, sp, poi in val_loader:
            origin, gau, sp, poi = origin.cuda(), gau.cuda(), sp.cuda(), poi.cuda()
            output = model_RGB(gau)
            gau_loss += criterion(output, origin).item()
            output = model_RGB(sp)
            sp_loss += criterion(output, origin).item()
            output = model_RGB(poi)
            poi_loss += criterion(output, origin).item()

    # l = len(val_set)
    # gau_loss /= l
    # sp_loss /= l
    # poi_loss /= l
    print('\nValidation set: Gaussian loss: {:.4f}, Salt&Pepper loss: {:.4f}, Poisson loss: {:.4f}\n'.format(gau_loss, sp_loss, poi_loss))

In [171]:
mn = 100
for ep in range(epoch):
    g,s,p,now = 0,0,0,0
    for batch_idx, (origin, gau, sp, poi) in enumerate(train_loader):
        origin, gau, sp, poi = origin.cuda(), gau.cuda(), sp.cuda(), poi.cuda()
        output1 = model_RGB(gau)
        gau_loss = criterion(output1, origin)
        output2 = model_RGB(sp)
        sp_loss = criterion(output2, origin)
        output3 = model_RGB(poi)
        poi_loss = criterion(output3, origin)
        loss_1, loss_2, loss_3 = criterion(output1, output2),criterion(output1, output3),criterion(output3, output2)
        loss = (gau_loss + sp_loss + poi_loss) + (loss_1+loss_2+loss_3)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
            
        g += gau_loss.item()
        s += sp_loss.item()
        p += poi_loss.item()

    #lr_scheduler.step()
    # l = len(train_set)
    # g /= l
    # s /= l
    # p /= l
    now = (g+s+p)/3
    if now < mn:
        mn = now
        save_checkpoint('/content/drive/MyDrive/torch/DIP/DenoiseAE_UNet_loss.pth', model, optimizer)

    print('Train Epoch: {} Gaussian loss: {:.4f}, Salt&Pepper loss: {:.4f}, Poisson loss: {:.4f}\n'.format(ep, g, s, p))
    test(model)

model saved to /content/drive/MyDrive/torch/DIP/DenoiseAE_UNet_loss.pth
Train Epoch: 0 Gaussian loss: 0.1032, Salt&Pepper loss: 0.0996, Poisson loss: 0.0782


Validation set: Gaussian loss: 0.0171, Salt&Pepper loss: 0.0147, Poisson loss: 0.0110

model saved to /content/drive/MyDrive/torch/DIP/DenoiseAE_UNet_loss.pth
Train Epoch: 1 Gaussian loss: 0.0650, Salt&Pepper loss: 0.0581, Poisson loss: 0.0431


Validation set: Gaussian loss: 0.0170, Salt&Pepper loss: 0.0147, Poisson loss: 0.0109

Train Epoch: 2 Gaussian loss: 0.0649, Salt&Pepper loss: 0.0581, Poisson loss: 0.0433


Validation set: Gaussian loss: 0.0169, Salt&Pepper loss: 0.0146, Poisson loss: 0.0110

Train Epoch: 3 Gaussian loss: 0.0648, Salt&Pepper loss: 0.0581, Poisson loss: 0.0436


Validation set: Gaussian loss: 0.0169, Salt&Pepper loss: 0.0146, Poisson loss: 0.0111

model saved to /content/drive/MyDrive/torch/DIP/DenoiseAE_UNet_loss.pth
Train Epoch: 4 Gaussian loss: 0.0643, Salt&Pepper loss: 0.0577, Poisson loss: 0.0435


V

KeyboardInterrupt: ignored

# Evaluate

In [205]:
#model = DAE().cuda()
model.eval()
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.MSELoss()
load_checkpoint('/content/drive/MyDrive/torch/DIP/DenoiseAE_UNet_loss.pth', model, optimizer)

t2i = transforms.Compose([
    #transforms.Resize(256),
    #transforms.CenterCrop(512),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
i2t = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize((720,720)),
])

stn = Image.open('blurred/402.png')
stn = t2i(stn).cuda()
img = Image.open('blurred_noise/402_poisson.png')
im_t = t2i(img).cuda()
# stn = train_set[305][0].cuda()
# im_t = train_set[305][1].cuda()
im_o2 = i2t(im_t)
print(im_t.shape)
stn_t = torch.unsqueeze(stn,0)
im_t = torch.unsqueeze(im_t,0)
# im_t = im_t.unfold(2, 64, 64).unfold(3, 64, 64)
# output = torch.zeros(1,3,768,768)
# for i in range(12):
#     for j in range(12):
#         output[:,:,i*64:(i+1)*64,j*64:(j+1)*64] = model(im_t[:,:,i,j,:,:])
#         #output[:,:,i*64:(i+1)*64,j*64:(j+1)*64] = im_t[:,:,i,j,:,:]

# for i in range(5):
#     for j in range(5):
#         im_t[:,:,i*128:(i+1)*128,j*128:(j+1)*128] = model(im_t[:,:,i*128:(i+1)*128,j*128:(j+1)*128])
im_t = model_RGB(im_t)

loss = criterion(stn_t,im_t)
stn_t = torch.squeeze(stn_t,0)
stn = i2t(stn_t)
im_t = torch.squeeze(im_t,0)
im_o = i2t(im_t)

model loaded from /content/drive/MyDrive/torch/DIP/DenoiseAE_UNet_loss.pth
torch.Size([3, 720, 720])


In [206]:
im_o.save('denoise.png')
im_o2.save('noise.png')
stn.save('blurred.png')

In [188]:
loss

tensor(0.0005, device='cuda:0', grad_fn=<MseLossBackward0>)