In [None]:
import climate_learn as cl
import numpy as np

### ERA5 dataset
ERA5 is a reanalysis dataset maintained by the European Center for Medium-Range Weather Forecasting (ECMWF). In its raw format, ERA5 contains hourly data from 1979 to the current time on a grid with cells of width and height 0.25 degree of the Earth, with different climate variables at 37 different pressure levels plus the planet’s surface. This corresponds to nearly 400,000 data samples, each a matrix of shape 721*1440. Since this is too big for most deep learning models, ClimateLearn supports downloading a smaller, pre-processed version of ERA5 data from WeatherBench.




#### downloaded 5.626 degree: 
##### 13-pressure-level
- <span style="color:red">geopotential</span> (all levels: 50,  100,  150,  200,  250,  300,  400, <span style="color:red">500</span> ,  600, 700, 850,  925, 1000 hPa)
- <span style="color:red">temperature</span> (all levels: 50,  100,  150,  200,  250,  300,  400,  500,  600, 700,  500, <span style="color:red">850</span>,  925, 1000 hPa)
- relative_humidity (all levels: 50,  100,  150,  200,  250,  300,  400,  500,  600, 700,  850,  925, 1000 hPa)
- specific_humidity (all levels: 50,  100,  150,  200,  250,  300,  400,  500,  600, 700,  850,  925, 1000 hPa)
- u_component_of_wind (all levels: 50,  100,  150,  200,  250,  300,  400,  500,  600, 700,  850,  925, 1000 hPa)
- v_component_of_wind (all levels: 50,  100,  150,  200,  250,  300,  400,  500,  600, 700,  850,  925, 1000 hPa)

##### single-surface-level
- <span style="color:red">2m_temperature</span> 
- <span style="color:red">10m_u_component_of_wind</span> 
- <span style="color:red">10m_v_component_of_wind</span> 
- total_precipitation
- total_cloud_cover
- toa_incident_solar_radiation
  
We mark the data to download in <span style="color:red">red</span>.

In [None]:
root_directory = "data"  
variable = "2m_temperature"  # temperature_850, geopotential_500, 10m_u_component_of_wind, 10m_v_component_of_wind
year = 2018

cl.data.download_weatherbench(
    dst=f"{root_directory}/{variable}",
    dataset="era5",
    year = year,
    variable="2m_temperature",
    resolution=5.625      
)

### Create train / val / test

In [None]:
def select_merge_data(var_list, year_start, year_end, data_folder, resolution, lat, long):
    directory_paths = var_list
    concat_years = []
    counts = 0
    years = []
    
    for year in range(year_start, year_end+1):
        years.append(str(year))

    for year in years:
        print('>>>', year, '<<<')
        for directory_path in directory_paths:
            # Open the NetCDF file using xarray
            ds = xr.open_dataset(data_folder + '/' + directory_path + '/' + directory_path + '_' + year + '_' + str(resolution) + 'deg.nc')
    
            # Select every 6th sample
            ds = ds.isel(time=slice(None, None, 6))
        
            # =========== pressure-level =============  
            if directory_path == 'geopotential':
                geopotential = ds['z'].values
                geopotential = geopotential.reshape((-1, 1, lat, long))
                print('geopotential:', geopotential.shape)
                
            if directory_path == 'temperature':
                temperature = ds['t'].values
                temperature = temperature.reshape((-1, 1, lat, long))
                print('temperature:', temperature.shape)
        
            # ======================= surface variable ======================
            if directory_path == '2m_temperature':
                t2m_temperature = ds['t2m'].values
                t2m_temperature = t2m_temperature.reshape((-1, 1, lat, long))
                print('2m_temperature:', t2m_temperature.shape)
        
            if directory_path == '10m_u_component_of_wind':
                u10m = ds['u10'].values
                u10m = u10m.reshape((-1, 1, lat, long))
                print('10m_u_component_of_wind:', u10m.shape)
        
            if directory_path == '10m_v_component_of_wind': 
                v10m = ds['v10'].values
                v10m = v10m.reshape((-1, 1,lat, long))
                print('10m_v_component_of_wind:', v10m.shape)
        
        # concatenate one year
        concat_one_year = np.concatenate([geopotential, temperature, t2m_temperature, u10m, v10m], axis=1)        
        print("concat_one_year.shape:", concat_one_year.shape)
    
        concat_years.append(concat_one_year)
    
        counts += concat_one_year.shape[0]

    concat_years = np.concatenate(concat_years, axis=0)
    
    print("concat_years.shape:", concat_years.shape)
    
    print("total time points:", counts)

    print(">>> saving data <<<") 
    np.save(data_folder + '/concat_' + str(year_start) + '_' + str(year_end) + '_' + str(resolution) + '_' + str(concat_years.shape[1]) + 'var.npy', concat_years)
    

    print(">>> saved data <<<")

In [None]:
var_list = ['geopotential', 'temperature', '2m_temperature', '10m_u_component_of_wind', '10m_v_component_of_wind']

year_start, year_end = 2017, 2018

resolution = 5.625 

data_folder = 'data'

lat, long = 32, 64

In [None]:
select_merge_data(var_list, year_start, year_end, resolution_folder, resolution, lat, long)