# Generative Adversarial Network (GAN)

GANs are based on an adversarial process, in which two models are trained simultaneously: a generator network G which learns to synthesize samples from the distribution of provided data, and a discriminator network D which measures the probability that a sample came from the original data, rather than was
generated by the generator. During the training stage, both models compete with each other,
by G trying to synthesise samples so real, that D can no longer differentiate between real and fake. Implementation presented in this notebook also includes a classifier as a third model. It penalizes the generator for generating samples from wrong class. This addition is particularly useful, because it later allows the user to indicate from which class (distribution) one would like to synthesize the samples. 

In [None]:
import os

import torch
import numpy as np
from torch import nn
from keras.models import load_model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader

from python_research.augmentation.GAN.WGAN import WGAN
from python_research.augmentation.GAN.classifier import Classifier
from python_research.augmentation.GAN.discriminator import Discriminator
from python_research.augmentation.GAN.generator import Generator
from python_research.keras_models import build_1d_model
from python_research.dataset_structures import OrderedDataLoader, \
    HyperspectralDataset, BalancedSubset
from python_research.augmentation.GAN.samples_generator import SamplesGenerator

DATA_DIR = os.path.join('..', '..', 'hypernet-data')
RESULTS_DIR = os.path.join('..', '..', 'hypernet-data', 'results', 'gan_augmentation')
DATASET_PATH = os.path.join(DATA_DIR, '')
GT_PATH = os.path.join(DATA_DIR, '')
os.makedirs(RESULTS_DIR, exist_ok=True)


# Prepare the data

Extract the training, validation and test sets. Trainig set will be balanced (each class will have equal number of samples)

In [None]:
# Number of samples to be extracted from each class as training samples
SAMPLES_PER_CLASS = 100
# Percentage of the training set to be extracted as validation set 
VAL_PART = 0.1

# Load dataset
test_data = HyperspectralDataset(DATASET_PATH, GT_PATH)

test_data.normalize_labels()

# Extract training and validation sets
train_data = BalancedSubset(test_data, SAMPLES_PER_CLASS)
val_data = BalancedSubset(train_data, VAL_PART)


# Data normalization

Data is normalized using Min-Max feature scaling. Min and max values are extracted from train and test sets.

In [None]:
# Normalize data
max_ = train_data.max if train_data.max > val_data.max else val_data.max
min_ = train_data.min if train_data.min < val_data.min else val_data.min
train_data.normalize_min_max(min_=min_, max_=max_)
val_data.normalize_min_max(min_=min_, max_=max_)
test_data.normalize_min_max(min_=min_, max_=max_)


# Data loaders and models initialization

GAN is composed of three models: generator, discriminator and classifier. All of them have an identical topology (2 hidden layers with 512 neurons each).

In [None]:
# Number of epochs without improvement on validation set after which the 
# training will be terminated for the GAN classifier
CLASSIFIER_PATIENCE = 30
# GAN learning rate
LEARNING_RATE = 0.00001
# Number of classes in the dataset
CLASSES_COUNT = 16 
BATCH_SIZE = 64

# Initialize pytorch data loaders
custom_data_loader = OrderedDataLoader(train_data, BATCH_SIZE)
data_loader = DataLoader(train_data, batch_size=BATCH_SIZE,
                         shuffle=True, drop_last=True)

cuda = True if torch.cuda.is_available() else False

input_shape = bands_count = train_data.shape[-1]

classifier_criterion = nn.CrossEntropyLoss()
# Initialize generator, discriminator and classifier
generator = Generator(input_shape, CLASSES_COUNT)
discriminator = Discriminator(input_shape)
classifier = Classifier(classifier_criterion, input_shape, CLASSES_COUNT,
                        use_cuda=cuda, patience=CLASSIFIER_PATIENCE)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=LEARNING_RATE,
                               betas=(0, 0.9))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=LEARNING_RATE,
                               betas=(0, 0.9))
optimizer_C = torch.optim.Adam(classifier.parameters(),
                               lr=LEARNING_RATE,
                               betas=(0, 0.9))
# Use GPU if possible
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    classifier = classifier.cuda()
    classifier_criterion = classifier_criterion.cuda()

# Classifier pre-training

The classifier has to be trained beforehand, so it gives valuable feedback to the generator regarding the classes of samples that it generates.

In [None]:
# Number of classifier training epochs
CLASSIFIER_EPOCHS = 200

# Train classifier
classifier.train_(data_loader, optimizer_C, CLASSIFIER_EPOCHS)

# GAN training

In [None]:
# Number of epochs without improvement on discriminator loss after
# after which the GAN training will be terminated
GAN_PATIENCE = 200
# Gradient penalty
LAMBDA_GP = 10
# Number of GAN epochs
GAN_EPOCHS = 2000

# Initialize GAN
gan = WGAN(generator, discriminator, classifier, optimizer_G, optimizer_D,
           use_cuda=cuda, lambda_gp=LAMBDA_GP, patience=GAN_PATIENCE)
# Train GAN
gan.train(custom_data_loader, GAN_EPOCHS, bands_count,
          BATCH_SIZE, CLASSES_COUNT, os.path.join(RESULTS_DIR, "generator_model"))


# Generating samples

When the training is complete, the generator is used to synthesize new samples. Generation process is performed by the **`SamplesGenerator`** class, using the `generate` method. It accepts training set (in order to calculate number of samples in each class) and the pre-trained generator model. 

In [None]:

# Generate samples using trained Generator
generator = Generator(input_shape, CLASSES_COUNT)
generator_path = os.path.join(RESULTS_DIR, "generator_model")
generator.load_state_dict(torch.load(generator_path))
if cuda:
    generator = generator.cuda()
train_data.convert_to_numpy()

device = 'gpu' if cuda is True else 'cpu'
samples_generator = SamplesGenerator(device=device)
generated_x, generated_y = samples_generator.generate(train_data,
                                                      generator)
# Convert generated Tensors back to numpy
generated_x = np.reshape(generated_x.detach().cpu().numpy(),
                         generated_x.shape + (1, ))

# Add one dimension to convert row vectors to column vectors (keras 
# requirement) 
train_data.expand_dims(axis=-1)
test_data.expand_dims(axis=-1)
val_data.expand_dims(axis=-1)

# Add generated samples to original dataset
train_data.vstack(generated_x)
train_data.hstack(generated_y)


# Training and evaluation

In [None]:
# Number of epochs without improvement on validation set after which the 
# training will be terminated 
PATIENCE = 15 
# Number of kernels in the first convolutional layer
KERNELS = 200 
# Size of the kernel in the first convolutional layer
KERNEL_SIZE = 5 
# Number of  training epochs
EPOCHS = 200

# Keras Callbacks
early = EarlyStopping(patience=PATIENCE)
checkpoint = ModelCheckpoint(os.path.join(RESULTS_DIR, "GAN_augmentation") + 
                                          "_model",
                             save_best_only=True)
# Build 1d model
model = build_1d_model((test_data.shape[1:]), KERNELS,
                       KERNEL_SIZE, CLASSES_COUNT)
# Train model
history = model.fit(x=train_data.get_data(),
                    y=train_data.get_one_hot_labels(CLASSES_COUNT),
                    batch_size=BATCH_SIZE,
                    epochs=EPOCHS,
                    verbose=2,
                    callbacks=[early, checkpoint],
                    validation_data=(val_data.get_data(),
                                     val_data.get_one_hot_labels(CLASSES_COUNT)))
# Load best model
model = load_model(os.path.join(RESULTS_DIR, "GAN_augmentation") + "_model")

# Calculate test set score with GAN augmentation
test_score = model.evaluate(x=test_data.get_data(),
                            y=test_data.get_one_hot_labels(CLASSES_COUNT))
print("Test set score with GAN offline augmentation: {}".format(test_score[1]))
