#import the code

In [None]:
! pip install cartopy  netCDF4 xarray timezonefinder numpy pandas matplotlib scipy

Collecting cartopy
  Downloading cartopy-0.25.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting netCDF4
  Downloading netCDF4-1.7.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting timezonefinder
  Downloading timezonefinder-8.1.0-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting cftime (from netCDF4)
  Downloading cftime-1.6.4.post1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Collecting h3>4 (from timezonefinder)
  Downloading h3-4.3.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (18 kB)
Downloading cartopy-0.25.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (11.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.8/11.8 MB[0m [31m132.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading netCDF4-1.7.2-cp312-cp312-manylinux_2_

In [None]:
!git clone https://github.com/MELAI-1/downscaling-cgan.git


Cloning into 'downscaling-cgan'...
remote: Enumerating objects: 3236, done.[K
remote: Counting objects: 100% (409/409), done.[K
remote: Compressing objects: 100% (105/105), done.[K
remote: Total 3236 (delta 351), reused 309 (delta 304), pack-reused 2827 (from 3)[K
Receiving objects: 100% (3236/3236), 83.79 MiB | 19.67 MiB/s, done.
Resolving deltas: 100% (2302/2302), done.


In [None]:
#make sure we are in the good path
import os
import sys
os.chdir('/content/downscaling-cgan')
sys.path.insert(0, '/content/downscaling-cgan')
sys.path.insert(0, '/content')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Import functions and set TF_USE_LEGACY_KERAS = 1 (necessary only for tensorflow version >= 2.16.0)

In [None]:

# Only needed for tensorflow versions greater than 2.16.0
os.environ["TF_USE_LEGACY_KERAS"] = "1"


import joblib
import numpy as np

# from data import write_data, gen_fcst_norm

from dsrnngan.utils.read_config import get_data_paths
from dsrnngan.data.tfrecords_generator import write_data

## We retrieve the data paths set in [**config/data_paths.yaml**](https://github.com/snath-xoc/cGAN_tutorial/blob/main/config/data_paths.yaml) and check that they are correctly set. We also set the year that is used to generate normalisation constants

In [None]:
data_paths = get_data_paths()
CONSTANTS_PATH = data_paths["GENERAL"]["CONSTANTS"]
FCSTNorm_year = "2018"


print("Sanity check that these data paths are correctly set:\n")
print("FORECAST_PATH:",data_paths["GENERAL"]["NGCM"])
print("CONSTANTS_PATH:",data_paths["GENERAL"]["CONSTANTS"])
print("TRUTH_PATH:",data_paths["GENERAL"]["IMERG"])
print("TFRecords path:",data_paths["TFRecords"]["tfrecords_path"])
print("STATS path",data_paths["GENERAL"]["STATS"])


Sanity check that these data paths are correctly set:

FORECAST_PATH: /home/melvin_aims_ac_za/data/NGCM/
CONSTANTS_PATH: /home/melvin_aims_ac_za/data/constants/
TRUTH_PATH: /home/melvin_aims_ac_za/data/IMERG
TFRecords path: /home/melvin_aims_ac_za/data/tfrecords
STATS path /home/melvin_aims_ac_za/data/constants/neuralgcm_Horn_Africa_2018_stats.pkl


In [None]:
data_paths

{'GENERAL': {'IMERG': '/home/melvin_aims_ac_za/data/IMERG',
  'ERA5': '/bp1/geog-tropical/data/ERA-5/day',
  'IFS': '/bp1/geog-tropical/users/uz22147/east_africa_data/IFS',
  'OROGRAPHY': '/home/melvin_aims_ac_za/data/constants/elev.nc',
  'LSM': '/home/melvin_aims_ac_za/data/constants/lsm.nc',
  'LAKES': '/bp1/geog-tropical/users/uz22147/east_africa_data/constants/lake_mask.nc',
  'SEA': '/bp1/geog-tropical/users/uz22147/east_africa_data/constants/sea_mask.nc',
  'CONSTANTS': '/home/melvin_aims_ac_za/data/constants/',
  'NGCM': '/home/melvin_aims_ac_za/data/NGCM/',
  'STATS': '/home/melvin_aims_ac_za/data/constants/neuralgcm_Horn_Africa_2018_stats.pkl'},
 'NGCM': {'evaporation': [],
  'precipitation_cumulative_mean': [],
  'specific_cloud_ice_water_content_500': [],
  'specific_cloud_ice_water_content_700': [],
  'specific_cloud_ice_water_content_850': [],
  'u_component_of_wind_500': [],
  'u_component_of_wind_700': [],
  'u_component_of_wind_850': [],
  'v_component_of_wind_500': []

# Forecast normalisation constants
### Typically for AI approaches, training and optimisation as well as inference works best when the values are normalised in a manner that retains the original distirbution. We therefore calculate a set of forecast normalisation constants for the region with which to normalise the variables with. These normalisation constants include the mean, std, min and max.

### We normalise accordingly:
#### a) Precipitation variables of total precipitation (tp) and convective precipitation (cp) follow a log normalisation (i.e. log10(1+x) ).
#### b) 2m temperature (t2m) and surface pressure (sp) are reasonable normally distributed so we use a simple standard normalisation of (x-mean)/std.
#### c) Other variables are bounded to be non-negative (e.g., cape, tclw, tciw, tcrw, tcwv,tcw) and divided through by their maximum: x/max.
#### d) Wind (both u and v components) can be negative and are therefore divided by their absolute maximum: x/maxima(abs(min),abs(max)).
#### e) It is worth noting that Shortwave Solar Radiation (ssr) is an accumulated variables and needs to be converted from per second to per hour (i.e. x/3600).
#### f) Medium Cloud Cover (mcc) does not require any normalisation (as it is a fractional value bounded at 0-1).

In [None]:
##Mean, STD, min, max
if not os.path.exists(data_paths['GENERAL']['STATS']):
    gen_fcst_norm(year=FCSTNorm_year)
    fcstNorm = joblib.load(data_paths['GENERAL']['STATS'])
else:
    fcstNorm = joblib.load(data_paths['GENERAL']['STATS'])

print(fcstNorm)


# Data Generator
### As mentioned in the data module directory, a key part to the data load-in is the ```DataGenerator``` class from the [**data/data_generator.py**](https://github.com/snath-xoc/cGAN_tutorial/blob/main/data/data_generator.py) file.

### The ```DataGenerator``` calls the ```load_fcst_truth_batch``` function from [**data/data.py**](https://github.com/snath-xoc/cGAN_tutorial/blob/main/data/data.py#L157) that for a given date will:<br>1) Load in all desired forecast variables (i.e., using ```load_fcst_stack```)<br>2) Load in the truth variables (i.e., ```load_truth_and_mask```)<br>

### Once all file paths specified in [**config/data_paths.yaml**](https://github.com/snath-xoc/cGAN_tutorial/blob/main/config/data_paths.yaml) and [**config/local_config.yaml**](https://github.com/snath-xoc/cGAN_tutorial/blob/main/config/local_config.yaml) are set properly in the config directory we can initialise the data generator and visualise that everything looks alright

### Initialisation of the data generator:

#### First import the DataGenerator alongside the forecast fields (```all_fcst_fields```) from the data module. We also check which dates are available for the example year 2018 and fcst time horizon of 30 hour to 54 hour ahead using the ```get_dates``` function

In [None]:
##Example load-in
from dsrnngan.data.data_generator import DataGenerator
from dsrnngan.data.data import all_ngcm_fields, get_dates

print("Looking into getting dates for the forecast fields:\n", all_ngcm_fields,"\n")

year = 2018
start_hour = 30
end_hour = 54
dates = get_dates(year, obs_data_source='imerg', fcst_data_source='ngcm',data_paths=data_paths)

print(f"Available dates for the year {year} and forecast lead times of {start_hour} to {end_hour} are:", dates)

Looking into getting dates for the forecast fields:
 ['evaporation', 'precipitation_cumulative_mean', 'specific_cloud_ice_water_content_500', 'specific_cloud_ice_water_content_700', 'specific_cloud_ice_water_content_850', 'u_component_of_wind_500', 'u_component_of_wind_700', 'u_component_of_wind_850', 'v_component_of_wind_500', 'v_component_of_wind_700', 'v_component_of_wind_850'] 

/home/melvin_aims_ac_za/data/NGCM/evaporation/2018/evaporation_2018_ngcm_evaporation_2.8deg_6h_GHA_20180101_00h.nc
/home/melvin_aims_ac_za/data/NGCM/precipitation_cumulative_mean/2018/precipitation_cumulative_mean_2018_ngcm_precipitation_cumulative_mean_2.8deg_6h_GHA_20180101_00h.nc
/home/melvin_aims_ac_za/data/NGCM/specific_cloud_ice_water_content_500/2018/specific_cloud_ice_water_content_500_2018_ngcm_specific_cloud_ice_water_content_500_2.8deg_6h_GHA_20180101_00h.nc
/home/melvin_aims_ac_za/data/NGCM/specific_cloud_ice_water_content_700/2018/specific_cloud_ice_water_content_700_2018_ngcm_specific_cloud_ic

### Now we instantiate the DataGenerator, the main arguments are:

#### dates: list, positional argument that is a list of dates to load in<br> fcst_fields: list, which forecast variables to load in (we use all the forecast fields as printed above)<br> start_hour: int, the first lead time to load in (we usually use 30 hours)<br> end_hour: int, the last lead time up to which to load in (usually 54 hours)<br> batch_size: int<br> shuffle: Boolean<br> constants: Boolean, whether to use land-sea mask and elevation constants<br> fcst_norm: Boolean, whether to normalise forecast variables

In [None]:

from dsrnngan.data.data import load_fcst_radar_batch, load_hires_constants, all_fcst_hours, DATA_PATHS, all_ifs_fields, all_era5_fields, input_fields
##🚩import ngcm function
from dsrnngan.data.data import load_ngcm, all_ngcm_fields,  get_ngcm_stats, all_ngcm_fields
from dsrnngan.utils.read_config import read_model_config, get_data_paths, get_lat_lon_range_from_config,read_data_config


In [None]:
read_data_config()

change the

In [None]:
DATA_PATHS

{'GENERAL': {'IMERG': '/home/melvin_aims_ac_za/data/IMERG',
  'ERA5': '/bp1/geog-tropical/data/ERA-5/day',
  'IFS': '/bp1/geog-tropical/users/uz22147/east_africa_data/IFS',
  'OROGRAPHY': '/home/melvin_aims_ac_za/data/constants/elev.nc',
  'LSM': '/home/melvin_aims_ac_za/data/constants/lsm.nc',
  'LAKES': '/bp1/geog-tropical/users/uz22147/east_africa_data/constants/lake_mask.nc',
  'SEA': '/bp1/geog-tropical/users/uz22147/east_africa_data/constants/sea_mask.nc',
  'CONSTANTS': '/home/melvin_aims_ac_za/data/constants/',
  'NGCM': '/home/melvin_aims_ac_za/data/NGCM/',
  'STATS': '/home/melvin_aims_ac_za/data/constants/neuralgcm_Horn_Africa_2018_stats.pkl'},
 'NGCM': {'evaporation': [],
  'precipitation_cumulative_mean': [],
  'specific_cloud_ice_water_content_500': [],
  'specific_cloud_ice_water_content_700': [],
  'specific_cloud_ice_water_content_850': [],
  'u_component_of_wind_500': [],
  'u_component_of_wind_700': [],
  'u_component_of_wind_850': [],
  'v_component_of_wind_500': []

In [None]:
data_config=read_data_config()

In [None]:
data_config

namespace(data_paths='BLUE_PEBBLE',
          fcst_data_source='ngcm',
          obs_data_source='imerg',
          normalisation_year=2018,
          num_samples=320000,
          normalise_inputs=True,
          output_normalisation='log',
          num_classes=4,
          downscaling_factor=2,
          min_latitude=-18.14,
          max_latitude=29,
          latitude_step_size=2.8,
          min_longitude=16,
          max_longitude=60,
          longitude_step_size=2.8,
          input_fields=['evaporation',
                        'precipitation_cumulative_mean',
                        'specific_cloud_ice_water_content_500',
                        'specific_cloud_ice_water_content_700',
                        'specific_cloud_ice_water_content_850',
                        'u_component_of_wind_500',
                        'u_component_of_wind_700',
                        'u_component_of_wind_850',
                        'v_component_of_wind_500',
                        'v

In [None]:
get_data_paths(data_config=data_config)

{'GENERAL': {'IMERG': '/home/melvin_aims_ac_za/data/IMERG',
  'ERA5': '/bp1/geog-tropical/data/ERA-5/day',
  'IFS': '/bp1/geog-tropical/users/uz22147/east_africa_data/IFS',
  'OROGRAPHY': '/home/melvin_aims_ac_za/data/constants/elev.nc',
  'LSM': '/home/melvin_aims_ac_za/data/constants/lsm.nc',
  'LAKES': '/bp1/geog-tropical/users/uz22147/east_africa_data/constants/lake_mask.nc',
  'SEA': '/bp1/geog-tropical/users/uz22147/east_africa_data/constants/sea_mask.nc',
  'CONSTANTS': '/home/melvin_aims_ac_za/data/constants/',
  'NGCM': '/home/melvin_aims_ac_za/data/NGCM/',
  'STATS': '/home/melvin_aims_ac_za/data/constants/neuralgcm_Horn_Africa_2018_stats.pkl'},
 'NGCM': {'evaporation': [],
  'precipitation_cumulative_mean': [],
  'specific_cloud_ice_water_content_500': [],
  'specific_cloud_ice_water_content_700': [],
  'specific_cloud_ice_water_content_850': [],
  'u_component_of_wind_500': [],
  'u_component_of_wind_700': [],
  'u_component_of_wind_850': [],
  'v_component_of_wind_500': []

In [None]:
DataGenerator

In [None]:
test_idx_till = 20

dgc = DataGenerator(dates=['20180101'],hour=0,batch_size=1,data_config=data_config, shuffle=False)

### Next we draw a sample by calling the ```__getitem__``` function of the ```DataGenerator```. The sample itself should be a tuple with:
#### 1) A dictionary of:<br>&nbsp;&nbsp; a) lo_res_inputs: the forecast inputs going into the GAN, this is an array of size (```batch_size, lat, lon, n_variables```). For our domain ```lat=384``` and ```lon=352```. Our forecast list printed above has 14 variables, for each variable we take the ensemble mean and standard deviation at the lead time of interest, and the lead time of interest + 6. This means that we have ```n_variables=14*2*2=56```.<br>&nbsp;&nbsp; b) hi_res_inputs: constant inputs of elevation and land-sea mask, this is an array of size (```batch_size, lat, lon, 2```)
#### 2)  A dictionary of:<br>&nbsp;&nbsp; a) output: truth data, an array of (```batch_size, lat, lon```).<br>&nbsp;&nbsp; b) mask: mask containing invalid points in the truth data, an array of (```batch_size, lat, lon```)

In [None]:
sample = dgc.__getitem__(0)

keys = [k for k in sample[0].keys()]
shapes = [sample[0][k].shape for k in sample[0].keys()]
print("Sample consisting of a tuple has been retrieved. Keys for the first component are:", keys)
print("The corresponding shapes are:")
for key, shape in zip(keys,shapes):
    print(key,":",shape)
keys = [k for k in sample[1].keys()]
shapes = [sample[1][k].shape for k in sample[1].keys()]

print("\n")
print("Keys for the second component are:", keys)
print("The corresponding shapes are:")
for key, shape in zip(keys,shapes):
    print(key,":",shape)


* Loaded data_x_batch shape: (1, 384,
352, 11),
* data_y_batch shape: (1, 384, 352)
* Data x batch fields: (1, 384, 352, 2)

keys:

* 'lo_res_inputs',
* 'hi_res_inputs',
* 'dates',
* 'hours'



shapes:


*   (1, 384, 352, 11),
* (1, 384, 352, 2)
* (1,)
* (1,)




* lo_res_inputs : (1, 384, 352, 11)
* hi_res_inputs : (1, 384, 352, 2)
* dates : (1,)
* hours : (1,)

### Next we visualise the loaded in values, an easy first one to check are the constant fields of elevation and land-sea mask

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np

ax = plt.axes(projection = ccrs.PlateCarree())

lats = np.arange(-13.65,24.65+0.1,0.1)
lons = np.arange(19.15,54.25+0.1,0.1)

mesh = ax.pcolormesh(lons, lats, sample[0]['hi_res_inputs'][0,:,:,0], cmap='terrain_r')
plt.colorbar(mesh)
plt.title('elevation')
plt.show()
plt.close()

ax = plt.axes(projection = ccrs.PlateCarree())

mesh = ax.pcolormesh(lons, lats, sample[0]['hi_res_inputs'][0,:,:,1], cmap='Blues')
plt.colorbar(mesh)
plt.title('land-sea mask')

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons, lats, sample[0]['hi_res_inputs'][0,:,:,0], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)
# ax1.set_title('Elevation')

# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, sample[0]['hi_res_inputs'][0,:,:,1], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)
# ax2.set_title('Land-sea mask')

plt.tight_layout()
plt.savefig('combined_maps.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,0], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)


# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,1], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)


plt.tight_layout()
plt.savefig('combined_maps_sample_1.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,2], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)


# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,3], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)


plt.tight_layout()
plt.savefig('combined_maps_sample_2.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,4], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)


# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,5], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)


plt.tight_layout()
plt.savefig('combined_maps_sample_3.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,6], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)


# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,7], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)


plt.tight_layout()
plt.savefig('combined_maps_sample_4.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,8], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)


# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,9], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)


plt.tight_layout()
plt.savefig('combined_maps_sample_5.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,10], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)


# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, sample[0]['lo_res_inputs'][0,:,:,11], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)


plt.tight_layout()
plt.savefig('combined_maps_sample_6.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
from dsrnngan.data.data  import write_data

years = [2018]

for year in years:
    write_data(year)


In [None]:
import tensorflow as tf
from dsrnngan.data.tfrecords_generator import _parse_batch

tfrecords_path = data_paths["TFRecords"]["tfrecords_path"]

dataset = tf.data.TFRecordDataset(os.path.join(tfrecords_path,'final_tfrecord/train_12.3.0.tfrecords'))

In [None]:
dataset = dataset.map(lambda x: _parse_batch(x, insize = (384,352,11), consize=(384,352,2), outsize=(384,352,1)))
test = dataset.repeat().batch(2)
for inputs,outputs in test.take(1).as_numpy_iterator():

    print(inputs['lo_res_inputs'].shape)

    break

In [None]:
import matplotlib.pyplot as plt
from dsrnngan.data.data import all_ngcm_fields
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np

In [None]:
var = "evaporation"
evap = all_ngcm_fields.index(var)

var = "precipitation_cumulative_mean"
precip = all_ngcm_fields.index(var)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons,lats, inputs['lo_res_inputs'][0,:,:,evap], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)
ax1.set_title('evaporation')

# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, inputs['lo_res_inputs'][0,:,:,precip], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)
ax2.set_title('precipitation')

plt.tight_layout()
plt.savefig('verifications.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons, inputs['lo_res_inputs'][0,:,:,evap], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)
ax1.set_title('evaporation')

# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, inputs['lo_res_inputs'][0,:,:,precip], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)
ax2.set_title('precipitation')

plt.tight_layout()
plt.savefig('verifications.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                               subplot_kw={'projection': ccrs.PlateCarree()})

# Elevation
mesh1 = ax1.pcolormesh(lons, inputs['lo_res_inputs'][0,:,:,evap], cmap='terrain_r')
plt.colorbar(mesh1, ax=ax1)
ax1.set_title('evaporation')

# Land-sea mask
mesh2 = ax2.pcolormesh(lons, lats, inputs['lo_res_inputs'][0,:,:,precip], cmap='Blues')
plt.colorbar(mesh2, ax=ax2)
ax2.set_title('precipitation')

plt.tight_layout()
plt.savefig('evaporation_and_precipitation.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np

# Prépare les données
dataset = dataset.map(lambda x: _parse_batch(x, insize=(384, 352, 11), consize=(384, 352, 2), outsize=(384, 352, 1)))
test = dataset.repeat().batch(2)

for inputs, outputs in test.take(1).as_numpy_iterator():
    print(f"Shape: {inputs['lo_res_inputs'].shape}")

    # Trouve les indices des variables
    var = "evaporation"
    evap_idx = all_.index(var) * 4  # Ajuste selon ta structure

    var2 = "precipitation"
    precip_idx = all_fcst_fields.index(var2) * 4  # Ajuste selon ta structure

    # Crée les coordonnées lat/lon
    n_lat, n_lon = 384, 352
    lats = np.linspace(-90, 90, n_lat)  # Ajuste selon ta région
    lons = np.linspace(-180, 180, n_lon)  # Ajuste selon ta région

    # Crée les meshgrids
    lons_grid, lats_grid = np.meshgrid(lons, lats)

    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                                   subplot_kw={'projection': ccrs.PlateCarree()})

    # Évaporation
    mesh1 = ax1.pcolormesh(lons_grid, lats_grid,
                           inputs['lo_res_inputs'][0, :, :, evap_idx],
                           cmap='terrain_r', transform=ccrs.PlateCarree())
    plt.colorbar(mesh1, ax=ax1, label='Évaporation')
    ax1.coastlines()
    ax1.gridlines(draw_labels=True)
    ax1.set_title('Évaporation')

    # Précipitation
    mesh2 = ax2.pcolormesh(lons_grid, lats_grid,
                           inputs['lo_res_inputs'][0, :, :, precip_idx],
                           cmap='Blues', transform=ccrs.PlateCarree())
    plt.colorbar(mesh2, ax=ax2, label='Précipitation')
    ax2.coastlines()
    ax2.gridlines(draw_labels=True)
    ax2.set_title('Précipitation')

    plt.tight_layout()
    plt.savefig('evaporation_and_precipitation.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    break

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np

# Prépare les données
dataset = dataset.map(lambda x: _parse_batch(x, insize=(384, 352, 11), consize=(384, 352, 2), outsize=(384, 352, 1)))
test = dataset.repeat().batch(2)

for inputs, outputs in test.take(1).as_numpy_iterator():
    print(f"Shape: {inputs['lo_res_inputs'].shape}")

    # Trouve les indices des variables
    var = "evaporation"
    evap_idx = all_.index(var) * 4  # Ajuste selon ta structure

    var2 = "precipitation"
    precip_idx = all_fcst_fields.index(var2) * 4  # Ajuste selon ta structure

    # Crée les coordonnées lat/lon
    n_lat, n_lon = 384, 352
    lats = np.linspace(-90, 90, n_lat)  # Ajuste selon ta région
    lons = np.linspace(-180, 180, n_lon)  # Ajuste selon ta région

    # Crée les meshgrids
    lons_grid, lats_grid = np.meshgrid(lons, lats)

    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6),
                                   subplot_kw={'projection': ccrs.PlateCarree()})

    # Évaporation
    mesh1 = ax1.pcolormesh(lons_grid, lats_grid,
                           inputs['lo_res_inputs'][0, :, :, evap_idx],
                           cmap='terrain_r', transform=ccrs.PlateCarree())
    plt.colorbar(mesh1, ax=ax1, label='Évaporation')
    ax1.coastlines()
    ax1.gridlines(draw_labels=True)
    ax1.set_title('Évaporation')

    # Précipitation
    mesh2 = ax2.pcolormesh(lons_grid, lats_grid,
                           inputs['lo_res_inputs'][0, :, :, precip_idx],
                           cmap='Blues', transform=ccrs.PlateCarree())
    plt.colorbar(mesh2, ax=ax2, label='Précipitation')
    ax2.coastlines()
    ax2.gridlines(draw_labels=True)
    ax2.set_title('Précipitation')

    plt.tight_layout()
    plt.savefig('verification.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    break

In [None]:
plt.imshow(inputs['lo_res_inputs'][0,:,:,idx],cmap='Blues')