In [5]:
import os
from keras.models import load_model
from keras.callbacks import ModelCheckpoint, EarlyStopping
from python_research.experiments.multiple_feature_learning.builders.keras_builders import build_1d_model, build_3d_model, build_settings_for_dataset
from python_research.experiments.utils.datasets.hyperspectral_dataset import HyperspectralDataset
from python_research.experiments.utils.datasets.subset import BalancedSubset, UnbalancedSubset

DATASET_PATH = "C:\\Users\mmyller.KPLABS\Documents\datasets\pavia\PaviaU_corrected.npy"
DATASET_GT_PATH = "C:\\Users\mmyller.KPLABS\Documents\datasets\pavia\PaviaU_gt.npy"
OUTPUT_PATH = "monte_carlo\\artifact"
BALANCED = True
TRAIN_SAMPLES_PER_CLASS_COUNT = 250
TOTAL_NUMBER_OF_SAMPLES = 2700
PIXEL_NEIGHBOURHOOD = 7
CLASSES_COUNT = 9
PATIENCE = 15
EPOCHS = 200
BATCH_SIZE = 64
os.makedirs("monte_carlo_3D", exist_ok=True)

# Load data and split into train, val and test sets
test_data = HyperspectralDataset(DATASET_PATH, DATASET_GT_PATH,
                                 neighbourhood_size=PIXEL_NEIGHBOURHOOD)
test_data.normalize_labels()
bands_count = test_data.shape[-1]
if PIXEL_NEIGHBOURHOOD == 1:
    test_data.expand_dims(axis=-1)
if BALANCED:
    train_data = BalancedSubset(test_data, TRAIN_SAMPLES_PER_CLASS_COUNT)
    val_data = BalancedSubset(train_data, 0.1)
else:
    train_data = UnbalancedSubset(test_data, TOTAL_NUMBER_OF_SAMPLES)
    val_data = UnbalancedSubset(train_data, 0.1) 

# Build model, 1D or 3D
if PIXEL_NEIGHBOURHOOD == 1:
    model = build_1d_model((test_data.shape[1:]), 200, 5, CLASSES_COUNT)
else:
    settings = build_settings_for_dataset((PIXEL_NEIGHBOURHOOD,
                                           PIXEL_NEIGHBOURHOOD))
    model = build_3d_model(settings, CLASSES_COUNT, bands_count)

print(model.summary())
print("Training samples: {}".format(train_data.shape))
print("Validation samples: {}".format(val_data.shape))
print("Test samples: {}".format(test_data.shape))

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_7 (Conv2D)            (None, 4, 4, 200)         329800    
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 2, 2, 200)         0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 2, 2, 200)         160200    
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 1, 1, 9)           7209      
_________________________________________________________________
flatten_3 (Flatten)          (None, 9)                 0         
_________________________________________________________________
softmax_3 (Softmax)          (None, 9)                 0         
Total params: 497,209
Trainable params: 497,209
Non-trainable params: 0
_________________________________________________________________
None

In [4]:
# Normalize data
max_ = train_data.max if train_data.max > val_data.max else val_data.max
min_ = train_data.min if train_data.min < val_data.min else val_data.min
train_data.normalize_min_max(min_=min_, max_=max_)
val_data.normalize_min_max(min_=min_, max_=max_)
test_data.normalize_min_max(min_=min_, max_=max_)

# Callbacks
early = EarlyStopping(patience=PATIENCE)
checkpoint = ModelCheckpoint(OUTPUT_PATH + "_model", save_best_only=True)

# Model training
model.fit(x=train_data.get_data, y=train_data.get_one_hot_labels(CLASSES_COUNT), batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=False, 
          callbacks=[early, checkpoint], validation_data=[val_data.get_data, val_data.get_one_hot_labels(CLASSES_COUNT)])

# Load best model
best_model = load_model(OUTPUT_PATH + "_model")

# Evaluate test set score
accuracy = best_model.evaluate(x=test_data.get_data, y=test_data.get_one_hot_labels(CLASSES_COUNT))[1]
print("Test set accuracy: {}".format(accuracy))

   32/40526 [..............................] - ETA: 59s

  320/40526 [..............................] - ETA: 12s

  544/40526 [..............................] - ETA: 11s

  864/40526 [..............................] - ETA: 9s 

 1152/40526 [..............................] - ETA: 9s

 1504/40526 [>.............................] - ETA: 8s

 1824/40526 [>.............................] - ETA: 8s

 2144/40526 [>.............................] - ETA: 7s

 2496/40526 [>.............................] - ETA: 7s

 2816/40526 [=>............................] - ETA: 7s

 3200/40526 [=>............................] - ETA: 7s

 3520/40526 [=>............................] - ETA: 7s

 3904/40526 [=>............................] - ETA: 7s

 4224/40526 [==>...........................] - ETA: 6s

 4512/40526 [==>...........................] - ETA: 6s

 4896/40526 [==>...........................] - ETA: 6s

 5216/40526 [==>...........................] - ETA: 6s

 5600/40526 [===>..........................] - ETA: 6s

 5920/40526 [===>..........................] - ETA: 6s

 6272/40526 [===>..........................] - ETA: 6s

 6592/40526 [===>..........................] - ETA: 6s

 6912/40526 [====>.........................] - ETA: 6s

 7296/40526 [====>.........................] - ETA: 6s

 7616/40526 [====>.........................] - ETA: 5s

 8000/40526 [====>.........................] - ETA: 5s

 8416/40526 [=====>........................] - ETA: 5s

 8704/40526 [=====>........................] - ETA: 5s

 9024/40526 [=====>........................] - ETA: 5s

 9408/40526 [=====>........................] - ETA: 5s

























































































































































































Test set accuracy: 0.926466959476918
