## TO-DO 
Для начала поработаем с маленьким датасетом. Пусть будет 20 разных фотографий -> 20 * 2 * 15 = 600 фотографий суммарно на трейн.
Каждая фотография предобрабатывается заранее, для увеличения датасета и препятствия к overfit, добавлю random crop.
1. попробую втупую взять первые train_size изображений из raw_images

In [25]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os

from colour_demosaicing import (
    EXAMPLES_RESOURCES_DIRECTORY,
    demosaicing_CFA_Bayer_bilinear,
    demosaicing_CFA_Bayer_Malvar2004,
    demosaicing_CFA_Bayer_Menon2007,
    mosaicing_CFA_Bayer,
    masks_CFA_Bayer
)

In [26]:
class DataPreparation:
    def __init__(self, max_sigma, data_size, in_dir='raw_images/', data_dir='data/'):
        
        self.max_sigma = max_sigma
        self.data_size = data_size
        self.in_dir = in_dir
        self.data_dir = data_dir

        self.make_data(in_dir, data_dir, data_size)
        print("Data has been prepared, check ", data_dir)
        print("Number of data: ", 2 * data_size * max_sigma)
    
    
    def make_data(self, in_dir, out_dir, size):
        sigmas = np.linspace(1, self.max_sigma, self.max_sigma)
        for mode in ['awgn', 'bayer']:
            save_path = f"{out_dir}{mode}/"
            
            for i, image_name in enumerate(list(os.listdir(in_dir))[:size]):
                if image_name[0] == '.':
                    continue
                    
                for sigma in sigmas:
                    image_path = f'{in_dir}{image_name}'
                    new_image_path = f'{save_path}{sigma}_{image_name}'
                    image = cv2.imread(image_path, 0)
                    if mode == 'awgn':
                        image = cv2.imread(image_path, 0)
                        new_image = self.get_awgn_image(image, sigma)
                    else:
                        image = cv2.imread(image_path)
                        new_image = self.mosaic_awgn_demosaic(image, sigma)
                    self.save_image(new_image_path, new_image)
                print(f"{mode}: {i} / {size}")

            
    def get_rgb_masks(self, shape):
        # GRBG
        g = np.zeros(shape)
        g[::2,::2] = 1
        g[1::2, 1::2] = 1
    
        b = np.zeros(shape)
        b[1::2,::2] = 1
        
        r = np.zeros(shape)
        r[::2,1::2] = 1
    
        return r, g, b
    

    def mosaic(self, image):
        h, w = image.shape[0], image.shape[1]
        r_mask, g_mask, b_mask = get_rgb_masks((h,w))
        # BGR
        blue, green, red = image[:,:,0], image[:,:,1], image[:,:,2]
        # RGB
        #red, green, blue = image[:,:,0], image[:,:,1], image[:,:,2]
        return blue * b_mask + green * g_mask + red * r_mask


    def mosaic_awgn_demosaic(self, image, sigma):
        mosaic_im = self.mosaic(image)
        noisy_mosaic_im = self.get_awgn_image(mosaic_im, sigma)
        demosaic_noisy_im = demosaicing_CFA_Bayer_Menon2007(noisy_mosaic_im, 'GRBG')
        # rgb to bgr
        bgr_im = demosaic_noisy_im[:,:, [2, 1, 0]]
        #rgb_img = demosaic_noisy_im
        # convert to grayscale
        #gray = cv2.cvtColor(bgr_im, cv2.COLOR_BGR2GRAY)
        gray = np.mean(bgr_im, axis=2)
        return gray
    
    def save_image(self, path, image):
        resized_im = cv2.resize(image, (512, 512))
        return cv2.imwrite(path, resized_im)
    
    def get_awgn_image(self, image, scale, loc=0.0):
        noise3d = np.random.normal(loc=loc, scale=scale, size=image.shape)
        noisy_image = np.uint8(np.clip(image + noise3d, 0, 255))
        return noisy_image

In [27]:
dataprep = DataPreparation(max_sigma=15, data_size=30)

awgn: 0 / 30
awgn: 1 / 30
awgn: 2 / 30
awgn: 3 / 30
awgn: 4 / 30
awgn: 5 / 30
awgn: 6 / 30
awgn: 7 / 30
awgn: 8 / 30
awgn: 9 / 30
awgn: 10 / 30
awgn: 11 / 30
awgn: 12 / 30
awgn: 13 / 30
awgn: 14 / 30
awgn: 15 / 30
awgn: 16 / 30
awgn: 17 / 30
awgn: 18 / 30
awgn: 19 / 30
awgn: 20 / 30
awgn: 21 / 30
awgn: 22 / 30
awgn: 23 / 30
awgn: 24 / 30
awgn: 25 / 30
awgn: 26 / 30
awgn: 27 / 30
awgn: 28 / 30
awgn: 29 / 30
bayer: 0 / 30
bayer: 1 / 30
bayer: 2 / 30
bayer: 3 / 30
bayer: 4 / 30
bayer: 5 / 30
bayer: 6 / 30
bayer: 7 / 30
bayer: 8 / 30
bayer: 9 / 30
bayer: 10 / 30


KeyboardInterrupt: 