# Set up

In [None]:
import tensorflow as tf
import tensorflow.keras as tfk
import numpy as np
import random

import augmentation as augh
import plotter as plot
import folding as fold

print("Libraries have been imported")

# Dataset import 

In [None]:
data = np.load("/kaggle/input/marsterrain-general/dataset_general.npz")
images = data['images']
labels = data['labels']

shuffled_indices = np.random.permutation(len(images))
images = images[shuffled_indices]
labels = labels[shuffled_indices]

print(f"Dataset: {images.dtype}{images.shape} - {labels.dtype}{labels.shape}")

# Example augmentation

In [None]:
i = random.randint(0, len(labels))
print(f"Augmenting image {i}")

augmented_image, augmented_mask = augh.masked_augment(images[i], labels[i])

plot.plot_masked_image((images[i], labels[i]), (augmented_image, augmented_mask), mask_alpha=0.1)
plot.plot_masked_image((images[i], labels[i]), (augmented_image, augmented_mask), mask_alpha=0)
plot.plot_masked_image((images[i], labels[i]), (augmented_image, augmented_mask), mask_alpha=1)

# Augment dataset

In [None]:
augmented_data = augh.augment_masked_set(data)

# Show classes presence

In [None]:
print(f"Total old images: {len(data['labels'])}")
print(f"Total new images: {len(augmented_data['labels'])}")

v = []
for label in labels:
    v = np.append(v, np.unique(label))
v = v.astype(int)
counts = np.bincount(v)
print(f"Old classes counts: {counts}")

v = []
for label in augmented_data['labels']:
    v = np.append(v, np.unique(label))
v = v.astype(int)
counts = np.bincount(v)
print(f"New classes counts: {counts}")

# Export dataset

In [None]:
np.savez("dataset_enhanced.npz", **augmented_data)

# Print some images

In [None]:
augmented_images = augmented_data['images']
augmented_labels = augmented_data['labels']
plot.plot_masked_images(augmented_images, augmented_labels, row=6)

# Split dataset

In [None]:
validation, training = fold.split_masked_set(augmented_data)

# Print example images after split

In [None]:
print("VALIDATION")
plot.plot_masked_images(validation['images'], validation['labels'])
print("TRAINING")
plot.plot_masked_images(training['images'], training['labels'])