In [1]:
import os
import torch
from config import data_config

from src.utils import generate_patches, split_train_val_test
from src.datafactory import SatelliteDataset

In [2]:
# Setup data configs
data_config = data_config.data_config()

In [3]:
# Create directories
    
isExist = os.path.exists(data_config.PATCHES_DATA_PATH)
if not isExist:
    os.makedirs(data_config.PATCHES_DATA_PATH)
    
isExist = os.path.exists(data_config.TRAIN_DATA_PATH)
if not isExist:
    os.makedirs(data_config.TRAIN_DATA_PATH)
 
isExist = os.path.exists(data_config.VAL_DATA_PATH)
if not isExist:
    os.makedirs(data_config.VAL_DATA_PATH)
       
isExist = os.path.exists(data_config.TEST_DATA_PATH)
if not isExist:
    os.makedirs(data_config.TEST_DATA_PATH)

In [4]:
# Create the base dataset
base_dataset = SatelliteDataset(root=data_config.RAW_DATA_PATH)

print(len(base_dataset))

72


### Patchify the dataset

In [5]:
patches_images, patches_masks = generate_patches(base_dataset, data_config.PATCH_SIZE)
print(patches_images.shape, patches_masks.shape)

(1305, 256, 256, 3) (1305, 256, 256)


### Save patches data

In [6]:
torch.save(patches_images, data_config.PATCHES_DATA_PATH + 'patches_images.pt')
torch.save(patches_masks, data_config.PATCHES_DATA_PATH + 'patches_masks.pt')

### Split Data

In [7]:
Xtrain_patches, Xval_patches, \
    Xtest_patches, ytrain_patches,\
        yval_patches, ytest_patches = split_train_val_test(X = patches_images, 
                                                            y = patches_masks, 
                                                            test_size=data_config.TEST_SIZE, 
                                                            val_size=data_config.VAL_SIZE, 
                                                            random_state= data_config.RANDOM_STATE)

Xtrain_patches.shape, Xval_patches.shape, Xtest_patches.shape, ytrain_patches.shape, yval_patches.shape, ytest_patches.shape

((913, 256, 256, 3),
 (261, 256, 256, 3),
 (131, 256, 256, 3),
 (913, 256, 256),
 (261, 256, 256),
 (131, 256, 256))

### Save Splited Data

In [9]:
torch.save(Xtrain_patches, data_config.TRAIN_DATA_PATH + 'Xtrain_patches.pt')
torch.save(ytrain_patches, data_config.TRAIN_DATA_PATH + 'ytrain_patches.pt')


torch.save(Xval_patches, data_config.VAL_DATA_PATH + 'Xval_patches.pt')
torch.save(yval_patches, data_config.VAL_DATA_PATH + 'yval_patches.pt')

torch.save(Xtest_patches, data_config.TEST_DATA_PATH + 'Xtest_patches.pt')
torch.save(ytest_patches, data_config.TEST_DATA_PATH + 'ytest_patches.pt')