In [0]:
import torch.nn as nn
import torch, h5py
from torch.autograd import Variable
from torch.utils.data import DataLoader
import numpy as np
import os, math
from google.colab import drive
from tqdm import tqdm_notebook
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt

In [0]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

drive.mount('/content/drive')

Wed Mar  4 17:44:15 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.59       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0    25W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [0]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, file_name):
        super(Dataset, self).__init__()
        self.file_name = file_name
        with h5py.File(file_name, 'r') as data:
            self.keys = list(data.keys())

    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, index):
        with h5py.File(self.file_name, 'r') as data:
            example = np.array(data[self.keys[index]])
        return torch.Tensor(example)

    def shape(self):
        with h5py.File(self.file_name, 'r') as data:
            return np.array(data[self.keys[0]]).shape

In [0]:
def batch_psnr(clean_image, denoised_image):
    clean_image = clean_image.data.cpu().numpy().astype(np.float32)
    denoised_image = denoised_image.data.cpu().numpy().astype(np.float32)

    batch_psnr_val = 0
    for i in range(clean_image.shape[0]):
        batch_psnr_val += psnr(clean_image[i,:,:,:], denoised_image[i,:,:,:], data_range=1)

    return batch_psnr_val / clean_image.shape[0]   

def setup_gpus():
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    device_ids = [i for i in range(torch.cuda.device_count())]
    if len(device_ids) > 3:
        device_ids = device_ids[:-1]
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, device_ids))
    return device_ids

In [0]:
def init_weights(layer):
    if isinstance(layer, nn.Conv2d):
        nn.init.kaiming_normal_(layer.weight.data, a=0, mode='fan_in')
    if isinstance(layer, nn.BatchNorm2d):
        layer.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
        nn.init.constant_(layer.bias.data, 0.0)

class CNN_ReLU(nn.Module):
    def __init__(self, in_channels, out_channels, filter_size):
        super(CNN_ReLU, self).__init__()
        padding = int((filter_size - 1) / 2)
        self.layer = nn.Sequential(nn.Conv2d(in_channels, out_channels, filter_size, padding=padding, bias=False), nn.ReLU())
		
    def forward(self, x):
        return self.layer(x)


class CNN_BN_ReLU(nn.Module):
    def __init__(self, in_channels, out_channels, filter_size):
        super(CNN_BN_ReLU, self).__init__()
        padding = int((filter_size - 1) / 2)
        self.layer = nn.Sequential(nn.Conv2d(in_channels, out_channels, filter_size, padding=padding, bias=False),\
		                           nn.BatchNorm2d(in_channels),\
		                           nn.ReLU())
		
    def forward(self, x):
        return self.layer(x)
	

class CNN(nn.Module):
    def __init__(self, in_channels, out_channels, filter_size):
        super(CNN, self).__init__()
        padding = int((filter_size - 1) / 2)
        self.layer = nn.Conv2d(in_channels, out_channels, filter_size, padding=padding, bias=False)
		
    def forward(self, x):
        return self.layer(x)

class DnCNN(nn.Module):
    def __init__(self, num_layers, input_channels, output_channels, filter_size):
        super(DnCNN, self).__init__()
        self.layers = nn.Sequential(
        CNN_ReLU(input_channels, output_channels, filter_size),
        nn.Sequential(*[CNN_BN_ReLU(output_channels, output_channels, filter_size) for x in range(num_layers)]),
        CNN(output_channels, input_channels, filter_size))
		
    def forward(self, x):
        return self.layers(x)	

In [0]:
train_set = '/content/drive/My Drive/Colab Notebooks/train.h5'
val_set = '/content/drive/My Drive/Colab Notebooks/val.h5'
batch_size = 64

assert os.path.exists(train_set), f'Cannot find training vectors file {train_set}'
assert os.path.exists(val_set), f'Cannot find validation vectors file {val_set}'

print('Loading datasets')

train_data = Dataset(train_set)
val_data = Dataset(val_set)

print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(val_data)}')

train_loader = DataLoader(dataset=train_data, num_workers=os.cpu_count(), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_data, num_workers=os.cpu_count(), batch_size=batch_size, shuffle=False) 

Loading datasets
Number of training examples: 96400
Number of validation examples: 14132


In [0]:
RESUME_TRAINING = True

DEPTH = 17
INPUT_CHANNELS = 1
OUTPUT_CHANNELS = 64
FILTER_SIZE = 3

LEARNING_RATE = 0.01
WEIGHT_DECAY = 0.00001
MOMENTUM = 0.9

END_LR = 0.00001
START_LR = 0.01
LR_EPOCHS = 50
GAMMA = np.power(END_LR / START_LR, 1/LR_EPOCHS)

NUM_ITERATIONS = 50

# detect gpus and setup environment variables
device_ids = setup_gpus()
print(f'Cuda devices found: {[torch.cuda.get_device_name(i) for i in device_ids]}')

model = DnCNN(DEPTH, INPUT_CHANNELS, OUTPUT_CHANNELS, FILTER_SIZE)
model.apply(init_weights)
model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()

loss = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA)

Cuda devices found: ['Tesla P100-PCIE-16GB']


In [0]:
epochs_trained = 0
epoch_losses = []
epoch_val_losses = []
epoch_psnrs = []
min_val_loss = 1000

if RESUME_TRAINING:
    checkpoint = torch.load('/content/drive/My Drive/Colab Notebooks/logs/model.state')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    epochs_trained = checkpoint['epoch']
    epoch_losses = checkpoint['epoch_train_losses']
    epoch_val_losses = checkpoint['epoch_val_losses']
    epoch_psnrs = checkpoint['epoch_psnrs']
    min_val_loss = checkpoint['min_val_loss']

In [0]:


# print(model)
for epoch in tqdm_notebook(range(NUM_ITERATIONS - epochs_trained)):
    print(f'Training epoch {epoch+1} with lr={optimizer.param_groups[0]["lr"]}')
    epoch_loss = 0
    num_steps = 0
    model.train()

    for batch_num, batch in enumerate(train_loader):
        model.zero_grad()
        optimizer.zero_grad()

        noise_25 = torch.FloatTensor(batch.size()).normal_(mean=0, std=25/255)
        noisy_image = batch + noise_25

        noisy_image = Variable(noisy_image.cuda())
        noise_25 =  Variable(noise_25.cuda())

        predict = model(noisy_image)
        batch_loss = loss(noise_25, predict) / batch.size()[0]
        batch_loss.backward()
        optimizer.step()
        epoch_loss += batch_loss.item()
        num_steps += 1

        batch.detach()
        noisy_image.detach()
        noise_25.detach()
        batch_loss.detach()
        del batch, noisy_image, noise_25, batch_loss

    epoch_losses.append(epoch_loss/num_steps)

    epoch_val_loss = 0
    epoch_psnr = 0
    num_steps = 0
    model.eval()

    for batch_num, batch in enumerate(val_loader):
        noise_25 = torch.FloatTensor(batch.size()).normal_(mean=0, std=25/255)
        noisy_image = batch + noise_25

        noisy_image = Variable(noisy_image.cuda())
        noise_25 =  Variable(noise_25.cuda())

        predict = model(noisy_image)
        val_loss = loss(noise_25, predict) / batch.size()[0]
        epoch_val_loss += val_loss.item()
        num_steps += 1

        if val_loss < min_val_loss:
            min_val_loss = val_loss
            torch.save(model.state_dict, f'/content/drive/My Drive/Colab Notebooks/logs/t_star.state')

        # Calculate PSNR
        denoised_image = torch.clamp(noisy_image - predict, 0.0, 1.0)
        epoch_psnr += batch_psnr(batch, denoised_image)

        batch.detach()
        noisy_image.detach()
        noise_25.detach()
        val_loss.detach()
        del batch, noisy_image, noise_25, val_loss

    scheduler.step()
    
    epoch_val_losses.append(epoch_val_loss/num_steps)
    epoch_psnrs.append(epoch_psnr/num_steps)

    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'epoch_train_losses': epoch_losses,
            'epoch_val_losses': epoch_val_losses,
            'min_val_loss': min_val_loss,
            'epoch_psnrs': epoch_psnrs,
            }, '/content/drive/My Drive/Colab Notebooks/logs/model.state')

    print(f'Epoch {epoch} train loss = {epoch_loss/num_steps}, val loss = {epoch_val_loss/num_steps}, PSNR = {epoch_psnr/num_steps}')

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

Training epoch 1 with lr=0.01
Training epoch 1 with lr=0.01


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch 0 train loss = 0.0002517753310379234, val loss = 3.166324432842629e-05, PSNR = 27.17054581129571
Training epoch 2 with lr=0.008709635899560806
Epoch 0 train loss = 0.0002517753310379234, val loss = 3.166324432842629e-05, PSNR = 27.17054581129571
Training epoch 2 with lr=0.008709635899560806
Epoch 1 train loss = 0.00018392477254956124, val loss = 2.3280115861755666e-05, PSNR = 28.713474826659535
Training epoch 3 with lr=0.007585775750291837
Epoch 1 train loss = 0.00018392477254956124, val loss = 2.3280115861755666e-05, PSNR = 28.713474826659535
Training epoch 3 with lr=0.007585775750291837
Epoch 2 train loss = 0.00018092194067613734, val loss = 2.6335660105644107e-05, PSNR = 28.047256201338495
Training epoch 4 with lr=0.006606934480075959
Epoch 2 train loss = 0.00018092194067613734, val loss = 2.6335660105644107e-05, PSNR = 28.047256201338495
Training epoch 4 with lr=0.006606934480075959
Epoch 3 train loss = 0.00017926065454266626, val loss = 2.1380835097174875e-05, PSNR = 29.0812