In [1]:
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
from skimage.util import random_noise
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
from dataloaders import mias_dataset
from model import cnn_autoencoder
from gaussian_noise_transform import GaussianNoise

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
clean_transform = transforms.Compose([ 
                                     ])

noise_transform = transforms.Compose([GaussianNoise(0,0.005), ## add gaussian noise of mean 0 var 0.01
                                     ])

In [5]:
mias_train_data = torch.load('mias_train.pt')
mias_dataset_ = mias_dataset(mias_train_data, clean_transform, noise_transform)
mias_data_loader = DataLoader(dataset=mias_dataset_, batch_size=10, shuffle=True) 

In [6]:
mias_clean_test = torch.load('mias_clean_test.pt')
mias_noisy_test = torch.load('mias_noisy_test.pt')
mias_clean_test = mias_clean_test.float()
mias_noisy_test = mias_noisy_test.float()
mias_clean_test = mias_clean_test.to(device)
mias_noisy_test = mias_noisy_test.to(device);

In [9]:
mias_net = cnn_autoencoder()
mias_net.to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(mias_net.parameters(), lr=0.001)
mias_net.train()

cnn_autoencoder(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsamp4): Upsample(scale_factor=2.0, mode=nearest)
  (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsamp5): Upsample(scale_factor=2.0, mode=nearest)
  (conv5): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [10]:
for epoch in range(1000):
    
    running_train_loss = 0
    counter = 0
    for clean, noisy in mias_data_loader:
        
        clean = clean.to(device)
        noisy = noisy.to(device)
        
        optimizer.zero_grad()
        output = mias_net(noisy.float())
        train_loss = criterion(output, clean.float())
        train_loss.backward()
        optimizer.step()
        running_train_loss += train_loss.item()
        counter += 1
    
    if (epoch+10) % 1 == 0:
        mias_net.eval()
        test_output = mias_net(mias_noisy_test)
        test_loss = criterion(test_output, mias_clean_test)
        test_loss = test_loss.item()
        mias_net.train()
        print("Epoch: {:d}, train loss: {:f}, test loss {:f}".format(epoch, running_train_loss/counter, test_loss))

Epoch: 0, train loss: 0.333238, test loss 0.270467
Epoch: 1, train loss: 0.319087, test loss 0.269362
Epoch: 2, train loss: 0.318123, test loss 0.271115
Epoch: 3, train loss: 0.317486, test loss 0.269075
Epoch: 4, train loss: 0.317222, test loss 0.270509
Epoch: 5, train loss: 0.317085, test loss 0.268362
Epoch: 6, train loss: 0.316941, test loss 0.268700
Epoch: 7, train loss: 0.316749, test loss 0.268159
Epoch: 8, train loss: 0.316863, test loss 0.271076
Epoch: 9, train loss: 0.316506, test loss 0.269085
Epoch: 10, train loss: 0.316277, test loss 0.269293
Epoch: 11, train loss: 0.316074, test loss 0.268971
Epoch: 12, train loss: 0.316325, test loss 0.268171


KeyboardInterrupt: 