In [6]:
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
from skimage.util import random_noise
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [7]:
from read_data import get_mias_data, get_dx_data
from dataloaders import mias_dataset, dx_dataset
from model import cnn_autoencoder
from gaussian_noise_transform import GaussianNoise

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
img_size = 128 # 128 x 128 pixels

clean_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Resize(img_size),
                                     ])

noise_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Resize(img_size), 
                                      GaussianNoise(0,0.005), ## add gaussian noise of mean 0 var 0.01
                                     ])

In [10]:
mias_data = mias_dataset('mias_train', clean_transform, noise_transform)
mias_data_loader = DataLoader(dataset=mias_data, batch_size=10, shuffle=True)

In [11]:
# for _, noisy in mias_data_loader:
#     plt.imshow(noisy.squeeze(), cmap="gray")
#     plt.show()

In [12]:
dx_data = dx_dataset('dx_train', clean_transform, noise_transform)
dx_data_loader = DataLoader(dataset=dx_data, batch_size=1, shuffle=True)

KeyboardInterrupt: 

In [None]:
# for _, noisy in dx_data_loader:
#     plt.imshow(noisy.squeeze(), cmap="gray")
#     plt.show()

## Define CNN autoencoder model

In [None]:
mias_net = cnn_autoencoder()
mias_net.to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(mias_net.parameters(), lr=0.001)

In [None]:
for epoch in range(1000):
    
    running_loss = 0
    for clean, noisy in mias_data_loader:
        clean = clean.to(device)
        noisy = noisy.to(device)
        
        optimizer.zero_grad()
        output = mias_net(noisy.float())
        loss = criterion(output, clean)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    if epoch % 10 == 9:
        print("loss for {:d} is {:f}".format(epoch, running_loss))

In [None]:
clean, noisy = next(iter(mias_data_loader))

In [None]:
plt.imshow(clean[0].squeeze(),cmap='gray')
plt.show()

In [None]:
plt.imshow(noisy[0].squeeze(),cmap='gray')
plt.show()

In [None]:
prediction = mias_net(noisy.float().to(device))
prediction = prediction[0].cpu().detach().numpy().squeeze()
plt.imshow(prediction, cmap='gray')