In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
import os
import glob
from skimage import io, color, img_as_float

from image_py_scripts import run_image_inpainting

import torch.utils.data as data
from pathlib import Path

In [2]:
ROOT = Path('C:/Users/Talha/OneDrive - Higher Education Commission/Documents/GitHub/convmc-net/Image_Inpainting_Data/BSDS300/images')
train_dir = ROOT / 'train'
test_dir = ROOT / 'test'

ground_truth_train_dir = train_dir / 'groundtruth'
lowrank_train_dir = train_dir / 'lowrank'

ground_truth_test_dir = test_dir / 'groundtruth'
lowrank_test_dir = test_dir / 'lowrank'

In [3]:
def make_imgs(split = 'train', shape = (150, 300), sampling_rate = 0.2, dB = 5.0):
    if split == 'train':
        jpg_files = list(train_dir.glob('*.jpg'))
        for idx, img in enumerate(jpg_files):
            image = io.imread(img)
            image = color.rgb2gray(img_as_float(image))
            image = np.resize(image, shape)

            np.save(os.path.join(ground_truth_train_dir, f'ground_image_MC_train_{idx}.npy'), image)
            
            image_lowrank = run_image_inpainting.add_gmm_noise(image = image, per = sampling_rate, dB = dB)
            np.save(os.path.join(lowrank_train_dir, f'lowrank_image_MC_train_{idx}.npy'), image_lowrank)
    
    else:
        jpg_files = list(test_dir.glob('*.jpg'))
        for idx, img in enumerate(jpg_files):
            image = io.imread(img)
            image = color.rgb2gray(img_as_float(image))
            image = np.resize(image, shape)

            np.save(os.path.join(ground_truth_test_dir, f'ground_image_MC_test_{idx}.npy'), image)
            
            image_lowrank = run_image_inpainting.add_gmm_noise(image = image, per = sampling_rate, dB = dB)
            np.save(os.path.join(lowrank_test_dir, f'lowrank_image_MC_test_{idx}.npy'), image_lowrank)

"""
Example Usage: make_imgs(split = 'train', shape = (150, 300), sampling_rate = 0.2, dB = 5.0)
"""

In [19]:
np.load(os.path.join(train_dir, f'groundtruth/ground_image_MC_train_' + str(0) + '.npy'))

array([[0.44320627, 0.46055922, 0.46194314, ..., 0.37857922, 0.40463137,
        0.40379804],
       [0.43037804, 0.4272898 , 0.39642235, ..., 0.61461569, 0.56952902,
        0.56168588],
       [0.59108627, 0.67061882, 0.62327725, ..., 0.61532118, 0.64500431,
        0.66770039],
       ...,
       [0.33893294, 0.34285451, 0.34901608, ..., 0.33698353, 0.33249647,
        0.33363529],
       [0.34483451, 0.36052078, 0.37285098, ..., 0.35339098, 0.34946941,
        0.34664902],
       [0.34969961, 0.34549529, 0.33848549, ..., 0.37052078, 0.35875608,
        0.35483451]])

In [20]:
class ImageDataset(data.Dataset):
    def __init__(self, shape, split, path, transform = None):
        self.shape = shape
        
        # TRAIN
        if split == 0:
            # dummy image loader
            images_L = torch.zeros(tuple([200]) + self.shape) # --> shape: (200, shape)
            images_D = torch.zeros(tuple([200]) + self.shape) # --> shape: (200, shape)
            for n in range(200):
                L = np.load(os.path.join(path, f'lowrank/lowrank_image_MC_train_' + str(n) + '.npy'))
                D = np.load(os.path.join(path, f'groundtruth/ground_image_MC_train_' + str(n) + '.npy'))
                # L, D = preprocess(L, D, None, None, None)

                images_L[n] = torch.from_numpy(L)
                images_D[n] = torch.from_numpy(D)

         # TEST
        if split == 1:
            images_L = torch.zeros(tuple([100]) + self.shape) # --> shape: (200, shape)
            images_D = torch.zeros(tuple([100]) + self.shape) # --> shape: (200, shape)
            for n in range(100):
                L = np.load(os.path.join(path, f'lowrank/lowrank_image_MC_test_' + str(n) + '.npy'))
                D = np.load(os.path.join(path, f'groundtruth/ground_image_MC_test_' + str(n) + '.npy'))
                # L, D = preprocess(L, D, None, None, None)

                images_L[n] = torch.from_numpy(L)
                images_D[n] = torch.from_numpy(D)


        self.transform = transform
        self.images_L = images_L
        self.images_D = images_D

    def __getitem__(self, index):
        L = self.images_L[index]
        D = self.images_D[index]
        return L, D

    def __len__(self):
        return len(self.images_L)

In [21]:
# Create DataLoaders
train_dataset = ImageDataset((params_net['size1'], params_net['size2']), 0, ROOT)
train_loader = data.DataLoader(train_dataset, batch_size = 5, shuffle = True)
test_dataset = ImageDataset((params_net['size1'], params_net['size2']), 1, ROOT)
test_loader = data.DataLoader(test_dataset, batch_size = 2)
