In [1]:
import os
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import math
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

In [2]:
class DnCNN(nn.Module):
    def __init__(self):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        layers = []
        layers.append(nn.Conv2d(3, 64, kernel_size=3, padding = 1))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(15):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding,\
                                    bias=False))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))
            
        layers.append(nn.Conv2d(in_channels=features, out_channels=3, kernel_size=kernel_size, padding=padding,\
                                bias=False))
        
        self.dncnn = nn.Sequential(*layers)
    

    def forward(self, x):
        out = self.dncnn(x)
        return out

In [3]:
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        # nn.init.uniform(m.weight.data, 1.0, 0.02)
        m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
        nn.init.constant_(m.bias.data, 0.0)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
net = DnCNN()
net.apply(weights_init_kaiming)
model = net.to(device)
print(model)

DnCNN(
  (dncnn): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU(inplace=True)
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): Conv2d(64, 64, kernel_size=(3, 3), str

In [5]:
lr=1e-3
#Optimizer
optimizer = optim.Adam(model.parameters(), lr = lr)

In [None]:
epochs = 50
milestone = 30
step = 0
for epoch in range(epochs):
    if epoch < milestone:
        current_lr = lr
    else:
        current_lr = lr / 2
        lr = current_lr
    for param_group in optimizer.param_groups:
        param_group["lr"] = current_lr
    print(f'learning rate {current_lr}')
    good_images_dir = os.path.abspath("")+"\\data\\train\\"
    noisy_images_dir = os.path.abspath("")+"\\data\\train\\noisy_train\\"

    criteria = nn.MSELoss(size_average=False)
    step = 0
    total_loss = 0
    for i in os.listdir(noisy_images_dir):
        step = step + 1
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        #load each training image
        nimage= Image.open(noisy_images_dir+i)
        nimage = nimage.resize((240,240))
        nimage = np.array(nimage).astype('float32')
        goodi = i.split('_')
        gimage = Image.open(good_images_dir+goodi[0]+'.jpg')
        gimage = gimage.resize((240,240))
        gimage = np.array(gimage).astype('float32')
        #turn image into valid tensor
        nimage = np.transpose(nimage, (2,0,1))
        gimage = np.transpose(gimage, (2,0,1))
        ntensor = torch.from_numpy(nimage)
        gtensor = torch.from_numpy(gimage)
        ntensor = ntensor.unsqueeze(0)
        gtensor = gtensor.unsqueeze(0)
        
        out_train=model(ntensor)
        loss = criteria(out_train, gtensor) / (255*240*2)
        total_loss = loss + total_loss
        if step % 100 == 0:
            print(f'Step: {step} in loss value {loss}')
        loss.backward()
        optimizer.step()
        model.eval()
    torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth'))
    avg_loss = total_loss / step
    print(f'checkpoint epoch: {epoch}, average_loss: {avg_loss}')

learning rate 0.001




Step: 100 in loss value 21939.775390625
Step: 200 in loss value 33999.30859375
Step: 300 in loss value 3231.089599609375
Step: 400 in loss value 2391.1884765625
Step: 500 in loss value 2966.862060546875
Step: 600 in loss value 2082.3583984375
