In [11]:
import torch.nn as nn
import torch, h5py
from torch.autograd import Variable
from torch.utils.data import DataLoader
import numpy as np
import os

In [12]:
device = ("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
class CNN_ReLU(nn.Module):
	def __init__(self, in_channels, out_channels, filter_size):
		super(CNN_ReLU, self).__init__()
		padding = int((filter_size - 1) / 2)
		self.layer = nn.Sequential(nn.Conv2d(in_channels, out_channels, filter_size, padding=padding),
		                           nn.ReLU())
		
	def forward(self, x):
		return self.layer(x)


class CNN_BN_ReLU(nn.Module):
	def __init__(self, in_channels, out_channels, filter_size):
		super(CNN_BN_ReLU, self).__init__()
		padding = int((filter_size - 1) / 2)
		self.layer = nn.Sequential(nn.Conv2d(in_channels, out_channels, filter_size, padding=padding),
		                           nn.BatchNorm2d(in_channels),
		                           nn.ReLU())
		
	def forward(self, x):
		return self.layer(x)
	

class CNN(nn.Module):
	def __init__(self, in_channels, out_channels, filter_size):
		super(CNN, self).__init__()
		padding = int((filter_size - 1) / 2)
		self.layer = nn.Conv2d(in_channels, out_channels, filter_size, padding=padding)
		
	def forward(self, x):
		return self.layer(x)
	

In [4]:
class DnCNN(nn.Module):
	def __init__(self, num_layers, input_channels, output_channels, filter_size):
		super(DnCNN, self).__init__()
		self.layers = nn.Sequential(
			CNN_ReLU(input_channels, output_channels, filter_size),
			nn.Sequential(*[CNN_BN_ReLU(output_channels, output_channels, filter_size) for x in range(num_layers)]),
			CNN(output_channels, input_channels, filter_size))
		
	def forward(self, x):
		return self.layers(x)

In [5]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, file_name):
        super(Dataset, self).__init__()
        self.file_name = file_name
        with h5py.File(file_name, 'r') as data:
            self.keys = list(data.keys())

    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, index):
        with h5py.File(self.file_name, 'r') as data:
            example = np.array(data[self.keys[index]])
        return torch.Tensor(example)

    def shape(self):
        with h5py.File(self.file_name, 'r') as data:
            return np.array(data[self.keys[0]]).shape

In [6]:
train_set = 'Image-Denoising-DNN/train.h5'
val_set = 'Image-Denoising-DNN/val.h5'
batch_size = 128

assert os.path.exists(train_set), f'Cannot find training vectors file {train_set}'
assert os.path.exists(val_set), f'Cannot find validation vectors file {train_set}'

print('Loading datasets')

train_data = Dataset(train_set)
val_data = Dataset(val_set)

train_loader = DataLoader(dataset=train_data, num_workers=os.cpu_count(), batch_size=batch_size, shuffle=True)

Loading datasets


In [7]:
DEPTH = 17
INPUT_CHANNELS = 1
OUTPUT_CHANNELS = 64
FILTER_SIZE = 3

LEARNING_RATE = 0.1
WEIGHT_DECAY = 0.0001
MOMENTUM = 0.9

END_LR = 0.0001
START_LR = 0.1
LR_EPOCHS = 50
GAMMA = np.log(END_LR / START_LR) / (-LR_EPOCHS)

NUM_ITERATIONS = 1

model = DnCNN(DEPTH, INPUT_CHANNELS, OUTPUT_CHANNELS, FILTER_SIZE)
optimizer = torch.optim.SGD([{'params': model.layers[0].parameters()},
                             {'params': model.layers[1].parameters()},
                             {'params': model.layers[2].parameters(), 'weight_decay':0}], 
                            lr=LEARNING_RATE,
                            momentum=MOMENTUM,
                            weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA)
loss = nn.MSELoss()

epoch_losses = []

model.to(device)

for epoch in range(NUM_ITERATIONS):
	print(f'Training epoch {epoch+1}')
	
	for batch_num, batch in enumerate(train_loader):
		batch = Variable(batch).to(device)
		
		optimizer.zero_grad()
		predict = model(batch)
		epoch_loss = loss(batch, predict)
		epoch_losses.append(epoch_loss.item())
		epoch_loss.backward()
		optimizer.step()
		
	scheduler.step()
	
print(loss(batch, model(batch)))

Training epoch 1


KeyboardInterrupt: 