In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl

import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from src.models import *
# from ilan_src.models import *
from src.dataloader import *
from src.utils import *
from src.evaluation import *

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

import pickle

if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

In [3]:
import sys

## Set up dataset - my way

In [4]:
members = 10
zero_noise = False

In [5]:
DATADRIVE = '/home/jupyter/data/'

In [6]:
# ds_train = TiggeMRMSDataset(
#     tigge_dir=f'{DATADRIVE}/tigge/32km/',
#     tigge_vars=['total_precipitation_ens10'],
#     mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
#     rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
#     data_period=('2018-01', '2019-12'),
#     val_days=5,
#     split='train',
#     tp_log=0.01,
#     ensemble_mode='random',
#     idx_stride=16
# )

In [7]:
# ds_train.mins.to_netcdf('tmp/mins1.nc')
# ds_train.maxs.to_netcdf('tmp/maxs1.nc')

In [8]:
mins = xr.open_dataset('tmp/mins1.nc')
maxs = xr.open_dataset('tmp/maxs1.nc')

In [10]:
ds_test = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation_ens10'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
#     rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
    data_period=('2020-01', '2020-12'),
    first_days=5,
    tp_log=0.01,
    mins=mins,
    maxs=maxs,
    ensemble_mode='random',
    idx_stride=16
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [11]:
# For ens_tp with stuff
ds_test_pad = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation_ens10'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
#     rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
    data_period=('2020-01', '2020-12'),
    first_days=5,
    tp_log=0.01,
    mins=mins,
    maxs=maxs,
    ensemble_mode='random',
    idx_stride=16,
    pad_tigge=10,
    pad_tigge_channel=True,
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


## Load models

### Model 1: single_forecast_tp_pure_sr_pretraining

In [12]:
name='single_forecast_tp_pure_sr_pretraining'

In [13]:
zero_noise=True

In [14]:
model_dir = '/home/jupyter/data/saved_models/saved_models/leingan/single_forecast_tp_pure_sr_pretraining/0'

In [15]:
sys.path.append(model_dir)

In [16]:
gan = BaseGAN2.load_from_checkpoint(
    f"{model_dir}/epoch=199-step=133999.ckpt")

In [17]:
model = gan.gen
model = model.to(device)
model.train(False);

## Model 2: ens_mean_L1_weighted_gen_loss

In [None]:
name='ens_mean_L1_weighted_gen_loss'

In [68]:
model_dir = '/home/jupyter/data/saved_models/saved_models/leingan/ens10_tp/random/ens_mean_L1_weighted_gen_loss/3'

In [69]:
sys.path.append(model_dir)

In [70]:
gan = BaseGAN2.load_from_checkpoint(
    f"{model_dir}/epoch=349-step=234499.ckpt")

In [71]:
model = gan.gen
model = model.to(device)
model.train(False);

### Model 3: ens10_tp_and_added_vars_TCW_broadfield_channel

In [None]:
name='ens10_tp_and_added_vars_TCW_broadfield_channel'

In [None]:
ds_test = ds_test_pad

In [92]:
model_dir = '/home/jupyter/data/saved_models/saved_models/leingan/ens10_tp_and_added_vars_TCW_broadfield_channel/15'

In [93]:
sys.path.append(model_dir)

In [94]:
sys.path

['/home/jupyter/repositories/nwp-downscale/notebooks/stephan_notebooks',
 '/opt/conda/envs/ilan/lib/python39.zip',
 '/opt/conda/envs/ilan/lib/python3.9',
 '/opt/conda/envs/ilan/lib/python3.9/lib-dynload',
 '',
 '/opt/conda/envs/ilan/lib/python3.9/site-packages',
 '/opt/conda/envs/ilan/lib/python3.9/site-packages/IPython/extensions',
 '/home/jupyter/.ipython',
 '/home/jupyter/data/saved_models/saved_models/leingan/ens10_tp_and_added_vars_TCW_broadfield_channel/15']

In [95]:
gan = BaseGAN2.load_from_checkpoint(
    f"{model_dir}/epoch=499-step=258499.ckpt")

TypeError: __init__() got an unexpected keyword argument 'input_channels'

In [79]:
model = gan.gen
model = model.to(device)
model.train(False);

## Full field eval

In [18]:
def create_valid_predictions(model, ds_valid, member_idx=None, zero_noise=False):
    # Get predictions for full field
    preds = []
    for t in tqdm(range(len(ds_valid.tigge.valid_time))):
        X, y = ds_valid.return_full_array(t, member_idx=member_idx)
        noise = torch.randn(1, X.shape[0], X.shape[1], X.shape[2]).to(device)
        if zero_noise:
            noise *= 0
        pred = model(torch.FloatTensor(X[None]).to(device), noise).to('cpu').detach().numpy()[0, 0]
        preds.append(pred)
    preds = np.array(preds)
    
    # Unscale
    preds = preds * (ds_valid.maxs.tp.values - ds_valid.mins.tp.values) + ds_valid.mins.tp.values
    
    # Un-log
    if ds_valid.tp_log:
        preds = log_retrans(preds, ds_valid.tp_log)
    
    # Convert to xarray
    preds = xr.DataArray(
        preds,
        dims=['valid_time', 'lat', 'lon'],
        coords={
            'valid_time': ds_valid.tigge.valid_time,
            'lat': ds_valid.mrms.lat.isel(
                lat=slice(ds_valid.pad_mrms, ds_valid.pad_mrms+preds.shape[1])
            ),
            'lon': ds_valid.mrms.lon.isel(
                lon=slice(ds_valid.pad_mrms, ds_valid.pad_mrms+preds.shape[2])
            )
        },
        name='tp'
    )
    return preds

In [19]:
def create_stitched_predictions(model, ds_test, member_idx, zero_noise=False):
    preds = ds_test.mrms.copy(True) * np.NaN
    for idx in tqdm(range(len(ds_test.idxs))):
        time_idx, lat_idx, lon_idx = ds_test.idxs[idx]
        lat_slice = slice(lat_idx * ds_test.ratio, lat_idx * ds_test.ratio + ds_test.patch_mrms)
        lon_slice = slice(lon_idx * ds_test.ratio, lon_idx * ds_test.ratio + ds_test.patch_mrms)

        X, y = ds_test.__getitem__(idx, member_idx=member_idx)

        noise = torch.randn(1, X.shape[0], X.shape[1], X.shape[2]).to(device)
        if zero_noise:
            noise *= 0
        p = model(torch.FloatTensor(X[None]).to(device), noise).to('cpu').detach().numpy()[0, 0]

        preds[time_idx, lat_slice, lon_slice] = p
    # Unscale
    preds = preds * (ds_test.maxs.tp.values - ds_test.mins.tp.values) + ds_test.mins.tp.values

    # Un-log
    if ds_test.tp_log:
        preds = log_retrans(preds, ds_test.tp_log)
    preds = preds.rename({'time': 'valid_time'})
    return preds

In [20]:
def create_valid_ensemble(model, ds_valid, nens, stitched=False, zero_noise=False):
    """Wrapper to create ensemble"""
    if stitched:
        fn = create_stitched_predictions
    else:
        fn = create_valid_predictions
    preds = [fn(model, ds_valid, member_idx=member_idx, zero_noise=zero_noise) for member_idx in range(nens)]
    return xr.concat(preds, 'member')

In [None]:
ens_pred = create_valid_ensemble(model, ds_test, members, zero_noise=zero_noise)

  8%|▊         | 9/110 [01:11<13:23,  7.96s/it]

In [None]:
ens_pred_stitched = create_valid_ensemble(model, ds_test, members, stitched=True, zero_noise=zero_noise)

In [None]:
ens_pred.to_netcdf(f'tmp/ens_pred_{name}.nc')
ens_pred_stitched.to_netcdf(f'tmp/ens_pred_stitched_{name}.nc')

In [None]:
ens_pred.isel(valid_time=1).plot(vmin=0, vmax=20, cmap='gist_ncar_r', col='member')

## Get ground truth

In [None]:
mrms = ds_test.mrms.rename(
    {'time': 'valid_time'}) * ds_test.maxs.tp.values
mrms = log_retrans(mrms, ds_test.tp_log)

In [None]:
mrms.to_netcdf('tmp/mrms.nc')

## Get interpolation baseline

In [None]:
tigge = ds_test.tigge.isel(variable=0) * ds_test.maxs.tp.values
tigge = log_retrans(tigge, ds_test.tp_log)

In [None]:
interp = tigge.interp_like(mrms, method='linear')

In [None]:
interp.to_netcdf('tmp/interp_ens.nc')

### HREF

In [None]:
href = xr.open_mfdataset('/home/jupyter/data/hrefv2//4km/total_precipitation/2020*.nc')

In [None]:
href = href.tp.diff('lead_time').sel(lead_time=np.timedelta64(12, 'h'))

In [None]:
href['valid_time'] = href.init_time + href.lead_time
href = href.swap_dims({'init_time': 'valid_time'})

In [None]:
href = href.assign_coords({'lat': interp.lat.values, 'lon': interp.lon.values})

In [None]:
overlap_times = np.intersect1d(interp.valid_time, href.valid_time)

In [None]:
href = href.sel(valid_time=overlap_times)

In [None]:
href.load();

In [None]:
href.to_netcdf('tmp/href.nc')

# Old

## Get mask

In [None]:
ds = xr.open_dataset(
    '/home/jupyter/data/hrrr/raw/total_precipitation/20180215_00.nc')

In [None]:
from src.regrid import *

In [None]:
ds_regridded = regrid(ds, 4, lons=(235, 290), lats=(50, 20))

In [None]:
hrrr_mask = np.isfinite(ds_regridded).tp.isel(init_time=0, lead_time=0)

In [None]:
rq.plot(vmin=0, vmax=1)

In [None]:
(rq>0.3).plot(vmin=0, vmax=1)

In [None]:
mrms_mask.plot(vmin=0, vmax=1)

In [None]:
rq = xr.open_dataarray(f'{DATADRIVE}/mrms/4km/RadarQuality.nc')
mrms_mask = rq>-1
mrms_mask = mrms_mask.assign_coords({
    'lat': hrrr_mask.lat,
    'lon': hrrr_mask.lon
})

In [None]:
total_mask = mrms_mask * hrrr_mask

In [None]:
total_mask = total_mask.isel(lat=slice(0, -6))

In [None]:
total_mask = total_mask.assign_coords({'lat': interp.lat.values, 'lon': interp.lon.values})

In [None]:
total_mask.plot()

## Compute scores

In [None]:
hrrr = hrrr.isel(lat=slice(0, -6))

In [None]:
hrrr = hrrr.assign_coords({'lat': interp.lat.values, 'lon': interp.lon.values})

In [None]:
# Apply mask
mrms = mrms.where(total_mask)
det_pred = det_pred.where(total_mask)
hrrr = hrrr.where(total_mask)
interp = interp.where(total_mask)

In [None]:
det_pred2 = det_pred2.where(total_mask)


In [None]:
hrrr.load()

## Bias

In [None]:
mrms.mean().values

In [None]:
det_pred.mean().values

In [None]:
interp.mean().values

In [None]:
hrrr.mean().values

### Histograms

In [None]:
bins = np.logspace(0, 2, 25)-1
mid_bin = (bins[1:] + bins[:-1])/2

In [None]:
def plot_hist(ds, bins, label):
    nums, bins = np.histogram(ds.values, bins=bins)
    plt.plot(mid_bin, nums, marker='o', label=label)

In [None]:
plt.figure(figsize=(10, 5))
plot_hist(det_pred, bins, 'GAN')
plot_hist(mrms, bins, 'Obs')
plot_hist(interp, bins, 'Interp')
plot_hist(hrrr, bins, 'HRRR')
plt.yscale('log')
plt.legend()

### RMSE

In [None]:
xs.rmse(det_pred, mrms, dim=['lat', 'lon', 'valid_time'], skipna=True).values

In [None]:
xs.rmse(interp, mrms, dim=['lat', 'lon', 'valid_time'], skipna=True).values

In [None]:
xs.rmse(hrrr, mrms, dim=['lat', 'lon', 'valid_time'], skipna=True).values

### FSS

In [None]:
thresh = 10
window = 100 // 4

In [None]:
def compute_fss(f, o, thresh, window, time_mean=True):
    f_thresh = f > thresh
    o_thresh = o > thresh
    f_frac = f_thresh.rolling({'lat': window, 'lon': window}, center=True).mean()
    o_frac = o_thresh.rolling({'lat': window, 'lon': window}, center=True).mean()
    mse = ((f_frac - o_frac)**2).mean(('lat', 'lon'))
    mse_ref = (f_frac**2).mean(('lat', 'lon')) + (o_frac**2).mean(('lat', 'lon'))
    fss = 1 - mse / mse_ref
    if time_mean:
        fss = fss.mean('valid_time')
    return fss

In [None]:
compute_fss(mrms, det_pred, thresh, window).values

In [None]:
compute_fss(mrms, interp, thresh, window).values

In [None]:
compute_fss(mrms, hrrr, thresh, window).values

In [None]:
fig, ax = plt.subplots(figsize=(20, 10))
det_pred.isel(valid_time=2).plot(vmin=0, vmax=20)
ax.set_aspect('equal')

In [None]:
fig, ax = plt.subplots(figsize=(20, 10))
det_pred2.isel(valid_time=2).plot(vmin=0, vmax=20)
ax.set_aspect('equal')

In [None]:
fig, ax = plt.subplots(figsize=(20, 10))
mrms.isel(valid_time=2).plot(vmin=0, vmax=20)
ax.set_aspect('equal')

In [None]:
fig, ax = plt.subplots(figsize=(20, 10))
interp.isel(valid_time=2).plot(vmin=0, vmax=20)
ax.set_aspect('equal')

In [None]:
fig, ax = plt.subplots(figsize=(20, 10))
hrrr.isel(valid_time=2).plot(vmin=0, vmax=20)
ax.set_aspect('equal')

In [None]:
interp[50].plot(vmin=0, vmax=20)

In [None]:
hrrr[50].plot(vmin=0, vmax=20)

In [None]:

 eps = 1e-6
bin_edges = [-eps] + np.linspace(eps, log_retrans(ds_max, tp_log)+eps, 51).tolist()
pred_means.append(np.mean(preds.sel(member=0)))
pred_hists.append(np.histogram(preds.sel(member=0), bins = bin_edges, density=False)[0])
truth_means.append(np.mean(truth))
truth_hists.append(np.histogram(truth, bins = bin_edges, density=False)[0])

truth_pert = truth + np.random.normal(scale=1e-6, size=truth.shape)
preds_pert = preds + np.random.normal(scale=1e-6, size=preds.shape) 
