In [1]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt 
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1 import make_axes_locatable
from einops import repeat

work_dir = "/mnt/workspace"
data_dir = os.path.join(work_dir, "datasets/era5")
save_dir = os.path.join(work_dir, "output/results")
lat = np.load(os.path.join(data_dir, 'lat.npy'))
lon = np.load(os.path.join(data_dir, 'lon.npy'))
eval_names = ("z500", "t850", "t2m", "u10", "v10", "tp")

[2022-08-15 16:11:18,277.277 dsw37230-5564686ddb-9mqpj:198501 INFO utils.py:30] NOTICE: PAIDEBUGGER is turned off.


In [None]:
def adjust_sample_ratio(max_iter=150000, k=8000):
    ratios = []
    for itr in range(max_iter):
        ratio = k / (k + np.exp((itr + 1) / k))            
        ratios.append(ratio)
    x = np.arange(len(ratios))
    plt.plot(x, ratios)
    plt.show()        
    return ratios

In [None]:
ratios = adjust_sample_ratio()

In [None]:
def plot_var(imgs, exp, lat, lon, init_time, lead_time, save_dir="log_images", use_colorbar=False, use_ticks=False):
    os.makedirs(save_dir, exist_ok=True)
    print(save_dir)
    
    save_names = ["z500", "t850", "t2m", "3u10", "v10", "tp", "ws"]
    cmaps = ["cividis", "RdYlBu_r", "RdYlBu_r", "bwr", "bwr", "jet", "bwr"]
    
    var_names = [
         r'Z500 [m$^2$ s$^{-2}$]',
         r'T850 [K]',
         r'T2M [K]',
         r'U10 [m s$^{-1}$]',
         r'V10 [m s$^{-1}$]',
         r'TP [mm]',
         r'WS [m s$^{-1}$]',
    ]
    
    min_max = [
        (),
        (),
        (260, 310),
        (),
        (),
        (0, 50),
        (0, 25),
    ]

    step = int(60/0.25)
    linewidth = 1
    min_lat, max_lat, min_lon, max_lon = -90, 90, 0, 180 # for speed 
    bx, by, bw, bh = 450, 250, 64, 64
    pw = int(bw * 4)
    ph = int(bh * 4)
    
    # min_lat, max_lat, min_lon, max_lon = -60, 31, 270, 361 # for t2m 
    # min_lat, max_lat, min_lon, max_lon = -90, 0, 60, 150 # for speed 
    # min_lat, max_lat, min_lon, max_lon = -30, 60, 150, 240 # for tp 
    
    def crop(img, lat, lon):
        msk_lat = (lat >= min_lat) & (lat < max_lat)
        msk_lon = (lon >= min_lon) & (lon < max_lon)
        msk = msk_lat.reshape(-1, 1) & msk_lon.reshape(1, -1)
        lat = lat[msk_lat]
        lon = lon[msk_lon]
        x, y, w, h = cv2.boundingRect(msk.astype(np.uint8))
        img = img[:, y:y+h, x:x+w]
        return img, lat, lon 
    
    z500, t850, t2m, u10, v10, tp = imgs
    speed = np.sqrt(u10 ** 2 + v10 ** 2)
    
    imgs = np.concatenate([imgs, speed[None]], axis=0)
    
    # imgs, lat, lon = crop(imgs, lat, lon) 

    for i in range(len(imgs)):
        save_name = save_names[i]
        
        if save_name != "ws":
            continue
            
        var_name = var_names[i]
        
        cmap = cmaps[i]
        img = imgs[i]
        ih, iw = img.shape
        
        patch = img[by:by+bh, bx:bx+bw]
        patch = cv2.resize(patch, (pw, ph))
        img[-ph:, -pw:] = patch 
        
        title = f"{exp} {var_name} t={lead_time}h"
        save_f = os.path.join(save_dir, f"{exp.lower()}_{save_name}_{init_time}_{lead_time:03d}.png")
        print(save_name, img.min(), img.max(), save_f)

        fig, ax = plt.subplots()

        im = ax.imshow(img, cmap=cmap)
        
        rect_r = patches.Rectangle((bx, by), bw, bh, linewidth=linewidth, edgecolor='r', facecolor='none')
        rect_g = patches.Rectangle((iw-pw - linewidth, ih-ph -linewidth), pw, ph, linewidth=linewidth, edgecolor='g', facecolor='none')
        ax.add_patch(rect_r)
        ax.add_patch(rect_g)
        
        # vmin, vmax = min_max[i]
        # im = ax.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax)

        xticks = np.linspace(0, iw, len(lon))
        yticks = np.linspace(0, ih, len(lat))
        
        xlabels = [f"{label:.2f}" for label in lon]
        ylabels = [f"{label:.2f}" for label in reversed(lat)]
        
        if use_ticks:
            plt.xticks(xticks[::step], xlabels[::step])
            plt.yticks(yticks[::step], reversed(ylabels[::step]))
        else:
            plt.axis("off")

        plt.title(title, fontsize=16)
        plt.tight_layout()
            
        if use_colorbar:
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im, cax=cax)
        
        plt.savefig(save_f, bbox_inches='tight', pad_inches=0.0, transparent='true', dpi=300)
        plt.show()
        plt.close()            

In [None]:
exp = "Bilinear"
lead_time = 120
init_time = "2018060400"
noise = np.load(os.path.join(data_dir, 'noise.npy'))
output = np.load(os.path.join(save_dir, exp.lower(), init_time, f'output_{lead_time:03d}.npy'))
target = np.load(os.path.join(save_dir, exp.lower(), init_time, f'target_{lead_time:03d}.npy'))
plot_var(target, "ERA5", lat, lon, init_time, lead_time, os.path.join(save_dir, "log_images"), use_colorbar=True, use_ticks=True)
# plot_var(output, exp, lat, lon, init_time, lead_time, os.path.join(save_dir, "log_images"))
# plot_var(target, "ERA5", lat, lon, init_time, lead_time, os.path.join(save_dir, "log_images"))
# plot_var(noise, "Noise", lat, lon, init_time, lead_time, os.path.join(save_dir, "log_images"))

In [11]:
def load_predicition(pred_f, lat, lon, init_time, lead_time=120):
    predictions = torch.load(pred_f)
    idx = np.concatenate(predictions.pop("idx"))
    # init_time = [start_time + np.timedelta64(i * interval, 'h') for i in range(len(idx))]
    
    preds = []
    for name in eval_names:
        key = "{}_{}".format(name, lead_time)
        assert key in predictions, key
        pred = np.concatenate(predictions[key])[idx]
        # print(pred.shape, lat.shape, lon.shape, len(init_time))

        pred = xr.DataArray(
            pred,
            dims=['time', 'lat', 'lon'],
            coords={'time': init_time, 'lat': lat, 'lon': lon},
            name=name
        )
        # print(pred)

        preds.append(pred)
    
    pred = xr.merge(preds)
    return pred

In [None]:
def plot_t2m(results, save_f, lat=39.916668, lon=116.383331):
    labels = []
    
    for name in results:
        labels.append(name)
        pred = results[name].t2m
        # rmse = np.sqrt((gt - pred) ** 2).mean()
        # print(f"{name}: {rmse.values:.3f}")
        city = pred.sel(lat=lat,lon=lon, method="nearest")
        city = city.sel(time=slice('2018-07-01', '2018-08-01'))
        plt.plot(city.time.values, city.values - 273.15)
    plt.legend(labels)
    plt.grid()
    plt.ylabel(r'$^\circ$C',)
    # plt.savefig(save_f, bbox_inches='tight', pad_inches=0.0, transparent='true', dpi=300)
    plt.show()

In [None]:
t2m_dir = os.path.join(save_dir, "t2m")
save_f = os.path.join(t2m_dir, "t2m.png")
t2m_results = {
    "ERA5": xr.open_dataset(os.path.join(t2m_dir, "era5_t2m.nc")),
    # "Bilinear": xr.open_dataset(os.path.join(t2m_dir, "bilinear_t2m.nc")),
    "SwinIR": xr.open_dataset(os.path.join(t2m_dir, "swinir_t2m.nc")),
    # "SwinRDM": xr.open_dataset(os.path.join(t2m_dir, "swinrdm_t2m.nc")),
    "SwinRDM*": xr.open_dataset(os.path.join(t2m_dir, "ensemble_t2m.nc")),
}
plot_t2m(t2m_results, save_f)

In [None]:
fid_dir = os.path.join(save_dir, "fid")
save_f = os.path.join(fid_dir, "fid.png")
fid_results = {
    "Bilinear": [int(line.strip()) for line in open(os.path.join(fid_dir, 'bilinear_fid.txt')).readlines()],
    "SwinIR": [int(line.strip()) for line in open(os.path.join(fid_dir, 'swinir_fid.txt')).readlines()],
    "SwinRDM": [int(line.strip()) for line in open(os.path.join(fid_dir, 'swinrdm_fid.txt')).readlines()],
    r"SwinRDM*": [int(line.strip()) for line in open(os.path.join(fid_dir, 'ensemble_fid.txt')).readlines()],
}
# print(fid_results)
labels = fid_results.keys()
for name in fid_results:
    y = fid_results[name]
    print(f"{name}: {y[-1]}")
    x = np.linspace(6, 120, len(y))
    plt.plot(x, y)
    
plt.ylabel('FID')    
plt.xlabel('Forecast Time (Hours)')    
plt.legend(labels)
plt.savefig(save_f, bbox_inches='tight', pad_inches=0.0, transparent='true', dpi=300)
plt.show()

In [18]:
def csi(pred, gt, eps=1e-6):
    for th in [2, 5, 10, 20, 50]:
        pred_j = pred.values >= th 
        gt_j = gt.values >= th 
        tp = ((pred_j == 1) & (gt_j == 1)).sum()
        fp = ((pred_j == 1) & (gt_j == 0)).sum()
        fn = ((pred_j == 0) & (gt_j == 1)).sum()    
        csi = tp / (tp + fn + fp + eps)
        print(f"CSI{th}: {csi:.3f}")

In [31]:
def rmse(pred, gt, mean_dims=xr.ALL_DIMS):
    error = pred - gt
    weights_lat = np.cos(np.deg2rad(error.lat))
    weights_lat /= weights_lat.mean()
    rmse = np.sqrt(((error)**2 * weights_lat).mean(mean_dims))
    for name in rmse:
        score = rmse[name].values
        print(f'MSE_{name}: {score:.3f}')


In [39]:
gt_f = os.path.join(data_dir, "gt_14d.nc")
gt = xr.open_dataset(gt_f)
exp = "swinrnn256_unet128_lr0.0005_itr30000_vrnn_tf"
pred_dir = os.path.join(work_dir, "output", exp)
pred_f = os.path.join(pred_dir,  "predictions.pth")
pred = load_predicition(pred_f, lat=gt.lat, lon=gt.lon, init_time=gt.time, lead_time=336)

In [41]:
rmse(pred, gt)

MSE_z500: 933.232
MSE_t850: 3.982
MSE_t2m: 3.339
MSE_u10: 4.511
MSE_v10: 4.682
MSE_tp: 2.284
