## TO-DO 
Для начала поработаем с маленьким датасетом. Пусть будет 20 разных фотографий -> 20 * 2 * 15 = 600 фотографий суммарно на трейн.
Каждая фотография предобрабатывается заранее, для увеличения датасета и препятствия к overfit, добавлю random crop.
1. попробую втупую взять первые train_size изображений из raw_images

In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os

from colour_demosaicing import (
    EXAMPLES_RESOURCES_DIRECTORY,
    demosaicing_CFA_Bayer_bilinear,
    demosaicing_CFA_Bayer_Malvar2004,
    demosaicing_CFA_Bayer_Menon2007,
    mosaicing_CFA_Bayer,
    masks_CFA_Bayer
)

In [8]:
class DataPreparation:
    def __init__(self, max_sigma, train_size, val_size, in_dir='raw_images/', data_dir='data/'):
        
        self.max_sigma = max_sigma
        self.train_size = train_size
        self.in_dir = in_dir
        self.data_dir = data_dir
        self.val_size = val_size
        
        #self.make_data(in_dir, "{}train/".format(data_dir), self.train_size)
        self.make_val_data(in_dir, "{}val/".format(data_dir), self.val_size)
        print("Data has been prepared, check ", data_dir)
        #print("Number of data: ", 2 * data_size * max_sigma)
    
    
    def make_data(self, in_dir, out_dir, size):
        print("============================ TRAIN ==========================")
        sigmas = np.linspace(1, self.max_sigma, self.max_sigma)
        for mode in ['awgn', 'bayer']:
            save_path = f"{out_dir}{mode}/"
            counter = 0
            for image_name in list(os.listdir(in_dir)[:size]):
                if image_name[0] == '.':
                    continue
                    
                for sigma in sigmas:
                    image_path = f'{in_dir}{image_name}'
                    new_image_path = f'{save_path}{sigma}_{image_name}'
                    image = cv2.imread(image_path, 0)
                    if mode == 'awgn':
                        image = cv2.imread(image_path, 0)
                        new_image = self.get_awgn_image(image, sigma)
                    else:
                        image = cv2.imread(image_path)
                        new_image = self.mosaic_awgn_demosaic(image, sigma)
                    self.save_image(new_image_path, new_image)
                    print(f"{mode}: {counter} / {size * len(sigmas)}")
                    counter += 1
    
    
    def make_val_data(self, in_dir, out_dir, size):
        print("============================ VALIDATION ==========================")
        sigmas = np.linspace(1, self.max_sigma, self.max_sigma)
        for mode in ['awgn', 'bayer']:
            save_path = f"{out_dir}{mode}/"
            counter = 0
            val_images = list(os.listdir(in_dir)[self.train_size:self.train_size+self.val_size])
            for image_name in val_images:
                if image_name[0] == '.':
                    continue
                    
                for sigma in sigmas:
                    image_path = f'{in_dir}{image_name}'
                    new_image_path = f'{save_path}{sigma}_{image_name}'
                    image = cv2.imread(image_path, 0)
                    if mode == 'awgn':
                        image = cv2.imread(image_path, 0)
                        new_image = self.get_awgn_image(image, sigma)
                    else:
                        image = cv2.imread(image_path)
                        new_image = self.mosaic_awgn_demosaic(image, sigma)
                    self.save_image(new_image_path, new_image)
                    print(f"{mode}: {counter} / {size * len(sigmas)}")
                    counter += 1
            
    def get_rgb_masks(self, shape):
        # GRBG
        g = np.zeros(shape)
        g[::2,::2] = 1
        g[1::2, 1::2] = 1
    
        b = np.zeros(shape)
        b[1::2,::2] = 1
        
        r = np.zeros(shape)
        r[::2,1::2] = 1
    
        return r, g, b
    
    def mosaic(self, image):
        h, w = image.shape[0], image.shape[1]
        r_mask, g_mask, b_mask = self.get_rgb_masks((h,w))
        # BGR
        blue, green, red = image[:,:,0], image[:,:,1], image[:,:,2]
        # RGB
        #red, green, blue = image[:,:,0], image[:,:,1], image[:,:,2]
        return blue * b_mask + green * g_mask + red * r_mask

    def mosaic_awgn_demosaic(self, image, sigma):
        mosaic_im = self.mosaic(image)
        noisy_mosaic_im = self.get_awgn_image(mosaic_im, sigma)
        demosaic_noisy_im = demosaicing_CFA_Bayer_Menon2007(noisy_mosaic_im, 'GRBG')
        # rgb to bgr
        bgr_im = demosaic_noisy_im[:,:, [2, 1, 0]]
        #rgb_img = demosaic_noisy_im
        # convert to grayscale
        #gray = cv2.cvtColor(bgr_im, cv2.COLOR_BGR2GRAY)
        gray = np.mean(bgr_im, axis=2)
        return gray
    
    def save_image(self, path, image):
        resized_im = cv2.resize(image, (256, 256))
        return cv2.imwrite(path, resized_im)
    
    def get_awgn_image(self, image, scale, loc=0.0):
        noise3d = np.random.normal(loc=loc, scale=scale, size=image.shape)
        noisy_image = np.uint8(np.clip(image + noise3d, 0, 255))
        return noisy_image

In [9]:
%%time
dataprep = DataPreparation(max_sigma=15, train_size=30, val_size=10)

awgn: 0 / 150
awgn: 1 / 150
awgn: 2 / 150
awgn: 3 / 150
awgn: 4 / 150
awgn: 5 / 150
awgn: 6 / 150
awgn: 7 / 150
awgn: 8 / 150
awgn: 9 / 150
awgn: 10 / 150
awgn: 11 / 150
awgn: 12 / 150
awgn: 13 / 150
awgn: 14 / 150
awgn: 15 / 150
awgn: 16 / 150
awgn: 17 / 150
awgn: 18 / 150
awgn: 19 / 150
awgn: 20 / 150
awgn: 21 / 150
awgn: 22 / 150
awgn: 23 / 150
awgn: 24 / 150
awgn: 25 / 150
awgn: 26 / 150
awgn: 27 / 150
awgn: 28 / 150
awgn: 29 / 150
awgn: 30 / 150
awgn: 31 / 150
awgn: 32 / 150
awgn: 33 / 150
awgn: 34 / 150
awgn: 35 / 150
awgn: 36 / 150
awgn: 37 / 150
awgn: 38 / 150
awgn: 39 / 150
awgn: 40 / 150
awgn: 41 / 150
awgn: 42 / 150
awgn: 43 / 150
awgn: 44 / 150
awgn: 45 / 150
awgn: 46 / 150
awgn: 47 / 150
awgn: 48 / 150
awgn: 49 / 150
awgn: 50 / 150
awgn: 51 / 150
awgn: 52 / 150
awgn: 53 / 150
awgn: 54 / 150
awgn: 55 / 150
awgn: 56 / 150
awgn: 57 / 150
awgn: 58 / 150
awgn: 59 / 150
awgn: 60 / 150
awgn: 61 / 150
awgn: 62 / 150
awgn: 63 / 150
awgn: 64 / 150
awgn: 65 / 150
awgn: 66 / 150
awgn: