In [48]:
import warnings
import os
from model_function import *
from model_utils import *
from utils import *
from torch.utils.data import DataLoader
import torch.nn.functional as Fin
import timeit
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib 
from torchdiffeq import odeint as odeint
import matplotlib
matplotlib.use('Agg')
import argparse
import sys
import time
import torch
torch.manual_seed(42)
torch.cuda.empty_cache() 
import torch.optim as optim
import random
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)
import sys

set_seed(42)
cwd = os.getcwd()
#data_path = {'z500':str(cwd) + '/era5_data/geopotential_500/*.nc','t850':str(cwd) + '/era5_data/temperature_850/*.nc'}
SOLVERS = ["dopri8","dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams',"adaptive_heun","euler"]
parser = argparse.ArgumentParser('ClimODE')

solver = "euler"
atol = 5e-3
rtol = 5e-3
step_size = None  # Optional fixed step size
niters = 300
scale = 0
batch_size = 8
spectral = 0  # Choices: [0, 1]
lr = 0.0005
weight_decay = 1e-5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


train_time_scale= slice('2006','2016')
val_time_scale = slice('2016','2016')
test_time_scale = slice('2017','2018')

paths_to_data = [str(cwd) + '/era5_data/geopotential_500/*.nc',str(cwd) + '/era5_data/temperature_850/*.nc',str(cwd) + '/era5_data/2m_temperature/*.nc',str(cwd) + '/era5_data/10m_u_component_of_wind/*.nc',str(cwd) + '/era5_data/10m_v_component_of_wind/*.nc']

num_years = len(range(2006,2016))


class Climate_encoder_free_uncertain_summary(nn.Module): 
    
    def __init__(self,num_channels,const_channels,out_types,method,use_att,use_err,use_pos):
        super().__init__()
        self.layers = [5,3,2]
        self.hidden = [128,64,2*out_types]
        input_channels = 30 + out_types*int(use_pos) + 34*(1-int(use_pos))
        self.vel_f = Climate_ResNet_2D(input_channels,self.layers,self.hidden)

        if use_att: 
            self.vel_att = Self_attn_conv(input_channels,10)
            self.gamma = nn.Parameter(torch.tensor([0.1]))

        self.scales = num_channels
        self.const_channel = const_channels
        
        self.out_ch = out_types
        self.past_samples = 0
        self.const_info = 0
        self.lat_map = 0
        self.lon_map = 0
        self.elev = 0
        self.pos_emb = 0
        self.elev_info_grad_x = 0
        self.elev_info_grad_y = 0
        self.method = method
        err_in =  9 + out_types*int(use_pos) + 34*(1-int(use_pos))
        if use_err: self.noise_net = Climate_ResNet_2D(err_in,[3,2,2],[128,64,2*out_types])
        if use_pos: self.pos_enc = Climate_ResNet_2D(4,[2,1,1],[32,16,out_types])
        self.att = use_att
        self.err = use_err
        self.pos = use_pos
        self.pos_feat = 0
        self.lsm =0 
        self.oro =0 


    def update_param(self, params):
        self.past_samples = params[0]
        self.const_info = params[1]
        self.lat_map = params[2]
        self.lon_map = params[3]

    def pde(self,t,vs):

        ds = vs[:,-self.out_ch:,:,:].view(-1,self.out_ch,vs.shape[2],vs.shape[3]).float()
        v = vs[:,:2*self.out_ch,:,:].view(-1,2*self.out_ch,vs.shape[2],vs.shape[3]).float()
        t_emb = ((t*100)%24).view(1,1,1,1).expand(ds.shape[0],1,ds.shape[2],ds.shape[3])
        sin_t_emb = torch.sin(torch.pi*t_emb/12 - torch.pi/2)
        cos_t_emb = torch.cos(torch.pi*t_emb/12 - torch.pi/2)
        
        sin_seas_emb = torch.sin(torch.pi*t_emb/(12*365) - torch.pi/2)
        cos_seas_emb = torch.cos(torch.pi*t_emb/(12*365) - torch.pi/2)

        day_emb = torch.cat([sin_t_emb,cos_t_emb],dim=1)
        seas_emb = torch.cat([sin_seas_emb,cos_seas_emb],dim=1)
        
        ds_grad_x = torch.gradient(ds,dim=3)[0]
        ds_grad_y = torch.gradient(ds,dim=2)[0]
        nabla_u = torch.cat([ds_grad_x,ds_grad_y],dim=1)

        if self.pos:
            comb_rep = torch.cat([t_emb/24,day_emb,seas_emb,nabla_u,v,ds,self.pos_feat],dim=1)
        else:
            cos_lat_map,sin_lat_map = torch.cos(self.new_lat_map),torch.sin(self.new_lat_map)
            cos_lon_map,sin_lon_map = torch.cos(self.new_lon_map),torch.sin(self.new_lon_map)
            t_cyc_emb = torch.cat([day_emb,seas_emb],dim=1)
            pos_feats = torch.cat([cos_lat_map,cos_lon_map,sin_lat_map,sin_lon_map,sin_lat_map*cos_lon_map,sin_lat_map*sin_lon_map],dim=1)
            pos_time_ft = self.get_time_pos_embedding(t_cyc_emb,pos_feats)
            comb_rep = torch.cat([t_emb/24,day_emb,seas_emb,nabla_u,v,ds,self.new_lat_map,self.new_lon_map,self.lsm,self.oro,pos_feats,pos_time_ft],dim=1)

        if self.att: dv = self.vel_f(comb_rep) + self.gamma*self.vel_att(comb_rep)
        else: dv = self.vel_f(comb_rep)
        v_x = v[:,:self.out_ch,:,:].view(-1,self.out_ch,vs.shape[2],vs.shape[3]).float()
        v_y = v[:,-self.out_ch:,:,:].view(-1,self.out_ch,vs.shape[2],vs.shape[3]).float()

        adv1 = v_x*ds_grad_x + v_y*ds_grad_y
        adv2 = ds*(torch.gradient(v_x,dim=3)[0] + torch.gradient(v_y,dim=2)[0] )
        

        ds = adv1 + adv2

        dvs = torch.cat([dv,ds],1)
        return dvs
    

    def get_time_pos_embedding(self,time_feats,pos_feats):
        for idx in range(time_feats.shape[1]):
            tf = time_feats[:,idx].unsqueeze(dim=1)*pos_feats
            if idx == 0:
                final_out = tf
            else:
                final_out = torch.cat([final_out,tf],dim=1)

        return final_out

    def noise_net_contrib(self,t,pos_enc,s_final,noise_net,H,W):

        t_emb = (t%24).view(-1,1,1,1,1)
        sin_t_emb = torch.sin(torch.pi*t_emb/12 - torch.pi/2).expand(len(s_final),s_final.shape[1],1,H,W)
        cos_t_emb = torch.cos(torch.pi*t_emb/12 - torch.pi/2).expand(len(s_final),s_final.shape[1],1,H,W)
        
        sin_seas_emb = torch.sin(torch.pi*t_emb/(12*365)- torch.pi/2).expand(len(s_final),s_final.shape[1],1,H,W)
        cos_seas_emb = torch.cos(torch.pi*t_emb/(12*365)- torch.pi/2).expand(len(s_final),s_final.shape[1],1,H,W)

        pos_enc = pos_enc.expand(len(s_final),s_final.shape[1],-1,H,W).flatten(start_dim=0,end_dim=1)
        t_cyc_emb = torch.cat([sin_t_emb,cos_t_emb,sin_seas_emb,cos_seas_emb],dim=2).flatten(start_dim=0,end_dim=1)

        pos_time_ft = self.get_time_pos_embedding(t_cyc_emb,pos_enc[:,2:-2])

        comb_rep = torch.cat([t_cyc_emb,s_final.flatten(start_dim=0,end_dim=1),pos_enc,pos_time_ft],dim=1)

        final_out = noise_net(comb_rep).view(len(t),-1,2*self.out_ch,H,W)

        mean = s_final + final_out[:,:,:self.out_ch]
        std = nn.Softplus()(final_out[:,:,self.out_ch:])
        
        return mean,std


    def forward(self, T=None, data=None, atol=0.1, rtol=0.1):
        # if T is None:
        #     T = torch.tensor([0.0, 1.0])  # Example time tensor
        # if data is None:
        #     data = torch.zeros(1, self.out_ch, 128, 128)  # Ensure 4D shape
        print(type(self.past_samples))

        H, W = self.past_samples.shape[2], self.past_samples.shape[3]
        
        # Debug the shapes
        print(f"Shape of self.past_samples: {self.past_samples.shape}")
        print(f"Shape of data: {data.shape}")

        # Ensure batch size matches between self.past_samples and data
        if self.past_samples.shape[0] != data.shape[0]:
            self.past_samples = self.past_samples.expand(data.shape[0], -1, H, W)  # Adjust the batch size

        # Concatenate past samples and data along the channel dimension
        final_data = torch.cat([self.past_samples, data.float().view(-1, self.out_ch, H, W)], 1)

        init_time = T[0].item() * 6
        final_time = T[-1].item() * 6
        steps_val = final_time - init_time

        # Check data shape
        print(f"Data shape: {data.shape}")

        if self.pos:
            lat_map = self.lat_map.unsqueeze(dim=0) * torch.pi / 180
            lon_map = self.lon_map.unsqueeze(dim=0) * torch.pi / 180
            pos_rep = torch.cat([lat_map.unsqueeze(dim=0), lon_map.unsqueeze(dim=0), self.const_info], dim=1)
            
            # Check pos_rep shape
            print(f"pos_rep shape: {pos_rep.shape}")

            self.pos_feat = self.pos_enc(pos_rep).expand(data.shape[0], -1, data.shape[3], data.shape[4])
            final_pos_enc = self.pos_feat
        else:
            self.oro,self.lsm = self.const_info[0,0],self.const_info[0,1]
            self.lsm = self.lsm.unsqueeze(dim=0).expand(data.shape[0],-1,data.shape[3],data.shape[4])
            self.oro  = F.normalize(self.const_info[0,0]).unsqueeze(dim=0).expand(data.shape[0],-1,data.shape[3],data.shape[4])
            self.new_lat_map = self.lat_map.expand(data.shape[0],1,data.shape[3],data.shape[4])*torch.pi/180 # Converting to radians
            self.new_lon_map = self.lon_map.expand(data.shape[0],1,data.shape[3],data.shape[4])*torch.pi/180
            cos_lat_map,sin_lat_map = torch.cos(self.new_lat_map),torch.sin(self.new_lat_map)
            cos_lon_map,sin_lon_map = torch.cos(self.new_lon_map),torch.sin(self.new_lon_map)
            pos_feats = torch.cat([cos_lat_map,cos_lon_map,sin_lat_map,sin_lon_map,sin_lat_map*cos_lon_map,sin_lat_map*sin_lon_map],dim=1)
            final_pos_enc = torch.cat([self.new_lat_map,self.new_lon_map,pos_feats,self.lsm,self.oro],dim=1)


        new_time_steps = torch.linspace(init_time,final_time,steps=int(steps_val)+1).to(data.device)
        t = 0.01*new_time_steps.float().to(data.device).flatten().float()
        pde_rhs  = lambda t,vs: self.pde(t,vs) # make the ODE forward function
        final_result = odeint(pde_rhs,final_data,t,method=self.method,atol=atol,rtol=rtol)
        #breakpoint()
        s_final = final_result[:,:,-self.out_ch:,:,:].view(len(t),-1,self.out_ch,H,W)

        if self.err:
            mean,std = self.noise_net_contrib(T,final_pos_enc,s_final[0:len(s_final):6],self.noise_net,H,W)

        else:
            s_final = s_final[0:len(s_final):6]

        return mean,std,s_final[0:len(s_final):6]





model = Climate_encoder_free_uncertain_summary(len(paths_to_data),2,out_types=len(paths_to_data),method=solver,use_att=True,use_err=True,use_pos=False).to(device)

Random seed set as 42


In [45]:
from torchsummary import summary
# from torchinfo import summary

In [54]:
const_info_path = [str(cwd) + '/era5_data/constants/constants_5.625deg.nc']

In [None]:
max_lev = []
min_lev = []
levels = ["z","t","t2m","u10","v10"]
paths_to_data = [str(cwd) + '/era5_data/geopotential_500/*.nc',str(cwd) + '/era5_data/temperature_850/*.nc',str(cwd) + '/era5_data/2m_temperature/*.nc',str(cwd) + '/era5_data/10m_u_component_of_wind/*.nc',str(cwd) + '/era5_data/10m_v_component_of_wind/*.nc']
paths_to_data = paths_to_data[0:5]
levels = levels[0:5]

for idx,data in enumerate(paths_to_data):
    Train_data,Val_data,Test_data,time_steps,lat,lon,mean,std,time_stamp = get_train_test_data_batched_regional(data,train_time_scale,val_time_scale,test_time_scale,levels[idx],spectral,"EastAfrica")  
    max_lev.append(mean)
    min_lev.append(std)
    if idx==0: 
        Final_train_data = Train_data
        Final_val_data = Val_data
        Final_test_data = Test_data
    else:
        Final_train_data = torch.cat([Final_train_data,Train_data],dim=2)
        Final_val_data = torch.cat([Final_val_data,Val_data],dim=2)
        Final_test_data = torch.cat([Final_test_data,Test_data],dim=2)



In [46]:
type(Train_data)

torch.Tensor

In [52]:

# Move model to devicedevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Set the dimensions for your input data based on the variables in the training loop
num_years = len(range(2006,2016))  # Example number of years (time dimension)
args_scale = 0
H,W = Train_data.shape[3],Train_data.shape[4]

# Input shape for data
input_shape = (num_years, 1, len(paths_to_data)*(args_scale+1), H, W)  # Adjusted for the training loop


summary(model, input_size=input_shape, batch_size=1)

<class 'int'>


AttributeError: 'int' object has no attribute 'shape'

In [53]:
Train_loader = DataLoader(Final_train_data[2:],batch_size=8,shuffle=False,pin_memory=False)
Val_loader = DataLoader(Final_val_data[2:],batch_size=8,shuffle=False,pin_memory=False)
Test_loader = DataLoader(Final_test_data[2:],batch_size=8,shuffle=False,pin_memory=False)
time_loader = DataLoader(time_steps[2:],batch_size=8,shuffle=False,pin_memory=False)
time_idx_steps = torch.tensor([i for i in range(365*4)]).view(-1,1)
time_idx = DataLoader(time_idx_steps[2:],batch_size=8,shuffle=False,pin_memory=False)
#Model declaration

In [55]:
vel_train,vel_val = load_velocity(['train_10year_2day_mm','val_10year_2day_mm'])
const_channels_info,lat_map,lon_map = add_constant_info_region(const_info_path,"EastAfrica",H,W)

In [58]:
for epoch in range(1):
    total_train_loss = 0
    val_loss = 0
    test_loss = 0
    RMSD = []
    
    for entry,(time_steps,batch) in enumerate(zip(time_loader,Train_loader)):
        data = batch[0].to(device).view(num_years,1,len(paths_to_data)*(1),H,W)
        # print(vel_train[entry].shape)
        past_sample = vel_train[entry].view(num_years,2*len(paths_to_data)*(1),H,W).to(device)
        model.update_param([past_sample,const_channels_info.to(device),lat_map.to(device),lon_map.to(device)])
        t = time_steps.float().to(device).flatten()
        summary(model,(t,data))

RuntimeError: shape '[10, 10, 6, 6]' is invalid for input of size 204800

In [59]:
# Print the original tensor's shape and total elements
print(f"Original vel_train[entry] shape: {vel_train[entry].shape}")
print(f"Total elements: {vel_train[entry].numel()}")

# Print the dimensions you're trying to use
print(f"num_years: {num_years}")
print(f"2*len(paths_to_data)*(1): {2*len(paths_to_data)*(1)}")
print(f"H, W: {H}, {W}")

# Calculate the total elements in the proposed reshape
proposed_reshape_elements = num_years * (2*len(paths_to_data)*(1)) * H * W
print(f"Proposed reshape total elements: {proposed_reshape_elements}")

Original vel_train[entry] shape: torch.Size([10, 2, 5, 32, 64])
Total elements: 204800
num_years: 10
2*len(paths_to_data)*(1): 10
H, W: 6, 6
Proposed reshape total elements: 3600


In [60]:
print(model)

Climate_encoder_free_uncertain_summary(
  (vel_f): Climate_ResNet_2D(
    (layer_cnn): ModuleList(
      (0): Sequential(
        (0): ResidualBlock(
          (activation): LeakyReLU(negative_slope=0.3)
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop): Dropout(p=0.1, inplace=False)
          (shortcut): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
          (norm1): Identity()
          (norm2): Identity()
        )
        (1): ResidualBlock(
          (activation): LeakyReLU(negative_slope=0.3)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2)