In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
from dataloaders import mias_dataset
from model import cnn_autoencoder
from gaussian_noise_transform import GaussianNoise

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
clean_transform = transforms.Compose([ 
                                     ])

noise_transform = transforms.Compose([GaussianNoise(0,0.005), ## add gaussian noise of mean 0 var 0.01
                                     ])

In [None]:
mias_train_data = torch.load('mias_train.pt')
mias_dataset_ = mias_dataset(mias_train_data, clean_transform, noise_transform)
mias_data_loader = DataLoader(dataset=mias_dataset_, batch_size=10, shuffle=True) 

In [None]:
mias_clean_test = torch.load('mias_clean_test.pt')
mias_noisy_test = torch.load('mias_noisy_test.pt')
mias_clean_test = mias_clean_test.float()
mias_noisy_test = mias_noisy_test.float()
mias_clean_test = mias_clean_test.to(device)
mias_noisy_test = mias_noisy_test.to(device);

In [None]:
mias_net = cnn_autoencoder()
mias_net.to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(mias_net.parameters(), lr=0.001)
mias_net.train()

In [None]:
for epoch in range(1000):
    
    running_train_loss = 0
    counter = 0
    for clean, noisy in mias_data_loader:
        
        clean = clean.to(device)
        noisy = noisy.to(device)
        
        optimizer.zero_grad()
        output = mias_net(noisy.float())
        train_loss = criterion(output, clean.float())
        train_loss.backward()
        optimizer.step()
        running_train_loss += train_loss.item()
        counter += 1
    
    if (epoch+1) % 10 == 0:
        mias_net.eval()
        test_output = mias_net(mias_noisy_test)
        test_loss = criterion(test_output, mias_clean_test)
        test_loss = test_loss.item()
        mias_net.train()
        print("Epoch: {:d}, train loss: {:f}, test loss {:f}".format(epoch, running_train_loss/counter, test_loss))

### Plot NN denoised predictions versus ground truths

In [None]:
test_pred = mias_net(mias_noisy_test)

for i in range(len(test_pred)):
  test_img = mias_noisy_test[i].squeeze().cpu().numpy()
  test_pred_img = test_pred[i].squeeze().cpu().detach().numpy()
  clean_img = mias_clean_test[i].squeeze().cpu().numpy()

  fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(7,21))
  ax1.imshow(test_img, cmap='gray')
  ax2.imshow(test_pred_img, cmap='gray')
  ax3.imshow(clean_img, cmap='gray')
  plt.show()