In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torchvision import transforms
from gaussian_noise import GaussianNoise

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
gaussiannoise = GaussianNoise()

In [None]:
img_size = 128
totensor = transforms.ToTensor()
resize = transforms.Resize((img_size, img_size))
grayscale = transforms.Grayscale(num_output_channels=1)
transforms_ = transforms.Compose([totensor,
                                grayscale,
                                resize])

train_data = torchvision.datasets.ImageFolder("train", transform=transforms_)
train_data_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=128, shuffle=True)

In [None]:
test_clean = torch.load("test/test_clean.pt")
test_noisy_005 = torch.load("test/test_noisy_var_0.005.pt")
test_noisy_010 = torch.load("test/test_noisy_var_0.010.pt")
test_noisy_025 = torch.load("test/test_noisy_var_0.025.pt")
test_noisy_050 = torch.load("test/test_noisy_var_0.050.pt")

test_dataset = torch.utils.data.TensorDataset(test_noisy_010, test_clean) ## while training, test on var=0.010 dataset
test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128)

In [None]:
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

class cnn_autoencoder(nn.Module):
    
    def __init__(self):
        super(cnn_autoencoder,self).__init__()

        self.conv1 = nn.Conv2d(1, 64, 4, padding=1, stride=2) # [1, 128, 128] -> [64, 64, 64]
        self.conv1_bn = nn.BatchNorm2d(64)

        self.conv2 = nn.Conv2d(64, 128, 4, padding=1, stride=2) # [64, 64, 64] -> [128, 32, 32]
        self.conv2_bn = nn.BatchNorm2d(128)
  
        self.conv3 = nn.Conv2d(128, 256, 4, padding=1, stride=2) # [128, 32, 32] -> [256, 16, 16]
        self.conv3_bn = nn.BatchNorm2d(256)

        self.conv4 = nn.ConvTranspose2d(256, 128, 4, padding=1, stride=2) # [256, 16, 16] -> [128, 32, 32]
        self.conv4_bn = nn.BatchNorm2d(128)

        self.conv5 = nn.ConvTranspose2d(128, 64, 4, padding=1, stride=2) # [128, 32, 32] -> [64, 64, 64]
        self.conv5_bn = nn.BatchNorm2d(64)

        self.conv6 = nn.ConvTranspose2d(64, 1, 4, padding=1, stride=2) # [64, 64, 64] -> [1, 128, 128]
        
    def forward(self, x):
        x = self.conv1_bn(F.relu(self.conv1(x)))
        x = self.conv2_bn(F.relu(self.conv2(x)))
        x = self.conv3_bn(F.relu(self.conv3(x)))
        x = self.conv4_bn(F.relu(self.conv4(x)))
        x = self.conv5_bn(F.relu(self.conv5(x)))
        x = torch.sigmoid(self.conv6(x))
        return x

In [None]:
net = cnn_autoencoder()
net.to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
net.train()

cnn_autoencoder(
  (conv1): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv3_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv4_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv5_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv6): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)

In [None]:
from tqdm import tqdm

for epoch in range(10):
    
    running_train_loss = 0
    counter = 0

    for train_batch in train_data_loader:
        gt = train_batch[0].to(device)
        noisy = gaussiannoise(train_batch[0]).to(device)

        optimizer.zero_grad()
        output = net(noisy.float())
        train_loss = criterion(output, gt.float())
        train_loss.backward()
        optimizer.step()
        running_train_loss += train_loss.item()
        counter += 1
        
    net.eval()
    test_loss = []
    for test_batch in test_data_loader:
        test_noisy = test_batch[0].to(device)
        test_clean = test_batch[1].to(device)
        test_output = net(test_noisy.float())
        test_loss_val = criterion(test_output, test_clean)
        test_loss.append(test_loss_val.item())
    net.train()

    print("Epoch: {:d}, train loss: {:f}, test loss {:f}".format(epoch, running_train_loss/counter, np.mean(test_loss)))

Epoch: 0, train loss: 0.611233, test loss 0.578793
Epoch: 1, train loss: 0.560013, test loss 0.560193
Epoch: 2, train loss: 0.557247, test loss 0.558593
Epoch: 3, train loss: 0.555441, test loss 0.556569
Epoch: 4, train loss: 0.554043, test loss 0.556235
Epoch: 5, train loss: 0.553540, test loss 0.555194
Epoch: 6, train loss: 0.553393, test loss 0.555506
Epoch: 7, train loss: 0.552449, test loss 0.554327
Epoch: 8, train loss: 0.552390, test loss 0.558210
Epoch: 9, train loss: 0.552658, test loss 0.555041


In [None]:
# torch.save(net.state_dict(), "cnn_autoencoder_pathology.model")

In [None]:
test_pred_010 = net(test_noisy_010.float()).detach().cpu().float()
psnr_nn_010 = psnr(test_clean, test_pred_010)

fastNI_010 = torch.empty((0, 1, img_size, img_size))

for i in range(len(test_noisy_010)):
    noisy_img = test_noisy_010[i]
    noisy_img = (noisy_img.cpu().numpy().squeeze() * 255).astype(np.uint8)
    denoised = cv.fastNlMeansDenoising(noisy_img, None, 50, 7, 21)
    denoised = denoised[:,:,np.newaxis]
    denoised = totensor(denoised)
    fastNI_010 = torch.cat((fastNI_010, denoised.unsqueeze(0)), dim=0)

psnr_opencv_010 = psnr(test_clean, fastNI_010)

torch.save(test_pred_010, "cnn_autoencoder_denoised/var_010.pt")
torch.save(fastNI_010, "classical_denoised/var_010.pt")


print("PSNR for CNN: {:f} vs OpenCV: {:f} for noise var=0.010".format(psnr_nn_010, psnr_opencv_010))

In [None]:
test_img = test_noisy_010[0].squeeze().cpu().numpy()
nn_denoised_img = test_pred_010[0].squeeze().cpu().detach().numpy()
ni_denoised_img = fastNI_010[0].squeeze().numpy()
gt_img = test_clean[0].squeeze().cpu().numpy()

fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4, figsize=(10,40))
ax1.imshow(test_img, cmap='gray')
ax1.set_title("noisy img")
ax2.imshow(nn_denoised_img, cmap='gray')
ax2.set_title("NN denoised img")
ax3.imshow(ni_denoised_img, cmap='gray')
ax3.set_title("Classically denoised img")
ax4.imshow(gt_img, cmap='gray')
ax4.set_title("ground truth")
plt.show()

In [None]:
test_pred_025 = net(test_noisy_025.float()).detach().cpu().float()
psnr_nn_025 = psnr(test_clean, test_pred_025)

fastNI_025 = torch.empty((0, 1, img_size, img_size))

for i in range(len(test_noisy_025)):
    noisy_img = test_noisy_025[i]
    noisy_img = (noisy_img.cpu().numpy().squeeze() * 255).astype(np.uint8)
    denoised = cv.fastNlMeansDenoising(noisy_img, None, 50, 7, 21)
    denoised = denoised[:,:,np.newaxis]
    denoised = totensor(denoised)
    fastNI_025 = torch.cat((fastNI_025, denoised.unsqueeze(0)), dim=0)

psnr_opencv_025 = psnr(test_clean, fastNI_025)

torch.save(test_pred_025, "cnn_autoencoder_denoised/var_025.pt")
torch.save(fastNI_025, "classical_denoised/var_025.pt")

print("PSNR for CNN: {:f} vs OpenCV: {:f} for noise var=0.025".format(psnr_nn_025, psnr_opencv_025))

In [None]:
test_img = test_noisy_025[0].squeeze().cpu().numpy()
nn_denoised_img = test_pred_025[0].squeeze().cpu().detach().numpy()
ni_denoised_img = fastNI_025[0].squeeze().numpy()
gt_img = test_clean[0].squeeze().cpu().numpy()

fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4, figsize=(10,40))
ax1.imshow(test_img, cmap='gray')
ax1.set_title("noisy img")
ax2.imshow(nn_denoised_img, cmap='gray')
ax2.set_title("NN denoised img")
ax3.imshow(ni_denoised_img, cmap='gray')
ax3.set_title("Classically denoised img")
ax4.imshow(gt_img, cmap='gray')
ax4.set_title("ground truth")
plt.show()

In [None]:
test_pred_050 = net(test_noisy_050.float()).detach().cpu().float()
psnr_nn_050 = psnr(test_clean, test_pred_050)

fastNI_050 = torch.empty((0, 1, img_size, img_size))

for i in range(len(test_noisy_050)):
    noisy_img = test_noisy_050[i]
    noisy_img = (noisy_img.cpu().numpy().squeeze() * 255).astype(np.uint8)
    denoised = cv.fastNlMeansDenoising(noisy_img, None, 50, 7, 21)
    denoised = denoised[:,:,np.newaxis]
    denoised = totensor(denoised)
    fastNI_050 = torch.cat((fastNI_050, denoised.unsqueeze(0)), dim=0)

psnr_opencv_050 = psnr(test_clean, fastNI_050)

torch.save(test_pred_050, "cnn_autoencoder_denoised/var_050.pt")
torch.save(fastNI_050, "classical_denoised/var_050.pt")

print("PSNR for CNN: {:f} vs OpenCV: {:f} for noise var=0.050".format(psnr_nn_050, psnr_opencv_050))

In [None]:
test_img = test_noisy_050[0].squeeze().cpu().numpy()
nn_denoised_img = test_pred_050[0].squeeze().cpu().detach().numpy()
ni_denoised_img = fastNI_050[0].squeeze().numpy()
gt_img = test_clean[0].squeeze().cpu().numpy()

fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4, figsize=(10,40))
ax1.imshow(test_img, cmap='gray')
ax1.set_title("noisy img")
ax2.imshow(nn_denoised_img, cmap='gray')
ax2.set_title("NN denoised img")
ax3.imshow(ni_denoised_img, cmap='gray')
ax3.set_title("Classically denoised img")
ax4.imshow(gt_img, cmap='gray')
ax4.set_title("ground truth")
plt.show()