<a href="https://colab.research.google.com/github/aim56009/Bias_GAN/blob/master/data/run_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports 

In [1]:
from google.colab import drive
drive.mount("/content/gdrive")
#%cd "/content/gdrive/MyDrive/data_gan"

Mounted at /content/gdrive


In [2]:
!git clone https://github.com/aim56009/Bias_GAN.git

Cloning into 'Bias_GAN'...
remote: Enumerating objects: 313, done.[K
remote: Counting objects: 100% (214/214), done.[K
remote: Compressing objects: 100% (87/87), done.[K
remote: Total 313 (delta 136), reused 196 (delta 124), pack-reused 99[K
Receiving objects: 100% (313/313), 79.67 MiB | 15.77 MiB/s, done.
Resolving deltas: 100% (168/168), done.


In [3]:
%%capture
!pip install pytorch_lightning

In [4]:
!pip install importlib-metadata==4.0.1
!pip install xarray==0.18.1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting importlib-metadata==4.0.1
  Downloading importlib_metadata-4.0.1-py3-none-any.whl (16 kB)
Installing collected packages: importlib-metadata
  Attempting uninstall: importlib-metadata
    Found existing installation: importlib-metadata 5.2.0
    Uninstalling importlib-metadata-5.2.0:
      Successfully uninstalled importlib-metadata-5.2.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
markdown 3.4.1 requires importlib-metadata>=4.4; python_version < "3.10", but you have importlib-metadata 4.0.1 which is incompatible.
gym 0.25.2 requires importlib-metadata>=4.8.0; python_version < "3.10", but you have importlib-metadata 4.0.1 which is incompatible.[0m[31m
[0mSuccessfully installed importlib-metadata-4.0.1
Looking in indexes: https://pypi.org/simp

In [5]:
from argparse import ArgumentParser
import warnings
warnings.filterwarnings('ignore')
from dataclasses import dataclass, field
from typing import List
import getpass

#from Bias_GAN.code.src.trainer import train_cycle_gan

# Training

In [6]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import json

from Bias_GAN.code.src.model import CycleGAN
from Bias_GAN.code.src.data import DataModule
from Bias_GAN.code.src.utils import get_version, set_environment, get_checkpoint_path, save_config
from Bias_GAN.code.src.callbacks import get_cycle_gan_callbacks

In [7]:
def train_cycle_gan(config, pretrain_path=False):
    """ Main routing to train the Cycle GAN """
    
    config = Config()
    version = get_version()
    print(f'Running model: {version}')
    checkpoint_path = get_checkpoint_path(config, version)
    
    set_environment()
    print("checkpoint_path before:",checkpoint_path) 
    tb_logger = TensorBoardLogger(config.tensorboard_path,
                           name=config.model_name,
                           default_hp_metric=False,
                           version = checkpoint_path)
                           #version = version

    trainer = pl.Trainer(gpus = 1,
                         max_epochs = config.epochs,
                         precision = 16, 
                         #progress_bar_refresh_rate = config.progress_bar_refresh_rate,
                         callbacks = get_cycle_gan_callbacks(checkpoint_path),
                         num_sanity_val_steps = 1,
                         logger = tb_logger,
                         log_every_n_steps = config.log_every_n_steps,
                         deterministic = False,
                         accelerator=accelerator) 

    datamodule = DataModule(config, training_batch_size = config.train_batch_size,
                                    test_batch_size = config.test_batch_size)

    datamodule.setup("fit")

    
    
    if pretrain_path==False:
      print("no pretraining")
      model = CycleGAN(epoch_decay = config.epochs // 2,running_bias=config.running_bias)
    else:
      print("using pretrained model with path:",pretrain_path)
      model = CycleGAN(epoch_decay = config.epochs // 2,running_bias=config.running_bias).load_from_checkpoint(pretrain_path)

    trainer.fit(model, datamodule)

    save_config(config, version)
    print('Training finished')
    return model

# Config

In [8]:
@dataclass
class Config:
    """ 
    Training configuration parameters. For model evaluation parameters see
    src/configuration.py.
    """
    
    scratch_path: str = '/content/gdrive/MyDrive/bias_gan/results'
    tensorboard_path: str = f'{scratch_path}/'
    checkpoint_path: str = f'{scratch_path}/'
    config_path: str = f'{scratch_path}/'
    poem_path: str = f"/content/gdrive/MyDrive/bias_gan/data_gan/pr_gfdl-esm4_historical_regionbox_1979-2014.nc"
    era5_path: str = f"/content/gdrive/MyDrive/bias_gan/data_gan/pr_W5E5v2.0_regionbox_era5_1979-2014.nc"
   

    results_path: str = f'{scratch_path}/'
    projection_path: str = None

    train_start: int = 1979
    train_end: int = 1980 # set to 2000 for full run
    valid_start: int = 2004 #was 2001
    valid_end: int = 2004
    test_start: int = 2004
    test_end: int = 2014
    
    model_name: str = 'tibet_gan'

    epochs: int = 2 # set to 250 for reproduction
    progress_bar_refresh_rate: int = 1
    train_batch_size: int = 1
    test_batch_size: int = 64
    transforms: List = field(default_factory=lambda: ['log', 'normalize_minus1_to_plus1'])
    rescale: bool = False
    epsilon: float = 0.0001
    lazy: bool = False
    log_every_n_steps: int = 2 ### was 10
    norm_output: bool = True
    running_bias: bool = False


def main():
    _ = train_cycle_gan(Config())

# Run

In [11]:
accelerator="gpu"

#train_cycle_gan(Config())

train_cycle_gan(Config(),"/content/gdrive/MyDrive/bias_gan/results/2023_01_05_12h_30m_56s/last.ckpt")

INFO:lightning_lite.utilities.seed:Global seed set to 42
  rank_zero_deprecation(
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit native Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True


Running model: 2023_01_05_12h_36m_05s
checkpoint_path before: /content/gdrive/MyDrive/bias_gan/results//2023_01_05_12h_36m_05s


INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


using pretrained model with path: /content/gdrive/MyDrive/bias_gan/results/2023_01_05_12h_30m_56s/last.ckpt


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type          | Params
----------------------------------------
0 | d_A   | Discriminator | 2.8 M 
1 | d_B   | Discriminator | 2.8 M 
2 | g_A2B | Generator     | 449 K 
3 | g_B2A | Generator     | 449 K 
----------------------------------------
6.4 M     Trainable params
0         Non-trainable params
6.4 M     Total params
12.849    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.


Training finished


CycleGAN(
  (d_A): Discriminator(
    (net): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (4): LeakyReLU(negative_slope=0.2, inplace=True)
      (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (7): LeakyReLU(negative_slope=0.2, inplace=True)
      (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
      (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (10): LeakyReLU(negative_slope=0.2, inplace=True)
      (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    )
  )
  (d_B): Discriminator(
    (ne