In [165]:
from torch.utils.data.dataset import Dataset
import torch
from glob import glob
from PIL import Image
import torchvision
from tqdm import tqdm
import numpy as np




import json, os, sys
import time

# class LevelSetDatasetESN(Dataset):
#     def __init__(self, ):

#     def __len__(self):

#     def __getitem__(self, idx):
    

class LevelSetDataset(Dataset):
    """
    Dataset object for CNN models
    Temporal is defined implicitly 
    as the number of channels
    example: 
        - X dimension
            [H, W, C=number_of_timestap(t)]
        - Y dimension
            [W, W, C =(t+1)]
    """
    def __init__(self, input_image_path:str,
                target_image_path:str,
                threshold:float=0.5,
                num_input_steps:int=3,
                num_future_steps:int=1,
                image_dimension:int=32,
                data_transformations=None,
                istraining_mode:bool=True
                ):
        
        self.input_image_path    = input_image_path
        self.target_image_path   = target_image_path
        self.threshold           = threshold
        self.num_input_steps     = num_input_steps
        self.num_future_steps    = num_future_steps
        self.image_dimension     = image_dimension
        self.data_transformations= data_transformations
        self.istraining_mode     = istraining_mode
        
        
        # get a list of input filenames as sort them (e.g. 1.png, 2.png,..,N.png)
        input_image_fp = sorted(glob(os.path.join(self.input_image_path , "*")), 
                                    key=lambda x: int(os.path.basename(x).split('.')[0])
                                                     )
        
        
        # repeat the input image untill it matches the number of segmentation
        # step of the target image
        self.input_image_fp = [i for i in input_image_fp for _ in range(100)][:1000]
        
        # get a list of the target filenames and sort them by the first id and second
        # id after the underscore (e.g.  1_1.png, 1_2,..,N_M.png)
        self.target_image_fp= sorted(glob(os.path.join(self.target_image_path , "*")),
                                    key=lambda x: (int(os.path.basename(x).split('_')[0]), 
                                                   int(os.path.basename(x).split('_')[1].split('.')[0]))
                                    )[:1000]
        
        self.input_image_fp  = self.input_image_fp [0+self.num_input_steps-1:len(self.input_image_fp)-self.num_future_steps]
        self.target_image_fp = self.target_image_fp[0+self.num_input_steps-1:len(self.input_image_fp)-self.num_future_steps]
        
        
        # check if in training mode
        # to apply transformations
        if (self.data_transformations is None) and (self.istraining_mode):
            self.data_transformations= torchvision.transforms.Compose([
                                            torchvision.transforms.Resize(size=(self.image_dimension,self.image_dimension), 
                                                                                interpolation=Image.BILINEAR),
                                            torchvision.transforms.RandomHorizontalFlip(p=0.5),
                                            torchvision.transforms.RandomVerticalFlip(p=0.5),
                                            torchvision.transforms.ToTensor()
                                                                      ])
            
        if (self.data_transformations is None) and (not self.istraining_mode):
            self.data_transformations== torchvision.transforms.Compose([
                                            torchvision.transforms.Resize(size=(self.image_dimension,self.image_dimension), 
                                                                                interpolation=Image.BILINEAR),
                                            torchvision.transforms.ToTensor()
                                                                      ])
            
            
            
        self.transforms = self.data_transformations    
        self.mean_image =  self._compute_mean(self.input_image_fp)
        self.stddev_image = self._compute_stddev(self.input_image_fp)

            
    def _create_binary_mask(self, x):
        x[x>=self.threshold] = 1
        x[x <self.threshold] = 0
        return x
    
    def _stat_norm(self, x):
        norm =torchvision.transforms.Compose([torchvision.transforms.Resize(
            size=(self.image_dimension,self.image_dimension), 
                      interpolation=Image.BILINEAR),
                    torchvision.transforms.ToTensor()])
        return norm(x)
    
    def _compute_mean(self,  fp_list):
        mean_image = torch.zeros([1, self.image_dimension, self.image_dimension])
        file_counter = 0
        for fp in tqdm(fp_list):
            mean_image+=self._stat_norm(Image.open(fp).convert('L'))   
            file_counter += 1
        mean_image /= file_counter
        return mean_image
        
    def _compute_stddev(self, fp_list):
        stddev_image = torch.zeros([1, self.image_dimension, self.image_dimension])
        file_counter = 0
        for fp in tqdm(fp_list):
            stddev_image += (self._stat_norm(Image.open(fp).convert('L')) - self.mean_image)**2
            file_counter += 1
        stddev_image /= file_counter
        stddev_image = torch.sqrt(stddev_image)
        return stddev_image
                                     
    def __len__(self):
        return len(self.target_image_fp) - (self.num_input_steps+self.num_future_steps)

    def __getitem__(self, index):
        X          = torch.zeros((self.image_dimension, self.image_dimension, self.num_input_steps+1))
        # place the input as the first channel (t)
        
        for step_idx, step in enumerate(np.arange(index, self.num_input_steps, 1)):
            print(step_idx, step , self.target_image_fp[step+self.num_input_steps+self.num_future_steps-1])
            target_image = Image.open(self.target_image_fp[step+self.num_input_steps+self.num_future_steps-1])
            target_image = self.transforms(target_image)
            target_image = self._create_binary_mask(target_image)
            X[:, :, step_idx] = target_image # (t+1)
           
        input_img  = Image.open(self.input_image_fp[index]).convert('L')
        #         input_img  = self.transforms(input_img) 
        input_img  = (self.transforms(input_img) - self.mean_image )/self.stddev_image 
        X[:, :, 0] = input_img 
#         print('l', self.target_image_fp[index+self.num_input_steps+self.num_future_steps-1])
#         print('='*15)
        target_image = Image.open(self.target_image_fp[index+self.num_input_steps+self.num_future_steps-1])
        target_image = self.transforms(target_image)
        target_image = self._create_binary_mask(target_image)
        image_name   = self.target_image_fp[index+self.num_input_steps+self.num_future_steps-1].split('/')[-1]
    
        Y = target_image 
        return X, Y, image_name
    
    
    
def create_datasets(parameters:dict):
    dataloader={}
    for partition in ['train', 'test', 'val']:
        input_image_path = os.path.join(parameters['input_image_path'],partition)
        target_image_path = os.path.join(parameters['target_image_path'],partition)
        if partition=='train':
            ds = LevelSetDataset(
            input_image_path=input_image_path,
            target_image_path=target_image_path,
            threshold=parameters['threshold'],
            num_input_steps=parameters['num_input_steps'],
            num_future_steps=parameters['num_future_steps'],
            image_dimension=parameters['image_dimension'],
            data_transformations=parameters['data_transformations'],
            istraining_mode=True
            )

            dl = torch.utils.data.DataLoader(
            ds,
            batch_size=parameters['batch_size'],
            shuffle=parameters['shuffle'],
            num_workers=parameters['num_workers'],
            pin_memory=parameters['pin_memory']
            )
        else:
            ds = LevelSetDataset(
            input_image_path=input_image_path,
            target_image_path=target_image_path,
            threshold=parameters['threshold'],
            num_input_steps=parameters['num_input_steps'],
            num_future_steps=parameters['num_future_steps'],
            image_dimension=parameters['image_dimension'],
            data_transformations=parameters['data_transformations'],
            istraining_mode=False
            )

            dl = torch.utils.data.DataLoader(
            ds,
            num_workers=parameters['num_workers'],
            pin_memory=parameters['pin_memory'],
            batch_size=1,
            shuffle=False
            )        
        dataloader[partition]=dl
    return dataloader
        

In [166]:
# dataset used
dataset= 'CIFAR_100' # 'CIFAR_10' , 'BSR', 'WEIZMANN'

# name of the model
model= 'RNN3' # 'RNN', 'LSTM', 'GRU', 'ESN'

# Model instance
batch_size      = 64
image_dimension = 32

# path to orignal images
original_image_directory     = f'../../Data/FINAL_DATA/HETEROGENEOUS/{dataset}/'
# path to segmentation images
segmentation_image_directory = f'../../Data/FINAL_DATA/HETEROGENEOUS/{dataset}/SEGMENTATION_DATA/BINARY_SEGMENTATION/'

# device to perform computation (CPU or GPU)
device             = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# parameters to the dataloader
parameters = {
"input_image_path":original_image_directory,
"target_image_path":segmentation_image_directory,
"image_dimension":image_dimension,
'threshold':0.5,
'num_input_steps':3,
'num_future_steps':1,
'data_transformations':None,
"istraining_mode":True,
"batch_size":batch_size, 
"shuffle":False,
"pin_memory":False,
"num_workers": 0 if sys.platform =='win32' else 4
}


# use GPU if available
parameters["cuda"] = torch.cuda.is_available()
dataloaders        = create_datasets(parameters)




  0%|          | 0/997 [00:00<?, ?it/s][A
100%|██████████| 997/997 [00:00<00:00, 7635.34it/s][A

  0%|          | 0/997 [00:00<?, ?it/s][A
100%|██████████| 997/997 [00:00<00:00, 7158.05it/s][A

  0%|          | 0/997 [00:00<?, ?it/s][A
100%|██████████| 997/997 [00:00<00:00, 7256.33it/s][A

  0%|          | 0/997 [00:00<?, ?it/s][A
100%|██████████| 997/997 [00:00<00:00, 6813.86it/s][A

  0%|          | 0/997 [00:00<?, ?it/s][A
100%|██████████| 997/997 [00:00<00:00, 7135.08it/s][A

  0%|          | 0/997 [00:00<?, ?it/s][A
100%|██████████| 997/997 [00:00<00:00, 6986.96it/s][A


In [162]:
dataloaders['train'].dataset.target_image_fp

['../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_1.jpg',
 '../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_2.jpg',
 '../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_3.jpg',
 '../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_4.jpg',
 '../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_5.jpg',
 '../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_6.jpg',
 '../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_7.jpg',
 '../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg',
 '../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_9.jpg',
 '../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTAT

In [163]:
# dataloaders['train'].dataset.target_image_fp

In [167]:
next(iter(dataloaders['train']));

0 0 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_6.jpg
1 1 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_7.jpg
2 2 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg
0 1 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_7.jpg
1 2 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg
0 2 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg


In [175]:
next(iter(dataloaders['train']));

0 0 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_6.jpg
1 1 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_7.jpg
2 2 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg
0 1 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_7.jpg
1 2 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg
0 2 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg


In [182]:
x[:,:,:,0]

tensor([[[-0.5320, -0.5927, -0.6814,  ..., -1.2546, -1.4361, -1.2081],
         [-0.5835, -0.5715, -0.5733,  ..., -0.9946, -1.1641, -0.9913],
         [-0.5148, -0.5110, -0.4586,  ..., -1.0691, -1.2260, -0.8971],
         ...,
         [ 0.3046,  0.2030,  0.6679,  ..., -0.1121,  0.0301, -0.2081],
         [ 0.6994,  0.3735,  0.9025,  ..., -0.0913,  0.0419, -0.2255],
         [ 0.6303, -0.0717,  0.4528,  ...,  0.0422,  0.1455, -0.0936]],

        [[-1.1682, -1.3124, -1.2450,  ..., -0.6295, -0.5955, -0.4948],
         [-1.1614, -1.2315, -1.1403,  ..., -0.3899, -0.4354, -0.3546],
         [-1.0998, -1.1692, -1.0747,  ..., -0.4098, -0.4784, -0.2750],
         ...,
         [-0.4502, -0.2282, -0.2179,  ...,  0.5228,  0.3733,  0.5196],
         [-0.2899, -0.2049, -0.3179,  ...,  0.7529,  0.5588,  0.7423],
         [-0.0622, -0.1042, -0.2035,  ...,  0.5566,  0.1780,  0.6290]],

        [[-0.3377, -0.2914, -0.4423,  ...,  0.3555, -0.2045,  0.6345],
         [-0.5473, -0.4105, -0.6543,  ...,  0

In [136]:
x[0, :, :,0]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [177]:
for i, data in enumerate(dataloaders['train']):
    x, y, n = data
    print(i, n)

0 0 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_6.jpg
1 1 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_7.jpg
2 2 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg
0 1 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_7.jpg
1 2 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg
0 2 ../../Data/FINAL_DATA/HETEROGENEOUS/CIFAR_100/SEGMENTATION_DATA/BINARY_SEGMENTATION/train/1_8.jpg
0 ('1_6.jpg', '1_7.jpg', '1_8.jpg', '1_9.jpg', '1_10.jpg', '1_11.jpg', '1_12.jpg', '1_13.jpg', '1_14.jpg', '1_15.jpg', '1_16.jpg', '1_17.jpg', '1_18.jpg', '1_19.jpg', '1_20.jpg', '1_21.jpg', '1_22.jpg', '1_23.jpg', '1_24.jpg', '1_25.jpg', '1_26.jpg', '1_27.jpg', '1_28.jpg', '1_29.jpg', '1_30.jpg', '1_31.jpg', '1_32.jpg', '1_33.jpg', '1_34.jpg', '1_35.jpg', '1_36.jpg', '1_37.jpg', '1_38

In [94]:
# x.shape, y.shape, n

In [64]:
y.flatten().shape

torch.Size([49152])

_________________


____________________



__________

In [168]:
import numpy as np
a = np.arange(1, 10)
a

array([1, 2, 3, 4, 5, 6, 7, 8, 9])

In [169]:
num_input_steps =3
num_future_steps=1

num_train = len(a)
a = a[0+num_input_steps-1:num_train-num_future_steps]
a

array([3, 4, 5, 6, 7, 8])

In [222]:

for index in range(0, len(a)-num_input_steps-num_future_steps):
    x, y = [], []
    c = 0
  
    
    if index < (len(a)-num_input_steps+num_future_steps+1):
#         print(index, num_input_steps+num_future_steps+1)
        for s_i,s  in enumerate(np.arange(index, index+num_input_steps, 1)):
#             print(s_i,s)
            x.append(a[s])
    #     print(index)

        y.append(a[index+num_input_steps+num_future_steps-1])

        print(x)
        print(y)
        print()

        
#  x = self.data[index:index+self.window]
#         y = self.data[index+self.window,0:target_cols]


[1, 2, 3]
[4]

[2, 3, 4]
[5]

[3, 4, 5]
[6]

[4, 5, 6]
[7]

[5, 6, 7]
[8]



In [172]:
x=None
y=False
if x is None and not y:
    print('working')

working
