In [1]:
import os
import numpy as np
import pandas as pd
import xarray as xr

import torch
torch.random.seed()
np.random.seed(0)


data_dir = '/localdata_ssd/gaoziyi/dataset' # change to you dataset dir

def chunk_time(ds):
    dims = {k:v for k, v in ds.dims.items()}
    dims['time'] = 1
    ds = ds.chunk(dims)
    return ds


def load_dataset():
    ds = []
    for y in range(2007, 2008):
        data_name = os.path.join(data_dir, f'weather_round1_train_{y}')
        x = xr.open_zarr(data_name, consolidated=True)
        print(f'{data_name}, {x.time.values[0]} ~ {x.time.values[-1]}')
        ds.append(x)
    ds = xr.concat(ds, 'time')
    ds = chunk_time(ds)
    return ds

ds = load_dataset().x

num_step = 20 # for 5-days
shape = ds.shape # batch x channel x lat x lon 
times = ds.time.values
init_times = times[slice(1, -num_step)] 
num_data = len(init_times)
names = list(ds.channel.values)
test_names = names[-5:]

print(f'\n shape: {shape}')
print('\n times: {} ~ {}'.format(times[0], times[-1]))
print('\n init_times: {} ~ {}'.format(init_times[0], init_times[-1]))
print(f'\n names: {names}')
print(f'\n test_names: {test_names}\n')    

/localdata_ssd/gaoziyi/dataset/weather_round1_train_2007, 2007-01-01T00:00:00.000000000 ~ 2007-12-31T18:00:00.000000000
/localdata_ssd/gaoziyi/dataset/weather_round1_train_2008, 2008-01-01T00:00:00.000000000 ~ 2008-12-31T18:00:00.000000000
/localdata_ssd/gaoziyi/dataset/weather_round1_train_2009, 2009-01-01T00:00:00.000000000 ~ 2009-12-31T18:00:00.000000000
/localdata_ssd/gaoziyi/dataset/weather_round1_train_2010, 2010-01-01T00:00:00.000000000 ~ 2010-12-31T18:00:00.000000000
/localdata_ssd/gaoziyi/dataset/weather_round1_train_2011, 2011-01-01T00:00:00.000000000 ~ 2011-12-31T18:00:00.000000000


ValueError: unrecognized chunk manager dask - must be one of: []

In [None]:
# visualize any variable at any time 
def visualize(time, name):
    import cartopy.crs as ccrs
    import matplotlib.pyplot as plt 
    import matplotlib.patches as patches
    
    assert name in names
    v = ds.sel(time=time, channel=name)

    
    def plot(ds, ax, title):
        ds.plot(
            ax=ax, 
            x='lon', 
            y='lat', 
            transform=ccrs.PlateCarree(),  
            # cbar_kwargs={'label': 'K'},     
            add_colorbar=False
        )
        ax.set_title(title)
        ax.coastlines()
        gl = ax.gridlines(draw_labels=True, linewidth=0.5)
        gl.top_labels = False
        gl.right_labels = False    
        

    fig, ax = plt.subplots(figsize=(8, 6), subplot_kw={"projection": ccrs.PlateCarree()})
    plot(v, ax, title=f'{name.upper()}')
    
visualize(time='20080101-00', name='t2m')    
visualize(time='20080101-00', name='u10')
visualize(time='20080101-00', name='v10')
visualize(time='20080101-00', name='msl')
visualize(time='20080101-00', name='tp')

In [None]:
# load seqs from dataset
def getitem(idx):
    assert idx < num_data
    t = init_times[idx]
    t1 = t - pd.Timedelta(hours=6)
    t2 = t + pd.Timedelta(days=5) # you can reduce it for auto-regressive training 
    tid = pd.date_range(t1, t2, freq='6h')
    
    input = ds.sel(time=tid[:2]) # you can use subset of input, eg: only surface 
    target = ds.sel(time=tid[2:], channel=test_names)
    
    input = torch.from_numpy(input.values)
    target = torch.from_numpy(target.values)
    
    input = torch.nan_to_num(input) # t c h w 
    target = torch.nan_to_num(target) # t c h w 
    return input, target


def dummy_model(input, target=None):
    if target is None:
        output = torch.randn(20, 5, 101, 101) # step x channel x lat x lon 
    else:
        # TODO, train your model, base on input and target,
        # here we add noise to target produce fake output
        output = target + torch.randn_like(target)
    return output    


# daily climate baseline
climates = {
    't2m': 3.1084048748016357,
    'u10': 4.114771819114685,
    'v10': 4.184110546112061,
    'msl': 729.5839385986328,
    'tp': 0.49046186606089276,
}


def compute_rmse(out, tgt):
    rmse = torch.sqrt(((out - tgt)**2).mean())
    return rmse


def run_eval(output, target):
    '''
        result: (batch x step x channel x lat x lon), eg: N x 20 x 5 x H x W
        target: (batch x step x channel x lat x lon), eg: N x 20 x 5 x H x W
    '''
    result = {}
    for cid, (name, clim) in enumerate(climates.items()):
        res = []
        for sid in range(output.shape[1]):
            out = output[:, sid, cid]
            tgt = target[:, sid, cid]
            rmse = compute_rmse(out, tgt)
            nrmse = (rmse - clim) / clim
            res.append(nrmse)
            
            # normalized rmse, lower is better,
            # 0 means equal to climate baseline, 
            # less than 0 means better than climate baseline,   
            # -1 means perfect prediction            

        score = max(0, -np.mean(res))
        result[name] = float(score)

    score = np.mean(list(result.values()))
    result['score'] = float(score) 
    return result

# Please note the data has been normalized by pre-compute mean and std, you do't need to normalized again, 
idx = np.random.randint(num_data)
input, target = getitem(idx)
output = dummy_model(input, target)
print('input: {}, {} ~ {}'.format(input.shape, input.min(), input.max()))
print('output: {}, {} ~ {}'.format(output.shape, output.min(), output.max()))
print('target: {}, {} ~ {}'.format(target.shape, target.min(), target.max()))


# run eval on the normalized output and target, but the online evalation will un-normalized your output to origial scale
result = run_eval(output[None], target[None])
score = "".join(f"{k}: {v:.2f} " for k, v in result.items()) 
print(score)