# Basic Pipeline

In [None]:
import json
from random import random

import matplotlib.pyplot as plt
import numpy as np
from keras.preprocessing.image import ImageDataGenerator

In [None]:
def quick_plot_img(data, x_dim=80, y_dim=80, layer_dim=3, order="F"):
    """
    Helper, plot image quickly, assumes 1D input or array or list.
    The default values are all for this specific project's dataset.
    """
    img = np.reshape(np.array(data), (x_dim, y_dim, layer_dim), order=order)
    plt.imshow((img))


def format_img(data, x_dim=80, y_dim=80, layer_dim=3, order="F"):
    """Helper, reshape 1D input data into an appropriate numpy array"""
    return np.reshape(
        np.array(data), (len(data), x_dim, y_dim, 3), order=order
    )  # noqa:E501


def train_test_validation_split(
    X: iter, y: iter, train_size: float = 0.8, validation=True
):
    """
    Splits features and labels at random, roughly,
    into training, validation and test sets.

    Args:
        X (iter): Features iterable
        y (iter): Labels iterable
        train_size (float, optional): Fraction to end up in training set.
            Defaults to 0.8.
        validation (bool, optional): Wether to split the test set further
            into test and validation. Defaults to True.

    Returns:
        tuple: training, validation and test features and labels.
    """
    train_set_index = [x for x in range(len(y)) if random() <= train_size]

    train_X = []
    train_y = []
    val_X = []
    val_y = []
    test_X = []
    test_y = []

    for i in range(len(y)):
        if i in train_set_index:
            train_X.append(X[i])
            train_y.append(y[i])

        elif validation & (random() > 0.5):
            val_X.append(X[i])
            val_y.append(y[i])

        else:
            test_X.append(X[i])
            test_y.append(y[i])

    # Report the result - it's random so someone might want to retry
    stats = (
        f"train: {len(train_X)}, "
        + f"validation: {len(val_X)}, "
        + f"test: {len(test_X)}"
    )
    print(stats)

    if len(val_X) > 0:
        return (train_X, train_y, val_X, val_y, test_X, test_y)
    return (train_X, train_y, test_X, test_y)

In [None]:
with open("data/shipsnet.json", "r") as f:
    data = json.load(f)

# Data structure
print([key for key in data.keys()])

# Check labels
print(data["labels"][:10], data["labels"][-10:])

# Check labels split
print("True: ", sum([i == 1 for i in data["labels"]]))
print("False: ", sum([i == 0 for i in data["labels"]]))

# Plot an example
quick_plot_img(data["data"][5])

In [None]:
train_X, train_y, val_X, val_y, test_X, test_y = train_test_validation_split(
    data["data"], data["labels"]
)

## 2. Make data generators (implements data augmentation steps)

In [None]:
train_datagen = ImageDataGenerator(
    rescale=1.0 / 255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    validation_split=0.2,
)

# For test, no point in augmentation
test_datagen = ImageDataGenerator(rescale=1.0 / 255)

train_datagen.fit(train_X)

In [None]:
train_y[0]

In [None]:
test = format_img(data["data"])
quick_plot_img(test[0])