In [7]:
import argparse
import sys
import re
import os, glob, datetime, time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import data_generator as dg
from data_generator import DenoisingDataset
import warnings

In [8]:
warnings.filterwarnings('ignore')

In [9]:
# Params
if any(["jupyter" in arg for arg in sys.argv]):
    # Simulate command line arguments (replace these with your desired defaults)
    sys.argv = ['ipykernel_launcher.py', '--model', 'DnCNN', '--batch_size', '4', '--train_data', './data/Train400', '--sigma', '4', '--epoch', '2', '--lr', '0.0005']

parser = argparse.ArgumentParser(description='PyTorch DnCNN')
parser.add_argument('--model', default='DnCNN', type=str, help='choose a type of model')
parser.add_argument('--batch_size', default=4, type=int, help='batch size')
parser.add_argument('--train_data', default='./data/Train400', type=str, help='path of train data')
parser.add_argument('--sigma', default=4, type=int, help='noise level')
parser.add_argument('--epoch', default=2, type=int, help='number of train epoches')
parser.add_argument('--lr', default=5e-4, type=float, help='initial learning rate for Adam')
args = parser.parse_args()
batch_size = args.batch_size
cuda = torch.cuda.is_available()
n_epoch = args.epoch
sigma = args.sigma

save_dir = os.path.join('models', args.model+'_' + 'sigma' + str(sigma))

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

class Residual_Blocks(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, padding = 1):
        super(Residual_Blocks, self).__init__()
        self.convx_1 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)
        self.Leakyrelu = nn.LeakyReLU(inplace=True)
        self.BN = nn.BatchNorm2d(64)
        self.convx_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)
        self.Leakyrelu = nn.LeakyReLU(inplace=True)

    def forward(self, x):
        x = self.BN(self.Leakyrelu(self.convx_1(x)))
        out = self.BN(self.Leakyrelu(self.convx_2(x)))
        return out
    
class Attention(nn.Module):
    def __init__(self, reduction=8):
        super(Attention, self).__init__()
        self.red = reduction
        self.query = nn.Conv2d(in_channels=64, out_channels=64//self.red, kernel_size=3, padding=1, stride=1, bias=False)
        self.value = nn.Conv2d(in_channels=64, out_channels=64//self.red, kernel_size=3, padding=1, stride=1, bias=False)
        self.key = nn.Conv2d(in_channels=64, out_channels=64//self.red, kernel_size=3, padding=1, stride=1, bias=False)
        self.out = nn.Conv2d(in_channels=64//self.red, out_channels=64, kernel_size=3, padding=1, stride=1, bias=False)
        
    def forward(self, x):
        b, c, h, w = x.size()
        query = self.query(x).view(b, -1, h*w)
        value = self.value(x).view(b, -1, h*w)
        key = self.key(x).view(b, -1, h*w)
        mul_1 = torch.bmm(query, value.transpose(1,2))
        attention_weights = mul_1 / np.sqrt(64//self.red)
        res_1 = nn.functional.softmax(attention_weights, dim=-1)
        mul_2 = torch.bmm(res_1, key).view(b, -1, h, w)
        out = self.out(mul_2)
        return x + out

        
class DnCNN(nn.Module):
    def __init__(self, n_channels=64, image_channels=1):
        super(DnCNN, self).__init__()
        self.conv1 = nn.Conv2d(image_channels, n_channels, kernel_size=3, padding=1, bias=False)
        self.attention = Attention()
        self.act = nn.LeakyReLU(inplace=True)
        self.out = nn.Sequential(*[Residual_Blocks() for _ in range(1)])
        self.dn = nn.Conv2d(64, 1, kernel_size=3, padding=1, bias=True)
        self._initialize_weights()
        
    def forward(self, x):
        y = x
        out = self.act(self.attention(self.conv1(x)))
        out = self.out(out)
        dn = y - self.dn(out)
        return dn
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

                
class sum_squared_error(_Loss):  # PyTorch 0.4.1
    """
    Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
    The backward is defined as: input-target
    """
    def __init__(self, size_average=None, reduce=None, reduction='sum'):
        super(sum_squared_error, self).__init__(size_average, reduce, reduction)

    def forward(self, input, target):
        # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
        return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)


def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)

In [10]:
import torch

# Attempt to manually select a CUDA device
try:
    torch.cuda.current_device()
    print("Current CUDA device:", torch.cuda.get_device_name())
except Exception as e:
    print("Error accessing CUDA device:", str(e))

Current CUDA device: NVIDIA GeForce RTX 3070 Laptop GPU


In [None]:
from tqdm import tqdm

if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = DnCNN()
    
    initial_epoch = 0
    model.train()
    # criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1
    criterion = sum_squared_error()
    if cuda:
        model = model.cuda()
         # device_ids = [0]
         # model = nn.DataParallel(model, device_ids=device_ids).cuda()
         # criterion = criterion.cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay = 1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.2)  # learning rates
    for epoch in tqdm(range(initial_epoch, n_epoch)):

        scheduler.step(epoch)  # step to the learning rate in this epcoh
        xs = dg.datagenerator(data_dir=args.train_data)
        xs = xs.astype('float32')/255.0
        xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))  # tensor of the clean patches, NXCXHXW
        DDataset = DenoisingDataset(xs, sigma)
        DLoader = DataLoader(dataset=DDataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True)
        epoch_loss = 0
        start_time = time.time()

        for n_count, batch_yx in enumerate(DLoader):
                optimizer.zero_grad()
                if cuda:
                    batch_x, batch_y = batch_yx[1].cuda(), batch_yx[0].cuda()
                loss = criterion(model(batch_y), batch_x)
                epoch_loss += loss.item()
                loss.backward()
                optimizer.step()
                if n_count % 10 == 0:
                    print('%4d %4d / %4d loss = %2.4f' % (epoch+1, n_count, xs.size(0)//batch_size, loss.item()/batch_size))
        elapsed_time = time.time() - start_time

        log('epcoh = %4d , loss = %4.4f , time = %4.2f s' % (epoch+1, epoch_loss/n_count, elapsed_time))
        np.savetxt('train_result.txt', np.hstack((epoch+1, epoch_loss/n_count, elapsed_time)), fmt='%2.4f')
        # torch.save(model.state_dict(), os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))
        torch.save(model, os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))


===> Building model
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight


  0%|          | 0/2 [00:00<?, ?it/s]


AxisError: axis 3 is out of bounds for array of dimension 2