In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image

In [None]:
path = './data'
trainInPath = path + '/train'
trainOutPath = path + '/train_cleaned'
testPath = path + '/test_dirty'
testOutPath = path + '/test_clean'
savePath = path + '/AEsave'

In [None]:
def loadImg(path):
    data = []
    for title in sorted(os.listdir(path)):
        img = cv2.imread(path + '/' + title)
        #zero centre
        img = np.asarray(img, dtype = np.float32)
        img = cv2.resize(img, (540, 420))
        img = img / 255.0
        data.append(img)
    return data

def saveImg(path, data):
    for i,img in enumerate(data):
        img = np.asarray(img * 255.0, dtype = np.uint8)
        Image.fromarray(img).save(path+'/'+str(i)+'.png')

In [None]:
train = np.einsum('klij->kjli',np.asarray(loadImg(trainInPath)))
train_cleaned = np.einsum('klij->kjli',np.asarray(loadImg(trainOutPath)))
#test = loadImg(testPath)
#test_cleanned = loadImg(testOutPath)

In [None]:
train.shape

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        # 540 x 420
        self.conv1 = nn.Conv2d(
                in_channels = 3,
                out_channels = 64,
                kernel_size = 3,
                stride = 3,
                padding = 1
        )
        self.batch1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU(True)

        self.maxpool1 = nn.MaxPool2d(
                kernel_size = 2,
                stride = 2,
                return_indices=True
        )
        self.conv2 = nn.Conv2d(
                in_channels = 64,
                out_channels = 128,
                kernel_size = 3,
                stride = 2,
                padding = 1
        )
        self.batch2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU(True)

        self.maxpool2 = nn.MaxPool2d(
                kernel_size = 2,
                stride = 1,
                return_indices=True
        )

                        
        self.unpool1 = nn.MaxUnpool2d(
                kernel_size = 2,
                stride = 1,
        )
        self.convt1 = nn.ConvTranspose2d(
                in_channels = 128,
                out_channels = 64,
                kernel_size = 4,
                stride = 2,
                padding = 1
         )
        self.batch3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU(True)
        self.unpool2 = nn.MaxUnpool2d(
                kernel_size = 2,
                stride = 2,
        )
        self.convt2 = nn.ConvTranspose2d(
                in_channels = 64,
                out_channels = 3,
                kernel_size = 5,
                stride = 3,
                padding = 1
        )
        self.batch4 = nn.BatchNorm2d(3)
        self.tanh = nn.Tanh()
        self.relu4 = nn.ReLU(True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.relu1(x)
        x, INDICES_1 = self.maxpool1(x)
        x = self.conv2(x)
        x = self.batch2(x)
        x = self.relu2(x)
        x, INDICES_2 = self.maxpool2(x)
        #encoded = x
        x = self.unpool1(x,INDICES_2)
        x = self.convt1(x)
        x = self.batch3(x)
        x = self.relu3(x)
        x = self.unpool2(x,INDICES_1)
        x = self.convt2(x)
        x = self.batch4(x)
        x = self.tanh(x)
        return x
        
#initialize network
model = autoencoder().to(device)

#hyperparameters
learning_rate = 1e-2
num_epochs = 300
batch_size = 144
#loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 1e-5)
total_loss = 0



In [None]:
#load data
class denosiDataset(Dataset):
  def __init__(self):
    self.x_data = torch.from_numpy(train)
    self.y_data = torch.from_numpy(train_cleaned)
    self.n_samples = len(train)

  def __getitem__(self,index):
    return self.x_data[index], self.y_data[index]

  def __len__(self):
    return self.n_samples

dataset = denosiDataset()


In [None]:
#load data
train_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = True)


In [None]:
#train network
lossList = []
for epoch in range(num_epochs):
    for batch_idx, (data,target) in enumerate(train_loader):
        data = data.to(device = device)
        targets = target.to(device = device)
        
        output = model(data)
        loss = criterion(output,targets)
        
        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()
        
    total_loss += loss.data
    tmp = total_loss+0
    lossList.append(tmp)
    #print("epoch [{}/{}], loss:{:.4f}".format(epoch+1,num_epochs,total_loss))
    if epoch%10 == 0:
        print("epoch [{}/{}], loss:{:.4f}".format(epoch+1,num_epochs,total_loss))
        img = np.einsum('klij->kijl',output.cpu().data)
        saveImg(savePath,img)
        
        
torch.save(model.state_dict,'./autoencoder.pth')

In [None]:
print(device)
plt.plot(lossList)