In [1]:
!pip install torch torchvision



In [2]:
from math import log10
from datetime import datetime
from skimage.transform import resize
import os, json, imageio, numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

from google.colab import drive
%matplotlib inline

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, lr_scheduler
import torchvision.transforms as transforms
import torchvision.transforms.functional as tvF
from torch.utils.data import Dataset, DataLoader

In [4]:
class BSDDataset(Dataset):
    def __init__(self,
                 root_dir,
                 crop_size=64,
                 noise_model='gaussian',
                 noise_sigma=0.2,
                 img_bitdepth=8,
                 seed=None):
        self.seed = seed
        self.root_dir = root_dir
        self.crop_size = crop_size
        self.img_bitdepth = img_bitdepth
        self.noise_model = noise_model
        self.noise_sigma = noise_sigma
        self.imgs = os.listdir(root_dir)

        if self.seed:
            np.random.seed(self.seed)
      
    def __len__(self):
        return len(self.imgs)
    
    def _random_crop_to_size(self, img):
        h, w, c = img.shape

        if min(w, h) < self.crop_size:
          img = resize(img, (self.crop_size, self.crop_size))

        i = np.random.randint(0, h - self.crop_size)
        j = np.random.randint(0, w - self.crop_size)

        cropped_img = img[i:i+self.crop_size, j:j+self.crop_size, :]
        return cropped_img
    
    def _add_gaussian_noise(self, image):
        noisy_image = image + np.random.normal(0, self.noise_sigma, image.shape)
        return np.clip(noisy_image, 0, 1)

    def corrupt_image(self, image):
        if self.noise_model == 'gaussian':
            return self._add_gaussian_noise(image)
        else:
            raise ValueError('No such noise model.')

    def __getitem__(self, index):
        # Load image
        img_path = os.path.join(self.root_dir, self.imgs[index])
        image = imageio.imread(img_path) / (2**self.img_bitdepth - 1)

        # Crop source image
        if self.crop_size > 0:
            image = self._random_crop_to_size(image)
        
        # Generate noisy images
        image_noisy = self.corrupt_image(image)
        image_target = self.corrupt_image(image)

        # Transpose channels
        image_target = np.array(image_target).transpose((2,0,1)) 
        image_noisy = np.array(image_noisy).transpose((2,0,1))

        # Conver to tensor
        image_target = torch.from_numpy(image_target).type(torch.DoubleTensor)
        image_noisy = torch.from_numpy(image_noisy).type(torch.DoubleTensor)

        return image_noisy, image_target

In [17]:
def conv_block(in_channels, out_channels):
    return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),
                         nn.ReLU(inplace=True),
                         nn.Conv2d(out_channels, out_channels, 3, padding=1),
                         nn.ReLU(inplace=True))   

class Denoiser(nn.Module):
    def __init__(self):
        super().__init__()
                
        self.encode_1 = conv_block(3, 64)
        self.encode_2 = conv_block(64, 128)
        self.encode_3 = conv_block(128, 256)
        self.encode_4 = conv_block(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)        
        
        self.decode_3 = conv_block(256 + 512, 256)
        self.decode_2 = conv_block(128 + 256, 128)
        self.decode_1 = conv_block(128 + 64, 64)
        self.conv_last = nn.Conv2d(64, 3, 1)
        
    def forward(self, input_image):
        conv1 = self.encode_1(input_image)
        x = self.maxpool(conv1)

        conv2 = self.encode_2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.encode_3(x)
        x = self.maxpool(conv3)   
        
        x = self.encode_4(x)
        
        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)
        
        x = self.decode_3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)       

        x = self.decode_2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)   
        
        x = self.decode_1(x)
        
        noise = self.conv_last(x)
        
        return input_image + noise

***
# Utility functions

In [19]:
def clear_line():
    """Clears line from any characters."""
    print('\r{}'.format(' ' * 80), end='\r')


def progress_bar(batch_idx, num_batches, report_interval, train_loss):
    """Neat progress bar to track training."""

    dec = int(np.ceil(np.log10(num_batches)))
    bar_size = 21 + dec
    progress = (batch_idx % report_interval) / report_interval
    fill = int(progress * bar_size) + 1
    print('\rBatch {:>{dec}d} [{}{}] Train loss: {:>1.5f}'.format(batch_idx + 1, '=' * fill + '>', ' ' * (bar_size - fill), train_loss, dec=str(dec)), end='')


def time_elapsed_since(start):
    """Computes elapsed time since start."""

    timedelta = datetime.now() - start
    string = str(timedelta)[:-7]
    ms = int(timedelta.total_seconds() * 1000)

    return string, ms


def show_on_epoch_end(epoch_time, valid_time, valid_loss, valid_psnr):
    """Formats validation error stats."""

    clear_line()
    print('Train time: {} | Valid time: {} | Valid loss: {:>1.5f} | Avg PSNR: {:.2f} dB'.format(epoch_time, valid_time, valid_loss, valid_psnr))


def show_on_report(batch_idx, num_batches, loss, elapsed):
    """Formats training stats."""

    clear_line()
    dec = int(np.ceil(np.log10(num_batches)))
    print('Batch {:>{dec}d} / {:d} | Avg loss: {:>1.5f} | Avg train time / batch: {:d} ms'.format(batch_idx + 1, num_batches, loss, int(elapsed), dec=dec))


def plot_per_epoch(ckpt_dir, title, measurements, y_label):
    """Plots stats (train/valid loss, avg PSNR, etc.)."""

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(range(1, len(measurements) + 1), measurements)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.set_xlabel('Epoch')
    ax.set_ylabel(y_label)
    ax.set_title(title)
    plt.tight_layout()

    fname = '{}.png'.format(title.replace(' ', '-').lower())
    plot_fname = os.path.join(ckpt_dir, fname)
    plt.savefig(plot_fname, dpi=200)
    plt.close()

def psnr(input, target):
    """Computes peak signal-to-noise ratio."""
    
    return 10 * torch.log10(1 / F.mse_loss(input, target))

In [20]:
class AvgMeter(object):
    """Computes and stores the average and current value.
    Useful for tracking averages such as elapsed times, minibatch losses, etc.
    """

    def __init__(self):
        self.reset()


    def reset(self):
        self.val = 0
        self.avg = 0.
        self.sum = 0
        self.count = 0


    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

***
# Train the model

In [21]:
# Mount GDrive with dataset
drive.mount("/content/drive", force_remount=True)

# Path to BSDS500 image dataset
root_train = '/content/drive/My Drive/Colab Notebooks/Data/BSDS500/train/'
root_valid = '/content/drive/My Drive/Colab Notebooks/Data/BSDS500/val/'

# Explore images
print('Train images: ', len(os.listdir(root_train)))
print('Valid images: ', len(os.listdir(root_valid)))

Mounted at /content/drive
Train images:  400
Valid images:  100


In [22]:
class Params:
  def __init__(self):
    self.train_dir = '/content/drive/My Drive/Colab Notebooks/Data/BSDS500/train/'
    self.valid_dir = '/content/drive/My Drive/Colab Notebooks/Data/BSDS500/val/'
    self.ckpt_save_path = '/content/drive/My Drive/Colab Notebooks'
    self.nb_epochs = 10
    self.batch_size = 4
    self.learning_rate = 0.001
    self.loss = 'l2'
    self.noise_model = 'gaussian'
    self.noise_sigma = 50
    self.crop_size = 64
    self.report_interval = 4
    self.plot_stats = True
    self.seed = 57
    self.image_bitdepth = 8

In [23]:
def get_loaders(params):
    # Declare training / testing datsets
    dataset_train = BSDDataset(params.train_dir,
                               crop_size=params.crop_size,
                               noise_model=params.noise_model,
                               noise_sigma=params.noise_sigma)

    dataset_valid = BSDDataset(params.valid_dir,
                               crop_size=params.crop_size,
                               noise_model=params.noise_model,
                               noise_sigma=params.noise_sigma)

    # Declare training / testing data loaders
    train_loader = DataLoader(dataset_train, batch_size=params.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset_valid, batch_size=params.batch_size, shuffle=True)
    return train_loader, valid_loader

In [35]:
def train_model(model, criterion, optim, train_loader, valid_loader, params):
    model.train()
    num_batches = len(train_loader)

    # Dictionaries of tracked stats
    stats = {'train_loss': [],
             'valid_loss': [],
             'valid_psnr': []}

    # Main training loop
    for epoch in range(params.nb_epochs):
        print('Epoch {:d} / {:d}'.format(epoch + 1, params.nb_epochs))

        # Init stat meters
        loss_meter = AvgMeter()
        time_meter = AvgMeter()
        train_loss_meter = AvgMeter()

        # Train on batches
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            batch_start = datetime.now()
            progress_bar(batch_idx, num_batches, params.report_interval, loss_meter.val)

            # Denoise image
            results = model(inputs)
            loss = criterion(results, targets)
            loss_meter.update(loss.item())

            # Zero gradients, perform a backward pass, and update the weights
            optim.zero_grad()
            loss.backward()
            optim.step()

            # Report/update statistics
            time_meter.update(time_elapsed_since(batch_start)[1])
            if (batch_idx + 1) % params.report_interval == 0 and batch_idx:
                show_on_report(batch_idx, num_batches, loss_meter.avg, time_meter.avg)
                train_loss_meter.update(loss_meter.avg)
                loss_meter.reset()
                time_meter.reset()

        # Epoch end, save and reset tracker
        # self._on_epoch_end(stats, train_loss_meter.avg, epoch, epoch_start, valid_loader)
        train_loss_meter.reset()

        # Validation loop
        model.eval()

        valid_start = datetime.now()
        loss_meter = AvgMeter()
        psnr_meter = AvgMeter()

        for batch_idx, (inputs, targets) in enumerate(valid_loader):
            # Denoise
            results = model(inputs)

            # Update loss
            loss = self.loss(results, targets)
            loss_meter.update(loss.item())

            # Compute PSRN
            images_in_batch = results.shape[0]
            for i in range(images_in_batch):
                results = results.cpu()
                targets = targets.cpu()
                psnr_meter.update(psnr(results[i], targets[i]).item())

        valid_loss = loss_meter.avg
        valid_time = time_elapsed_since(valid_start)[0]
        psnr_avg = psnr_meter.avg
        show_on_epoch_end(epoch_time, valid_time, valid_loss, valid_psnr)


In [38]:
    def test(self, test_loader, show):
        """Evaluates denoiser on test set."""

        self.model.train(False)

        source_imgs = []
        denoised_imgs = []
        clean_imgs = []

        # Create directory for denoised images
        denoised_dir = os.path.dirname(self.p.data)
        save_path = os.path.join(denoised_dir, 'denoised')
        if not os.path.isdir(save_path):
            os.mkdir(save_path)

        for batch_idx, (source, target) in enumerate(test_loader):
            # Only do first <show> images
            if show == 0 or batch_idx >= show:
                break

            source_imgs.append(source)
            clean_imgs.append(target)

            if self.use_cuda:
                source = source.cuda()

            # Denoise
            denoised_img = self.model(source).detach()
            denoised_imgs.append(denoised_img)

        # Squeeze tensors
        source_imgs = [t.squeeze(0) for t in source_imgs]
        denoised_imgs = [t.squeeze(0) for t in denoised_imgs]
        clean_imgs = [t.squeeze(0) for t in clean_imgs]

        # Create montage and save images
        print('Saving images and montages to: {}'.format(save_path))
        for i in range(len(source_imgs)):
            img_name = test_loader.dataset.imgs[i]
            create_montage(img_name, self.p.noise_type, save_path, source_imgs[i], denoised_imgs[i], clean_imgs[i], show)


In [36]:
params = Params()

model = Denoiser().double()
optim = Adam(model.parameters(), lr=params.learning_rate)
criterion = nn.MSELoss()

In [37]:
train_loader, valid_loader = get_loaders(params)

In [34]:
train_model(model, criterion, optim, train_loader, valid_loader, params)

Epoch 1 / 10
Batch  1 [=>                      ] Train loss: 0.00000

KeyboardInterrupt: ignored

In [None]:
params = Params()
n2n = Noise2Noise(params, trainable=True)

Noise2Noise: Learning Image Restoration without Clean Data (Lethinen et al., 2018)


In [None]:
n2n.train(dloader_train, dloader_valid)

Training parameters: 
  Train dir = /content/drive/My Drive/Colab Notebooks/Data/BSDS500/train/
  Valid dir = /content/drive/My Drive/Colab Notebooks/Data/BSDS500/val/
  Ckpt save path = /content/drive/My Drive/Colab Notebooks
  Nb epochs = 10
  Batch size = 4
  Learning rate = 0.001
  Adam = [0.9, 0.99, 1e-08]
  Loss = l2
  Noise type = gaussian
  Noise param = 50
  Crop size = 64
  Report interval = 4
  Plot stats = True
  Seed = 57
  Image bitdepth = 8
  Cuda = False
  Clean targets = False
  Ckpt overwrite = True

EPOCH 1 / 10
Batch  4 / 100 | Avg loss: 0.25787 | Avg train time / batch: 5309 ms
Batch  8 / 100 | Avg loss: 0.08295 | Avg train time / batch: 5124 ms
Batch 12 / 100 | Avg loss: 0.06237 | Avg train time / batch: 5118 ms
Batch 16 / 100 | Avg loss: 0.04230 | Avg train time / batch: 5104 ms
Batch 20 / 100 | Avg loss: 0.03580 | Avg train time / batch: 5075 ms
Batch 24 / 100 | Avg loss: 0.02924 | Avg train time / batch: 5112 ms
Batch 28 / 100 | Avg loss: 0.02631 | Avg train ti