In [37]:
import glob
import os
import time

import numpy as np

import cv2
import torch
import torch.cuda
import torch.nn as nn
import torch.nn.init as init
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import Dataset, DataLoader

In [38]:
model = "DnCNN"
batch_size = 128
train_path = ""
test_path = ""
noise_level = 15
n_epoch = 180
learning_rate = 1e-3

patch_size = 40
stride = 10

cuda = torch.cuda.is_available()

save_path = os.path.join('models', f"{model}_sigma{noise_level}")

if not os.path.exists(save_path):
    os.mkdir(save_path)

In [39]:
def load_train(train_path):
    images = []

    train_set = glob.glob(train_path + '/*.jpg')
    print(f"Train Set: {len(train_set)} images")

    for i in range(len(train_set)):
        image = cv2.imread(train_set[i], 0)
        images.append(image)

    print(f"{len(images)} images loaded")

    return images

In [40]:
def load_test(test_path):
    images = []

    test_set = glob.glob(test_path + '/*.jpg')
    print(f"Test Set: {len(test_set)} images")

    for i in range(len(test_set)):
        image = cv2.imread(test_set[i], 0)
        images.append(image)

    print(f"{len(images)} images loaded")
    
    return images

In [41]:
def gen_patches(img):
    h, w = img.shape

    patches = []

    for i in range(0, h - patch_size + 1, stride):
        for j in range(0, w - patch_size + 1, stride):
            patch = img[i:i + patch_size, j:j + patch_size]
            patches.append(patch)

    return patches

In [42]:
def gen_data(images):
    data = []

    for img in images:
        patches = gen_patches(img)
        for patch in patches:
            data.append(patch)

    data = np.array(data, dtype='uint8')
    data = np.expand_dims(data, axis=3)

    discard = len(data)-len(data)// batch_size*batch_size

    data = np.delete(data, range(discard), axis=0)

    return data

In [43]:
images = load_train("train_set")
data = gen_data(images)
#print(data)

Train Set: 200 images
200 images loaded


In [None]:
class DnDataset(Dataset):
    def __init__(self, patches, sigma):
        super(DnDataset, self).__init__()
        self.patches = patches
        self.sigma = sigma

    def __getitem__(self, idx):
        batch_x = self.patches[idx]
        noise = torch.randn(batch_x.size()).mul_(self.sigma / 255.0)
        batch_y = batch_x + noise
        return batch_y, batch_x

    def __len__(self):
        return self.patches.size(0)

In [44]:
class DnCNN(nn.Module):
    def __init__(self, depth=17, n_channels=64, images_channels=1, bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()

        padding = 1

        layers = [
            nn.Conv2d(in_channels=images_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding,
                      bias=True), nn.ReLU(inplace=True)]

        for i in range(depth - 2):
            layers.append(
                nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding,
                          bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum=0.95))
            layers.append(nn.ReLU(inplace=True))

        layers.append(
            nn.Conv2d(in_channels=n_channels, out_channels=images_channels, kernel_size=kernel_size, padding=padding,
                      bias=False))

        self.dncnn = nn.Sequential(*layers)
        self.initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y - out

    def initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                init.orthogonal_(module.weight)

                if module.bias is not None:
                    init.constant_(module, 0)

            elif isinstance(module, nn.BatchNorm2d):
                init.constant_(module.weight, 1)
                init.constant_(module.bias, 0)

In [None]:
model = DnCNN()

model.train()

if cuda:
    model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.MSELoss(reduction='sum')
scheduler = MultiStepLR(optimizer, milestones=[30,60,90], gamma=0.2)

In [None]:
for epoch in range(0, n_epoch):
    scheduler.step(epoch)
    images = load_train("train_set")
    xs = gen_data(images)
    xs = xs.astype('float32')/255.0
    xs = torch.from_numpy(xs.transpose((0,3,1,2)))
    DnDataset = DnDataset(xs, sigma=noise_level)
    DLoader = DataLoader(dataset=DnDataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True)
    epoch_loss = 0
    start_time = time.time()

    for n_count, batch_yx in enumerate(DLoader):
        optimizer.zero_grad()
        if cuda:
            batch_x, batch_y = batch_yx[1].cuda(), batch_yx[0].cuda()

        loss = criterion(model(batch_y), batch_x)

        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

        if n_count % 10 == 0:
            print('%4d %4d / %4d loss = %2.4f' % (epoch+1, n_count, xs.size(0)//batch_size, loss.item()/batch_size))