In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import torch
import time
import os
import numpy as np
from pathlib import Path
from PIL import Image
from skimage.transform import resize
import helper
import matplotlib.pyplot as plt
from matplotlib import pyplot
from matplotlib.image import imread
from torchvision import datasets
from torchvision import datasets, transforms, models
from torch import nn, optim, Tensor
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from torch.utils.data import DataLoader, Dataset, TensorDataset
import random
seed=42

In [None]:
# Loading paths
base_path = Path('../input/95cloud-cloud-segmentation-on-satellite-images/95-cloud_training_only_additional_to38-cloud')
red_dir   = base_path/'train_red_additional_to38cloud'
blue_dir  = base_path/'train_blue_additional_to38cloud'
green_dir = base_path/'train_green_additional_to38cloud'
nir_dir   = base_path/'train_nir_additional_to38cloud'
gt_dir    = base_path/'train_gt_additional_to38cloud'

# Saving paths

os.mkdir('./TrueColor_imgs')
os.mkdir('./FalseColor_imgs')
os.mkdir('./GT')

In [None]:
class RGB_CloudDataset (Dataset):
    def __init__(self, red_dir, blue_dir, green_dir, gt_dir, transform= None):
        
        
        self.transform   = transform
    
        
        # Listing subdirectories
        # Loop through the files in red folder  
        # and combine, into a dictionary, the other bands
        
        self.files = [self.combine_files(f, green_dir, blue_dir, gt_dir) 
                      for f in red_dir.iterdir() if not f.is_dir()]
        
        
        
    def combine_files(self, red_file: Path, green_dir, blue_dir, gt_dir):
        
        files = {'red': red_file, 
                 'green':green_dir/red_file.name.replace('red', 'green'),
                 'blue': blue_dir/red_file.name.replace('red', 'blue'), 
                 'gt': gt_dir/red_file.name.replace('red', 'gt')}

        return files
    

    
    
    def OpenAsArray(self, idx):
        
        TrueColor = np.stack([np.array(Image.open(self.files[idx]['red'])),
                              np.array(Image.open(self.files[idx]['green'])),
                              np.array(Image.open(self.files[idx]['blue']))], axis = 2)
     
        TrueColor = TrueColor.transpose((2, 0, 1))
    
        return TrueColor 
    
    
    
    def OpenMask(self, idx, add_dims=False):
        raw_mask=np.array(Image.open(self.files[idx]['gt']))
        raw_mask = np.where(raw_mask==255, 1, 0)
        
        return np.expand_dims(raw_mask, 0) if add_dims else raw_mask


        
    def __len__(self):
        return len(self.files)
    
    
    def __getitem__(self, idx):
        x = self.OpenAsArray(idx)
        y = self.OpenMask(idx, add_dims=False)
        
        
        if self.transform is not None:
            x, y = self.transform((x, y))
        
        
        return x, y

In [None]:
class NirGB_CloudDataset (Dataset):
    def __init__(self, red_dir, blue_dir, green_dir, nir_dir, gt_dir, transform = None):
        
        
        self.transform = transform
        
        # Listing subdirectories
        # Loop through the files in red folder  
        # and combine, into a dictionary, the other bands
        
        self.files = [self.combine_files(f, green_dir, blue_dir, nir_dir, gt_dir) 
                      for f in red_dir.iterdir() if not f.is_dir()]
        
        
        
    def combine_files(self, red_file: Path, green_dir, blue_dir, nir_dir, gt_dir):
        
        files = {'red': red_file, 
                 'green':green_dir/red_file.name.replace('red', 'green'),
                 'blue': blue_dir/red_file.name.replace('red', 'blue'), 
                 'nir': nir_dir/red_file.name.replace('red', 'nir'),
                 'gt': gt_dir/red_file.name.replace('red', 'gt')}

        return files
    
    
    
    def OpenAsArray(self, idx):
        
        FalseColor = np.stack([np.array(Image.open(self.files[idx]['nir'])),
                               np.array(Image.open(self.files[idx]['green'])),
                               np.array(Image.open(self.files[idx]['blue']))], axis = 2)
     
                    
        FalseColor      = FalseColor.transpose((2, 0, 1))
    
        return FalseColor 
    
    
    
    def OpenMask(self, idx, add_dims=False):
        raw_mask=np.array(Image.open(self.files[idx]['gt']))
        raw_mask = np.where(raw_mask==255, 1, 0)
        
        return np.expand_dims(raw_mask, 0) if add_dims else raw_mask


        
    def __len__(self):
        return len(self.files)
    
    
    def __getitem__(self, idx):
        x = self.OpenAsArray(idx)
        y = self.OpenMask(idx, add_dims=False)
        return torch.from_numpy(x), torch.from_numpy(y)

In [None]:
class Resize(object):
    def __init__(self, size = 256):
        self.size = size
    def __call__(self, sample):
        x, y = sample
        return (resize(x, (x.shape[0], self.size, self.size), mode = "constant", 
                      preserve_range = True, anti_aliasing = False),
                resize(y, (self.size, self.size), mode = "constant", 
                      preserve_range = True, anti_aliasing = False))
    

class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std  = std
        
    def __call__(self, sample):
        x,y =sample
        x = x.transpose(1,2,0)
        x=(x-self.mean)/self.std
        return x.transpose(2,0,1), y
    
    
# TODO: Define transforms for the training data and testing data
train_transforms=transforms.Compose([Resize(256),
                                     Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225])])

                                 
test_transforms=transforms.Compose(Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225]))

In [None]:
NirGB=NirGB_CloudDataset(red_dir, blue_dir, green_dir, nir_dir, gt_dir, transform = None)
for i,d in enumerate(NirGB):
    torch.save(d[0], './FalseColor_imgs/FasleColor_{}'.format(i))
    torch.save(d[1], './GT/Mask_{}'.format(i))

In [None]:
data   = RGB_CloudDataset(red_dir, blue_dir, green_dir, gt_dir, transform = train_transforms)
for i,d in enumerate(data):
    torch.save(d[0], './TrueColor_imgs/TrueColor_{}'.format(i))
    torch.save(d[1], './GT/Mask_{}'.format(i))

In [None]:
list_dir = sorted(os.listdir(base_path/'train_gt_additional_to38cloud'))
idx = -1
for d in enumerate(data):
    idx+=1
    img_name = list_dir[idx]
    new_img_name = img_name.replace('gt', 'TrueColor')
    print(img_name)
    print( './TrueColor_imgs/{}'.format(new_img_name))
    print( './GT/{}'.format(img_name))
    break