<a href="https://colab.research.google.com/github/aim56009/Bias_GAN/blob/master/code/temperature_run_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports 

In [1]:
"""
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)
"""

"\ngpu_info = !nvidia-smi\ngpu_info = '\n'.join(gpu_info)\nif gpu_info.find('failed') >= 0:\n  print('Not connected to a GPU')\nelse:\n  print(gpu_info)\n"

In [2]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [3]:
!git clone https://github.com/aim56009/Bias_GAN.git

Cloning into 'Bias_GAN'...
remote: Enumerating objects: 842, done.[K
remote: Counting objects: 100% (111/111), done.[K
remote: Compressing objects: 100% (111/111), done.[K
remote: Total 842 (delta 54), reused 0 (delta 0), pack-reused 731[K
Receiving objects: 100% (842/842), 130.21 MiB | 16.76 MiB/s, done.
Resolving deltas: 100% (531/531), done.


In [4]:
%%capture
!pip install pytorch_lightning
from pytorch_lightning.loggers import TensorBoardLogger
!pip install basemap
!pip install importlib-metadata==4.0.1
!pip install xarray==0.18.1
!pip install torchvision

In [5]:
import os
import xarray as xr
import torch
import json
import glob
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import argparse
import pathlib
import cv2
import matplotlib


from tensorboard.backend.event_processing import event_accumulator
from pytorch_lightning.callbacks import Callback
from datetime import datetime
from io import BytesIO
from PIL import Image
from dataclasses import dataclass, field
from typing import List


#from Bias_GAN.code.src.model import CycleGAN, Generator, DataModule                     
from Bias_GAN.code.src.model import CycleGAN, Generator#, DataModule                     

#from Bias_GAN.code.src.data import TestData, CycleDataset
from Bias_GAN.code.src.utils import get_version, set_environment, get_checkpoint_path, save_config, log_transform, inv_norm_transform, inv_log_transform, inv_norm_minus1_to_plus1_transform, norm_minus1_to_plus1_transform 
from Bias_GAN.code.src.plots import PlotAnalysis, plot_basemap
from Bias_GAN.code.src.callbacks import get_cycle_gan_callbacks, MAE_Callback
from Bias_GAN.code.src.inference_tas import Inference, EvaluateCheckpoints, create_folder

# Data.py

In [6]:
class DataModule(pl.LightningDataModule):

    def __init__(self,
                 config,
                 training_batch_size: int = 4,
                 test_batch_size: int = 64):


        super().__init__()

        self.config = config
        self.training_batch_size = training_batch_size
        self.test_batch_size = test_batch_size

    def setup(self, stage: str = None):

        if stage == 'fit' or stage is None:
            self.train = CycleDataset('train', self.config)
            self.valid = CycleDataset('valid', self.config)

        if stage == 'test':
            self.test = CycleDataset('test', self.config)
            self.valid = CycleDataset('valid', self.config)

        if stage == 'predict':
            self.test = ProjectionDataset(self.config)


    def train_dataloader(self):
        return DataLoader(self.train,
                         batch_size=self.training_batch_size,
                         shuffle=True,
                         num_workers=0,
                         pin_memory=True)


    def val_dataloader  (self):
        return DataLoader(self.valid,
                          batch_size=self.test_batch_size,
                          shuffle=False,
                          num_workers=0,
                          pin_memory=True)


    def test_dataloader (self):
        return DataLoader(self.test,
                          batch_size=self.test_batch_size,
                          shuffle=False,
                          num_workers=0,
                          pin_memory=True)


def show_image(image):
    plt.imshow((image.squeeze()))


def get_random_sample(dataset):
    return dataset[np.random.randint(0, len(dataset))]


In [7]:
from dataclasses import dataclass
import cftime
from torch.utils.data import DataLoader


@dataclass
class TestData():
    
    era5: xr.DataArray
    gan: xr.DataArray
    climate_model: xr.DataArray = None
    cmip_model: xr.DataArray = None
    gan_constrained: xr.DataArray = None
    poem: xr.DataArray = None
    quantile_mapping: xr.DataArray = None
    uuid: str = None
    model = None


    def model_name_definition(self, key):
        dict = {
            'era5': 'ERA5',
            'gan': 'GAN (unconstrained)',
            'climate_model': 'Climate model',
            'cmip_model': 'GFDL-ESM4',
            'poem': 'CM2Mc-LPJmL',
            'gan_constrained': 'GAN',
            'quantile_mapping': 'Quantile mapping',
        }
        return dict[key]


    def colors(self, key):
        dict = {
            'era5': 'k',
            'gan': 'brown',
            'cmip_model': 'b',
            'climate_model': 'r',
            'gan_constrained': 'c',
            'quantile_mapping': 'm',
        }
        return dict[key]

        
    def convert_units(self):
        """ from mm/s to mm/d"""
        self.climate_model = self.climate_model#*3600*24
        self.era5 = self.era5#*3600*24
        self.gan = self.gan#*3600*24

    
    def crop_test_period(self):
        print('')
        print(f'Test set period: {self.gan.time[0].values} - {self.gan.time[-1].values}')
        self.climate_model = self.climate_model.sel(time=slice(self.gan.time[0], self.gan.time[-1]))
        self.era5 = self.era5.sel(time=slice(self.gan.time[0], self.gan.time[-1]))

        
    def show_mean(self):
        print('')
        print(f'Mean [mm/d]:')
        print(f'ERA5: {self.era5.mean().values:2.3f}')
        print(f'Climate Model: {self.climate_model.mean().values:2.3f}')
        print(f'GAN:  {self.gan.mean().values:2.3f}')




class CycleDataset(torch.utils.data.Dataset):
    
    def __init__(self, stage, config, epsilon=0.0001):
        """ 
            stage: train, valid, test
        """
        self.transforms = config.transforms
        self.epsilon = epsilon
        self.config = config

        if config.lazy:
            self.cache = False
            self.chunks = {'time': 1}
        else:
            self.cache = True
            self.chunks = None

        self.splits = {
                "train": [str(config.train_start), str(config.train_end)],
                "valid": [str(config.valid_start), str(config.valid_end)],
                "test":  [str(config.test_start), str(config.test_end)],
        }

        self.stage = stage
        self.climate_model = self.load_climate_model_data()
        climate_model_reference = self.load_climate_model_data(is_reference=True)
        self.era5 = self.load_era5_data()
        era5_reference = self.load_era5_data(is_reference=True)
        self.num_samples = len(self.era5.time.values)
        self.era5 = self.apply_transforms(self.era5, era5_reference)
        self.climate_model = self.apply_transforms(self.climate_model, climate_model_reference)


    def load_climate_model_data(self, is_reference=False):
        """ Y-domain samples """

        climate_model = xr.open_dataset(self.config.poem_path,
                                        cache=self.cache, chunks=self.chunks)

        
        climate_model =  climate_model.tas

        if not self.config.lazy:
            climate_model = climate_model.load()

        if is_reference:
            climate_model = climate_model.sel(time=slice(self.splits['train'][0],
                                                         self.splits['train'][1]))
        else:
            climate_model = climate_model.sel(time=slice(self.splits[self.stage][0],
                                                         self.splits[self.stage][1]))

        return climate_model


    def load_era5_data(self, is_reference=False):
        """ X-domain samples """

        era5 = xr.open_dataset(self.config.era5_path,
                               cache=self.cache, chunks=self.chunks)\
                               .tas

        if not self.config.lazy:
            era5 = era5.load()

        if is_reference:
            era5 = era5.sel(time=slice(self.splits['train'][0],
                                       self.splits['train'][1]))
        else:
            era5 = era5.sel(time=slice(self.splits[self.stage][0],
                                 self.splits[self.stage][1]))

        return era5
        

    def apply_transforms(self, data, data_ref):

        if 'log' in self.transforms:
            data = log_transform(data, self.epsilon)
            data_ref = log_transform(data_ref, self.epsilon)

        if 'normalize' in self.transforms:
            data = norm_transform(data, data_ref)

        if 'normalize_minus1_to_plus1' in self.transforms:
            data = norm_minus1_to_plus1_transform(data, data_ref)
        
        return data


    def __getitem__(self, index):

        x = torch.from_numpy(self.era5.isel(time=index).values).float().unsqueeze(0)
        y = torch.from_numpy(self.climate_model.isel(time=index).values).float().unsqueeze(0)

        sample = {'A': x, 'B': y}
        
        return sample

    def __len__(self):
        return self.num_samples


# Main training loop

## define MAE callback

In [8]:
class MAE_Callback(Callback):
    def __init__(self,logger,checkpoint_path,config, validation=True, lat_mean=False, plt_hist=False):
        self.MAE_list = []
        self.logger = logger
        self.checkpoint_path = checkpoint_path
        self.config = config
        self.version = get_version(config.date,config.time)
        self.validation = validation
        self.lat_mean = lat_mean
        self.plt_hist = plt_hist
        

    def on_train_epoch_end(self, trainer, pl_module):
        checkpoint_files = glob.glob(str(self.checkpoint_path) + '/*.ckpt')
        if not checkpoint_files:
            test_data_ = None
        else:
            last_checkpoint = max(checkpoint_files, key=os.path.getctime)
            data = EvaluateCheckpoints(checkpoint_path=last_checkpoint, config_path=self.config.config_path + self.version + "/config_model.json", save_model=True,validation=self.validation, version=self.version)
            _, reconstruction_data = data.run()
            test_data_ = data.get_test_data()


        if test_data_ is None or not test_data_:
            print("No test data available.")
            return

        gan_data = getattr(test_data_, 'gan')
        era5_data = getattr(test_data_, "era5")
        
        bias = gan_data.mean('time') - era5_data.mean('time') 
        print("GAN-OBS",f" \t \t MAE: {abs(bias).values.mean():2.3f} [mm/d]")
        self.MAE_list.append(abs(bias).values.mean())
        print("MAE_list:",self.MAE_list)

        self.log('MAE', abs(bias).values.mean())

        if test_data_ is not None and self.lat_mean==True:
            data_era5 = era5_data.mean(dim=("lon", "time"))
            data_gan= gan_data.mean(dim=("lon", "time"))
            plt.figure()
            plt.plot(data_gan.lat, data_gan.data,
                      label="gan",
                      alpha=0.9,
                      linestyle='-',
                      linewidth=2,
                      color="red")
            
            plt.plot(data_era5.lat, data_era5,
                      label="era5",
                      alpha=1,
                      linestyle='--',
                      linewidth=2,
                      color="black")
            
            plt.ylim(0,3)
            plt.xlim(25,58)
            plt.xlabel('Latitude')
            plt.ylabel('Mean temperature [??]')
            plt.grid()
            plt.legend(loc='upper right')  
            #plt.show()
          
            buf = BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            im = Image.open(buf)
            img = torchvision.transforms.ToTensor()(im)
            
            self.logger.experiment.add_image(f"latitudinal_mean", img, trainer.current_epoch)

        if test_data_ is not None and self.plt_hist==True:
            data_gan = getattr(test_data_, "gan").values.flatten()
            data_era5 = getattr(test_data_, "era5").values.flatten()
            plt.figure()
            _ = plt.hist(data_gan,
                        bins=100,
                        histtype='step',
                        log=True,
                        label="gan",
                        alpha=0.9,
                        density=True,
                        linewidth=2,
                        color="red")
            
            _ = plt.hist(data_era5,
                        bins=100,
                        histtype='step',
                        log=True,
                        label="era5",
                        alpha=1,
                        density=True,
                        linewidth=2,
                        color="black")

            plt.xlabel('Temperature [??]')
            plt.ylabel('Histogram')
            plt.xlim(0,400)
            plt.grid()
            plt.legend(loc='upper right')

            #plt.show()
            buf = BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            im_ = Image.open(buf)
            img_ = torchvision.transforms.ToTensor()(im_)
            
            self.logger.experiment.add_image(f"histogram", img_, trainer.current_epoch)

## Train Cycle GAN

In [9]:
def train_cycle_gan(config, pretrain_path=False,validation=True,track_lat_mean=False,plt_hist=False ):
    """ Main routing to train the Cycle GAN """

    config = Config()
    global version
    version = get_version(config.date,config.time)
    print(f'Running model: {version}')
    checkpoint_path = get_checkpoint_path(config, version)
    set_environment()

    tb_logger = TensorBoardLogger(config.tensorboard_path,name="",version=version,default_hp_metric=False)
    
    create_folder(f"/content/gdrive/MyDrive/bias_gan/results/{version}")
    save_config(config, version)
    
    mse_callback = MAE_Callback(tb_logger,checkpoint_path,config,validation,lat_mean=track_lat_mean,plt_hist=plt_hist)
    
    
    trainer = pl.Trainer(callbacks=[mse_callback] + get_cycle_gan_callbacks(checkpoint_path),
                         gpus = 1,
                         max_epochs = config.epochs,
                         precision = 16, 
                         num_sanity_val_steps = 1,
                         logger = tb_logger,
                         log_every_n_steps = config.log_every_n_steps,
                         deterministic = False,
                         accelerator=accelerator,
                         enable_model_summary=False) 
    

    datamodule = DataModule(config, training_batch_size = config.train_batch_size, test_batch_size = config.test_batch_size)
    datamodule.setup("fit")
    
    
    if pretrain_path==False:
      print("no pretraining")
      model = CycleGAN(d_lr=config.d_lr, g_lr=config.g_lr, beta_1=config.beta_1, beta_2=config.beta_2,
                       epoch_decay = config.epochs // 2,running_bias=config.running_bias,num_resnet_blocks=config.num_resnet_layer, default_nbr_resnet=config.default_nbr_resnet)
    else:
      print("using pretrained model with path:",pretrain_path)
      model = CycleGAN(d_lr=config.d_lr, g_lr=config.g_lr, beta_1=config.beta_1, beta_2=config.beta_2,
                       epoch_decay = config.epochs // 2, running_bias=config.running_bias,num_resnet_blocks=config.num_resnet_layer, default_nbr_resnet=config.default_nbr_resnet).load_from_checkpoint(pretrain_path)

    trainer.fit(model, datamodule)

    print('Training finished')
    return model

# Config

In [10]:
load_pretrained_world_gan=False

In [13]:
@dataclass
class Config:
    """ 
    Training configuration parameters. For model evaluation parameters see
    src/configuration.py.
    """
    
    scratch_path: str = '/content/gdrive/MyDrive/bias_gan/results'
    tensorboard_path: str = f'{scratch_path}/'
    checkpoint_path: str = f'{scratch_path}/'
    config_path: str = f'{scratch_path}/'
    poem_path: str = f"/content/gdrive/MyDrive/bias_gan/data/detrend_pr_gfdl-esm4_historical_regionbox_1979-2014.nc"
    era5_path: str = f"/content/gdrive/MyDrive/bias_gan/data/detrend_pr_W5E5v2.0_regionbox_era5_1979-2014.nc"
   

    results_path: str = f'{scratch_path}/'
    projection_path: str = None

    train_start: int = 1979
    train_end: int = 1980 #2000 
    valid_start: int = 2001 #was 2001
    valid_end: int = 2004
    test_start: int = 2004
    test_end: int = 2014
    
    model_name: str = 'tibet_gan'

    epochs: int = 2 # set to 250 for reproduction
    progress_bar_refresh_rate: int = 50
    train_batch_size: int = 1
    test_batch_size: int = 64
    transforms: List = field(default_factory=lambda: ['log', 'normalize_minus1_to_plus1'])
    transformations = ['log', 'normalize_minus1_to_plus1']
    rescale: bool = False
    epsilon: float = 0.0001
    lazy: bool = False
    log_every_n_steps: int = 10 ### was 10
    norm_output: bool = True
    running_bias: bool = False

    d_lr = 2e-4
    g_lr = 2e-4
    beta_1 = 0.5
    beta_2 = 0.999
    epoch_decay = 200
    

    time = datetime.now().time().strftime("%Hh_%Mm_%Ss")
    date = datetime.now().date().strftime("%Y_%m_%d")

    if load_pretrained_world_gan==True:
      default_nbr_resnet=False
      num_resnet_layer=7
    else:
      default_nbr_resnet=True
      num_resnet_layer=6


def main():
    _ = train_cycle_gan(Config())

#Run

In [14]:
do_training = True
from_skratch = True

track_lat_mean = True
plt_hist=True

runtime_instance = "2023_02_10_10h_26m_48s"

if do_training == True:
    accelerator="gpu"

    if from_skratch == True:
        train_cycle_gan(Config(),validation=False,track_lat_mean=track_lat_mean,plt_hist=plt_hist)
        

    if from_skratch == False:
        train_cycle_gan(Config(),f"/content/gdrive/MyDrive/bias_gan/results/{runtime_instance}/last.ckpt",validation=True,track_lat_mean=track_lat_mean,plt_hist=plt_hist)

INFO:lightning_fabric.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit None Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Running model: 2023_02_15_10h_58m_44s


AttributeError: ignored

In [17]:
climate_model_

In [16]:
climate_model_ = xr.open_dataset(Config.poem_path)
climate_model_.tas.values.shape

AttributeError: ignored

In [15]:
climate_model = xr.open_dataset(Config.era5_path)
climate_model.tas.values.shape

AttributeError: ignored

# Tensorboard logging

In [None]:
%load_ext tensorboard

In [None]:
if do_training==True: 
    %tensorboard --logdir /content/gdrive/MyDrive/bias_gan/results/{version}/

In [None]:
if do_training==False: 
  %tensorboard --logdir /content/gdrive/MyDrive/bias_gan/results/{runtime_instance}/

## save images from tensorboard files to drive

In [None]:
save_images_for_gif = False


def save_tensorboard_images(event_file, outdir):
    event_acc = event_accumulator.EventAccumulator(event_file, size_guidance={'images': 0})
    event_acc.Reload()

    outdir = pathlib.Path(outdir)
    outdir.mkdir(exist_ok=True, parents=True)

    for tag in event_acc.Tags()['images']:
        events = event_acc.Images(tag)

        tag_name = tag.replace('/', '_')
        dirpath = outdir / tag_name
        dirpath.mkdir(exist_ok=True, parents=True)

        for index, event in enumerate(events):
            s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
            image = cv2.imdecode(s, cv2.IMREAD_COLOR)
            #outpath = dirpath / '{:04}.jpg'.format(index) 
            outpath = dirpath / '{:04}.jpg'.format(index+239)
            cv2.imwrite(outpath.as_posix(), image)



if save_images_for_gif == True:
    path_to_event_file = '/content/gdrive/MyDrive/bias_gan/results/2023_02_02_10h_51m_31s/events.out.tfevents.1675331519.gpu-001.2945388.0'
    outdir = "/content/gdrive/MyDrive/bias_gan/results/2023_02_02_10h_51m_31s"
    save_tensorboard_images(path_to_event_file, outdir)

In [None]:
!ls "/content/gdrive/MyDrive/bias_gan/results/latitudinal_mean"

## get MAE

In [None]:
Config_adjusted_trafo = Config
Config_adjusted_trafo.transforms = Config_adjusted_trafo.transformations
len_training_dataset = len(CycleDataset('train', Config_adjusted_trafo))
len_valid_dataset = len(CycleDataset('valid', Config_adjusted_trafo))
len_test_dataset = len(CycleDataset('test', Config_adjusted_trafo))

len_training_dataset, len_valid_dataset, len_test_dataset

In [None]:
combine_mae_training_fragments = False
if combine_mae_training_fragments:
    import tensorflow as tf
    from tensorflow.python.summary.summary_iterator import summary_iterator
    epochs_0 = []
    mae_values_0 = []

    for e in summary_iterator('/content/gdrive/MyDrive/bias_gan/results/2023_01_19_16h_26m_23s/events.out.tfevents.1674145623.c0c1f3e09513.1224.0'):
        for v in e.summary.value:
            if v.tag == 'MAE':
                epochs_0.append(e.step/len_training_dataset)
                mae_values_0.append(v.simple_value)

    """
    plt.plot(epochs_0, mae_values_0)
    plt.xlabel("epochs")
    plt.ylabel("MAE")
    plt.title("MAE VS EPOCHS validation")
    plt.show()
    """

    epochs_1 = []
    mae_values_1 = []

    for e in summary_iterator('/content/gdrive/MyDrive/bias_gan/results/2023_01_31_10h_24m_37s/events.out.tfevents.1675157093.dgx-002.1451939.4'):
        for v in e.summary.value:
            if v.tag == 'MAE':
                epochs_1.append(e.step/len_training_dataset)
                mae_values_1.append(v.simple_value)

    epochs_2 = []
    mae_values_2 = []

    for e in summary_iterator('/content/gdrive/MyDrive/bias_gan/results/2023_01_31_19h_24m_28s/events.out.tfevents.1675189486.dgx-002.1619062.0'):
        for v in e.summary.value:
            if v.tag == 'MAE':
                epochs_2.append(e.step/len_training_dataset)
                mae_values_2.append(v.simple_value)


    epochs_3 = []
    mae_values_3 = []

    for e in summary_iterator('/content/gdrive/MyDrive/bias_gan/results/2023_02_01_16h_44m_03s/events.out.tfevents.1675266262.dgx-002.1953755.0'):
        for v in e.summary.value:
            if v.tag == 'MAE':
                epochs_3.append(e.step/len_training_dataset)
                mae_values_3.append(v.simple_value)            

    epochs_4 = []
    mae_values_4 = []

    for e in summary_iterator('/content/gdrive/MyDrive/bias_gan/results/2023_02_02_08h_29m_00s/events.out.tfevents.1675322959.dgx-002.2208864.0'):
        for v in e.summary.value:
            if v.tag == 'MAE':
                epochs_4.append(e.step/len_training_dataset)
                mae_values_4.append(v.simple_value)

    epochs_5 = []
    mae_values_5 = []

    for e in summary_iterator('/content/gdrive/MyDrive/bias_gan/results/2023_02_02_10h_51m_31s/events.out.tfevents.1675331519.gpu-001.2945388.0'):
        for v in e.summary.value:
            if v.tag == 'MAE':
                epochs_5.append(e.step/len_training_dataset)
                mae_values_5.append(v.simple_value)


    epoch_total = epochs_0 + [69+i for i in epochs_1] + [69+34+i for i in epochs_2] +[69+34+65+i for i in epochs_3]+[69+34+65+59+i for i in epochs_4]+[69+34+65+59+7+i for i in epochs_5]
    mae_total = mae_values_0 + mae_values_1 + mae_values_2 + mae_values_3 + mae_values_4 +mae_values_5


    plt.figure(figsize=(10, 7))
    plt.plot(epoch_total, mae_total)
    plt.xlabel("epochs")
    plt.ylabel("MAE")
    plt.title("MAE VS EPOCHS validation")
    plt.show()

## Make gifs


In [None]:
import imageio
import os
from google.colab import files

def create_gif(images_folder, gif_name, duration=0.7):
    images = []
    filenames = sorted((images_folder).glob("*.jpg"))
    for filename in filenames:
        images.append(imageio.imread(filename))
    imageio.mimsave(gif_name, images, duration=duration)

In [None]:
create_gif = False

# create gif and save to the current directory
gif_name = '/content/gdrive/MyDrive/bias_gan/results/histogram.gif'
images_folder = pathlib.Path("/content/gdrive/MyDrive/bias_gan/results/histograms_combined")

if create_gif == True:
    create_gif(images_folder, gif_name)
    # show the gif in colab
    from IPython.display import Image
    with open(gif_name,'rb') as f:
        display(Image(data=f.read()))

In [None]:
# create gif and save to the current directory
gif_name = '/content/gdrive/MyDrive/bias_gan/results/latitudinal_mean.gif'
images_folder = pathlib.Path("/content/gdrive/MyDrive/bias_gan/results/latitudinal_mean")

if create_gif == True:
    create_gif(images_folder, gif_name)
    # show the gif in colab
    from IPython.display import Image
    with open(gif_name,'rb') as f:
        display(Image(data=f.read()))

# Evaluation

## Run Evaluation


In [None]:
if do_training==False: 
  version_ = runtime_instance
else:
  version_ = version


checkpoint_path = f"/content/gdrive/MyDrive/bias_gan/results/{version_}/last.ckpt" 
config_path = f"/content/gdrive/MyDrive/bias_gan/results/{version_}/config_model.json"

data = EvaluateCheckpoints(checkpoint_path=checkpoint_path, config_path=config_path, save_model=True, version=version_)

In [None]:
test_data, reconstruct_data = data.run()
test_data = data.get_test_data()

In [None]:
#average absolute error
avg_gan = np.round(np.sum(abs(test_data.era5.values - inv_transform(test_data.gan.values.squeeze())))/(4018*60*118),2)
print("average absolute differnce in tas values obs-gan:",avg_gan)

#average absolute error
avg_cm = np.round(np.sum(abs(test_data.era5.values - test_data.climate_model.values))/(4018*60*118),2)
print("average absolute differnce in tas values obs-cm:",avg_cm)

In [None]:
edata_era5 = test_data.era5.values.mean(axis=(0,2))

#data_gan= inv_transform(test_data.gan.values.squeeze()).mean(axis=(0,2))
data_gan = inv_transform(test_data.gan, climate_model_reference).squeeze().mean(axis=(0,2))

plt.figure()
plt.plot(test_data.gan.lat, data_gan,
          label="gan",
          alpha=0.9,
          linestyle='-',
          linewidth=2,
          color="red")


plt.plot(test_data.era5.lat, data_era5,
          label="era5",
          alpha=1,
          linestyle='--',
          linewidth=2,
          color="black")
plt.xlim(25,58)
plt.xlabel('Latitude')
plt.ylabel('Mean temperature [K]')
plt.grid()
plt.legend(loc='upper right')  
plt.show()

In [None]:
#data_gan= inv_transform(test_data.gan.values.squeeze()).mean(axis=(0,2))
data_gan = inv_transform(test_data.gan, climate_model_reference).squeeze().mean(axis=(0,2))

plt.figure()
plt.plot(test_data.gan.lat, data_gan,
          label="gan",
          alpha=0.9,
          linestyle='-',
          linewidth=2,
          color="red")


plt.xlim(25,58)
plt.xlabel('Latitude')
plt.ylabel('Mean temperature [K]')
plt.grid()
plt.legend(loc='upper right')  
plt.show()

In [None]:
def load_climate_model_reference_data():

        climate_model = xr.open_dataset(Config.poem_path)

        if 'poem_precipitation' in climate_model.variables:
            climate_model =  climate_model.poem_precipitation
        else:
            climate_model =  climate_model.precipitation

        if not Config.lazy:
            climate_model = climate_model.load()

        climate_model = climate_model.sel(time=slice(str(Config.train_start), str(Config.train_end)))

        return climate_model

In [None]:
def load_climate_model_data( is_reference=False):
        """ Y-domain samples """

        stage = "test"
        splits = {
                "train": [str(Config.train_start), str(Config.train_end)],
                "valid": [str(Config.valid_start), str(Config.valid_end)],
                "test":  [str(Config.test_start), str(Config.test_end)],
        }


        climate_model = xr.open_dataset(Config.poem_path,
                                        cache=True, chunks=None)

        if 'poem_precipitation' in climate_model.variables:
            climate_model =  climate_model.tas
        else:
            climate_model =  climate_model.tas

        if not Config.lazy:
            climate_model = climate_model.load()

        if is_reference:
            climate_model = climate_model.sel(time=slice(splits['train'][0],
                                                         splits['train'][1]))
        else:
            climate_model = climate_model.sel(time=slice(splits[stage][0],
                                                         splits[stage][1]))

        return climate_model

In [None]:
Config.transforms

In [None]:
def apply_transforms( data, data_ref):

        if 'log' in Config.transforms:
            data = log_transform(data, epsilon)
            data_ref = log_transform(data_ref, epsilon)

        if 'normalize' in Config.transforms:
            data = norm_transform(data, data_ref)

        if 'normalize_minus1_to_plus1' in Config.transforms:
            data = norm_minus1_to_plus1_transform(data, data_ref)
        
        return data

In [None]:
epsilon=0.0001

climate_model = load_climate_model_data()
climate_model_reference = load_climate_model_data(is_reference=True)

#era5 = load_era5_data()
#era5_reference = load_era5_data(is_reference=True)
#era5 = apply_transforms(era5, era5_reference)

climate_model_ = apply_transforms(climate_model, climate_model_reference)

In [None]:
climate_model,inv_transform(climate_model_,climate_model_reference)

In [None]:
climate_model.shape

In [None]:
4018*60*118

In [None]:
np.sum(np.round(climate_model,0) == np.round(inv_transform(climate_model_,climate_model_reference),0))-28447440

# Create reconstructions

In [None]:
Config_adjusted_trafo = Config
Config_adjusted_trafo.transforms = Config_adjusted_trafo.transformations
dataset = CycleDataset('train', Config_adjusted_trafo)

In [None]:
 nbr_reconstruction_examples = 2

## Define inverse transformation and define forward/backward models

In [None]:
class Generator(torch.nn.Module):
    def __init__(self, generator_model: torch.nn.Module, constrain=True):
        super(Generator, self).__init__()
        self.generator =  generator_model
        
    def forward(self, x):
        out = self.generator(x)
        return out

In [None]:
def inv_transform(data, reference=None):
        """ The output equals ERA5, therefore it needs to be
            constraind with respect to it
        """
        if reference is None:
            reference = xr.open_dataset(Config.era5_path).tas.sel(time=slice(str(Config.train_start), str(Config.train_end))).values

        if 'log' in Config.transformations:
            reference = log_transform(reference, Config.epsilon)

        if 'normalize' in Config.transformations:
            data = inv_norm_transform(data, reference)

        if 'normalize_minus1_to_plus1' in Config.transformations:
            data = inv_norm_minus1_to_plus1_transform(data, reference)

        if 'log' in Config.transformations:
            data = inv_log_transform(data, Config.epsilon)

        return data

ckpt_path = Config.checkpoint_path + f"{version_}" +"/last.ckpt"

model_fw = CycleGAN().load_from_checkpoint(checkpoint_path=ckpt_path)
model_fw.freeze()
model_fw = model_fw.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
model_fw = Generator(model_fw.g_B2A, constrain=False)


model_bw = CycleGAN().load_from_checkpoint(checkpoint_path=ckpt_path)
model_bw.freeze()
model_bw = model_bw.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
model_bw = Generator(model_bw.g_A2B, constrain=False)

## reconstruction starting with climate model

In [None]:
for i in range(nbr_reconstruction_examples):
    test_data_ = dataset[i]  

    obs = test_data_['B'].to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))   
    gan = model_fw(obs)
    rec = model_bw(gan)

    #print(np.array(obs.cpu()))
    #print( climate_model_reference)
    #data_obs = inv_transform(np.array(obs.cpu()),climate_model_reference).squeeze()
    data_obs = inv_transform(obs.squeeze().cpu())
    data_gan = inv_transform(gan.squeeze().cpu())
    data_rec = inv_transform(rec.squeeze().cpu())

    print("average predicted error in temperature:",np.round(torch.sum(abs(data_obs-data_gan).cpu())/(60*118),0),"degrees K")

    
    fig, ax = plt.subplots(1, 3, figsize=(20, 5))

    cs = ax[0].pcolormesh(data_obs.squeeze().cpu())
    norm = matplotlib.colors.Normalize(vmin=0, vmax=20)
    sm = plt.cm.ScalarMappable(norm=norm)
    sm.set_array([])

    fig.colorbar(cs, ax=ax[0], extend='max')
    ax[0].set_title("climate model data")

    cs = ax[1].pcolormesh(data_gan.squeeze().cpu() )#, cmap="Blues")
    fig.colorbar(cs, ax=ax[1], extend='max')
    ax[1].set_title("generated observation (gan)")

    cs = ax[2].pcolormesh(data_rec.squeeze().cpu() ) #, cmap="Blues")
    fig.colorbar(cs, ax=ax[2], extend='max')
    ax[2].set_title("reconstruction of climate model data")

    plt.show()

## reconstruction starting with observations

In [None]:
for i in range(nbr_reconstruction_examples):
    test_data_ = dataset[i]  

    model = test_data_['A'].to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))   
    gan = model_fw(model)
    rec = model_bw(gan)

    data_model = inv_transform(model.squeeze().cpu())#*3600*24
    data_gan = inv_transform(gan.squeeze().cpu())#*3600*24
    data_rec = inv_transform(rec.squeeze().cpu())#*3600*24

    fig, ax = plt.subplots(1, 3, figsize=(20, 5))

    cs = ax[0].pcolormesh(data_model.squeeze().cpu())#, cmap="Blues")
    fig.colorbar(cs, ax=ax[0], extend='max')
    ax[0].set_title("observation data")

    cs = ax[1].pcolormesh(data_gan.squeeze().cpu())#, cmap="Blues")
    fig.colorbar(cs, ax=ax[1], extend='max')
    ax[1].set_title("generated climate model data (gan)")

    cs = ax[2].pcolormesh(data_rec.squeeze().cpu())#, cmap="Blues")
    fig.colorbar(cs, ax=ax[2], extend='max')
    ax[2].set_title("reconstruction of observation data")

    plt.show()

# Plot  **frames**

## Plot single frames

set the chose_day parameter to plot the precipitation on a specific day

In [None]:
chose_day=10

PlotAnalysis(test_data).single_frames(time_index=chose_day)
PlotAnalysis(test_data).single_frames(projection="cyl",time_index=chose_day)

## plot of the average test_data for each data

In [None]:
PlotAnalysis(test_data).avg_frames(projection="cyl",scale_precip_by = 10)

## plot of the average **errors** between era5 & gan / climate_model

In [None]:
PlotAnalysis(test_data).avg_frames_abs_err(projection="cyl", scale_precip_by = 20)

**TODO**: plot spatial plot - mean Error - also show lands

# Plot **histogram** statistics
Precipitation rates averaged over time and longitudes and relative frequency histograms

## histogram no log

Here we plot the histogram over the daily precipitation values in the test dataset. 

In [None]:
fig, ax = plt.subplots(1,1,figsize=(4, 4),  constrained_layout=True)

PlotAnalysis(test_data).histograms(single_plot=False, ax=ax, show_legend=True, annotate=True,log=False,xlim_end=30)

## histogram log on **density**

Because it is hard to see anything because precipitations over 50 are very rare and thus the 3 plots are right above eachother, we apply the log to the probability desnity to better see the differences.

In [None]:
fig, ax = plt.subplots(1,1,figsize=(6, 6),  constrained_layout=True)

PlotAnalysis(test_data).histograms(single_plot=False, ax=ax, show_legend=True, annotate=True,log=True)

## plot histogram log density **differences**

days in the test_data set

In [None]:
len(getattr(test_data,"gan").time)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(6, 6),  constrained_layout=True)

PlotAnalysis(test_data).histogram_diff(single_plot=False, ax=ax, show_legend=True, annotate=True)

## plot log **precipitation**

Applying the **log** to the data itself instead of to the amount of points in the bins as in the plot before results in the density to be on one scale:

In [None]:
fig, ax = plt.subplots(1,1,figsize=(6, 6),  constrained_layout=True)
PlotAnalysis(test_data).log_histograms(single_plot=False, ax=ax, show_legend=True, annotate=True)

## plot histogram log precipitation differences

In [None]:
PlotAnalysis(test_data).log_histogram_diff(single_plot=False, ax=ax, show_legend=True, annotate=True)

# Plot **latitudinal** **mean**

In [None]:
PlotAnalysis(test_data).latitudinal_mean()

#try loading finished gan world

## new cyclegan model code

## load new model: 

In [None]:
#state_dict = torch.load("/content/gdrive/MyDrive/bias_gan/results/pretrained_gan_world/last.ckpt",map_location=torch.device('cpu'))
#CycleGAN(num_resnet_layer = 7).load_state_dict(state_dict, strict=False)

# SSIM comparison

In [None]:
from skimage.metrics import structural_similarity as ssim

# Open the .nc file
data_gan = xr.open_dataset(f'/content/gdrive/MyDrive/bias_gan/results/{version_}/gan.nc').gan_precipitation
data_era5 = xr.open_dataset(f"/content/gdrive/MyDrive/bias_gan/data_gan/pr_W5E5v2.0_regionbox_era5_1979-2014.nc").era5_precipitation #*3600*24 
data_model = xr.open_dataset(f"/content/gdrive/MyDrive/bias_gan/data_gan/pr_gfdl-esm4_historical_regionbox_1979-2014.nc").precipitation #*3600*24 

# Extract the data you want to calculate SSIM for
gan_values = data_gan.values
era5_values = data_era5.values
model_values = data_model.values

calculate the SSIM for the gan only for 4018 entries bc thats the size of the test dataset

SSIM for the climate model

In [None]:
# Calculate SSIM
model_score, model_diff = ssim(era5_values[-4018:,:,:], model_values[-4018:,:,:], full=True)
print("model score:", model_score)

SSIM for the GAN

In [None]:
gan_score, gan_diff = ssim(era5_values[-4018:,:,:], gan_values, full=True)
print("gan score:", gan_score)

In [None]:
gan_values.shape,era5_values.shape,model_values.shape

# Compare metrics

In [None]:
"""
instances = ["2023_01_11_13h_04m_08s","2023_01_12_05h_34m_48s","2023_01_12_07h_34m_09s","2023_01_13_07h_17m_53s", "2023_01_13_11h_06m_15s","2023_01_14_08h_45m_11s"]

for i in instances: 
    evaluation_instance = i
    checkpoint_path = f"/content/gdrive/MyDrive/bias_gan/results/{evaluation_instance}/last.ckpt" 
    config_path = f"/content/gdrive/MyDrive/bias_gan/results/{evaluation_instance}/config_model.json"
    data = EvaluateCheckpoints(checkpoint_path=checkpoint_path, config_path=config_path, save_model=True)
    data.run()
    test_data = data.get_test_data()
    print("")
    PlotAnalysis(test_data).avg_frames_abs_err(projection="cyl", scale_precip_by = 10)
    print("")
    PlotAnalysis(test_data).latitudinal_mean()
    print("")
"""

#Data: raw output
# check whats available in isimip


# do we use the right data? the already bias corrected data (also with downscaling)

# downscaled climate model

In [None]:
# 1. remove all trends and add again at the end !!!
# 2. 