# WGAN experiments
This notebook contains the code to run a single WGAN experiment.

## Set up
To use the custom modules defined in `src`, we first make sure that the working directory is the root folder.

In [None]:
# First make sure we start in the root folder
import os

root_folder = 'final-project'
while not os.getcwd().endswith(root_folder):
    os.chdir('../')

## Configuration
Specify the dataset, which portion of the data is used, how much unlabeled data is used and the name that will be used to save the model in `results/raw/wgan`. We can then load the configuration.

In [None]:
# Change parameters here
dataset = 'mams'
fraction = 0.1
unlabeled_ratio = 1.0
epb = 3
name = f'TEST_RUN_wgan_{dataset}_fr{fraction}_ur{unlabeled_ratio}_epb{epb}'

# Load the configuration
from src.experiments import get_config

config = get_config()
data_config = config[dataset]
model_config = config['wgan']

## Load the data
Load the training data from `datasets`.

In [None]:
from src.data import Preprocessor

# Load training, validation and testing data
preprocessor = Preprocessor()
(trainX, trainY), _ = preprocessor.parse_train(
    dataset,
    data_config['train'],
    validation_split=1 - fraction,
    unlabeled_ratio=unlabeled_ratio,
    unlabeled_data=data_config['unsupervised']
)
val_data = preprocessor.parse_test(dataset, data_config['val'])
test_data = preprocessor.parse_test(dataset, data_config['test'])

## Fit the model
Load, compile and train the model using the data that has been loaded in the previous step.

The model is trained via the regular `model.fit`, with the following additional callbacks to monitor performance and save the model:
* `WGANCallback` trains the generator by training it in a WGAN. Without this callback, the generator would not be updated!
* `EvaluateCallback` evaluates the model on an additional dataset (the test dataset).
* `ModelCheckpoint(..., save_weights_only=True, ...)` triggers the `SavableModel.save_weights()` to save the weights of the model. (`BaselineModel` is a subclass of `Savablemodel`)
* `CSVLogger` logs the history to a CSV file. This includes the results from the `EvaluateCallback`.

In [None]:
from src.components import ComplexGenerator
from src.callbacks import WGANCallback, EvaluateCallback
from src.models import BaseGAN, WGAN
# noinspection PyUnresolvedReferences
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint

# Load baseline model
model = BaseGAN(data_config['classes']) # For simple generator
model = BaseGAN(data_config['classes'], generator=ComplexGenerator()) # For complex generator

model.compile(optimizer=config['optimizer'], loss=config['loss'], metrics=data_config['metrics'])

wgan = WGAN(critic_steps=5, generator=model.generator)
wgan.compile(model_config['c_optimizer'], model_config['g_optimizer'])

model.fit(trainX, trainY, batch_size=config['batch_size'], epochs=config['epochs'], validation_data=val_data, callbacks=[
    WGANCallback(wgan, trainX, batch_size=config['batch_size'], epochs_per_batch=epb),
    EvaluateCallback(test_data),
    ModelCheckpoint(
        os.path.join(config['result_path'], 'checkpoints', name),
        save_best_only=True,
        save_weights_only=True,
        monitor='val_macro_f1',
        mode='max'),
    CSVLogger(os.path.join(config['result_path'], f'{name}.csv'))
])

## Quick methods
Alternatively, the following methods are available to quickly perform an experiment.

Please note: The second method generates name automatically (`wgan_{dataset}_fr{fraction}_ur{unlabeled_ratio}_epb{epb}`)

In [None]:
from src.experiments import wgan_experiment, wgan_experiments

wgan_experiment(dataset, fraction, unlabeled_ratio, epb, name)

# To perform multiple experiments:
wgan_experiments([dataset], [fraction], [unlabeled_ratio], [epb], range(3))

# You can also use generator=ComplexGenerator() in the above methods