In [None]:
#!/usr/bin/env python

import random
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet18
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.transforms import Compose
from torchvision.utils import save_image
from tqdm import tqdm
from gaussian_noise import GaussianNoise
from psnr import psnr

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
img_size = 128
totensor = transforms.ToTensor()
resize = transforms.Resize((img_size, img_size))
grayscale = transforms.Grayscale(num_output_channels=1)

PYTORCH_TRANSFORM = transforms.Compose([grayscale,
                                  resize,
                                  totensor
                                 ])
gaussiannoise = GaussianNoise()

def conv_down(in_ch, out_ch):
    layer = nn.Conv2d(in_ch, out_ch, kernel_size=4, padding=1, stride=2)
    return layer

class Discriminator(nn.Module):
    def __init__(self, in_ch, out_ch=32):
        super(Discriminator, self).__init__()
        self.out_ch = out_ch
        self.block = nn.Sequential(
            conv_down(in_ch, out_ch),
            nn.ReLU(inplace=True),
            conv_down(out_ch, out_ch*2),
            nn.ReLU(inplace=True),
            conv_down(out_ch*2, out_ch*4),
            nn.ReLU(inplace=True),
            conv_down(out_ch*4, out_ch*8),
            nn.ReLU(inplace=True),
            conv_down(out_ch*8, out_ch*8),
            nn.ReLU(inplace=True)
        )

        self.linear = nn.Linear(out_ch*8*8*8, 1)

    def forward(self, x):
        out = self.block(x)
        out = torch.sigmoid(self.linear(out.view(-1,self.out_ch*8*8*8)))

        return out


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, downsample=False):
        super(ConvBlock, self).__init__()

        self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
                                   nn.ReLU(inplace=True))
        self.downsample = downsample

        if self.downsample:
            self.down_layer = conv_down(out_ch, out_ch)

    def forward(self, x):
        out = self.block(x)
        if self.downsample:
            out_down = self.down_layer(out)
            return out_down, out
        return out

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UpBlock, self).__init__()

        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_ch, out_ch)

    def forward(self, x, upass):
        upsampled = self.up(x)
        combined = torch.cat((upsampled, upass), 1)
        out = self.conv(combined)

        return out

class DeNoiser(nn.Module):
    def __init__(self, in_ch, frames=32, depth=2):
        super(DeNoiser, self).__init__()

        self.depth = depth
        self.collapse = nn.ModuleList()
        prev_channels = in_ch

        for i in range(depth):
            downsample = True if (i+1) < depth else False
            self.collapse.append(ConvBlock(prev_channels, frames*(2**i), downsample))
            prev_channels = frames * (2**i)

        self.restore = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.restore.append(UpBlock(prev_channels, frames*(2**i)))
            prev_channels = frames * (2**i)

        self.final_conv = nn.Conv2d(prev_channels, in_ch, kernel_size=3, padding=1, stride=1)

    def forward(self, x):
        pass_forward = []
        for i, block in enumerate(self.collapse):
            if (i + 1) < self.depth:
                x, x_up = block(x)
                pass_forward.append(x_up)
            else:
                x = block(x)

        for i, block in enumerate(self.restore):
            x = block(x, pass_forward[-i-1])

        out = torch.sigmoid(self.final_conv(x))
        return out
            
def discriminator_step(imgs, noisy, discriminator, disc_opt):
    real = discriminator(imgs)
    fake = discriminator(noisy)

    real_loss = F.binary_cross_entropy(real, torch.zeros_like(real))
    fake_loss = F.binary_cross_entropy(fake, torch.ones_like(fake))

    loss = real_loss + fake_loss

    disc_opt.zero_grad()
    loss.backward()
    disc_opt.step()

    return real_loss, fake_loss

def training_loop(n_epochs, optimizer, disc_opt, model, discriminator,
                  loss_func, train_loader, pretrain_epoch=3, alpha=0.001):
    for epoch in range(1, pretrain_epoch+1):
        real_losses = 0
        fake_losses = 0
        #for imgs, labels in tqdm(train_loader):
        #    noise_gen = GaussianNoise(var=random.uniform(0.005,0.05))
        #    noisy = noise_gen(imgs).cuda()
        #    imgs, labels = imgs.cuda(), labels.cuda()

        #    real_loss, fake_loss = discriminator_step(imgs, noisy,
        #                                              discriminator, disc_opt)

        #    real_losses += real_loss.item()
        #    fake_losses += fake_loss.item()

        #with open("unet_loss.txt", "a") as resfile:
        #    print("Epoch {}: Training loss {} {}".format(
        #        epoch, float(real_losses), float(fake_losses)), file=resfile)
    noise_gen = GaussianNoise()
    for epoch in tqdm(range(1, n_epochs + 1)):
        recon_losses = 0
        disc_losses = 0
        real_losses = 0
        fake_losses = 0
        for imgs, labels in train_loader:
            noisy = noise_gen(imgs).float().cuda()
            imgs, labels = imgs.cuda(), labels.cuda()

            outputs = model(noisy)
            disc_out = discriminator(outputs)

            recon_loss = F.mse_loss(outputs, imgs)
            disc_loss = F.binary_cross_entropy(disc_out,
                                               torch.zeros_like(disc_out))

            loss = recon_loss + alpha * disc_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            noisy = noise_gen(imgs.cpu()).float().cuda()
            outputs = model(noisy)            
            real_loss, fake_loss = discriminator_step(imgs, outputs.detach(),
                                                      discriminator, disc_opt)

            recon_losses += recon_loss.item()
            disc_losses += disc_loss.item()
            real_losses += real_loss.item()
            fake_losses += fake_loss.item()

        save_image(noisy, str(epoch) + "_noisy.png")
        save_image(imgs, str(epoch) + "_real.png")
        save_image(outputs, str(epoch) + "_recon.png")

        with open("unet_loss.txt", "a") as resfile:
            print("Epoch {}: Training loss {} {} {} {}".format(
                epoch, float(recon_losses), float(disc_losses),
                float(real_losses), float(fake_losses)), file=resfile)

In [None]:
imgs = ImageFolder("train", PYTORCH_TRANSFORM)
full = torch.utils.data.DataLoader(imgs, batch_size=128)

model = nn.DataParallel(DeNoiser(1)).cuda()
model.load_state_dict(torch.load("unet.pt"))
model.eval()
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_func = nn.NLLLoss().cuda().cuda()

discriminator = nn.DataParallel(Discriminator(1)).cuda()
disc_opt = optim.Adam(discriminator.parameters(), lr=1e-4)

training_loop(10, optimizer, disc_opt, model, discriminator, loss_func, full)

torch.save(model.state_dict(), "gan.pt")


In [None]:
model = nn.DataParallel(DeNoiser(1)).cuda()
model.load_state_dict(torch.load("gan.pt"))

In [None]:
test_clean = torch.load("test/test_clean.pt")
test_noisy_010 = torch.load("test/test_noisy_var_0.010.pt")
test_noisy_025 = torch.load("test/test_noisy_var_0.025.pt")
test_noisy_050 = torch.load("test/test_noisy_var_0.050.pt")

In [None]:

with torch.no_grad():
    test_pred_010 = model(test_noisy_010.float()).cpu().float()
    psnr_nn_010 = psnr(test_clean, test_pred_010)

torch.save(test_pred_010, "gan_autoencoder_denoised/var_010.pt")

print("PSNR for GAN: {:f} for noise var=0.010".format(psnr_nn_010))


with torch.no_grad():
    test_pred_025 = model(test_noisy_025.float()).cpu().float()
    psnr_nn_025 = psnr(test_clean, test_pred_025)

torch.save(test_pred_025, "gan_autoencoder_denoised/var_025.pt")

print("PSNR for GAN: {:f} for noise var=0.025".format(psnr_nn_025))


with torch.no_grad():
    test_pred_050 = model(test_noisy_050.float()).cpu().float()
    psnr_nn_050 = psnr(test_clean, test_pred_050)

torch.save(test_pred_050, "gan_autoencoder_denoised/var_050.pt")

print("PSNR for GAN: {:f} for noise var=0.050".format(psnr_nn_050))


In [None]:
test_img = test_noisy_010[0].squeeze().cpu().numpy()
nn_denoised_img = test_pred_010[0].squeeze().cpu().detach().numpy()
gt_img = test_clean[0].squeeze().cpu().numpy()

fig, (ax1, ax2, ax4) = plt.subplots(1,3, figsize=(10,30))
ax1.imshow(test_img, cmap='gray')
ax1.set_title("noisy img")
ax2.imshow(nn_denoised_img, cmap='gray')
ax2.set_title("GAN denoised img")
ax4.imshow(gt_img, cmap='gray')
ax4.set_title("ground truth")
plt.show()

In [None]:
test_img = test_noisy_025[0].squeeze().cpu().numpy()
nn_denoised_img = test_pred_025[0].squeeze().cpu().detach().numpy()
gt_img = test_clean[0].squeeze().cpu().numpy()

fig, (ax1, ax2, ax4) = plt.subplots(1,3, figsize=(10,30))
ax1.imshow(test_img, cmap='gray')
ax1.set_title("noisy img")
ax2.imshow(nn_denoised_img, cmap='gray')
ax2.set_title("GAN denoised img")
ax4.imshow(gt_img, cmap='gray')
ax4.set_title("ground truth")
plt.show()

In [None]:
test_img = test_noisy_050[0].squeeze().cpu().numpy()
nn_denoised_img = test_pred_050[0].squeeze().cpu().detach().numpy()
gt_img = test_clean[0].squeeze().cpu().numpy()

fig, (ax1, ax2, ax4) = plt.subplots(1,3, figsize=(10,30))
ax1.imshow(test_img, cmap='gray')
ax1.set_title("noisy img")
ax2.imshow(nn_denoised_img, cmap='gray')
ax2.set_title("GAN denoised img")
ax4.imshow(gt_img, cmap='gray')
ax4.set_title("ground truth")
plt.show()