In [1]:
!pip install torch torchvision



In [4]:
import os, imageio, numpy as np
import matplotlib.pyplot as plt
from google.colab import drive
%matplotlib inline

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as tvF
from torch.utils.data import Dataset, DataLoader

In [36]:
# Mount GDrive with dataset
drive.mount("/content/drive", force_remount=True)

# Path to BSDS500 image dataset
root_train = '/content/drive/My Drive/Colab Notebooks/Data/BSDS500/train/'
root_valid = '/content/drive/My Drive/Colab Notebooks/Data/BSDS500/val/'

# Explore images
print('Train images: ', len(os.listdir(root_train)))
print('Valid images: ', len(os.listdir(root_valid)))

Mounted at /content/drive
Train images:  400
Valid images:  100


In [71]:
class BSDDataset(Dataset):
    def __init__(self,
                 root_dir,
                 crop_size=128,
                 train_noise_model=('gaussian', 50),
                 img_bitdepth=8):
        """
        root_dir: Path of image directory
        crop_size: Crop image to given size
        clean_targ: Use clean targets for training
        """
        self.root_dir = root_dir
        self.crop_size = crop_size
        self.img_bitdepth = img_bitdepth
        self.noise = train_noise_model[0]
        self.noise_param = train_noise_model[1] / (2**self.img_bitdepth - 1)
        self.imgs = os.listdir(root_dir)
      
    def __len__(self):
        return len(self.imgs)
    
    def _random_crop_to_size(self, img):
        
        h, w, c = img.shape
        assert w >= self.crop_size and h >= self.crop_size, 'Cannot be croppped. Invalid size'

        i = np.random.randint(0, h - self.crop_size + 2)
        j = np.random.randint(0, w - self.crop_size + 2)

        cropped_img = img[i:i+self.crop_size, j:j+self.crop_size, :]
        return cropped_img
    
    def _add_gaussian_noise(self, image):
        """
        Added only gaussian noise
        """
        h, w, c = image.shape
        
        std = np.random.uniform(0, self.noise_param)
        _n = np.random.normal(0, std, (h, w, c))
        noisy_image = image + _n
        
        noisy_image = np.clip(noisy_image, 0, 1)
        return noisy_image

    def corrupt_image(self, image):
        if self.noise == 'gaussian':
            return self._add_gaussian_noise(image)
        else:
            raise ValueError('No such image corruption supported')

    def __getitem__(self, index):
        """
        Read a image, corrupt it and return it
        """
        img_path = os.path.join(self.root_dir, self.imgs[index])
        image = imageio.imread(img_path)  / (2**self.img_bitdepth - 1)

        if self.crop_size > 0:
            image_clean = self._random_crop_to_size(image)

        image_noisy = self.corrupt_image(image_clean)
        
        # Conver to tensor
        image_clean = torch.from_numpy(np.array(image_clean))
        image_noisy = torch.from_numpy(np.array(image_noisy))

        return image_noisy, image_clean

In [72]:
# Declare training / testing datsets
dataset_train = BSDDataset(root_train, crop_size=128, train_noise_model=('gaussian', 50), img_bitdepth=8)
dataset_valid = BSDDataset(root_valid, crop_size=128, train_noise_model=('gaussian', 50), img_bitdepth=8)

# Declare training / testing data loaders
dloader_train = DataLoader(dataset_train, batch_size=1, shuffle=True)
dloader_valid = DataLoader(dataset_valid, batch_size=1, shuffle=True)