---
---
## Training and Evaluating Generative Adversarial Network on Generated MCMC Samples
---
---
This notebook contains examples of the following:
1. Train a Wasserstein GAN on MCMC Sampling results
2. Use the train model to generate samples from the learned distribution
3. Evalute the trained models on several distribution metrics.

#### Enviornment Setup

You may refer to reqirement.txt for specific package versions.

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pickle
import wandb

from process_data import Africa_Whole_Flat,MinMaxScaler
from generative_model import WGAN_SIMPLE

---
---
## Part 1. Model Training
---
---

#### Model inputs

**Note: This step is quite labor intensive, it is recommended to train on a GPU and not on Colab, skip to Part 2 if you only wish to generate samples from pretrained model.**

Lets define some inputs for the model training:
- dataroot - the path to the root of the dataset
- savepath - save path for trianed model check points
- workers - the number of worker threads for loading the data with the DataLoader
- batch_size - the batch size used in training. 
- num_epochs - number of training epochs to run. Training for longer will probably lead to better results but will also take much longer
- lr - learning rate for training.
- beta1 - beta1 hyperparameter for Adam optimizers.


In [5]:
dataroot = './data/Rayleigh_P30_downsampled_flat_extended.csv'
savepath = "./model"
workers = 1
batch_size = 128
num_epochs = 200
lr = 0.0002
beta1 = 0.5
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

#### Load the data

In [7]:
data = np.genfromtxt('./data/Rayleigh_P30_downsampled_flat_extended.csv', delimiter=',', skip_header=True)

#### Initialize Model and Start Training

In [None]:
model = WGAN_SIMPLE(ndim=data.shape[1],device=device)

model.optimize(data,output_path=savepath,use_wandb=False,batch_size=batch_size,
               epochs=num_epochs,lr=lr,beta1=beta1,device=device)

---
---
## Part 2. Load Trained Model and Produce Samples
---
---

In [2]:
# load the saved model file
checkpoint = torch.load("output/R10P/model/model_epoch0_EMD0.182109.pth")
model = WGAN_SIMPLE(ndim=checkpoint["ndim"])
model.load(checkpoint)

In [3]:
fake_data = model.generate()

---
---
## Visualization of the Generated Distribution
---
---