<a href="https://colab.research.google.com/github/aim56009/Bias_GAN/blob/master/code/run_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports 

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")
#%cd "/content/gdrive/MyDrive/data_gan"

In [None]:
!git clone https://github.com/aim56009/Bias_GAN.git

In [None]:
%%capture
!pip install pytorch_lightning

In [None]:
!pip install importlib-metadata==4.0.1
!pip install xarray==0.18.1

In [None]:
from argparse import ArgumentParser
import warnings
warnings.filterwarnings('ignore')
from dataclasses import dataclass, field
from typing import List
import getpass

#from Bias_GAN.code.src.trainer import train_cycle_gan

# Training

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import json

from Bias_GAN.code.src.model import CycleGAN
from Bias_GAN.code.src.data import DataModule
from Bias_GAN.code.src.utils import get_version, set_environment, get_checkpoint_path, save_config
from Bias_GAN.code.src.callbacks import get_cycle_gan_callbacks

In [None]:
def train_cycle_gan(config):
    """ Main routing to train the Cycle GAN """
    
    config = Config()
    version = get_version()
    print(f'Running model: {version}')
    checkpoint_path = get_checkpoint_path(config, version)
    set_environment()
    print("checkpoint_path before:",checkpoint_path) 
    tb_logger = TensorBoardLogger(config.tensorboard_path,
                           name=config.model_name,
                           default_hp_metric=False,
                           version = checkpoint_path)
                           #version = version

    trainer = pl.Trainer(gpus = 1,
                         max_epochs = config.epochs,
                         precision = 16, 
                         #progress_bar_refresh_rate = config.progress_bar_refresh_rate,
                         callbacks = get_cycle_gan_callbacks(checkpoint_path),
                         num_sanity_val_steps = 1,
                         logger = tb_logger,
                         log_every_n_steps = config.log_every_n_steps,
                         deterministic = False) 

    datamodule = DataModule(config, training_batch_size = config.train_batch_size,
                                    test_batch_size = config.test_batch_size)

    datamodule.setup("fit")

    model = CycleGAN(epoch_decay = config.epochs // 2,
                     running_bias=config.running_bias)
    
    #model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")

    trainer.fit(model, datamodule)

    print('Training finished')
    return model

In [None]:
"""
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

from Bias_GAN.code.src.model import CycleGAN
from Bias_GAN.code.src.data import DataModule
from Bias_GAN.code.src.utils import get_version, set_environment, get_checkpoint_path, save_config
from Bias_GAN.code.src.callbacks import get_cycle_gan_callbacks

import json


config = Config()
version = get_version()
print(f'Running model: {version}')
#print(json.dumps(config.__dict__, indent=4))
checkpoint_path = get_checkpoint_path(config, version)
print("checkpoint_path before:",checkpoint_path)



tb_logger = TensorBoardLogger(config.tensorboard_path,
                           name=config.model_name,
                           default_hp_metric=False,
                           version = version)

trainer = pl.Trainer(gpus = 1,
                         max_epochs = config.epochs,
                         precision = 16, 
                         callbacks = get_cycle_gan_callbacks(checkpoint_path),
                         num_sanity_val_steps = 1,
                         logger = tb_logger,
                         log_every_n_steps = config.log_every_n_steps,
                         deterministic = False)
"""


# Config

In [None]:
@dataclass
class Config:
    """ 
    Training configuration parameters. For model evaluation parameters see
    src/configuration.py.
    """
    
    scratch_path: str = '/content/gdrive/MyDrive/bias_gan/results'
    tensorboard_path: str = f'{scratch_path}/'
    checkpoint_path: str = f'{scratch_path}/'
    config_path: str = f'{scratch_path}/'
    poem_path: str = f"/content/gdrive/MyDrive/bias_gan/data_gan/pr_gfdl-esm4_historical_regionbox_1979-2014.nc"
    era5_path: str = f"/content/gdrive/MyDrive/bias_gan/data_gan/pr_W5E5v2.0_regionbox_era5_1979-2014.nc"
   

    results_path: str = f'{scratch_path}/'
    projection_path: str = None

    train_start: int = 1979
    train_end: int = 1980 # set to 2000 for full run
    valid_start: int = 2004 #was 2001
    valid_end: int = 2004
    test_start: int = 2004
    test_end: int = 2014
    
    model_name: str = 'tibet_gan'

    epochs: int = 2 # set to 250 for reproduction
    progress_bar_refresh_rate: int = 1
    train_batch_size: int = 1
    test_batch_size: int = 64
    transforms: List = field(default_factory=lambda: ['log', 'normalize_minus1_to_plus1'])
    rescale: bool = False
    epsilon: float = 0.0001
    lazy: bool = False
    log_every_n_steps: int = 2 ### was 10
    norm_output: bool = True
    running_bias: bool = False


def main():
    _ = train_cycle_gan(Config())

# Run

In [None]:
main()

In [None]:
from datetime import datetime
time = datetime.now().time().strftime("%Hh_%Mm_%Ss")
date = datetime.now().date().strftime("%Y_%m_%d")
f'{Config.config_path}config_model_{date}_{time}.json'

In [None]:
f'{Config.config_path}config_model_{version}.json'

In [None]:
def save_config(config, version):
    import json
    uuid_legth = 36
    fname = f'{config.config_path}config_model_{version[len(version)-uuid_legth:]}.json'
    #fname = f'config_model_{version[len(version)-uuid_legth:]}.json'
    with open(fname, 'w') as file:
        file.write(json.dumps(vars(config))) 

In [None]:
uuid_legth = 36
f'{config.config_path}config_model_{version[len(version)-uuid_legth:]}.json'

In [None]:
def get_checkpoint_path(config, version):

    model_name = config.model_name    
    checkpoint_path = config.checkpoint_path
    uuid_legth = 36
    date_legth = 10

    checkpoint_path = f'{Config.checkpoint_path}/{"tibet_gan"}/{version[len(version)-uuid_legth:][:-1]}'

    Path(path).mkdir(parents=True, exist_ok=True)

    return path

In [None]:
uuid_legth = 36
date_legth = 10
print(checkpoint_path)
#f'{checkpoint_path[:-1]}/{"tibet_gan"}/{version[:date_legth]}/{version[len(version)-uuid_legth:]}'
f'{Config.checkpoint_path}/{"tibet_gan"}/{version[len(version)-uuid_legth:][:-1]}'

In [None]:
Config.checkpoint_path