In [1]:
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm.notebook import tqdm
import os

%load_ext autoreload
%autoreload 2

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [3]:
from convection_param.NetworksTorch import Unet, Sequential

n_channels=6
model = Unet(n_channels=n_channels,
                n_classes=8,
                output_channels_total=181,
                n_levels=2,
                n_features=512,
                bn1=False,
                bn2=False,
                column_height=23,
                activation=F.leaky_relu,
                linear=False).to(device)

In [4]:
model_path = "/work/bd1179/b309215/heuer23_convection_parameterization/Models/my_test_model/"
model_path_normed = os.path.normpath(model_path)
model_name = os.path.basename(model_path_normed)
load_path = os.path.join(model_path, "model.state_dict")
print('model_name: ', model_name)

state_dict = torch.load(load_path, map_location=torch.device(device))['model_state_dict']
print(f'Number of parameters in model from state_dict approximated: {sum(p.numel() for p in state_dict.values())}')

checkpoint = torch.load(load_path, map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])

model_name:  my_test_model
Number of parameters in model from state_dict approximated: 41521095


<All keys matched successfully>

In [5]:
import pickle
with open('/work/bd1179/b309215/heuer23_convection_parameterization/local_data/TrainData/20230803-171736-R2B5_y13y16_vcg-halflvl-fluxes_scalerdict_X.pickle', 'rb') as handle:
    scalerdict_X = pickle.load(handle)

# with open('scalerdict_Y.pickle', 'rb') as handle:
with open('/work/bd1179/b309215/heuer23_convection_parameterization/local_data/TrainData/20230803-171736-R2B5_y13y16_vcg-halflvl-fluxes_scalerdict_Y.pickle', 'rb') as handle:
    scalerdict_Y = pickle.load(handle)

In [6]:
del scalerdict_X['h']
print(scalerdict_X.keys())
print(scalerdict_Y.keys())
print(scalerdict_X['w_fl'].scaler.mean_)
print(scalerdict_X['w_fl'].scaler.scale_)

dict_keys(['w_fl', 'qv', 'qc', 'qi', 'u', 'v'])
dict_keys(['subg_flux_qv', 'subg_flux_qc', 'subg_flux_qi', 'subg_flux_qr', 'subg_flux_qs', 'subg_flux_h', 'subg_flux_u', 'subg_flux_v', 'clt', 'cltp', 'liq_detri', 'ice_detri', 'tot_prec'])
[0.00499798]
[0.03553773]


In [7]:
means_X = torch.from_numpy(np.array([value.scaler.mean_ for _,value in scalerdict_X.items()], dtype=np.float32))
scales_X = torch.from_numpy(np.array([value.scaler.scale_ for _,value in scalerdict_X.items()], dtype=np.float32))
print(means_X.shape)
print(scales_X.shape)

torch.Size([6, 1])
torch.Size([6, 1])


In [9]:
class StandardizeLayerInput(nn.Module):
    def __init__(self, means, scales):
        super().__init__()
        self.means = nn.Parameter(means, requires_grad=False)
        self.scales = nn.Parameter(scales, requires_grad=False)
    
    def forward(self, x):
        x = (x - self.means) / self.scales
        return x

In [10]:
means_Y = torch.from_numpy(np.concatenate([np.repeat(value.scaler.mean_, 22) if 'subg_flux' in k else value.scaler.mean_ for k,value in scalerdict_Y.items()], dtype=np.float32))
scales_Y = torch.from_numpy(np.concatenate([np.repeat(value.scaler.scale_, 22) if 'subg_flux' in k else value.scaler.scale_  for k,value in scalerdict_Y.items()], dtype=np.float32))
print(means_Y.shape)
print(scales_Y.shape)

torch.Size([181])
torch.Size([181])


In [11]:
scalerdict_X

{'w_fl': <convection_param.HelperFuncs.StandardScalerOneVar at 0x7fa130383820>,
 'qv': <convection_param.HelperFuncs.StandardScalerOneVar at 0x7fa2b037b100>,
 'qc': <convection_param.HelperFuncs.StandardScalerOneVar at 0x7fa2b037beb0>,
 'qi': <convection_param.HelperFuncs.StandardScalerOneVar at 0x7fa0f0140e20>,
 'u': <convection_param.HelperFuncs.StandardScalerOneVar at 0x7fa0f0140fa0>,
 'v': <convection_param.HelperFuncs.StandardScalerOneVar at 0x7fa0f0141120>}

In [12]:
class InvStandardizeLayerOutput(nn.Module):
    def __init__(self, means, scales):
        super().__init__()
        self.means = nn.Parameter(means, requires_grad=False)
        self.scales = nn.Parameter(scales, requires_grad=False)
    
    def forward(self, x):
        x = x * self.scales + self.means
        x[...,-1] = torch.exp(x[...,-1]) - 1
        return x

In [13]:
example_input = np.full((6,23), 0.5857707, dtype=np.float32)
example_input_torch = torch.from_numpy(example_input[None,...]).to(device)

In [14]:
# Build the chain (input-normalization -> model -> output-inverse-normalization)
model_chain = nn.Sequential(StandardizeLayerInput(means_X, scales_X), model, InvStandardizeLayerOutput(means_Y, scales_Y)).to(device)

In [15]:
model_chain.eval()
traced_script_module_chain = torch.jit.trace(model_chain, example_input_torch)

In [17]:
output = traced_script_module_chain(example_input_torch)
output = output.detach().cpu().numpy()
print(output.shape)
output[0,:5]

(1, 181)


array([-0.00163603, -0.00302981, -0.00381329,  0.00966431,  0.05513914],
      dtype=float32)

## Testing for same result with manual normalization

In [49]:
example_input_norm = np.empty_like(example_input)
for i,scaler in enumerate(scalerdict_X.values()):
    example_input_norm[i] = scaler.transform(example_input[i])
example_input_norm_torch = torch.from_numpy(example_input_norm[None,...])

In [50]:
model.eval()
traced_script_module = torch.jit.trace(model, example_input_norm_torch)

In [51]:
output2 = traced_script_module(example_input_norm_torch)
output2 = output2.detach().numpy()
output_norm = np.empty_like(output2)
i0 = 0
for i,scaler in list(zip([22]*8+[1]*5, scalerdict_Y.values())):
    scaler.scaler.mean_ = scaler.scaler.mean_.astype(np.float32)
    scaler.scaler.scale_ = scaler.scaler.scale_.astype(np.float32)
    output_norm[:,i0:i0+i] = scaler.inverse_transform(output2[:,i0:i0+i])
    i0 += i

output_norm[...,-1] = np.exp(output_norm[...,-1]) - 1
output_norm[0,:5]

  output_norm[...,-1] = np.exp(output_norm[...,-1]) - 1


array([0.00256257, 0.00141733, 0.01223962, 0.06300803, 0.12845431],
      dtype=float32)

In [52]:
print(output[:,:5])
print(output_norm == output)

[[0.00256257 0.00141733 0.01223961 0.06300804 0.12845433]]
[[False False False False False  True False False  True False False False
  False False False  True False False  True False False False  True False
  False False False False False False False  True False False False False
   True  True False False False False False False False False False False
   True False False False False False False  True False False False False
  False False False False False False  True False False False False False
  False False False  True  True False False False False False False  True
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False  True False False False False False False False
  False False False False False False False False False False  True False
   True False False False False False False False False False False False
   True False False  True False False False False Fal

## Saving traced model

In [34]:
traced_script_module_chain.save(os.path.join(model_path, f'traced_model_chain_{device}.pt'))

In [35]:
# traced_script_module.save(os.path.join(model_path, 'traced_model_gpu.pt'))

# Testing output

In [41]:
output.shape
out3d = output[:,:-5]
out2d = output[:,-5:]

In [42]:
out3d = out3d.reshape(1,8,22)

In [43]:
out3d.shape

(1, 8, 22)

In [35]:
out3dnp = out3d.detach().numpy()
# torchgrad = torch.gradient(out3d, axis=-1)[0].detach().numpy()
npgrad = np.gradient(np.pad(out3dnp,1), axis=-1)

In [37]:
rho=1
tend = -1/rho*npgrad

In [42]:
out3d

tensor([[[-0.2731, -0.2718, -0.2077,  0.0426,  0.5806,  1.6189,  2.5021,
           2.0111,  2.5424,  4.2545,  4.0936,  2.3220,  2.6111,  2.6822,
           1.1742,  0.8887,  1.4639,  2.5013,  2.0258,  1.7445, -0.1113,
           0.3474],
         [-0.2867, -0.2791, -0.1472, -0.0236,  0.6536,  1.3011,  0.6817,
          -0.4935, -1.3413, -0.0104,  0.8263,  2.5139,  3.4545,  4.3861,
           3.2235,  3.3580,  0.5810, -0.2695, -0.1225, -0.0534, -0.1700,
          -0.2045],
         [ 2.7984,  6.4366,  5.4235, -1.7027, -0.3676,  0.7495,  0.3279,
          -0.1494, -0.1083, -0.0780, -0.0894, -0.0835, -0.0849, -0.0852,
          -0.0771, -0.0971, -0.0867, -0.0877, -0.0874, -0.0894, -0.0891,
          -0.0721],
         [-0.1233, -0.1203, -0.1175, -0.1277,  0.1609,  0.0344, -0.7033,
           1.1324,  3.6828,  4.6983,  2.6449,  2.1404,  1.8250,  1.3180,
           0.3411, -0.7062, -1.5319, -1.9917, -2.0617, -1.7752, -1.1768,
          -0.6175],
         [-0.6561,  2.7134,  3.1718,  2.6605