In [5]:
import sys
sys.path.append('../src')

import json
import os
import skimage.io
from config import DATASET_DIR
from sklearn.model_selection import train_test_split

Total dataset has 29300 elements
- Train: 70% (20510)
- Validation: 20% (5860)
- Test: 10% (2930)

In [2]:
with open(os.path.join(DATASET_DIR, 'galaxy_segment_classes.json')) as fp:
    galaxy_segment = json.load(fp)

In [3]:
# Convert the dictionary keys to a list
keys = list(galaxy_segment.keys())

train_keys, temp_keys = train_test_split(keys, test_size=0.3, random_state=42)
val_keys, test_keys = train_test_split(temp_keys, test_size=(1/3), random_state=42)

def split_data(keys, original_data):
    return {key: original_data[key] for key in keys}

# Create the splits
train_data = split_data(train_keys, galaxy_segment)
val_data = split_data(val_keys, galaxy_segment)
test_data = split_data(test_keys, galaxy_segment)

# Check that there is no overlap between the sets
assert not set(train_keys) & set(val_keys), "Overlap detected between train and validation sets!"
assert not set(train_keys) & set(test_keys), "Overlap detected between train and test sets!"
assert not set(val_keys) & set(test_keys), "Overlap detected between validation and test sets!"

print("No overlap between train, validation, and test sets.")

with open(os.path.join(DATASET_DIR, 'galaxy_train.json'), 'w') as file:
    json.dump(train_data, file)

with open(os.path.join(DATASET_DIR, 'galaxy_val.json'), 'w') as file:
    json.dump(val_data, file)

with open(os.path.join(DATASET_DIR, 'galaxy_test.json'), 'w') as file:
    json.dump(test_data, file)

# Print the number of items in each split to verify
print(f'Training set size: {len(train_data)}')
print(f'Validation set size: {len(val_data)}')
print(f'Test set size: {len(test_data)}')


No overlap between train, validation, and test sets.
Training set size: 20510
Validation set size: 5860
Test set size: 2930


In [20]:
# COMPUTE MEAN AND STD OVER THE WHOLE DATASET FOR LATER NORMALIZATION
running_mean = 0.0
running_std_dev = 0.0

for k,v in galaxy_segment.items():
    image_path = os.path.join(DATASET_DIR, "original/zoo2Main", v['filename'])
    image = skimage.io.imread(image_path)

    running_mean += image.mean(axis=(0,1))
    running_std_dev += image.std(axis=(0,1))

In [26]:
mean = [round(x, 4) for x in running_mean/len(galaxy_segment)]
std = [round(x, 4) for x in running_std_dev/len(galaxy_segment)]

In [27]:
print(f"Dataset mean: {mean}")
print(f"Dataset std: {std}")

Dataset mean: [8.9483, 6.2549, 5.1239]
Dataset std: [14.9211, 12.5903, 8.8366]
