In [None]:
import os
from keras.models import load_model
from keras.callbacks import EarlyStopping, ModelCheckpoint

from python_research.experiments.multiple_feature_learning.builders.\
    keras_builders import \
    build_1d_model
from python_research.experiments.utils.datasets.hyperspectral_dataset import HyperspectralDataset
from python_research.augmentation.transformations import PCATransformation
from python_research.experiments.utils.datasets.subset import BalancedSubset
from python_research.augmentation.augmenter import Augmenter

DATASET_PATH = ""
GT_PATH = ""
OUTPUT_PATH = ""
PIXEL_NEIGHBORHOOD = 1
SAMPLES_PER_CLASS = 300
VAL_PART = 0.1
PATIENCE = 15
KERNELS = 200
KERNEL_SIZE = 5
CLASSES_COUNT = 16
BATCH_SIZE = 64
EPOCHS = 200
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Load dataset
test_data = HyperspectralDataset(DATASET_PATH, GT_PATH,
                                 neighbourhood_size=PIXEL_NEIGHBORHOOD)

test_data.normalize_labels()
test_data.expand_dims(axis=-1)

# Extract training and validation sets
train_data = BalancedSubset(test_data, SAMPLES_PER_CLASS)
val_data = BalancedSubset(train_data, VAL_PART)

# Keras Callbacks
early = EarlyStopping(patience=PATIENCE)
checkpoint = ModelCheckpoint(os.path.join(OUTPUT_PATH, "offline_augmentation") + "_model",
                             save_best_only=True)

# Normalize data
max_ = train_data.max if train_data.max > val_data.max else val_data.max
min_ = train_data.min if train_data.min < val_data.min else val_data.min
train_data.normalize_min_max(min_=min_, max_=max_)
val_data.normalize_min_max(min_=min_, max_=max_)
test_data.normalize_min_max(min_=min_, max_=max_)

# Build 1d model
model = build_1d_model((test_data.shape[1:]), KERNELS,
                       KERNEL_SIZE, CLASSES_COUNT)

print(model.summary())


In [None]:
# Train model without train set augmentation

history = model.fit(x=train_data.get_data(),
                    y=train_data.get_one_hot_labels(CLASSES_COUNT),
                    batch_size=BATCH_SIZE,
                    epochs=1,
                    verbose=2,
                    callbacks=[early, checkpoint],
                    validation_data=(val_data.get_data(),
                                     val_data.get_one_hot_labels(CLASSES_COUNT)))
# Load best model
model = load_model(os.path.join(OUTPUT_PATH, "offline_augmentation") + "_model")

# Calculate test set score without augmentation
test_score = model.evaluate(x=test_data.get_data(),
                            y=test_data.get_one_hot_labels(CLASSES_COUNT))
print("Test set score without offline augmentation: {}".format(test_score[1]))


In [None]:
# Remove last dimension (convert column vectors to row vectors)
train_data.data = train_data.get_data()[:, :, 0]


# Augment training set
transformation = PCATransformation(low=0.9, high=1.1, n_components=train_data.shape[-1])
transformation.fit(train_data.get_data())
augmenter = Augmenter(transformation, sampling_mode='twice')
augmented_data, augmented_labels = augmenter.augment(train_data, transformations=1)

# Add augmented samples to the training set
train_data.vstack(augmented_data)
train_data.hstack(augmented_labels)

train_data.expand_dims(axis=-1)

checkpoint = ModelCheckpoint(os.path.join(OUTPUT_PATH, "offline_augmentation_augmented") + "_model",
                             save_best_only=True)
# Train model
history = model.fit(x=train_data.get_data(),
                    y=train_data.get_one_hot_labels(CLASSES_COUNT),
                    batch_size=BATCH_SIZE,
                    epochs=1,
                    verbose=2,
                    callbacks=[early, checkpoint],
                    validation_data=(val_data.get_data(),
                                     val_data.get_one_hot_labels(CLASSES_COUNT)))
# Load best model
model = load_model(os.path.join(OUTPUT_PATH, "offline_augmentation_augmented") + "_model")

# Calculate test set score without augmentation
test_score = model.evaluate(x=test_data.get_data(),
                            y=test_data.get_one_hot_labels(CLASSES_COUNT))
print("Test set score with offline augmentation: {}".format(test_score[1]))