# 1. GAIN trainining using Raytune
* [Raytune](https://ray.io/)
* [Raytune documentation](https://docs.ray.io/en/latest/tune/index.html)
* [Raytune PyTorch turorial](https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html)
## 1.1. Import all needed packages

In [2]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from fireman_imputation.src import utils
from fireman_imputation.gain_training import gain_train
from fireman_imputation.src import gain_net

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [None]:
config = {
    "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 8, 16])
}

In [4]:
gen = gain_net.Generator(input_dim, input_dim)
disc = gain_net.Discriminator(input_dim, input_dim)

# if there is a checkpoint load the state, otherwise initialize the weights and optimizers
if checkpoint_dir:
    gen_state, gen_opt_state, 
    disc_state, disc_opt_state = torch.load(os.path.join(checkpoint_dir, "checkpoint"))
    gen.load_state_dict(gen_model_state)
    disc.load_state_dict(disc_model_state)
    gen_opt.load_state_dict(gen_optimizer_state)
    disc_opt.load_state_dict(disc_optimizer_state)
else:
    gen.apply(utils.init_weights)
    disc.apply(utils.init_weights)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    disc_opt = torch.optim.Adam(disc.parameters(), lr=learning_rate)
    
if torch.cuda.is_available():
    device = 'cuda:0'
    gen.to(device)
    disc.to(device)
    # not tested https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html # noqa
    # if torch.cuda.device_count() > 1:
    #     gen = DistributedDataParallel(gen)
    #     disc = DistributedDataParallel(disc)
else:
    device = 'cpu'
    gen.to(device)
    disc.to(device)
    
with tune.checkpoint_dir(epoch) as checkpoint_dir:
    path = os.path.join(checkpoint_dir, "checkpoint")
    torch.save((gen.state_dict(), gen_opt.state_dict(), 
                disc.state_dict(), disc_opt.state_dict()), path)

tune.report(loss=(val_loss / val_steps), accuracy=correct / total)


In [None]:
# load and scale the data
data_orig = pd.read_csv('data/spam.csv',index_col=False)
data = data_orig.values

# create missing data
data_missing, mask = utils.mcar_gen(data, 0.5)
# scale the data
scaler = MinMaxScaler(feature_range=(0, 1))
scaler.fit(data_missing)
data_missing = scaler.transform(data_missing)

# divide the data to train/test
# by default shuffles data, if pandas is passed the index shows shuffle result
data_missing_train, data_missing_test, data_train, data_test = train_test_split(data_missing, data, train_size=0.9)

# set hyper-parameters
gain_params = {'batch_size': 100,
               'hint_rate': 0.9,
               'alpha': 100,
               'epochs': 10,
               'learning_rate': 0.001}

# train the net
gen, disc = gain_train(gain_params, data_missing_train, cont=False)

# transform test data to tensor and forward it through generator
data_missing_test_torch, mask_test_torch = utils.gain_data_prep(data_missing_test)
data_imputed_test = gen(data_missing_test_torch, mask_test_torch)
data_imputed_test = data_imputed_test.detach().numpy()

# merge the imputed data(zero out rest in imputed data) and data with missing values
inv_mask_test = 1 - mask_test_torch.numpy()
# data_missing_test contains nan values
data_missing_test_0 = data_missing_test.copy()
data_missing_test_0[np.isnan(data_missing_test_0)] = 0
data_imputed_test = inv_mask_test*data_imputed_test + data_missing_test_0

# rescale the imputed data
data_imputed_test = scaler.inverse_transform(data_imputed_test)
data_missing_test = scaler.inverse_transform(data_missing_test)

# compute error
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html
RMSE = mean_squared_error(data_test, data_imputed_test, squared=False)
print('RMSE of test dataset is {}'.format(RMSE))