In [45]:
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, Sampler
from skimage.util import view_as_windows
from utils import load_image
from transforms import ToTensor
import cv2
import yaml
from os.path import join
from transforms_test import RandomHorizontalFlip, RandomVerticalFlip, RandomRot90
from torchvision.transforms import transforms
import matplotlib.pyplot as plt 

# ham con 

In [46]:
def create_patches(image, patch_size, step):
    image = view_as_windows(image, patch_size, step)
    h, w = image.shape[:2]
    image = np.reshape(image, (h * w, patch_size[0], patch_size[1], patch_size[2]))

    return image


In [78]:
class AdditiveWhiteGaussianNoise(object):
    """Additive white gaussian noise generator."""
    def __init__(self, noise_level, fix_sigma=False, clip=False):
        self.noise_level = noise_level
        self.fix_sigma = fix_sigma
        self.rand = np.random.RandomState(1)
        self.clip = clip
        if not fix_sigma:
            self.predefined_noise = [i for i in range(5, noise_level + 1, 5)]

    def __call__(self, sample):
        """
        Generates additive white gaussian noise, and it is applied to the clean image.
        :param sample:
        :return:
        """
        image = sample.get('image')

        if image.ndim == 4:                 # if 'image' is a batch of images, we set a different noise level per image
            samples = image.shape[0]        # (Samples, Height, Width, Channels) or (Samples, Channels, Height, Width)
            if self.fix_sigma:
                sigma = self.noise_level * np.ones((samples, 1, 1, 1))
            else:
                sigma = np.random.choice(self.predefined_noise, size=(samples, 1, 1, 1))
            noise = self.rand.normal(0., 1., size=image.shape)
            noise = noise * sigma
        else:                               # else, 'image' is a simple image
            if self.fix_sigma:              # (Height, Width, Channels) or (Channels , Height, Width)
                sigma = self.noise_level
            else:
                sigma = self.rand.randint(5, self.noise_level)
            noise = self.rand.normal(0., sigma, size=image.shape)

        noisy = image + noise
        
        if self.clip:
            noisy = np.clip(noisy, 0., 255.)
        for i in range(len(noisy)):
            plt.imshow(noisy[i].astype('float32'))
            plt.show()
            
        return {'image': image, 'noisy': noisy.astype('float32')}

In [79]:
class NoisyImagesDataset(Dataset):
    def __init__(self, files, channels, patch_size, transform=None, noise_transform=None):
        self.channels = channels
        self.patch_size = patch_size
        self.transform = transform
        self.noise_transforms = noise_transform
        self.to_tensor = ToTensor()
        self.dataset = {'image': [], 'noisy': []}
        self.load_dataset(files)

    def __len__(self):
        return len(self.dataset['image'])

    def __getitem__(self, idx):
        image, noisy = self.dataset.get('image')[idx], self.dataset.get('noisy')[idx]
        sample = {'image': image, 'noisy': noisy}
        if self.transform is not None:
            sample = self.transform(sample)
        sample = self.to_tensor(sample)

        return sample.get('noisy'), sample.get('image')

    def load_dataset(self, files):
        patch_size = (self.patch_size, self.patch_size, self.channels)
        for file in tqdm(files):
            image = load_image(file, self.channels)
            if image is None:
                continue

            image = create_patches(image, patch_size, step=self.patch_size)
            sample = {'image': image, 'noisy': None}

            for noise_transform in self.noise_transforms:
                _sample = noise_transform(sample)
                image, noisy = _sample['image'], _sample['noisy']

                image, noisy = list(image), list(noisy)


                self.dataset['image'].extend(image)
                self.dataset['noisy'].extend(noisy)


# main

In [None]:
with open('config.yaml', 'r') as stream:                # Load YAML configuration file.
        config = yaml.safe_load(stream)

model_params = config['model']
train_params = config['train']
val_params = config['val']
with open('train_test.txt', 'r') as f_train, open('val_files.txt', 'r') as f_val:
    raw_train_files = f_train.read().splitlines()
    raw_val_files = f_val.read().splitlines()
    train_files = list(map(lambda file: join(train_params['dataset path'], file), raw_train_files))
    val_files = list(map(lambda file: join(val_params['dataset path'], file), raw_val_files))
training_transforms = transforms.Compose([
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        RandomRot90()
    ])
train_noise_transform = [AdditiveWhiteGaussianNoise(train_params['noise level'], clip=True)]
training_dataset = NoisyImagesDataset(train_files,
                                          model_params['channels'],
                                          train_params['patch size'],
                                          training_transforms,
                                          train_noise_transform)

In [None]:
# from torchvision.transforms import ToPILImage
# import numpy as np
# import matplotlib.pyplot as plt 
# # Assuming you have the tensor as `image_tensor`
# for i in range(117):
#     img_test = training_dataset[i][1]*255
#     image_np = img_test.detach().numpy()
#     image_np = np.transpose(image_np, (1, 2, 0)) # to convert from (3, 64, 64) to (64, 64, 3)
#     image_np_uint8 = (255*image_np).astype(np.uint8)

#     # image_pil = ToPILImage()(image_np_uint8)
#     # image_pil.show()
#     print(type(image_np_uint8))
#     plt.imshow(image_np_uint8)
#     plt.show()




# test lightweight model 

In [3]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torch.optim import Adam
from torchsummary import summary

In [7]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x

class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs

class Decoder(nn.Module):
    def __init__(self, chs=(256, 128, 64)):
        super().__init__()
        self.chs        = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = encoder_features[i]
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x


class UNet(nn.Module):
    '''
    Simple UNet-like model 
    Input: RGB image. For S7 (504, 504, 3), and for P20 (496, 496, 3) 
    
    Outputs: RAW image as 4-channels (H // 2, W // 2, 4) following RGGB pattern.
             For the S7 output should be (252, 252, 4), and for HP20  (248, 248, 4)
    '''
    def __init__(self, enc_chs=(3,64,128,256), dec_chs=(256, 128, 64), out_ch=3, out_sz=(504, 504)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], out_ch, 1)
        self.out_sz      = out_sz

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out = self.head(out)
        out = F.interpolate(out, self.out_sz)
        out = torch.clamp(out, min=0., max=1.)
        return out

In [9]:
xx = torch.zeros((1, 3, 504, 504))
model = UNet()
y = model(xx)
# print(y[0].detach().numpy()[0])
summary(model,input_size=(3, 504, 504))
# print(y.shape)

Layer (type:depth-idx)                   Param #
├─Encoder: 1-1                           --
|    └─ModuleList: 2-1                   --
|    |    └─Block: 3-1                   38,592
|    |    └─Block: 3-2                   221,184
|    |    └─Block: 3-3                   884,736
|    └─MaxPool2d: 2-2                    --
├─Decoder: 1-2                           --
|    └─ModuleList: 2-3                   --
|    |    └─ConvTranspose2d: 3-4         131,200
|    |    └─ConvTranspose2d: 3-5         32,832
|    └─ModuleList: 2-4                   --
|    |    └─Block: 3-6                   442,368
|    |    └─Block: 3-7                   110,592
├─Conv2d: 1-3                            195
Total params: 1,861,699
Trainable params: 1,861,699
Non-trainable params: 0


Layer (type:depth-idx)                   Param #
├─Encoder: 1-1                           --
|    └─ModuleList: 2-1                   --
|    |    └─Block: 3-1                   38,592
|    |    └─Block: 3-2                   221,184
|    |    └─Block: 3-3                   884,736
|    └─MaxPool2d: 2-2                    --
├─Decoder: 1-2                           --
|    └─ModuleList: 2-3                   --
|    |    └─ConvTranspose2d: 3-4         131,200
|    |    └─ConvTranspose2d: 3-5         32,832
|    └─ModuleList: 2-4                   --
|    |    └─Block: 3-6                   442,368
|    |    └─Block: 3-7                   110,592
├─Conv2d: 1-3                            195
Total params: 1,861,699
Trainable params: 1,861,699
Non-trainable params: 0