In [None]:

import os
from keras.models import load_model
from keras.callbacks import ModelCheckpoint, EarlyStopping
from python_research.keras_models import build_1d_model, build_3d_model, build_settings_for_dataset
from python_research.dataset_structures import HyperspectralDataset
from python_research.dataset_structures import BalancedSubset, ImbalancedSubset
DATA_DIR = os.path.join('..', '..', 'hypernet-data')
RESULTS_DIR = os.path.join('..', '..', 'hypernet-data', 'results', 'offline_augmentation')
DATASET_PATH = os.path.join(DATA_DIR, '')
DATASET_GT_PATH = os.path.join(DATA_DIR, '')
os.makedirs(RESULTS_DIR, exist_ok=True)


# Prepare the data
Load the data either as 1D or 3D samples, what is indicated by the **`PIXEL_NEIGHBOURHOOD`** variable. If it is equal to 1, 1D samples will be used, using only the spectral
information of a pixel. If **`PIXEL_NEIGHBOURHOOD`** variable is different than 1, pixel's neighbourhood of size equal to its value will be extracted for each sample. Neighbourhood of the pixel should be provided as an odd number. Please note that the 3D segmentation of hyperspectral data causes data leakage between train and test sets. We have described this issue in more details in our paper, which can be found [here](https://arxiv.org/abs/1811.03707).
 
The **`BALANCED`** variable indicates whether the training set should be balanced. 
If balanced, **`TRAIN_SAMPLES_PER_CLASS_COUNT`** variable indicates how many samples each class should contain. In the imbalanced case,
**`TOTAL_NUMBER_OF_SAMPLES`** variable indicates total number of samples in the training set, where samples are chosen randomly, regardless
of their class.

In both cases, validation set is set to have 10% of the data extracted from the training set.

In [None]:
BALANCED = True
PIXEL_NEIGHBOURHOOD = 7
TRAIN_SAMPLES_PER_CLASS_COUNT = 250
TOTAL_NUMBER_OF_SAMPLES = 2700

test_data = HyperspectralDataset(DATASET_PATH, DATASET_GT_PATH,
                                 neighbourhood_size=PIXEL_NEIGHBOURHOOD)
test_data.normalize_labels()
bands_count = test_data.shape[-1]
if PIXEL_NEIGHBOURHOOD == 1:
    test_data.expand_dims(axis=-1)
if BALANCED:
    train_data = BalancedSubset(test_data, TRAIN_SAMPLES_PER_CLASS_COUNT)
    val_data = BalancedSubset(train_data, 0.1)
else:
    train_data = ImbalancedSubset(test_data, TOTAL_NUMBER_OF_SAMPLES)
    val_data = ImbalancedSubset(train_data, 0.1)

# Normalize the data

Data is normalized using Min-Max feature scaling. Min and max values are extracted from train and test sets.

In [4]:
# Normalize data
max_ = train_data.max if train_data.max > val_data.max else val_data.max
min_ = train_data.min if train_data.min < val_data.min else val_data.min
train_data.normalize_min_max(min_=min_, max_=max_)
val_data.normalize_min_max(min_=min_, max_=max_)
test_data.normalize_min_max(min_=min_, max_=max_)

# Build the model

Build the keras model, depending on the dimensionality of samples prepared earlier.


In [5]:
CLASSES_COUNT = 9
NUMBER_OF_FILTERS = 200
KERNEL_SIZE = 5
PATIENCE = 15

# Build model, 1D or 3D
if PIXEL_NEIGHBOURHOOD == 1:
    model = build_1d_model((test_data.shape[1:]), NUMBER_OF_FILTERS, KERNEL_SIZE, CLASSES_COUNT)
else:
    settings = build_settings_for_dataset((PIXEL_NEIGHBOURHOOD,
                                           PIXEL_NEIGHBOURHOOD))
    model = build_3d_model(settings, CLASSES_COUNT, bands_count)

# Callbacks
early = EarlyStopping(patience=PATIENCE)
checkpoint = ModelCheckpoint(os.path.join(RESULTS_DIR, "monte_carlo_model"), save_best_only=True)


# Model and data summary

In [6]:
print(model.summary())
print("Training samples: {}".format(train_data.shape))
print("Validation samples: {}".format(val_data.shape))
print("Test samples: {}".format(test_data.shape))

# Training and evaluation

In [7]:
EPOCHS = 200
BATCH_SIZE = 64

model.fit(x=train_data.get_data(), 
          y=train_data.get_one_hot_labels(CLASSES_COUNT), 
          batch_size=BATCH_SIZE, 
          epochs=EPOCHS, 
          verbose=False, 
          callbacks=[early, checkpoint], 
          validation_data=[val_data.get_data(), val_data.get_one_hot_labels(CLASSES_COUNT)])

# Load best model
best_model = load_model(os.path.join(RESULTS_DIR, "monte_carlo_model"))

# Evaluate test set score
loss, accuracy = best_model.evaluate(x=test_data.get_data(), y=test_data.get_one_hot_labels(CLASSES_COUNT))
print("Test set accuracy: {}".format(accuracy))

   32/40526 [..............................] - ETA: 56s

  384/40526 [..............................] - ETA: 10s

  800/40526 [..............................] - ETA: 7s 

 1248/40526 [..............................] - ETA: 6s

 1696/40526 [>.............................] - ETA: 5s

 2144/40526 [>.............................] - ETA: 5s

 2592/40526 [>.............................] - ETA: 5s

 3040/40526 [=>............................] - ETA: 4s

 3488/40526 [=>............................] - ETA: 4s

 3936/40526 [=>............................] - ETA: 4s

 4384/40526 [==>...........................] - ETA: 4s

 4832/40526 [==>...........................] - ETA: 4s

 5280/40526 [==>...........................] - ETA: 4s

 5728/40526 [===>..........................] - ETA: 4s

 6176/40526 [===>..........................] - ETA: 4s

 6624/40526 [===>..........................] - ETA: 4s

 7040/40526 [====>.........................] - ETA: 4s

 7488/40526 [====>.........................] - ETA: 4s

 7936/40526 [====>.........................] - ETA: 3s



 8384/40526 [=====>........................] - ETA: 3s

 8832/40526 [=====>........................] - ETA: 3s

 9280/40526 [=====>........................] - ETA: 3s





















































































































































Test set accuracy: 0.8917238316113307
