In [2]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from tensorboardX import SummaryWriter

writer = SummaryWriter("logs")

class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17):
        super(DnCNN, self).__init__()
        kernel_size = 7
        padding = 3
        features = 32
        layers = []
        layers.append(nn.Conv1d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(nn.Conv1d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm1d(features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv1d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        out = self.dncnn(x.to(torch.float32))
        return out


class MyDataset(Dataset):
    def __init__(self, data_dir, noisy_dir, is_train):
        self.x = None
        self.y = None
        f_list = os.listdir(data_dir)
        if is_train: 
            for idx in range(min(sample_num, len(f_list))):
                if idx in [3,6,9]:
                    continue
                if self.x is None:
                    self.x = np.load(data_dir + str(idx+1) + '.npy')[:180, :]
                    self.y = np.load(noisy_dir + str(idx+1) + '.npy')[:180, :]
                else:
                    self.x = np.concatenate((self.x, np.load(data_dir + str(idx+1) + '.npy')[:180, :]))
                    self.y = np.concatenate((self.y, np.load(noisy_dir + str(idx+1) + '.npy')[:180, :]))
        else:
            for idx in range(min(sample_num, len(f_list))):
                if self.x is None:
                    self.x = np.load(data_dir + str(idx+1) + '.npy')[180:200, :]
                    self.y = np.load(noisy_dir + str(idx+1) + '.npy')[180:200, :]
                else:
                    self.x = np.concatenate((self.x, np.load(data_dir + str(idx+1) + '.npy')[180:200, :]))
                    self.y = np.concatenate((self.y, np.load(noisy_dir + str(idx+1) + '.npy')[180:200, :]))
            # print(len(self.x))
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]
        return x, y

def batch_PSNR(img, imclean):
    print(img.shape)
    print(imclean.shape)
    mse = np.mean((imclean - img) ** 2)
    if mse == 0:
        return 100
    max_pixel = 1.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = torch.device("cuda")

parser = argparse.ArgumentParser(description="DnCNN")
parser.add_argument("--preprocess", type=bool, default=False, help='run prepare_data or not')
parser.add_argument("--batchSize", type=int, default=64, help="Training batch size")
parser.add_argument("--num_of_layers", type=int, default=17, help="Number of total layers")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--milestone", type=int, default=30, help="When to decay learning rate; should be less than epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
parser.add_argument("--outf", type=str, default=".", help='path of log files')
parser.add_argument("--mode", type=str, default="B", help='with known noise level (S) or blind training (B)')
parser.add_argument("--noiseL", type=float, default=25, help='noise level; ignored when mode=B')
parser.add_argument("--val_noiseL", type=float, default=25, help='noise level used on validation set')

opt = parser.parse_args(args=[])  

sample_num = 10

model = DnCNN(1).cuda()
# print(model)

optimizer = optim.Adam(model.parameters(), lr=0.00002)
criterion = nn.MSELoss(size_average=False)


for clean_dataset,noisy_dataset in zip([
                                        # 'data/noisy_npy/','data/noisy_npy/','data/noisy_npy/','data/noisy_npy/',
                                        # 'data/bg_npy/','data/bg_npy/','data/bg_npy/',
                                        'data/bg_npy/',
                                        # 'data/clean_npy/','data/clean_npy/','data/clean_npy/','data/clean_npy/'
                                        ]
                                       ,[
                                         # 'data/bg_npy/','data/clean_npy/','data/new-n-to-c/','data/new-n-to-bg/',
                                         # 'data/noisy_npy/','data/clean_npy/','data/new-bg-to-c/',
                                           'data/new-bg-to-n/',
                                         # 'data/noisy_npy/','data/bg_npy/','data/new-c-to-n/','data/new-c-to-bg/'
                                       ]):

# for clean_dataset,noisy_dataset in zip(['data/noisy_npy/','data/noisy_npy/','data/bg_npy/','data/bg_npy/','data/clean_npy/','data/clean_npy/']
#                                        ,['data/bg_npy/','data/clean_npy/','data/noisy_npy/','data/clean_npy/','data/noisy_npy/','data/bg_npy/']):
# for clean_dataset,noisy_dataset in zip([
# #                                         # 'data/noisy_wv/','data/noisy_wv/',
#                                         'data/bg_wv/','data/bg_wv/',
# #                                         # 'data/clean_wv/','data/clean_wv/'
#                                        ]
#                                        ,[
# #                                          # 'data/bg_wv/','data/clean_wv/',
#                                          'data/noisy_wv/','data/clean_wv/',
# #                                          # 'data/noisy_wv/','data/bg_wv/'
#                                         ]):


    trian_data_dir = []
    train_dataset = MyDataset(clean_dataset, noisy_dataset, True)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    step = 0
    for epoch in range(opt.epochs):
        if epoch < opt.milestone:
            current_lr = opt.lr
        else:
            current_lr = opt.lr / 10.
        # set learning rate
        for param_group in optimizer.param_groups:
            param_group["lr"] = current_lr
        # print('learning rate %f' % current_lr)
        # train
        for i, data in enumerate(train_loader):
            # training step
            model.train()
            model.zero_grad()
            optimizer.zero_grad()
            img_train = Variable(data[0].cuda()).unsqueeze(1).to(torch.float32)
            imgn_train = Variable(data[1].cuda()).unsqueeze(1).to(torch.float32)
            noise = Variable((data[1]-data[0]).cuda()).unsqueeze(1).to(torch.float32)
            out_train = model(imgn_train)
            loss = criterion(out_train, noise) / (imgn_train.size()[0]*2)
            loss.backward()
            optimizer.step()
            # results
            model.eval()
            out_train = torch.clamp(imgn_train-model(imgn_train), 0., 1.)
            # if you are using older version of PyTorch, you may need to change loss.item() to loss.data[0]
            if step % 100 == 0:
                # Log the scalar values
                print(loss.item())
            step += 1

        if epoch  % 10 == 0:
            print(epoch)
        ## the end of each epoch
        model.eval()
        # validate
        psnr_val = 0

        val_dataset = MyDataset(clean_dataset, noisy_dataset, False)
        val_loader = DataLoader(val_dataset, batch_size=128, shuffle=True)
        # print(len(val_loader))
        for i, data in enumerate(val_loader, 0):
            img_val = torch.unsqueeze(data[0], 1)
            noise = data[1] - data[0]
            imgn_val = torch.unsqueeze(data[1], 1)
            img_val, imgn_val = Variable(img_val.cuda(), volatile=True), Variable(imgn_val.cuda(), volatile=True)
            out_val = torch.clamp(imgn_val-model(imgn_val), 0., 1.)
    torch.save(model.state_dict(), os.path.join(opt.outf, "net-{}-{}-part-30.pth".format(noisy_dataset.split('/')[1], clean_dataset.split('/')[1])))

    sample_num = 10

    model_dncnn = DnCNN(1)
    # model_dncnn.load_state_dict(torch.load("R-noisy-clean.pth"))
    model_dncnn.load_state_dict(torch.load("net-{}-{}-part-30.pth".format(noisy_dataset.split('/')[1], clean_dataset.split('/')[1])))
    model_dncnn.eval()
    Tensor = torch.cuda.FloatTensor
    input_A = Tensor(1, 1, 1000)

    val_dataset = MyDataset(clean_dataset, noisy_dataset, False)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

    from pathlib import Path
    path = Path("data/DnCNN-{}-{}-part-30/".format(noisy_dataset.split('/')[1], clean_dataset.split('/')[1]))
    path.mkdir(parents=True, exist_ok=True)

    out = []
    for i in range(len(val_dataset)):
        _, y = val_dataset[i]
        result = model_dncnn(torch.from_numpy(y).unsqueeze(0).unsqueeze(0)).detach().numpy()
        out.append(np.clip(y - result[0][0],0,1))
        if (i+1) % 20 == 0:
            np.save("data/DnCNN-{}-{}-part-30/{}".format(noisy_dataset.split('/')[1], clean_dataset.split('/')[1], i//20+1), np.asarray(out)) # DnCNN-noisy-clean R-noisy-clean
            out = []


50.85176086425781
0


  img_val, imgn_val = Variable(img_val.cuda(), volatile=True), Variable(imgn_val.cuda(), volatile=True)


0.7039929032325745
10
0.6817913055419922
20
0.7567195296287537
30
0.5274320840835571
40
0.5714787244796753
50
0.7146579623222351
60
0.5757565498352051
70
0.5676959156990051
80
0.6320645809173584
90
