In [None]:
%matplotlib inline

import numpy as np
import xarray as xr
from datetime import datetime
from datetime import timedelta

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

color_list = np.array([(255, 255, 255),  # 0.0
                       (245, 245, 255),  # 0.2
                       (180, 180, 255),  # 0.5
                       (120, 120, 255),  # 1.5
                       (20,  20, 255),   # 2.5
                       (0, 216, 195),    # 4.0
                       (0, 150, 144),    # 6.0
                       (0, 102, 102),    # 10
                       (255, 255,   0),  # 15
                       (255, 200,   0),  # 20
                       (255, 150,   0),  # 30
                       (255, 100,   0),  # 40
                       (255,   0,   0),  # 50
                       (200,   0,   0),  # 60
                       (120,   0,   0),  # 75
                       (40,   0,   0)])  # > 100

color_list = color_list/255.
bom_cm = LinearSegmentedColormap.from_list("BOM-RF3", color_list, N=32)
bom_cm.set_bad(color='gray')

In [None]:
tiles = {"VICTAS": [1130.8837, -4376.935], 
         "SYDM": [1740.4043, -3757.409], 
         "W_NSW": [935.5831, -3523.967], 
         "SA": [189.0527, -3376.863], 
         "SE_WA": [-661.3803, -3393.506], 
         "SW_WA": [-1414.6972, -3458.529], 
         "NW_WA": [-1519.7374, -2568.012], 
         "NE_WA": [-710.4872, -2498.161], 
         "N_SA": [203.0897, -2480.281], 
         "W_QLD": [1010.0804, -2574.165], 
         "SE_QLD": [1790.4011, -2828.908], 
         "NW_NT": [-704.899, -1609.515], 
         "NT": [216.9767, -1593.277], 
         "N_QLD": [867.3753, -1618.745], 
         "NE_QLD": [1597.875, -1905.570], 
         "NW_WA_COAST": [-1623.6555, -1687.006]}

tiles

In [None]:
def gpm_pred(loc, d):
    
    h8_fp = "/data/pluvi_pondus/HIM8_AU_2B/HIM8_2B_AU_{}.nc".format(d.strftime("%Y%m%d"))
    h8_ds = xr.open_dataset(h8_fp)

    b8 = h8_ds.B8.sel(time=d).data[::2, ::2]
    b14 = h8_ds.B14.sel(time=d).data[::2, ::2]
    h8_ds.close()

    x = np.stack((b8,b14), axis=-1)
    
    mse = load_model('../small_z/unet_mse_gpm.h5')
    out = np.zeros((513, 613), dtype=np.float32)
    out[:-1,:-101] = mse.predict(x[None,:-1,:-201,:])[:,:,:,0]
    out[1:,101:] = mse.predict(x[None,1:,201:,:])[:,:,:,0]
    
    ds = xr.open_dataset("/data/pluvi_pondus/GPM/GPM_BoM_201811.nc")
    ds = ds.sel(time=d)
    ds = ds.drop(["albers_conical_equal_area","PrecCal"])
    ds['prec'] = (('y', 'x'), out)
    
    if loc:
        return ds.prec.sel(x=slice((loc[0]-512)*1000,(loc[0]+512)*1000),
                           y=slice((loc[1]+512)*1000,(loc[1]-512)*1000))
    else:
        return ds.prec
    
    
def gpm_predz(loc, d):
    
    dsz1000 = xr.open_dataset("/data/pluvi_pondus/ERA5/au_z1000_201811.nc")
    dsz800 = xr.open_dataset("/data/pluvi_pondus/ERA5/au_z800_201811.nc")
    dsz500 = xr.open_dataset("/data/pluvi_pondus/ERA5/au_z500_201811.nc")
    
    z500 = dsz500['z'].sel(time=d, method='nearest').values
    z800 = dsz800['z'].sel(time=d, method='nearest').values
    z1000 = dsz1000['z'].sel(time=d, method='nearest').values
    
    z = np.stack((z1000,z800,z500), axis=-1)

    
    h8_fp = "/data/pluvi_pondus/HIM8_AU_2B/HIM8_2B_AU_{}.nc".format(d.strftime("%Y%m%d"))
    h8_ds = xr.open_dataset(h8_fp)

    b8 = h8_ds.B8.sel(time=d).data[::2, ::2]
    b14 = h8_ds.B14.sel(time=d).data[::2, ::2]
    h8_ds.close()

    x = np.stack((b8,b14), axis=-1)
    
    mse = load_model('../small_z/unet_mse_gpm_z.h5')
    out = np.zeros((513, 613), dtype=np.float32)
    out[:-1,:-101] = mse.predict([x[None,:-1,:-201,:], z[None,:,:-50,:]])[:,:,:,0]
    out[1:,101:] = mse.predict([x[None,1:,201:,:], z[None,:,50:,:]])[:,:,:,0]
    
    ds = xr.open_dataset("/data/pluvi_pondus/GPM/GPM_BoM_201811.nc")
    ds = ds.sel(time=d)
    ds = ds.drop(["albers_conical_equal_area","PrecCal"])
    ds['prec'] = (('y', 'x'), out)
    
    if loc:
        return ds.prec.sel(x=slice((loc[0]-512)*1000,(loc[0]+512)*1000),
                           y=slice((loc[1]+512)*1000,(loc[1]-512)*1000))
    else:
        return ds.prec

    
def plot_rainfields_pred(loc, d):
    
    h8_fp = "/data/pluvi_pondus/HIM8_AU_2B/HIM8_2B_AU_{}.nc".format(d.strftime("%Y%m%d"))
    h8_ds = xr.open_dataset(h8_fp)
    

    b8 = h8_ds.B8.sel(time=d).data[::2, ::2]
    b14 = h8_ds.B14.sel(time=d).data[::2, ::2]
    h8_ds.close()

    x = np.stack((b8,b14), axis=-1)
    
    mse = load_model('../small_z/unet_mse_rainfields.h5')
    
    out = np.zeros((1025,1225), dtype=np.float32)
    out[:-1,:-201] = mse.predict(x[None,:-1,:-201,:])[:,:,:,0]
    out[1:,201:] = mse.predict(x[None,1:,201:,:])[:,:,:,0]
    
    h8_ds = h8_ds.sel(time=d)
    h8_ds = h8_ds.drop(["albers_conical_equal_area","B8","B14"])
    h8_ds['y'] = h8_ds['y'].values[::2]
    h8_ds['x'] = h8_ds['x'].values[::2]
    h8_ds['prec'] = (('y', 'x'), out)
    
    
    h8_ds = h8_ds.sel(x=slice((loc[0]-512)*1000,(loc[0]+512)*1000),
                y=slice((loc[1]+512)*1000,(loc[1]-512)*1000))

    h8_ds.prec.plot(cmap=bom_cm, vmin=0, vmax=5)
    
    
def plot_rainfields_predz(loc, d):
    
    dsz1000 = xr.open_dataset("/data/pluvi_pondus/ERA5/au_z1000_201811.nc")
    dsz800 = xr.open_dataset("/data/pluvi_pondus/ERA5/au_z800_201811.nc")
    dsz500 = xr.open_dataset("/data/pluvi_pondus/ERA5/au_z500_201811.nc")
    
    z500 = dsz500['z'].sel(time=d, method='nearest').values
    z800 = dsz800['z'].sel(time=d, method='nearest').values
    z1000 = dsz1000['z'].sel(time=d, method='nearest').values
    
    z = np.stack((z1000,z800,z500), axis=-1)
    
    h8_fp = "/data/pluvi_pondus/HIM8_AU_2B/HIM8_2B_AU_{}.nc".format(d.strftime("%Y%m%d"))
    h8_ds = xr.open_dataset(h8_fp)

    b8 = h8_ds.B8.sel(time=d).data[::2, ::2]
    b14 = h8_ds.B14.sel(time=d).data[::2, ::2]
    h8_ds.close()

    x = np.stack((b8,b14), axis=-1)
    
    mse = load_model('../small_z/unet_mse_rainfields_z.h5')
    
    out = np.zeros((1025,1225), dtype=np.float32)
    out[:-1,:-201] = mse.predict([x[None,:-1,:-201,:], z[None,:,:-50,:]])[:,:,:,0]
    out[1:,201:] = mse.predict([x[None,1:,201:,:], z[None,:,50:,:]])[:,:,:,0]
    
    h8_ds = h8_ds.sel(time=d)
    h8_ds = h8_ds.drop(["albers_conical_equal_area","B8","B14"])
    h8_ds['y'] = h8_ds['y'].values[::2]
    h8_ds['x'] = h8_ds['x'].values[::2]
    h8_ds['prec'] = (('y', 'x'), out)
    
    
    
    h8_ds = h8_ds.sel(x=slice((loc[0]-512)*1000,(loc[0]+512)*1000),
                y=slice((loc[1]+512)*1000,(loc[1]-512)*1000))

    h8_ds.prec.plot(cmap=bom_cm, vmin=0, vmax=5)

In [None]:
def rainfields(loc, date):

    ds = xr.open_dataset("/data/pluvi_pondus/Rainfields/310_{}_{}.prcp-c10.nc".format(d.strftime("%Y%m%d"),
                                                                                            d.strftime("%H%M%S")))

    if loc:
        return ds.precipitation.sel(x=slice(loc[0]-512,loc[0]+512),
                                    y=slice(loc[1]+512,loc[1]-512))
    else:
        return ds.precipitation


ds = rainfields(tiles["SYDM"], d)
ds.plot(cmap=bom_cm, vmin=0, vmax=5)

In [None]:
def gpm(loc, d):
    ds = xr.open_dataset("/data/pluvi_pondus/GPM/GPM_BoM_201811.nc")

    ds = ds.sel(time=d)
    
    if loc:
        return ds.PrecCal.sel(x=slice((loc[0]-512)*1000,(loc[0]+512)*1000),
                              y=slice((loc[1]+512)*1000,(loc[1]-512)*1000))
    else:
        return ds.PrecCal

ds = gpm(tiles["SYDM"], d)
ds.plot(cmap=bom_cm, vmin=0, vmax=5)

In [None]:
def crr(loc, date):
    ds = xr.open_dataset("/data/pluvi_pondus/Himawari-CRR/{}{}-P1S-ABOM_CRR-PRJ_AEA132_2000-HIMAWARI8-AHI.nc".format(d.strftime("%Y%m%d"), d.strftime("%H%M%S")))

    ds = ds.isel(time=0)
    
    if loc:
        return ds.precipitation_flux.sel(x=slice((loc[0]-512)*1000,(loc[0]+512)*1000),
                                         y=slice((loc[1]+512)*1000,(loc[1]-512)*1000))
    else:
        return ds.precipitation_flux

ds = crr(tiles["SYDM"], d)
ds.plot(cmap=bom_cm, vmin=0, vmax=5)

In [None]:
d = datetime(2018, 11, 5, 0, 0)

plot_gpm(tiles["SYDM"], d)

In [None]:
ds = gpm_pred(tiles["SYDM"], d)
ds.plot(cmap=bom_cm, vmin=0, vmax=5)

In [None]:
ds = gpm_predz(tiles["SYDM"], d)
ds.plot(cmap=bom_cm, vmin=0, vmax=5)

In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model

In [None]:
plot_gpm_pred(tiles["SYDM"], d)

In [None]:
plot_gpm_predz(tiles["SYDM"], d)

In [None]:
plot_rainfields_pred(tiles["SYDM"], d)

In [None]:
plot_rainfields_predz(tiles["SYDM"], d)