In [1]:
import xarray as xr
import numpy as np
import torch 
from typing import List, Tuple
from enum import Enum
from torch.utils.data import Dataset, DataLoader
import scipy

In [3]:
dataset_path = "/home/egauillard/data/PR_era5_MED_1degr_19400101_20240229_new.nc"

In [9]:

ds = xr.open_dataset(dataset_path, chunks=None)

In [5]:
print(f"xarray version: {xr.__version__}")

print(f"numpy version: {np.__version__}")
print(f"scispy version: {scipy.__version__}")


xarray version: 2024.5.0
numpy version: 1.26.4
scispy version: 1.13.1


In [10]:

result = ds.groupby('time.week').mean(dim='time')
result



AttributeError: 'ScipyArrayWrapper' object has no attribute 'oindex'

In [7]:
result = ds.groupby('time.week').apply(lambda x: x.mean(dim='time'))



AttributeError: 'ScipyArrayWrapper' object has no attribute 'oindex'

In [None]:
ds["tp"].sel(time="1940-01-01").mean(dim = "time")

In [None]:
sea_mask = xr.open_dataset("/scistor/ivm/shn051/extreme_events_forecasting/primary_code/data/ERA5_land_sea_mask_1deg.nc")

In [None]:
ds["tp"]

KeyError: "No variable named 'tp'. Variables on the dataset include ['time', 'time_bnds', 'longitude', 'latitude', 't2m']"

In [None]:
sea = sea_mask["lsm"].values
ds['tp'].where(sea_mask == 0).mean(dim=['latitude', 'longitude'])
print(ds['tp'].where(sea_mask == 1).mean(dim=['latitude', 'longitude']))

KeyError: "No variable named 'tp'. Variables on the dataset include ['time', 'time_bnds', 'longitude', 'latitude', 't2m']"

In [None]:
ds.sek

In [None]:
class StackType(Enum):
    DAILY = 1
    WEEKLY = 7
    BIWEEKLY = 14
    MONTHLY = 30
    BIMONTHY = 60
    

In [None]:

class DataScaler:
    def __init__(self, mode, min, max, mean, std) -> None:
        self.mode = mode
        self.min = min
        self.max = max
        self.mean = mean
        self.std = std
    
    def normalize(self,data):
        return (data - self.min) / (self.max - self.min)
    
    def standardize(self,data):
        return (data - self.mean) / self.std

    def remove_outliers(self, data, clipped_min, clipped_max):
        data = np.clip(data.data[0, :].astype(float), clipped_min, clipped_max)
        bottom = clipped_max - clipped_min
        bottom[bottom == 0] = "nan"
        data = (data - clipped_min) / bottom
        return np.nan_to_num(data, 0)

    def scale(self, data):
        if self.mode == "normalize":
            return self.normalize(data)
        elif self.mode == "standardize":
            return self.standardize(data)
        else:
            raise ValueError(f"Mode '{self.mode}' non reconnu.")


In [None]:

class HierarchicalAggregator:
    def __init__(self,stack_type_input: list[StackType], lead_time_output: int = 45, resolution_output: int = 14, scaler : callable =None):
        self.stack_type_input = stack_type_input
        self.lead_time_output = lead_time_output
        self.resolution_output = resolution_output
        self._temporal_idx = 0
        self.scaler = scaler
        self._temporal_idx_maping = {}
        self._current_temporal_idx = 0
    
    def aggregate(self, data: xr.DataArray, idx: int):
        input_data = []
        target_data = []

        # compute the window of temporal indexes we will use to create the input data
        width_input = max([stack.value for stack in self.stack_type_input])
        input_time_indexes = data.time.values[self._current_temporal_idx:self._current_temporal_idx + width_input]
        input_window = data.sel(time=input_time_indexes)
        print('input window',input_window)
        print('width window',width_input)

        # compute the indexes of interest for the target data
        start_idx_output = self._current_temporal_idx + width_input + 1
        # (TODO) pour l'instant on a t+1 qui est predit, est ce qu'on le veut, aussi il faut savoir c quoi la target definition
        target_time_indexes = data.time.values[start_idx_output : start_idx_output + self.lead_time_output: self.resolution_output]
        print('target time indexes',target_time_indexes)
        print('start idx output',start_idx_output)
        
        # compute the aggregated data for the input and stack them
        rolling_mean = {stack.value : input_window.rolling(time = stack.value, center = False ).mean() for stack in self.stack_type_input}
        input_data = [rolling_mean[stack.value].values for stack in self.stack_type_input]
        input_data = np.stack(input_data, axis = 0)

        # prepare target data 
        target_seq = data.sel(time = target_time_indexes).values
        target_data.append(target_seq)


        # update temporal index for next iteration
        self._current_temporal_idx = start_idx_output + self.lead_time_output +1
        self._temporal_idx_maping[idx] = self._current_temporal_idx
        print('current temporal idx',self._current_temporal_idx)

        return torch.Tensor(input_data), torch.Tensor(target_data)
    

In [None]:
from typing import List, Dict

import xarray as xr
import os 
from enum import Enum

class Resolution(Enum):
    WEEKLY = "week"
    DAILY = "day"
    MONTHLY = "month"
    SEASON = "season"
    YEARLY = "year"

class DataStatistics:
    def __init__(self, data: xr.DataArray, years : List[int] , months: List[int], resolution: Resolution):
        """ Saves the data and the statistics for the given years and months.
        Compute the statistics with different level of resolution (week, day, month, season, year) over all the years and months given.
        If the statistics are already computed, it loads them, otherwise it computes them and saves them.

        Args:
            data (xr.DataArray): _description_
            years (List[int]): _description_
            months (List[int]): _description_
            resolution (Resolution): _description_
        """
        self.data = data
        self.years = years
        self.months = months
        self.resolution = resolution
        self.stats = self._get_stats()
    
    def _get_stats(self) -> Dict[str, float]:
        path_mean, path_std  = f"data/{self.resolution.value}_{self.years}_{self.months}_average.nc", f"data/{self.resolution.value}_{self.years}_{self.months}_std.nc"
        
        if os.path.exists(path_mean) and os.path.exists(path_std):
            average = xr.open_dataarray(path_mean)
            std = xr.open_dataarray(path_std)
        else:
            average, std = self._compute_stats()
            self.save_stats(average, std)
        return {"mean" : average, "std" : std}
        
    
    def _compute_stats(self) -> xr.DataArray:
        # check we have the right year and months 
        self.data= self.data.sel(time = self.data.time.dt.year.isin(self.years))
        self.data = self.data.sel(time = self.data.time.dt.month.isin(self.months))
       
        average = self.data.groupby(f"time.{self.resolution.value}").mean(dim = "time")
        std = self.data.groupby(f"time.{self.resolution.value}").std(dim = "time")
        return average, std
    
    def save_stats(self, average, std):
        average.to_netcdf(f"data/{self.resolution.value}_{self.years}_{self.months}_average.nc")
        std.to_netcdf(f"data/{self.resolution.value}_{self.years}_{self.months}_std.nc")        
        


In [None]:
class DatasetEra(Dataset):
    def __init__(
        self,
        wandb_config : dict,
        variables_nh : List[Enum],
        variables_med : List[Enum],
        data_dirs : str,
        temporal_aggregator : HierarchicalAggregator,
        stat_provider : DataStatistics,
        scaler : DataScaler = None,
    ):
        self._initialize_config(wandb_config)
        self.variables_nh = variables_nh
        self.variables_med = variables_med
        self.data_dirs = data_dirs
        self.scaler = scaler
        self.med_data, self.nh_data = self._load_and_prepare_data()
        self.temporal_aggregator = temporal_aggregator
        self.first_year = self.med_data.time.dt.year.min().item()
        self.last_year = self.med_data.time.dt.year.max().item()
        self.stats = self._get_stats(Resolution.WEEKLY)
        
    def _get_stats(self, resolution: Resolution):
        stat_med = DataStatistics(self.med_data, self.relevant_years, self.relevant_months, resolution)._get_stats()
        stat_nh = DataStatistics(self.nh_data, self.relevant_years, self.relevant_months, resolution)._get_stats()
        return {"mediteranean" : stat_med, 'north_hemisphere' : stat_nh}
        
    def _initialize_config(self, wandb_config):
        """Initialize configuration settings."""
        ds_conf = wandb_config["dataset_config"]
        self.land_mask = ds_conf["land_mask"]
        self.relevant_months = ds_conf["relevant_months"]
        self.relevant_years = ds_conf["relevant_years"]

    def _load_and_prepare_data(self):
        """Load data from directories and keep the relevant years/months."""
        # load the data from the directories
        med_data = self._load_data(self.data_dirs['mediteranean'])
        med_data = med_data[self.variables_med]
        nh_data = self._load_data(self.data_dirs['north_hemisphere'])
        nh_data = nh_data[self.variables_nh]
        print("data loaded")

        # remap north mediteranean to north hemisphere 
        med_data = self.remap_MED_to_NH(nh_data,med_data)
        print("data remapped")

        return med_data, nh_data
    
    def _load_data(self, dir_path):
        """Load data from a specified directory using xarray."""
        ds = xr.open_mfdataset(f"{dir_path}/*.nc", combine='by_coords')
        ds = self._filter_data_by_time(ds)
        return ds
    
    def remap_MED_to_NH(self, nh_data, med_data):
        """Remap Mediterranean data to North Hemisphere grid and pad with zeros."""
        # Empty array same dimensions and coordinates as nh_data
        print("nh_data",nh_data)
        fill_values = {var: 0 for var in nh_data.data_vars}
        remapped_data = xr.full_like(nh_data, fill_value=fill_values)
        print("remapped_data",remapped_data)

        # Find the overlap region and assign Mediterranean data to the remapped data
        med_lon_min, med_lon_max = med_data.min(dim = "longitude"), med_data.max(dim = "longitude")
        med_lat_min, med_lat_max = med_data.min(dim = "latitude"), med_data.max(dim = "latitude")

        # replace the values of the remapped data with the mediteranean data
        remapped_data = remapped_data.where(
            (remapped_data.longitude >= med_lon_min) & (remapped_data.longitude <= med_lon_max) &
            (remapped_data.latitude >= med_lat_min) & (remapped_data.latitude <= med_lat_max),
            other = med_data
        )

        return remapped_data

    def _filter_data_by_time(self, data):
        """Filter the data to include only the relevant months and years."""
        data = data.sel(time=data['time.year'].isin(self.relevant_years))
        data = data.sel(time=data['time.month'].isin(self.relevant_months))
        return data
    
    
    def __len__(self):
        len_med, len_nh = self.temporal_aggregator.compute_len_dataset(), self.temporal_aggregator.compute_len_dataset()
        assert len_med == len_nh, "The length of the two datasets should be the same."
        return len_med
    

    def __getitem__(self, idx):
        # Aggregate the input data
        med_input_data, med_target_data = self.temporal_aggregator.aggregate(self.med_data, idx)
        nh_input_data, nh_target_data = self.temporal_aggregator.aggregate(self.nh_data, idx)


        # Concatenate Mediterranean and North Hemisphere data along the variable dimension
        input_data = torch.cat([med_input_data, nh_input_data], dim=0)
        target_data = torch.cat([med_target_data, nh_target_data], dim=0)

        return input_data, target_data


In [None]:

    
class TemporalAggregator:
    def __init__(self, stack_number_input : int,lead_time_number : int, resolution_input : int, resolution_output: int, scaler: DataScaler):
        self.name = "TemporalAggregator"
        self.stack_number_input = stack_number_input
        self.lead_time_number = lead_time_number
        self.resolution_input = resolution_input
        self.resolution_output = resolution_output
        self.scaler = scaler
        self._temporal_idx = 0
        self._current_temporal_idx = 0
        self._temporal_idx_maping = {}
    
    def compute_len_dataset(self, data: xr.DataArray):
        width_input = self.stack_number_input*self.resolution_input
        return len(data.time.values) - width_input - self.lead_time_number*self.resolution_output
    
    def aggregate(self, data: xr.DataArray, idx: int):
        input_data = []
        target_data = []

        # compute the window of temporal indexes we will use to create the input data
        width_input = self.stack_number_input*self.resolution_input
        input_time_indexes = data.time.values[self._current_temporal_idx:self._current_temporal_idx + width_input]
        input_window = data.sel(time=input_time_indexes)
        # si dans input window, certaines données on plus 2 mois de différence
        # on peut pas les stacker
        

        # compute the indexes of interest for the target data
        start_idx_output = self._current_temporal_idx + width_input + 1
        target_time_indexes = data.time.values[start_idx_output : start_idx_output + self.resolution_output*self.lead_time_number]
        output_window = data.sel(time = target_time_indexes)

        # stack mean for 
        for i in range(0,self.stack_number_input):
            mean_input = input_window.sel(time = input_time_indexes[i*self.resolution_input:(i+1)*self.resolution_input]).mean(dim = "time")
            input_data.append(mean_input.values)
        for i in range(0,self.lead_time_number):
            mean_output = output_window.sel(time = target_time_indexes[i*self.resolution_output:(i+1)*self.resolution_output]).mean(dim = "time")
            target_data.append(mean_output.values)


        input_data = np.stack(input_data, axis = 0)
        target_data = np.stack(target_data, axis = 0)

        # update temporal index for next iteration
        self._current_temporal_idx = start_idx_output + self.resolution_output*self.lead_time_number +1
        self._temporal_idx_maping[idx] = self._current_temporal_idx

        return torch.Tensor(input_data), torch.Tensor(target_data)

In [None]:
aggregator = HierarchicalAggregator(stack_type_input=[StackType.DAILY, StackType.WEEKLY], lead_time_output=45, resolution_output=14)

relevant_years = [2015, 2016, 2017, 2018, 2019]
relevant_months = [9, 10, 11,12,1,2]
ds = xr.open_mfdataset(f"{dataset_path}", combine='by_coords')
ds = ds["t2m"]
ds = ds.sel(time=ds['time.year'].isin(relevant_years))
ds = ds.sel(time=ds['time.month'].isin(relevant_months))



KeyboardInterrupt: 

In [None]:
temp_aggregator = TemporalAggregator(stack_number_input=2,lead_time_number=2, resolution_input=14, resolution_output=14, scaler=None)
temp_aggregator.compute_len_dataset(ds)


850

In [None]:
stats = DataStatistics(ds, relevant_years, relevant_months, Resolution.WEEKLY)
print(stats.stats)

{'mean': <xarray.DataArray 't2m' (week: 28, latitude: 61, longitude: 201)> Size: 3MB
[343308 values with dtype=float64]
Coordinates:
  * longitude  (longitude) float32 804B -10.0 -9.75 -9.5 ... 39.5 39.75 40.0
  * latitude   (latitude) float32 244B 45.0 44.75 44.5 44.25 ... 30.5 30.25 30.0
  * week       (week) int64 224B 1 2 3 4 5 6 7 8 9 ... 46 47 48 49 50 51 52 53
Attributes:
    long_name:     2 metre temperature
    units:         K
    cell_methods:  time: mean, 'std': <xarray.DataArray 't2m' (week: 28, latitude: 61, longitude: 201)> Size: 3MB
[343308 values with dtype=float64]
Coordinates:
  * longitude  (longitude) float32 804B -10.0 -9.75 -9.5 ... 39.5 39.75 40.0
  * latitude   (latitude) float32 244B 45.0 44.75 44.5 44.25 ... 30.5 30.25 30.0
  * week       (week) int64 224B 1 2 3 4 5 6 7 8 9 ... 46 47 48 49 50 51 52 53
Attributes:
    long_name:     2 metre temperature
    units:         K
    cell_methods:  time: mean}


In [None]:
ds['wet_season_year']= ds.time.dt.year*(ds.time.dt.month>=9) + (ds.time.dt.year-1)*(ds.time.dt.month<9)
grouped_data = ds.groupby('wet_season_year')

resolution = "week"
# remove the year if there is not all months of the wet season
# tesxt the first element of the groupby
print(grouped_data)
for _, group in grouped_data:
    print(group.time.dt.weekofyear[2].values)
    # print the resolution number of this date



DatasetGroupBy, grouped over 'wet_season_year'
84 groups with labels 1939, 1940, 1941, ..., 2021, 2022.
1
36
36
36
35
35
36
36
36
36
35
35
36
36
36
35
35
36
36
36
36
35
35
36
36
36
35
35
35
36
36
36
35
35
36
36
36
36
35
35
36
36
36
35
35
36
36
36
36
35
35
36
36
36
35
35
35
36
36
36
35
35
36
36
36
36
35
35
36
36
36
35
35
36
36
36
36
35
35
36
36
36
35
35




In [None]:
#get season of the first element


print(grouped_data[2014][0].time.dt.season)



AttributeError: 'DataArray' object has no attribute 'dt'

In [None]:


group_indices = grouped_data.groups

# Affichage du premier groupe d'indices

for key, group in grouped_data:
    print(key)
    print(group)
    break

2014
<xarray.DataArray 't2m' (time: 59, latitude: 61, longitude: 201)> Size: 6MB
dask.array<getitem, shape=(59, 61, 201), dtype=float64, chunksize=(59, 61, 201), chunktype=numpy.ndarray>
Coordinates:
  * time             (time) datetime64[ns] 472B 2015-01-01T11:00:00 ... 2015-...
  * longitude        (longitude) float32 804B -10.0 -9.75 -9.5 ... 39.75 40.0
  * latitude         (latitude) float32 244B 45.0 44.75 44.5 ... 30.5 30.25 30.0
    wet_season_year  (time) int64 472B 2014 2014 2014 2014 ... 2014 2014 2014
Attributes:
    long_name:     2 metre temperature
    units:         K
    cell_methods:  time: mean


In [None]:
grouped_data.get_group(2015)

AttributeError: 'DataArrayGroupBy' object has no attribute 'get_group'

In [None]:
grouped_data[2014]

Unnamed: 0,Array,Chunk
Bytes,5.52 MiB,5.52 MiB
Shape,"(59, 61, 201)","(59, 61, 201)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 5.52 MiB 5.52 MiB Shape (59, 61, 201) (59, 61, 201) Dask graph 1 chunks in 5 graph layers Data type float64 numpy.ndarray",201  61  59,

Unnamed: 0,Array,Chunk
Bytes,5.52 MiB,5.52 MiB
Shape,"(59, 61, 201)","(59, 61, 201)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [None]:
weekly_average = ds.groupby("time.week").mean(dim="time")
print(weekly_average)

<xarray.DataArray 't2m' (week: 28, latitude: 61, longitude: 201)> Size: 3MB
dask.array<stack, shape=(28, 61, 201), dtype=float64, chunksize=(1, 61, 201), chunktype=numpy.ndarray>
Coordinates:
  * longitude  (longitude) float32 804B -10.0 -9.75 -9.5 ... 39.5 39.75 40.0
  * latitude   (latitude) float32 244B 45.0 44.75 44.5 44.25 ... 30.5 30.25 30.0
  * week       (week) int64 224B 1 2 3 4 5 6 7 8 9 ... 46 47 48 49 50 51 52 53
Attributes:
    long_name:     2 metre temperature
    units:         K
    cell_methods:  time: mean




In [None]:
input_data, target_data = aggregator.aggregate(ds, 0)

input window <xarray.DataArray 't2m' (time: 7, latitude: 61, longitude: 201)> Size: 687kB
dask.array<getitem, shape=(7, 61, 201), dtype=float64, chunksize=(7, 61, 201), chunktype=numpy.ndarray>
Coordinates:
  * time             (time) datetime64[ns] 56B 2015-01-01T11:00:00 ... 2015-0...
  * longitude        (longitude) float32 804B -10.0 -9.75 -9.5 ... 39.75 40.0
  * latitude         (latitude) float32 244B 45.0 44.75 44.5 ... 30.5 30.25 30.0
    wet_season_year  (time) int64 56B 2014 2014 2014 2014 2014 2014 2014
Attributes:
    long_name:     2 metre temperature
    units:         K
    cell_methods:  time: mean
width window 7
target time indexes ['2015-01-09T11:00:00.000000000' '2015-01-23T11:00:00.000000000'
 '2015-02-06T11:00:00.000000000' '2015-02-20T11:00:00.000000000']
start idx output 8


KeyboardInterrupt: 

In [None]:
dataset = DatasetEra(
    wandb_config = {"dataset_config" :{"variable_names": ["t2m"], "land_mask": False, "relevant_months": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "relevant_years": [1940, 2022]}},
    variables = ["t2m"],
    data_dirs = {"mediteranean": dataset_path, "north_hemisphere": dataset_path},
    temporal_aggregator = aggregator,
    scaler = None
)

data loaded
nh_data <xarray.Dataset> Size: 72MB
Dimensions:    (time: 731, latitude: 61, longitude: 201)
Coordinates:
  * time       (time) datetime64[ns] 6kB 1940-01-01T11:00:00 ... 2022-12-31T1...
  * longitude  (longitude) float32 804B -10.0 -9.75 -9.5 ... 39.5 39.75 40.0
  * latitude   (latitude) float32 244B 45.0 44.75 44.5 44.25 ... 30.5 30.25 30.0
Data variables:
    t2m        (time, latitude, longitude) float64 72MB dask.array<chunksize=(731, 61, 201), meta=np.ndarray>
Attributes:
    CDI:          Climate Data Interface version 1.9.10 (https://mpimet.mpg.d...
    Conventions:  CF-1.6
    history:      Wed Jun 05 15:44:23 2024: cdo -sellonlatbox,-10,40,30,45 /s...
    frequency:    day
    CDO:          Climate Data Operators version 1.9.10 (https://mpimet.mpg.d...
remapped_data <xarray.Dataset> Size: 72MB
Dimensions:    (time: 731, latitude: 61, longitude: 201)
Coordinates:
  * time       (time) datetime64[ns] 6kB 1940-01-01T11:00:00 ... 2022-12-31T1...
  * longitude  (longit