Can we capture the variability in both precip. and prob. of precip. with a single NN?

In [1]:
import os
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data.sampler import Sampler

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from torch.distributions import Beta
import wandb



In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import itertools
import matplotlib
import datetime as dt
import glob
import bisect
import zipfile

from Data_loader import prc_norm, LoadTraining, Normalize, Custom_Sampler

In [4]:
import importlib
# importlib.reload(Data_loader)


In [5]:
### Declare directories ###

PRECIP_ONLY=False
if PRECIP_ONLY:
    DIR_STR='precip_only/'
else:
    DIR_STR=''

fils_lrh=glob.glob('/neelin2020/ML_input/gpm2a_dpr_era5/npy_files/{}lrh/lrh_*npy'.format(DIR_STR))
fils_lrh.sort()

fils_conv_prc=glob.glob('/neelin2020/ML_input/gpm2a_dpr_era5/npy_files/{}conv_rain/gpm_conv_rain_*npy'.format(DIR_STR))
fils_conv_prc.sort()

fils_imerg_prctm1=glob.glob('/neelin2020/ML_input/gpm2a_dpr_era5/npy_files/{}imerg_bk_rain/imerg_rain_bk_*npy'.format(DIR_STR))
fils_imerg_prctm1.sort()

fils_conv_nn=glob.glob('/neelin2020/ML_input/gpm2a_dpr_era5/npy_files/precip_only/{}gpm_conv_neighbor_rain_*npy'.format(DIR_STR))
fils_conv_nn.sort()

conv_prc_dir='/neelin2020/ML_input/gpm2a_dpr_era5/npy_files/{}conv_rain/'.format(DIR_STR)
imerg_prc_tm1_dir='/neelin2020/ML_input/gpm2a_dpr_era5/npy_files/{}imerg_bk_rain/'.format(DIR_STR)
conv_nn_dir='/neelin2020/ML_input/gpm2a_dpr_era5/npy_files/{}conv_nn_pr/'.format(DIR_STR)

prc_norm_dict={'xbar':0,
               'normalizer':prc_norm}

# prc_norm_dict={'xbar':0,
#                'normalizer':prc_log_std}


lrh_norm_dict={'xbar':0,
               'normalizer':1}


In [6]:
BATCH_SIZE=256
transformed_samples=LoadTraining(fils_lrh[:150], conv_prc_dir, imerg_prc_tm1_dir, conv_nn_dir, 
                                 batch_size=BATCH_SIZE, 
                                 transform=Normalize(prc_norm_dict,lrh_norm_dict))

custom_dataloader = torch.utils.data.DataLoader(transformed_samples, batch_size=None,
                                                num_workers=6, 
                                                sampler=Custom_Sampler(len(transformed_samples),
                                                                                      BATCH_SIZE,
                                                                                      transformed_samples.array_sizes))

In [7]:
print(" {:d} samples, {:d} batches".format(len(transformed_samples),len(custom_dataloader))) ###

 12680448 samples, 49533 batches


### Check dataloader output ###

In [None]:
startTime = datetime.now()

lrh=[]
conv_prc=[]
imerg_prc_tm1=[]
conv_nn_prc=[]

for i_batch, samples in enumerate(custom_dataloader):
#     print(i_batch,samples)
    lrh.append(samples['lrh'].detach().numpy())
    conv_prc.append(samples['conv_prc'].detach().numpy())
    imerg_prc_tm1.append(samples['imerg_prc_tm1'].detach().numpy())
    conv_nn_prc.append(samples['conv_nn_prc'].detach().numpy())
    
    if i_batch==50:
        break
print("{:.2f} minutes".format((datetime.now() - startTime).total_seconds()/60)) 



In [None]:
fig,axx=plt.subplots(2,2,figsize=(10,5))

ax=axx[0,0]
ax.hist(np.concatenate(lrh).flatten())

ax=axx[0,1]
ax.hist(np.concatenate(conv_prc)*prc_norm.flatten())
ax.set_yscale('log')

ax=axx[1,0]
ax.hist(np.concatenate(conv_nn_prc)*prc_norm.flatten())
ax.set_yscale('log')

ax=axx[1,1]
ax.hist(np.concatenate(imerg_prc_tm1)*prc_norm.flatten())
ax.set_yscale('log')



### Import model ###

In [8]:
import VAE_models_CVAE_one_endec #as VAE
import sys

In [9]:
importlib.reload(VAE_models_CVAE_one_endec)

# sys.modules['VAE'] = VAE_models_CVAE_one_endec


<module 'VAE_models_CVAE_one_endec' from '/home/fiaz/ML/vae/Exploring_latents/VAE_models_CVAE_one_endec.py'>

In [10]:
LATENT_DIMS=2
INPUT_DIMS=2
NN_DIMS=4
model=VAE_models_CVAE_one_endec.CVAE_ORG_mod(LATENT_DIMS,INPUT_DIMS,NN_DIMS).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)

### Test

In [11]:
for i_batch, sample_batched in enumerate(custom_dataloader):
    print(sample_batched.keys())
    data=torch.stack((sample_batched['lrh'],
                      sample_batched['conv_prc']),dim=1).unsqueeze(2)
    
    optimizer.zero_grad()
    outputs=model(data)
    kl_loss, binary_loss, gamma_loss=model.vae_loss(data.squeeze(),outputs,0)
        
    syn_size=1_000_000
    z=torch.normal(mean=0.,std=1.,
                                size=(syn_size,LATENT_DIMS))
    VAE_models_CVAE_one_endec.print_params(z,model,torch.tensor([0.6]))

    
    break

dict_keys(['lrh', 'conv_prc', 'imerg_prc_tm1', 'conv_nn_prc'])
zmax: 3.00
shape: 3.19, scale: 1.00
rain prob.: 0.79
------------------------
zmin: -3.00
shape: 1.12, scale: 1.00
rain prob.: 0.72


### Train

In [13]:
EPOCHS = 50


In [14]:
# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="precip-VAE",
    # track hyperparameters and run metadata
    config={
    "learning_rate": optimizer.param_groups[0]['lr'],
    "architecture": "VAE",
    "dataset": "GPM/ERA5",
    "epochs": EPOCHS,
    "batch size":BATCH_SIZE,
    "NN DIM": NN_DIMS,
    "Latent DIM": LATENT_DIMS,
    "Predictor": "lrh"
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfahmed[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
startTime1 = datetime.now()
SAVE_MODEL=True
MODEL_NAME_STR='cvae_gamma_conv_rain_singleED_NN=4_LD=2'

fig,axx=plt.subplots(1,1,figsize=(6,4))
ax=axx
losses={'elbo':[],'gamma':[], 
        'binary':[], 'kl':[]}

torch.manual_seed(0)
epoch_number = 0
best_vloss = 1_000_000.

syn_size=1_000_000
z=torch.normal(mean=0.,std=1.,
                            size=(syn_size,LATENT_DIMS))
for epoch in range(EPOCHS):
    startTime2 = datetime.now()
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    mean_ELBO, mean_gamma_loss, mean_KL_loss, mean_binary_loss\
    = VAE_models_CVAE_one_endec.train_one_epoch(epoch_number,custom_dataloader,model,optimizer)
    
    losses['elbo'].append(mean_ELBO)
    losses['gamma'].append(mean_gamma_loss)
    losses['binary'].append(mean_binary_loss)
    losses['kl'].append(mean_KL_loss)
    
    # We don't need gradients on to do reporting
    model.train(False)
    
    VAE_models_CVAE_one_endec.print_params(z,model,torch.tensor([0.8]))

    ax.scatter(epoch,mean_ELBO,color='black')
    ax.scatter(epoch,mean_gamma_loss,color='red')
    ax.scatter(epoch,mean_KL_loss,color='blue')
    ax.scatter(epoch,mean_binary_loss,color='orange')
    
    epoch_number += 1
    print("Time for epoch: {:.2f} minutes".format((datetime.now() - startTime2).total_seconds()/60))
    
    wandb.log({"ELBO": mean_ELBO, "gamma loss": mean_gamma_loss,
              "binary loss": mean_binary_loss, "KL loss": mean_KL_loss})

    
    if SAVE_MODEL:
        dir_name='/home/fiaz/ML/vae/models/'

    #     model_name_prev='cvae_conv_rain_nn_memory_gbeta-bern_{}_epochs.pth'.format(epoch_number-1)
    #     model_name='cvae_conv_rain_nn_memory_gbeta-bern_{}_epochs.pth'.format(epoch_number)

        model_name_prev=MODEL_NAME_STR+'_{}_epochs.pth'.format(epoch_number-1)
        model_name=MODEL_NAME_STR+'_{}_epochs.pth'.format(epoch_number)

        if epoch_number>1:
            os.remove(dir_name+model_name_prev)
        torch.save(model.state_dict(), dir_name+model_name)
        print('Model saved as {}'.format(dir_name+model_name))

print("Total time: {:.2f} minutes".format((datetime.now() - startTime1).total_seconds()/60))

wandb.finish()


EPOCH 1:
batch 5000


### Diagnose

In [None]:
dst='/home/fiaz/ML/vae/models/cvae_gamma_conv_rain_singleED_50_epochs.pth'
# dst='/home/fiaz/ML/vae/models/cvae_gamma_conv_rain_singleED_NN=4_LD=1_48_epochs.pth'
model.load_state_dict(torch.load(dst))
VAE_models_CVAE_one_endec.print_params(z,model,torch.tensor([0.8]))



In [None]:
pcp_bins=2**(np.arange(-2.,8.125,0.125))
pcp_bins=np.insert(pcp_bins,0,1e-3)
pcp_bin_center=(pcp_bins[1:]+pcp_bins[:-1])*0.5
dx=np.diff(pcp_bins)

In [None]:
import matplotlib
colors_norm = matplotlib.colors.Normalize(vmin=0, vmax=1.)
cmap = plt.get_cmap('YlOrRd')
col = matplotlib.cm.ScalarMappable(norm=colors_norm, cmap=cmap)

colors_norm = matplotlib.colors.Normalize(vmin=-3, vmax=3.)
colz = matplotlib.cm.ScalarMappable(norm=colors_norm, cmap=cmap)

In [None]:
syn_size=1_000_00
synthetic_input=torch.normal(mean=0.,std=1.,
                            size=(syn_size,LATENT_DIMS))


fig,axx=plt.subplots(2,2,figsize=(8.,5.))

for i in np.arange(0,0.95,.05):    
#     synthetic_input[:,0]=2
    synthetic_input[:]=3.

    crh_cond=torch.ones([syn_size,1])
    crh_cond[:]=i
    
    log_alpha,log_beta, prob=model.decoder(synthetic_input,crh_cond)
    m1=torch.distributions.Gamma(log_alpha.exp(),log_beta.exp())
    m2=torch.distributions.Bernoulli(prob)
    
    prc_array= (m2.sample()*m1.sample()*prc_norm).numpy().squeeze()
    
    prc_hist=np.histogram(prc_array,bins=pcp_bins)[0]
    prc_hist=prc_hist/(dx*prc_hist.sum())
    
    axx[0,0].scatter(pcp_bin_center,prc_hist,color=col.to_rgba(i))
    axx[0,1].hist(prob.detach().numpy().squeeze(),color=col.to_rgba(i))
    
for i in np.arange(-3,3.2,0.2):
    
#     synthetic_input[:,0]=i
    synthetic_input[:]=i

    crh_cond=torch.ones([syn_size,1])
    crh_cond[:]=0.8
    
    log_alpha,log_beta, prob=model.decoder(synthetic_input,crh_cond)
    m1=torch.distributions.Gamma(log_alpha.exp(),log_beta.exp())
    m2=torch.distributions.Bernoulli(prob)
    
    prc_array= (m2.sample()*m1.sample()*prc_norm).numpy().squeeze()
    
    prc_hist=np.histogram(prc_array,bins=pcp_bins)[0]
    prc_hist=prc_hist/(dx*prc_hist.sum())
    
    axx[1,0].scatter(pcp_bin_center,prc_hist,color=colz.to_rgba(i))
    axx[1,1].hist(prob.detach().numpy().squeeze(),color=colz.to_rgba(i))
    
    
axx[0,0].set_yscale('log')
axx[1,0].set_yscale('log')
axx[0,0].set_xlim(0,150)
axx[1,0].set_xlim(0,150)

for ax in axx[0]:
    ax.tick_params(which='both',labelsize=13)

for ax in axx[1]:
    ax.tick_params(which='both',labelsize=13)

# axx[0,1].set_xlabel('Shape',fontsize=13)
# axx[1,1].set_xlabel('Scale',fontsize=13)
    
plt.tight_layout()

In [None]:
# def set_seed(seed: int = 42) -> None:
#     np.random.seed(seed)
#     random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     # When running on the CuDNN backend, two further options must be set
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False
#     # Set a fixed value for the hash seed
#     os.environ["PYTHONHASHSEED"] = str(seed)
#     print(f"Random seed set as {seed}")



