In [74]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from graphcast import data_utils
import dataclasses
from graphcast import graphcast
import his_utils

In [75]:
file_name = "testdata/source-era5_date-2022-01-01_res-0.25_levels-37_steps-12.nc"
eval_data = xr.open_dataset(file_name)
eval_data = eval_data.isel(time=[0,1,2,3])

In [76]:
task_config = graphcast.TaskConfig(
        input_variables=(graphcast.TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS + graphcast.FORCING_VARS +
        graphcast.STATIC_VARS),
        target_variables=graphcast.TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS,
        forcing_variables=graphcast.FORCING_VARS,
        pressure_levels=graphcast.PRESSURE_LEVELS[37],
        input_duration="6h"
    )

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    eval_data, 
    target_lead_times=slice("6h", f"{2 * 6}h"), 
    **dataclasses.asdict(task_config)
)

|| data_utils.py -> add_derived_vars() init. ||
<xarray.Dataset> Size: 4GB
Dimensions:                       (lon: 1440, lat: 721, level: 37, time: 4,
                                   batch: 1)
Coordinates:
  * lon                           (lon) float32 6kB 0.0 0.25 0.5 ... 359.5 359.8
  * lat                           (lat) float32 3kB -90.0 -89.75 ... 89.75 90.0
  * level                         (level) int32 148B 1 2 3 5 ... 950 975 1000
  * time                          (time) timedelta64[ns] 32B 00:00:00 ... 18:...
    datetime                      (batch, time) datetime64[ns] 32B ...
Dimensions without coordinates: batch
Data variables: (12/14)
    geopotential_at_surface       (lat, lon) float32 4MB ...
    land_sea_mask                 (lat, lon) float32 4MB ...
    2m_temperature                (batch, time, lat, lon) float32 17MB ...
    mean_sea_level_pressure       (batch, time, lat, lon) float32 17MB ...
    10m_v_component_of_wind       (batch, time, lat, lon) float32 

In [77]:
# 37 level
sample1=xr.open_dataset("testdata/1.nc")

# single level
sample2=xr.open_dataset("testdata/2.nc")

# merge 전에 이름 변경해줘야 충돌 안 발생함.
sample2 = sample2.rename({"z": "geopotential_at_surface"})

sample = xr.merge([sample1, sample2])

# sample data와 동일하게 처리해주는 
sample=his_utils.transform_dataset(sample).copy()

accumulating = sample.total_precipitation.resample(time='6h').sum()

hours = sample.time.dt.total_seconds() / 3600
time_selector = (hours % 24).isin([0, 6, 12, 18])

new_sample = sample.isel(time=time_selector)
new_sample = new_sample.assign(total_precipitation=accumulating.astype(np.float32))
new_sample = new_sample.rename({"total_precipitation": "total_precipitation_6hr"})

new_sample['geopotential_at_surface'] = new_sample['geopotential_at_surface'].squeeze('batch')
new_sample['land_sea_mask'] = new_sample['land_sea_mask'].squeeze('batch')
new_sample = new_sample.reindex(lat=new_sample.lat[::-1])

# new_sample.to_netcdf("testdata/new_sample.nc")

In [86]:
# import cartopy.crs as ccrs
# import cartopy.feature as cfeature

# plt.figure(figsize=(20,10))
# ax = plt.axes(projection=ccrs.PlateCarree())

# im = new_sample["2m_temperature"].isel(time=0).plot(ax=ax, transform=ccrs.PlateCarree(), cmap='rainbow', cbar_kwargs={'label': 'precipitation (m)'}, add_colorbar=True)
# ax.gridlines(draw_labels=True, linestyle='--')
# ax.add_feature(cfeature.COASTLINE)
# plt.tight_layout()
# ax.set_global()
# plt.savefig("precip2.png")

In [None]:
# for var in new_sample.data_vars:
#     diff = new_sample[var] - eval_data[var]
#     print(f"{var}: {diff.min().values}, {diff.max().values}, {diff.mean().values}")
    # plot diff
    # plt.figure(figsize=(20,10))
    # ax = plt.axes(projection=ccrs.PlateCarree())
    # im = diff.isel(time=0).plot(ax=ax, 
    #                             transform=ccrs.PlateCarree(), 
    #                             cmap='rainbow', 
    #                             cbar_kwargs={'label': 'precipitation (m)'}, 
    #                             add_colorbar=True)
    # ax.gridlines(draw_labels=True, linestyle='--')
    # ax.add_feature(cfeature.COASTLINE)
    # plt.tight_layout()
    # ax.set_global()


In [13]:
eval_targets = xr.open_dataset("testdata/eval_targets.nc")
eval_targets * np.nan

In [3]:
import his_utils

his_utils.create_nan_dataset(time_steps=2, 
                   resolution=0.25, 
                   pressure_levels=37, 
                   start_time="2022-01-01")

In [12]:
import numpy as np
import pandas as pd
import xarray as xr
from graphcast import data_utils, solar_radiation

resolution = 0.25
time_steps = 2
start_time = "2022-01-01"

lon = np.arange(0.0, 360.0, resolution, dtype=np.float32)
lat = np.arange(-90.0, 90.0 + resolution/2, resolution, dtype=np.float32)

start_datetime = pd.to_datetime(start_time) + pd.Timedelta(hours=6)
time = pd.date_range(start=start_datetime, periods=time_steps, freq='6h')

# Create the dataset
ds = xr.Dataset(
    coords={
        'lon': ('lon', lon),
        'lat': ('lat', lat),
        'datetime': ('time', time),
    }
)

ds.lat.attrs['long_name'] = 'latitude'
ds.lat.attrs['units'] = 'degrees_north'

ds.lon.attrs['long_name'] = 'longitude'
ds.lon.attrs['units'] = 'degrees_east'

variables = ['toa_incident_solar_radiation',
             'year_progress_sin',
             'year_progress_cos',
             'day_progress_sin',
             'day_progress_cos']

ds

In [14]:
data_utils.add_tisr_var(ds)

ds  