In [None]:
import os
import torch

import numpy as np
from keras.models import load_model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from torch import nn
from torch.utils.data import DataLoader

from python_research.augmentation.GAN.WGAN import WGAN
from python_research.augmentation.GAN.classifier import Classifier
from python_research.augmentation.GAN.discriminator import Discriminator
from python_research.augmentation.GAN.generator import Generator
from python_research.experiments.multiple_feature_learning.builders.\
    keras_builders import \
    build_1d_model
from python_research.experiments.utils.datasets.data_loader import \
    OrderedDataLoader
from python_research.experiments.utils.datasets.hyperspectral_dataset import \
    HyperspectralDataset
from python_research.experiments.utils.datasets.subset import BalancedSubset
from python_research.augmentation.GAN.samples_generator import SamplesGenerator

DATASET_PATH = ""
GT_PATH = ""
OUTPUT_PATH = ""
# Number of samples to be extracted from each class as training samples
SAMPLES_PER_CLASS = 100
# Percentage of the training set to be extracted as validation set 
VAL_PART = 0.1
# Number of epochs without improvement on validation set after which the 
# training will be terminated 
PATIENCE = 15 
# Number of kernels in the first convolutional layer
KERNELS = 200 
# Size of the kernel in the first convolutional layer
KERNEL_SIZE = 5 
# Number of classes in the dataset
CLASSES_COUNT = 16 
BATCH_SIZE = 64 
# Number of training epochs
EPOCHS = 1
# Number of epochs without improvement on validation set after which the 
# training will be terminated for the GAN classifier
CLASSIFIER_PATIENCE = 30
# Number of epochs without improvement on discriminator loss after
# after which the GAN training will be terminated
GAN_PATIENCE = 200
# GAN learning rate
LEARNING_RATE = 0.00001
# Gradient penalty
LAMBDA_GP = 10
# Number of GAN epochs
GAN_EPOCHS = 1


os.makedirs(OUTPUT_PATH, exist_ok=True)

# Load dataset
test_data = HyperspectralDataset(DATASET_PATH, GT_PATH)

test_data.normalize_labels()

# Extract training and validation sets
train_data = BalancedSubset(test_data, SAMPLES_PER_CLASS)
val_data = BalancedSubset(train_data, VAL_PART)

# Keras Callbacks
early = EarlyStopping(patience=PATIENCE)
checkpoint = ModelCheckpoint(os.path.join(OUTPUT_PATH, "offline_augmentation") + "_model",
                             save_best_only=True)

# Normalize data
max_ = train_data.max if train_data.max > val_data.max else val_data.max
min_ = train_data.min if train_data.min < val_data.min else val_data.min
train_data.normalize_min_max(min_=min_, max_=max_)
val_data.normalize_min_max(min_=min_, max_=max_)
test_data.normalize_min_max(min_=min_, max_=max_)

# Initialize pytorch data loaders
custom_data_loader = OrderedDataLoader(train_data, BATCH_SIZE)
data_loader = DataLoader(train_data, batch_size=BATCH_SIZE,
                         shuffle=True, drop_last=True)

cuda = True if torch.cuda.is_available() else False

input_shape = bands_count = train_data.shape[-1]

classifier_criterion = nn.CrossEntropyLoss()
# Initialize generator, discriminator and classifier
generator = Generator(input_shape, CLASSES_COUNT)
discriminator = Discriminator(input_shape)
classifier = Classifier(classifier_criterion, input_shape, CLASSES_COUNT,
                        use_cuda=cuda, patience=CLASSIFIER_PATIENCE)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=LEARNING_RATE,
                               betas=(0, 0.9))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=LEARNING_RATE,
                               betas=(0, 0.9))
optimizer_C = torch.optim.Adam(classifier.parameters(),
                               lr=LEARNING_RATE,
                               betas=(0, 0.9))
# Use GPU if possible
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    classifier = classifier.cuda()
    classifier_criterion = classifier_criterion.cuda()

# Train classifier
classifier.train_(data_loader, optimizer_C, GAN_EPOCHS)


In [None]:
# Initialize GAN
gan = WGAN(generator, discriminator, classifier, optimizer_G, optimizer_D,
           use_cuda=cuda, lambda_gp=LAMBDA_GP, patience=GAN_PATIENCE)
# Train GAN
gan.train(custom_data_loader, GAN_EPOCHS, bands_count,
          BATCH_SIZE, CLASSES_COUNT, os.path.join(OUTPUT_PATH, "generator_model"))


In [None]:

# Generate samples using trained Generator

generator = Generator(input_shape, CLASSES_COUNT)
generator_path = os.path.join(OUTPUT_PATH, "generator_model")
generator.load_state_dict(torch.load(generator_path))
if cuda:
    generator = generator.cuda()
train_data.convert_to_numpy()

device = 'gpu' if cuda is True else 'cpu'
samples_generator = SamplesGenerator(device=device)
generated_x, generated_y = samples_generator.generate(train_data,
                                                      generator)
# Convert generated Tensors back to numpy
generated_x = np.reshape(generated_x.detach().cpu().numpy(),
                         generated_x.shape + (1, ))

# Add one dimension to convert row vectors to column vectors (keras 
# requirement) 
train_data.expand_dims(axis=-1)
test_data.expand_dims(axis=-1)
val_data.expand_dims(axis=-1)

# Add generated samples to original dataset
train_data.vstack(generated_x)
train_data.hstack(generated_y)

# Keras Callbacks
early = EarlyStopping(patience=PATIENCE)
checkpoint = ModelCheckpoint(os.path.join(OUTPUT_PATH, "GAN_augmentation") + 
                                          "_model",
                             save_best_only=True)
# Build 1d model
model = build_1d_model((test_data.shape[1:]), KERNELS,
                       KERNEL_SIZE, CLASSES_COUNT)
# Train model
history = model.fit(x=train_data.get_data(),
                    y=train_data.get_one_hot_labels(CLASSES_COUNT),
                    batch_size=BATCH_SIZE,
                    epochs=EPOCHS,
                    verbose=2,
                    callbacks=[early, checkpoint],
                    validation_data=(val_data.get_data(),
                                     val_data.get_one_hot_labels(CLASSES_COUNT)))
# Load best model
model = load_model(os.path.join(OUTPUT_PATH, "GAN_augmentation") + "_model")

# Calculate test set score without augmentation
test_score = model.evaluate(x=test_data.get_data(),
                            y=test_data.get_one_hot_labels(CLASSES_COUNT))
print("Test set score with GAN offline augmentation: {}".format(test_score[1]))