# Class descriptions
PatchDataset should return a collection of patch files to be used for training or validation.
DatasetManager should be a new class that handles and selects patch datasets.

# Imports

In [2]:
from PatchGeneration.Modules.Network.datautils import *
import os
import torch

# Test datamanager
Testing the datamanager by selecting a folder with datapoints, generating a dataset and splitting it 20%-80% validation-training set.

In [3]:
dataManager = DatasetManager(batch_size=100)
dataManager.addFolder("samples/Test_Bunny")
dataManager.generateDatasetFromFolders()
dataManager.splitData(0.2)
trainingSet = dataManager.getTrainingSet()
validationSet = dataManager.getValidationSet()

Below are some print statements as proof that the split is a split and both contain all datapoints

In [4]:
print("Training data:")
print(len(trainingSet))
print("Validation data:")
print(len(validationSet))
print(np.intersect1d(trainingSet.dataset.data_path, validationSet.dataset.data_path))
print(np.all(np.isin(np.union1d(trainingSet.dataset.data_path, validationSet.dataset.data_path), dataManager.data_path)))
print(np.all(np.isin(dataManager.data_path, np.union1d(trainingSet.dataset.data_path, validationSet.dataset.data_path))))

Training data:
8
Validation data:
2
[]
True
True


# Testing Saving and Loading
Testing saving a loading the dataset and split into files with nonsense names. (When saving to a non-existing subdirectory, a directory will remain present after removing the file, therefore the code is saving to the current directory!)

In [5]:
DATASET_NAME = "TESTDATAESST_JWZ.h5"
dataManager.saveDataset(DATASET_NAME)
dataManager.loadDataset(DATASET_NAME)
os.remove(DATASET_NAME)

In [6]:
SPLIT_NAME = "JWZSPLITTHESHIT.npy"
dataManager.saveSplit(SPLIT_NAME)
dataManager.loadSplit(SPLIT_NAME)
os.remove(SPLIT_NAME)

In [7]:
for i, v in enumerate(dataManager.getValidationSet()):
    inputs, _, _, _ = v
    inputs = inputs.type(torch.FloatTensor)
    inputs = inputs.permute(0, 2, 1)
    print(inputs.shape, inputs)
    break


torch.Size([100, 20, 64]) tensor([[[4.9404e-01, 4.9629e-01, 5.0000e-01,  ..., 3.5917e-01,
          4.9700e-01, 4.8361e-01],
         [6.0233e-01, 5.0503e-01, 5.0000e-01,  ..., 1.0400e+00,
          6.5421e-01, 7.5234e-01],
         [3.6399e-01, 3.9227e-01, 5.0000e-01,  ..., 6.1746e-01,
          1.0076e+00, 9.6782e-01],
         ...,
         [1.0000e+00, 0.0000e+00, 1.0000e+00,  ..., 4.7000e+01,
          3.6000e+01, 3.3000e+01],
         [4.0000e+00, 2.0000e+00, 3.0000e+00,  ..., 6.0000e+01,
          6.3000e+01, 6.2000e+01],
         [1.3000e+01, 5.0000e+00, 8.0000e+00,  ..., 6.0000e+01,
          6.3000e+01, 6.2000e+01]],

        [[4.9696e-01, 5.0000e-01, 5.1557e-01,  ..., 4.5721e-01,
          0.0000e+00, 0.0000e+00],
         [4.7630e-01, 5.0000e-01, 6.0547e-01,  ..., 1.5952e-01,
          0.0000e+00, 0.0000e+00],
         [3.8938e-01, 5.0000e-01, 5.0951e-01,  ..., 1.1182e-02,
          0.0000e+00, 0.0000e+00],
         ...,
         [1.0000e+00, 0.0000e+00, 1.0000e+00,  ..., 4

In [3]:
print(type(0.1) == float)

True
