# ERA5-SWVL1-SR : Downscale swvl1 from ERA5 resolution to ERA5-Land resolution

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import xarray as xr
from pathlib import Path
import matplotlib.pyplot as plt
import json
import torchvision
import torch

In [None]:
plt.style.use("seaborn")

In [None]:
# importation of libraries

from models import Discriminator,Generator
from train_model import SRGAN_training
from plots import image_look_evaluation
from datasets import SRDataset
from evaluate_model import evaluate_model,evaluate_model_robustness
from utils import create_model_folder_and_copy_json
from downscale_era5_swvl1 import super_resolve_swvl1_world,super_resolve_swvl1_local_patch
from rasters_manipulation import load_and_clean_raster

## Parameters

In [None]:
TRAIN_MODEL = False
PLOT_METRICS = False
CHECK_DATASET = False
RESUME_TRAINING = False
PLOT_SR_IMAGES = False

In [None]:
DISPLAY_DEMO = True

In [None]:
model_path, configuration_path = create_model_folder_and_copy_json(model_path_parent= Path("models"),
                                resume_training=RESUME_TRAINING,
                                train_model=TRAIN_MODEL,
                                model_folder_path=Path("models/final_state/"),
                                )

In [None]:
with open(configuration_path, 'r') as file:
    configuration = json.load(file)

In [None]:
BATCH_SIZE = configuration["dataloader"]["batch_size"]
NUM_WORKERS = configuration["dataloader"]["num_workers"]

In [None]:
low_res_image_dim = configuration["dataset"]["low_res_image_dim"]
high_res_image_dim = configuration["dataset"]["high_res_image_dim"]

In [None]:
alpha = configuration["training"]["alpha"]
generator_learning_rate = configuration["training"]["generator_learning_rate"]
discriminator_learning_rate = configuration["training"]["discriminator_learning_rate"]
number_of_epochs = configuration["training"]["number_of_epochs"]
pre_training = (configuration["training"]["pre_training"])=="True"
pre_train_number_of_epochs = configuration["training"]["pre_train_number_of_epochs"]

##  Datasets, Preprocessing and Data loading

In [None]:
train_lr_data_path = Path("bucket_tensor/train/era5")
train_hr_data_path = Path("bucket_tensor/train/era5land")

test_lr_data_path = Path("bucket_tensor/test/era5")
test_hr_data_path = Path("bucket_tensor/test/era5land")

In [None]:
if TRAIN_MODEL:
    train_dataset = SRDataset(lr_data_path=train_lr_data_path,hr_data_path=train_hr_data_path,
                            low_res_image_dim=low_res_image_dim,high_res_image_dim=high_res_image_dim)
    train_dataloader = torch.utils.data.dataloader.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        drop_last=True,
        pin_memory=True)

In [None]:
if TRAIN_MODEL:
    eval_dataset = SRDataset(lr_data_path=test_lr_data_path,hr_data_path=test_hr_data_path,
                            low_res_image_dim=low_res_image_dim,high_res_image_dim=high_res_image_dim)
    eval_dataloader = torch.utils.data.dataloader.DataLoader(
        eval_dataset,    
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        drop_last=True,
        pin_memory=True)

In [None]:
if CHECK_DATASET:
    print(f" * Dataset contains {len(train_dataset)} image(s).")
    for _, batch in enumerate(train_dataloader, 0):
        lr_image, hr_image = batch
        # lr_image=lr_image[0, ...].mul(255).byte()
        # hr_image=hr_image[0, ...].mul(255).byte()
        print(lr_image.shape)
        print(hr_image.shape)
        #print(lr_image[0, ...].mul(255).byte().shape)   
        torchvision.io.write_png(lr_image[0, ...].repeat(3,1,1).mul(255).byte(), "lr_image.png")
        torchvision.io.write_png(hr_image[0, ...].repeat(3,1,1).mul(255).byte(), "hr_image.png")
        break # we deliberately break after one batch as this is just a test

## Train model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
discriminator=Discriminator(low_res_size=low_res_image_dim)
generator=Generator()

In [None]:
if TRAIN_MODEL:
    SRGAN_training(generator,discriminator,train_dataloader,model_path,
                   alpha_=alpha,G_learning_rate=generator_learning_rate,
                    D_learning_rate=discriminator_learning_rate,device=device,
                    number_of_epochs=number_of_epochs,resume_training=RESUME_TRAINING,
                    pre_training=pre_training,pre_train_number_of_epochs=pre_train_number_of_epochs)

## Evaluate model

In [None]:
generator=Generator()
checkpoint = torch.load(model_path)
generator.load_state_dict(checkpoint['generator_state_dict'])

In [None]:
if TRAIN_MODEL and PLOT_METRICS:
    metrics_df = evaluate_model(generator,eval_dataloader,device)
    metrics_df.plot(backend="plotly").show()
    display(metrics_df.mean())

In [None]:
if TRAIN_MODEL and PLOT_METRICS:
    n_metrics_df =  evaluate_model_robustness(generator,eval_dataloader, device,noise_power=0.05)
    n_metrics_df.plot(backend="plotly").show()
    display(n_metrics_df.mean())

In [None]:
if TRAIN_MODEL and PLOT_SR_IMAGES:
    image_look_evaluation(eval_dataloader,generator,nb_samples=50,device="cuda",seed=890)

## Test model on new ERA5 raster

Data with only one timestamp are currently supported

In [None]:
era5_raster = load_and_clean_raster("data/era5_31-12-2022.nc",tolerance=1e-6)
if DISPLAY_DEMO:
    display(era5_raster)

In [None]:
era5land_raster = xr.open_dataset("data/era5land_31-12-2022.nc")
if DISPLAY_DEMO:
    display(era5land_raster)

In [None]:
era5land_model_raster = super_resolve_swvl1_world(era5_raster,"2022-12-31",generator,device,BATCH_SIZE,NUM_WORKERS)
if DISPLAY_DEMO:
    display(era5land_model_raster)

In [None]:
model_era5land_array = super_resolve_swvl1_local_patch(generator,era5_raster,latitude=43.78,longitude=10.69,
                                                       era5land_raster=era5land_raster,device=device,verbose=DISPLAY_DEMO)

In [None]:
if DISPLAY_DEMO:
    era5_raster.swvl1.plot()

In [None]:
if DISPLAY_DEMO:
    era5land_raster.swvl1.plot()

In [None]:
if DISPLAY_DEMO:
    era5land_model_raster.swvl1.plot()

# End of notebook