## Notebook to run retrieval for GMI IWP

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import ipywidgets as w
import matplotlib.pyplot as plt
import numpy as np
import netCDF4
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
from torch.utils.data import DataLoader, random_split
from iwc2tb.GMI.gmiSatData import gmiSatData
from iwc2tb.GMI.gmiData import gmiData
import os
import numpy.ma as ma
import glob
from iwc2tb.GMI.write_training_data import *
from tqdm.notebook import tqdm
import xarray
import torch

from quantnn.qrnn import QRNN
import time


## Set inputs

In [2]:
batchSize          = 4
inputs             = np.array(["ta", "t2m",  "wvp", "z0", "stype"])
outputs            = "iwp"

latlims            = [0, 65]

training_data      = gmiData(os.path.expanduser("~/Dendrite/Projects/IWP/GMI/training_data/TB_GMI_train.nc"), 
                             inputs, 
                             outputs,
                             batch_size = batchSize,
                             latlims = latlims,
                             )


norm               = training_data.norm


year               = '2020'
month              = '01'


inpath             = os.path.join(os.path.expanduser('~/Dendrite/SatData/GMI/L1B/'), year, month)
gmifiles           = glob.glob(os.path.join(inpath, "*/*.HDF5"))

outpath            = "~/Dendrite/UserAreas/Kaur/IWP/"

qrnnfile           = "qrnn_gmi_adam.nc"
qrnn               =  QRNN.load(os.path.join(os.path.expanduser('~/Dendrite/Projects/IWP/GMI/training_data/try_training/')
                                                  , qrnnfile))
quantiles          = qrnn.quantiles
imedian            = np.argwhere((quantiles >= 0.50) & (quantiles < 0.51))[0][0]


In [3]:
def get_pos_mean(validation_data, qrnn):
    
    quantiles  = qrnn.quantiles
    nquantiles = len(quantiles)
    
    y          = validation_data.y
    y_mean     = np.zeros(validation_data.y.shape)
    y_pre      = np.zeros([validation_data.y.shape[0],validation_data.y.shape[1], len(quantiles)])



    nbatch = len(validation_data)
    ibatch = validation_data.batch_size
    nscan  = validation_data.y.shape[0]
    npix   = validation_data.y.shape[1]
    
    with torch.no_grad():
        for i in range(nbatch):
            
            istart                      = ibatch * i
            iend                        = ibatch * (i + 1) 

            xx, yy                      = validation_data[i]

            xx                          = xx.reshape(-1, 18)
            y_pred                      = qrnn.predict(xx)
            y_pre[istart:iend, :, :]    = y_pred.reshape(-1, npix, nquantiles).detach().numpy()
            y_mean[istart:iend]         = qrnn.posterior_mean(xx, y_pred).reshape(-1, npix).detach().numpy() 
                   
    return y, y_pre, y_mean

In [4]:
def get_coords(validation_data):
    lat   = validation_data.lat
    lon   = validation_data.lon%360
    t     = validation_data.lst
    stype = validation_data.stype
    tb    = validation_data.x[:, :, :4]
    return tb, lat, lon, stype, t

In [5]:
def predict_iwp(qrnn, filename):


    gmi_s    = GMI_Sat(filename)  

    validation_data    = gmiSatData(gmi_s, 
                             inputs, outputs,
                             batch_size = batchSize,
                             latlims = latlims,
                             normalize = norm,      
                             )

    y, y_pre, y_mean = get_pos_mean(validation_data, qrnn)
    
    
    
    tb, lat,  lon, stype, t    = get_coords(validation_data)
    
    return y, y_pre, y_mean, lat, lon, tb, stype, t
    

In [6]:
def plot_iwp(lat, lon, iwp0, iwp,tb, mask):
    
    fig, ax = plt.subplots(1, 2, figsize = [12, 6])
    ax = ax.ravel()
    diff = 100 * (np.exp(np.abs(np.log(iwp/iwp0))) - 1)
    
    bbox = [np.min(lon),np.min(lat),np.max(lon),np.max(lat)] # set bounds for plotting
    n_add = 0
    m = Basemap(llcrnrlon=bbox[0]-n_add,llcrnrlat=bbox[1]-n_add,
                urcrnrlon=bbox[2]+n_add,urcrnrlat=bbox[3]+n_add,resolution='l',
                projection='cyl')
    x, y = m(lon, lat)    
    for var, axes, t in zip([iwp0, iwp], ax, ["GMI", "QRNN"]):
        
        cs = axes.scatter(lon[mask],lat[mask], c = var[mask]* 1000,
                          norm=colors.LogNorm(), vmin = 1, vmax = 10000,)
                        # cmap = cm.Paired)
            
        #cs = m.scatter(lon, lat, var[mask]* 1000, ax = axes)    
        axes.set_title(t)
        ax[0].set_ylabel("Latitude [deg]")
        #axes.set_xlabel("Longitude [deg]")
    cbar = fig.colorbar(cs, ax=[ax[0], ax[1]])
    cbar.ax.set_ylabel("IWP [g/m2]")
      

def get_mask(lat, lon, latlims, lonlims):
    
    im  = (lat >= latlims[0]) & (lat <= latlims[1])
    im1 = (lon >=  lonlims[0]) & (lon < lonlims[1])
    mask  = np.logical_and(im, im1)
    
    return mask    

## Loop over all files

In [None]:
for gmifile in tqdm(gmifiles[:]):
    print (gmifile)
    bname = os.path.basename(gmifile)
    outfile = os.path.join(outpath, bname[:-5] + ".nc")
    try:
        y, y_pre, y_pos_mean, lat, lon, tb, stype, t = predict_iwp(qrnn, gmifile)
        y_pre[y_pre < 0] = 0
        y_pos_mean[y_pos_mean < 0] = 0


        stype = np.argmax(stype, axis = 2)

        print (np.nanmax(y_pos_mean))

        d = xarray.Dataset({
        "iwp": (["scans", "pixels"], y_pre[:, :, imedian]),
        "iwp_mean": (["scans", "pixels"], y_pos_mean),
        "stype": (["scans", "pixels"], stype),
        "local_time": (["scans", "pixels"], t),

        },
        coords={
        "lon": (["scans", "pixels"], lon),
        "lat": (["scans", "pixels"], lat),
        #"quantiles": (["quantiles"], qrnn.quantiles)
        })

        d.to_netcdf(outfile, mode = "w")


        
        
        
    except:
        print ("file not available")
        continue

  0%|          | 0/482 [00:00<?, ?it/s]

/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/01/1B.GPM.GMI.TB2016.20200101-S173314-E190548.033199.V05A.HDF5
20.60029411315918
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/01/1B.GPM.GMI.TB2016.20200101-S112255-E125529.033195.V05A.HDF5
13.566424369812012
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/01/1B.GPM.GMI.TB2016.20200101-S051236-E064510.033191.V05A.HDF5
25.052907943725586
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/01/1B.GPM.GMI.TB2016.20200101-S064511-E081744.033192.V05A.HDF5
39.237117767333984
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/01/1B.GPM.GMI.TB2016.20200101-S003452-E020726.033188.V05A.HDF5
16.070409774780273
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/01/1B.GPM.GMI.TB2016.20200101-S203824-E221057.033201.V05A.HDF5
22.613609313964844
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/01/1B.GPM.GMI.TB2016.20200101-S020727-E034000.033189.V05A.HDF5
19.81365394592285
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/01/1B.GPM.GMI.TB2016.2020010

26.47481918334961
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/05/1B.GPM.GMI.TB2016.20200105-S184542-E201815.033262.V05A.HDF5
26.320594787597656
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/05/1B.GPM.GMI.TB2016.20200105-S154032-E171306.033260.V05A.HDF5
24.447803497314453
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/05/1B.GPM.GMI.TB2016.20200105-S232326-E005559.033265.V05A.HDF5
25.780759811401367
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/05/1B.GPM.GMI.TB2016.20200105-S075739-E093012.033255.V05A.HDF5
26.35340690612793
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/05/1B.GPM.GMI.TB2016.20200105-S062504-E075738.033254.V05A.HDF5
24.473726272583008
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/05/1B.GPM.GMI.TB2016.20200105-S215051-E232325.033264.V05A.HDF5
30.716976165771484
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/05/1B.GPM.GMI.TB2016.20200105-S140758-E154031.033259.V05A.HDF5
25.46816062927246
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/05/1B.GPM.

25.918312072753906
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/09/1B.GPM.GMI.TB2016.20200109-S060442-E073714.033316.V05A.HDF5
29.354598999023438
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/09/1B.GPM.GMI.TB2016.20200109-S090949-E104222.033318.V05A.HDF5
20.951549530029297
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/09/1B.GPM.GMI.TB2016.20200109-S104223-E121455.033319.V05A.HDF5
26.415390014648438
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/09/1B.GPM.GMI.TB2016.20200109-S152003-E165236.033322.V05A.HDF5
20.931108474731445
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/09/1B.GPM.GMI.TB2016.20200109-S043208-E060441.033315.V05A.HDF5
18.449617385864258
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/09/1B.GPM.GMI.TB2016.20200109-S230251-E003524.033327.V05A.HDF5
34.158164978027344
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/09/1B.GPM.GMI.TB2016.20200109-S182510-E195743.033324.V05A.HDF5
33.005836486816406
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/09/1B.G

20.58016586303711
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/13/1B.GPM.GMI.TB2016.20200113-S210854-E224126.033388.V05A.HDF5
20.66720199584961
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/13/1B.GPM.GMI.TB2016.20200113-S041046-E054318.033377.V05A.HDF5
23.855968475341797
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/13/1B.GPM.GMI.TB2016.20200113-S054319-E071552.033378.V05A.HDF5
32.43690490722656
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/13/1B.GPM.GMI.TB2016.20200113-S102100-E115332.033381.V05A.HDF5
33.55018997192383
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/13/1B.GPM.GMI.TB2016.20200113-S023812-E041045.033376.V05A.HDF5
25.76410484313965
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/13/1B.GPM.GMI.TB2016.20200113-S084826-E102059.033380.V05A.HDF5
25.23613929748535
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/13/1B.GPM.GMI.TB2016.20200113-S163113-E180346.033385.V05A.HDF5
22.890071868896484
/home/inderpreet/Dendrite/SatData/GMI/L1B/2020/01/13/1B.GPM.GMI

In [None]:
gmi_s    = GMI_Sat(gmifiles[:1])  

validation_data    = gmiSatData(gmi_s, 
                         inputs, outputs,
                         batch_size = batchSize,
                         latlims = latlims,
                         normalize = norm,      
                         log = xlog)

In [None]:
fig, ax = plt.subplots(1, 1, figsize = [6, 6])
rndinds = np.random.randint(1, 2700, 1000)
for i in rndinds:
    ax.plot(quantiles, y_pre[i, 120, :], 'b', alpha = 0.2)
ax.set_xlabel("quantiles")
ax.set_ylabel("IWP")
fig.savefig("quantiles.png")

In [None]:
y_pre.max()

In [None]:
import pickle
IWP  = np.concatenate(IWP, axis = 0)
LAT  = np.concatenate(LAT, axis = 0)
LON  = np.concatenate(LON, axis = 0)
LSM  = np.concatenate(LSM, axis = 0)
IWP0 = np.concatenate(IWP0, axis = 0)
IWP_mean = np.concatenate(IWP_mean, axis = 0)
with open("jan2020_IWP.pickle", "wb") as f:
    pickle.dump(IWP, f)
    pickle.dump(IWP0, f)
    pickle.dump(IWP_mean, f)
    pickle.dump(LON, f)
    pickle.dump(LAT, f)
    pickle.dump(LSM, f)
    
    
    f.close()