# IMPORT LIBRARIES

In [1]:
from tqdm import *
import sys
sys.path.append("../")

import config2_ctlstm_modified_v2 as config
import MODEL
import UTILS

import os
import time
import argparse
import math
import numpy as np
import pandas as pd
from collections import Counter
import random
import pickle

import torch
from torch.utils.data.dataset import Dataset
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

torch.manual_seed(0)

<torch._C.Generator at 0x7efcfbd495b0>

# HYPERPARAMETERS

In [2]:
parser = argparse.ArgumentParser(
    description='arguments')
parser.add_argument('--init', type=int, default=0, help='init number')
parser.add_argument('--fold', type=int, default=0, help='fold number')
parser.add_argument('--model_name', type=str, default='ctlstm', help='model_name')
parser.add_argument('--date', type=str, default='20240804', help='date')

args = parser.parse_args()



In [3]:
# TIME SERIES INFO
window = config.window
stride = config.stride
channels = config.channels_names
print(channels)
# CHANNELS INFO
dynamic_channels = config.dynamic_channels
static_channels = config.static_channels
output_channels = config.output_channels
# metaflux_channels = config.metaflux_channels
no_normalize_channels = config.no_normalize_channels
normalize_channels = config.normalize_channels

# LABELS INFO
unknown = config.unknown

# MODEL INFO
model_name = args.model_name
forward_code_dim = config.forward_code_dim
device = torch.device(config.device)
dropout = config.dropout

# TRAIN INFO
train = config.train
batch_size = config.batch_size
epochs = config.epochs
learning_rate = config.learning_rate
init = args.init
fold = args.fold

print("Hyperparameters:{}".format(model_name))
print("window : {}".format(window))
print("stride : {}".format(stride))
print("dynamic_channels : {}".format(dynamic_channels))
print("static_channels : {}".format(static_channels))
print("output_channels : {}".format(output_channels))
print("unknown : {}".format(unknown))
print("model_name : {}".format(model_name))
print("forward_code_dim : {}".format(forward_code_dim))
print("device : {}".format(device))
print("dropout : {}".format(dropout))
print("train : {}".format(train))
print("batch_size : {}".format(batch_size))
print("epochs : {}".format(epochs))
print("learning_rate : {}".format(learning_rate))
print("init : {}".format(init))
print("fold : {}".format(fold))

['P_ERA' 'Lai' 'VPD_ERA' 'TA_ERA' 'SW_IN_ERA' 'GPP_NT_VUT_REF' 'RECO'
 'pft_MF' 'pft_CRO' 'pft_CSH' 'pft_DBF' 'pft_EBF' 'pft_ENF' 'pft_GRA'
 'pft_OSH' 'pft_SAV' 'pft_SNO' 'pft_WET' 'pft_WSA' 'climate_Arctic'
 'climate_Continental' 'climate_Temperate' 'climate_Tropical'
 'climate_Arid' 'Lat' 'Lon']
Hyperparameters:ctlstm
window : 30
stride : 15
dynamic_channels : [0, 1, 2, 3, 4]
static_channels : [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
output_channels : [6]
unknown : nan
model_name : ctlstm
forward_code_dim : 256
device : cuda
dropout : 0.4
train : True
batch_size : 64
epochs : 1
learning_rate : 0.001
init : 1
fold : 0


# DEFINE DIRECTORIES

In [32]:
DATE = args.date
PREPROCESSED_DIR = config.PREPROCESSED_DIR
RESULT_DIR = os.path.join(config.RESULT_DIR, DATE)
MODEL_DIR = os.path.join(config.MODEL_DIR, DATE)

# LOAD DATA

In [33]:
def load_dataset(file):
    dataset = np.load(os.path.join(PREPROCESSED_DIR, "{}.npz".format(file)), allow_pickle=True)
    return dataset

def get_data(dataset,preprocessed=True):
    data = dataset["data"]
    print(data.shape)
    if preprocessed:
        data_mean = dataset["train_data_means"]
        data_std =  dataset["train_data_stds"]

        #print(data_mean.shape)
        normalized_data = np.zeros_like(data)
        if len(data.shape)==4:
            for feature in range(data_mean.shape[0]):
                if data_std[feature]!=0:
                    normalized_data[:,:,:,feature] = (data[:,:,:,feature] - data_mean[feature])/data_std[feature]
                else:
                    normalized_data[:,:,:,feature] = data[:,:,:,feature]
                normalized_data[:,:,:,no_normalize_channels] = data[:,:,:,no_normalize_channels]
        else:
            normalized_data[:,:,normalize_channels] = data[:,:,normalize_channels] - data_mean[normalize_channels]/data_std[normalize_channels]
            # normalized_data[:,:,-1] = data[:,:,-1]
        data = normalized_data
    data = np.nan_to_num(data, nan=unknown)
    #print(data.shape)
    return data

In [34]:
def get_data_test(dataset, preprocessed=True):
    data = dataset["data"]
    data_final = []
    if preprocessed:
        data_mean = dataset["train_data_means"]
        data_std = dataset["train_data_stds"]
        for data_value in data:
            normalized_data = np.zeros_like(data_value)
            if len(data_value.shape)==3:
                for feature in range(data_mean.shape[0]):
                    if data_std[feature]!=0:
                        normalized_data[:,:,feature] = (data_value[:,:,feature] - data_mean[feature])/data_std[feature]
                    else:
                        normalized_data[:,:,feature] = data_value[:,:,feature]
                normalized_data[:,:,no_normalize_channels] = data_value[:,:,no_normalize_channels]
            else:
                 normalized_data[:,normalize_channels] = data_value[:,normalize_channels] - data_mean[normalize_channels]/data_std[normalize_channels]
            data_final.append(normalized_data)
            
    data_final = np.array(data_final, dtype=object)
    data_final = np.nan_to_num(data_final, nan=unknown)
    return data_final
    

In [None]:
# def unstride_array_list(strided_data_list):
#     unstrided_list = []

#     for strided_data in strided_data_list:
#         shape = strided_data.shape
#         data = config.unknown * np.ones((shape[0], 1 + (shape[1] // config.stride), shape[2]))
#         data[:, config.stride:] = strided_data[:, ::2, config.stride:]
#         first_part = config.unknown * np.ones((shape[0], 1, config.stride))
#         second_part = strided_data[:, 1::2, config.stride + 1:]
#         data[:, :config.stride] = np.concatenate((first_part, second_part), axis=1)
#         data = np.reshape(data, (shape[0], -1))
#         unstrided_list.append(data)

#     return unstrided_list


def unstride_array(strided_data_list):
    combined_second_parts = []
    for strided_data in strided_data_list:
        shape = strided_data.shape
        if shape[1] > 0:
            second_part = strided_data[:, 1::2, config.stride + 1:]
        else:
            second_part = np.empty((shape[0], 0, shape[2]))
        combined_second_parts.append(second_part)

    combined_unstrided_data = np.concatenate(combined_second_parts, axis=1)
    return combined_unstrided_data





In [35]:
# def get_data_test(dataset, preprocessed=True, unknown=0):
#     data = dataset["data"]
#     data_final = np.empty_like(data)  # Initialize with the same shape as data
    
#     if preprocessed:
#         data_mean = dataset["train_data_means"]
#         data_std = dataset["train_data_stds"]
#         no_normalize_channels = []  # Assuming this is defined somewhere
        
#         for idx, data_value in enumerate(data):
#             normalized_data = np.zeros_like(data_value)
            
#             if len(data_value.shape) == 3:
#                 for feature in range(data_mean.shape[0]):
#                     if data_std[feature] != 0:
#                         # Normalize only the second half along the first axis
#                         normalized_data[len(data_value)//2:, :, feature] = (
#                             (data_value[len(data_value)//2:, :, feature] - data_mean[feature]) / data_std[feature]
#                         )
#                         # Copy channels that should not be normalized
#                         for channel in no_normalize_channels:
#                             normalized_data[:, :, channel] = data_value[:, :, channel]
#                     else:
#                         normalized_data[:, :, feature] = data_value[:, :, feature]
                        
#             else:  # Assuming 2D case
#                 normalize_channels = []  # Assuming this is defined somewhere
#                 for channel in normalize_channels:
#                     if data_std[channel] != 0:
#                         # Normalize only the second half along the first axis
#                         normalized_data[len(data_value)//2:, channel] = (
#                             (data_value[len(data_value)//2:, channel] - data_mean[channel]) / data_std[channel]
#                         )
#                     else:
#                         normalized_data[:, channel] = data_value[:, channel]
            
#             data_final[idx] = normalized_data
    
#     data_final = np.nan_to_num(data_final, nan=unknown)
#     return data_final


In [36]:
file, index = "strided_train", "in_indices"
dataset = load_dataset(file)
data = dataset["data"]
normalized_data = np.zeros_like(data)
print(normalized_data.shape)
data = get_data(dataset)
nodes, years, window, channels = data.shape
print(nodes, years, window, channels)


(145, 533, 30, 26)
(145, 533, 30, 26)
145 533 30 26


In [37]:
# file, index = "strided_test", "in_indices"
# dataset = load_dataset(file)
# data = dataset["data"]
# normalized_data = np.zeros_like(data)
# print(normalized_data.shape)
# for i in data:
#     print(i.shape)
# print(data.shape)
# data = get_data(dataset)
# nodes, years, window, channels = data.shape
# print(nodes, years, window, channels)

# TRAIN MODEL

In [38]:
if train:
#     print("fold:{}\tinit:{}".format(fold, init))

    # BUILD MODEL
    model = getattr(MODEL, "ctlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=len(static_channels), hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
    model = model.to(device)
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("#Parameters:{}".format(pytorch_total_params))
    print(model)
    criterion = torch.nn.MSELoss(reduction="none")
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    train_loss = []
    valid_loss = []
    min_loss = 10000

    for epoch in range(1,epochs+1):

        start = time.time()

        # LOSS ON TRAIN SET
        model.train()

        # LOAD DATA
        file, index = "strided_train", "in_indices"
        dataset = load_dataset(file)
        data = get_data(dataset)
        nodes, years, window, channels = data.shape
        print(nodes, years, window, channels)

        # GET RANDOM YEARS
        random_years = np.zeros((nodes, years))
        for node in range(nodes):
            random_years[node] = random.sample(range(years), years)
        random_years = random_years.astype(np.int64)
        # print(random_years.shape)

        # LOSS
        epoch_loss = 0
        for year in range(random_years.shape[1]):

            #Get instance for each node
            node_data = data[np.arange(nodes), random_years[:, year]]
            # print(node_data.shape)

            random_batches = random.sample(range(node_data.shape[0]),node_data.shape[0])
            for batch in range(math.ceil(nodes/batch_size)):

                optimizer.zero_grad()

                # GET BATCH DATA AND LABEL
                random_batch = random_batches[batch*batch_size:(batch+1)*batch_size]
                batch_data = torch.from_numpy(node_data[random_batch]).float().to(device)
                batch_dynamic_input = batch_data[:, :, dynamic_channels].float().to(device)
                batch_static_input = batch_data[:, :, static_channels].float().to(device)
#                 print(batch_dynamic_input.shape, "batch dybamic inout")
#                 print(batch_static_input.shape, "batch static inptu")
#                 print(batch_data.shape)
#                 print(output_channels)
#                 print(batch_data[:, :, output_channels])
                batch_label = batch_data[:, :, output_channels].float().to(device)
                # print(batch_dynamic_input.shape, batch_static_input.shape, batch_label.shape)

                # GET OUTPUT
                batch_pred = model(x_dynamic=batch_dynamic_input, x_static=batch_static_input)
#                 print(batch_pred.shape, "This is batch pred shape")

                # CALCULATE LOSS
                batch_loss = criterion(batch_label, batch_pred)											# PER CHANNEL LOSS
                mask = (batch_label!=unknown).float()													# CREATE MASK
                batch_loss = batch_loss * mask															# MULTIPLY MASK
                batch_loss, mask = torch.sum(batch_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER INSTANCE LOSS
                batch_loss = torch.sum(batch_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS
                # print(batch_loss.shape)

                # LOSS BACKPROPOGATE
                batch_loss.backward()
                optimizer.step()

                # AGGREGATE LOSS
                epoch_loss += batch_loss.item()

        epoch_loss /= ((batch+1)*(year+1))
        print('Epoch:{}\tTrain Loss:{:.4f}'.format(epoch, epoch_loss), end="\t")
        train_loss.append(epoch_loss)

        # SCORE ON VALIDATION SET
        model.eval()

        # LOAD DATA
        data = data[:,int(0.8*years):,:,:]      
        nodes, years, window, channels = data.shape
        
#         file, index = "strided_valid", "in_indices"
#         dataset = load_dataset(file)
#         data = get_data(dataset)
#         nodes, years, window, channels = data.shape
        # print(nodes, years, window, channels)

        # SCORE
        epoch_loss = 0
        for year in range(years):

            #Get instance for each node
            node_data = data[np.arange(nodes), year]
            # print(node_data.shape)

            for batch in range(math.ceil(nodes/batch_size)):

                # GET BATCH DATA AND LABEL
                batch_data = torch.from_numpy(node_data[batch*batch_size:(batch+1)*batch_size]).float().to(device)
                batch_dynamic_input = batch_data[:, :, dynamic_channels].float().to(device)
                batch_static_input = batch_data[:, :, static_channels].float().to(device)
                batch_label = batch_data[:, :, output_channels].float().to(device)
                # print(batch_dynamic_input.shape, batch_static_input.shape, batch_label.shape)

                # GET OUTPUTbatch static inptu
                batch_pred = model(x_dynamic=batch_dynamic_input, x_static=batch_static_input)
                # print(batch_pred.shape)

                # CALCULATE LOSS
                batch_loss = criterion(batch_label, batch_pred)											# PER CHANNEL LOSS
                mask = (batch_label!=unknown).float()													# CREATE MASK
                batch_loss = batch_loss * mask															# MULTIPLY MASK
                batch_loss, mask = torch.sum(batch_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER SEQUENCE LOSS
                batch_loss = torch.sum(batch_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS
                # print(batch_loss.shape)

                # AGGREGATE LOSSbut we need fp
                epoch_loss += batch_loss.item()

        epoch_loss /= ((batch+1)*(year+1))
        print("Val Loss:{:.4f}\tMin Loss:{:.4f}".format(epoch_loss, min_loss), end="\t")
        valid_loss.append(epoch_loss)
        if min_loss>epoch_loss:
            min_loss = epoch_loss
            torch.save(model.state_dict(), os.path.join(MODEL_DIR, "{}".format(model_name)))
        end = time.time()
        print("Time:{:.4f}".format(end-start))

    # PLOT LOSS
    fig = plt.figure(figsize=(10,10))
    ax1 = fig.add_subplot(111)
    ax1.set_xlabel("#Epoch", fontsize=50)

    # PLOT TRAIN LOSS
    lns1 = ax1.plot(train_loss, color='red', marker='o', linewidth=4, label="TRAIN LOSS")

    # PLOT VALIDATION SCORE
    ax2 = ax1.twinx()
    lns2 = ax2.plot(valid_loss, color='blue', marker='o', linewidth=4, label="VAL LOSS")

    # added these three lines
    lns = lns1+lns2
    labs = [l.get_label() for l in lns]
    ax1.legend(lns, labs, loc="upper right", fontsize=40, frameon=False)

    plt.tight_layout(pad=0.0,h_pad=0.0,w_pad=0.0)
    plt.savefig(os.path.join(RESULT_DIR, "{}_SCORE.pdf".format(model_name)), format = "pdf")
    plt.close()

#Parameters:289025
ctlstm(
  (encoder): LSTM(24, 256, batch_first=True)
  (out): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
)
(145, 533, 30, 26)
145 533 30 26
Epoch:1	Train Loss:0.3729	Val Loss:0.3785	Min Loss:10000.0000	Time:6.4336


# TEST MODEL

## IN DISTRIBUTION

In [40]:
# print("IN\tfold:{}\tinit:{}".format(fold, init))

# BUILD MODEL
model = getattr(MODEL, "ctlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=len(static_channels), hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
model = model.to(device)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print("#Parameters:{}".format(pytorch_total_params))
# print(model)

# LOAD MODEL
model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}".format(model_name))))
model.eval()

# LOAD DATA
file, index = "strided_test", "in_indices"
dataset = load_dataset(file)
#do a for loop for each test site: 1, x, 30, 26 where x is the sliding windows
data = get_data_test(dataset)
# print(data.shape)
dataset_true_final = []
dataset_pred_final = []
# dataset_metaflux_final = []

dataset_true_final_data = []
dataset_pred_final_data = []
# dataset_metaflux_final_data = []
for data_value in data:
#     print(data_value.shape)
    data_value = np.expand_dims(data_value, axis=0)
#     print(data_value.shape)
    nodes, years, window, channels = data_value.shape
#     print(nodes, years, window, channels)

    dataset_true = unknown*np.ones((nodes, years, stride, len(output_channels)), dtype=np.float32)
#     print("this is the dtat true shape")
#     print(dataset_true.shape)
    dataset_pred = unknown*np.ones((nodes, years, stride, len(output_channels)), dtype=np.float32)
#     dataset_metaflux = unknown*np.ones((nodes, years, stride, len(metaflux_channels)), dtype=np.float32)
    for year in range(years):

        #Get instance for each node
        node_data = data_value[np.arange(nodes), year]
        node_data = np.array(node_data, dtype=np.float32)
#         print("this is node data shape")
#         print(node_data.shape)

        for batch in range(math.ceil(nodes/batch_size)):

            # GET BATCH DATA AND LABEL
            batch_data = torch.from_numpy(node_data[batch*batch_size:(batch+1)*batch_size]).float().to(device)
            batch_dynamic_input = batch_data[:, :, dynamic_channels].float().to(device)
            batch_static_input = batch_data[:, :, static_channels].float().to(device)
            batch_label = batch_data[:, :, output_channels].float().to(device)   #this is the true value
#             batch_metaflux = batch_data[:, :, metaflux_channels].float().to(device)
#             print(batch_dynamic_input.shape, batch_static_input.shape, batch_label.shape)
            
        
            # GET OUTPUT
            batch_pred = model(x_dynamic=batch_dynamic_input, x_static=batch_static_input)    #this is the pred value 
            print("this is batch pred shape")
            print(batch_pred.shape)
            
            #Slicing the batch_pred and batch_label into two halves and taking the second half.
            half_idx = stride
            batch_pred_second_half = batch_pred[:, -half_idx:]
            batch_label_second_half = batch_label[:, -half_idx:]
#             batch_metaflux_second_half = batch_metaflux[:, -half_idx:]
#             print("this is the shape of the half one")
#             print(batch_label_second_half.shape)

            # STORE OUTPUT
            dataset_true[batch*batch_size:(batch+1)*batch_size, year] = batch_label_second_half.detach().cpu().numpy()  #this is the true value
            dataset_pred[batch*batch_size:(batch+1)*batch_size, year] = batch_pred_second_half.detach().cpu().numpy()    #this is the predicted value
#             dataset_metaflux[batch*batch_size:(batch+1)*batch_size, year] = batch_metaflux_second_half.detach().cpu().numpy()
#     print("dtaset true")
#     print(dataset_true.shape)
#     print("dtaset pred")
#     print(dataset_pred.shape)
    dataset_true_final.append(dataset_true)
    dataset_pred_final.append(dataset_pred)
    
# print("thos one")
# print(dataset_true_final[0].shape)
# print(dataset_true_final[1].shape)
#dataset_true_final = np.concatenate(dataset_true_final, axis=1)
# print("THIS IS THE SHAPE OF THE DATASET TRUE FINAL")
# print(dataset_true_final.shape)

#dataset_pred_final = np.concatenate(dataset_pred_final, axis=1)
# print("THIS IS THE SHAPE OF THE DATASET PRED FINAL")
# print(dataset_pred_final.shape)

# for true_final in dataset_true_final:
#     dataset_true_final_data.append((true_final*dataset["train_data_stds"][output_channels])+dataset["train_data_means"][output_channels])
    
# for pred_final in dataset_pred_final:
#     dataset_pred_final_data.append((pred_final*dataset["train_data_stds"][output_channels])+dataset["train_data_means"][output_channels])
    
# for metaflux_final in dataset_metaflux_final:
#     dataset_metaflux_final_data.append((metaflux_final*dataset["train_data_stds"][metaflux_channels])+dataset["train_data_means"][metaflux_channels])
# dataset_pred_final_data = np.array(dataset_pred_final_data)
# print("shape1")
# print(dataset_pred_final_data.shape)
# print(dataset_true_final_data.shape)
# print("shape2")
# # print((UTILS.unstride_array(dataset_true_final)).shape)

# print(dataset_true_final.shape)
# print(dataset_pred_final.shape)

# for value in dataset_true_final_data:
#     print("one")
#     print(value.shape)
#     dataset_true_final = unstride_array(value)
#     print("two")
#     print(dataset_true_final.shape)
final_output_true = []
for value in dataset_true_final:
#     print(value.shape)
    value = np.reshape(value, (value.shape[0], value.shape[1] * value.shape[2], value.shape[3]))
    print(value.shape)
    final_output_true.append(value)
    
final_output_pred = []
for value in dataset_pred_final:
#     print(value.shape)
    value = np.reshape(value, (value.shape[0], value.shape[1] * value.shape[2], value.shape[3]))
    print(value.shape)
    final_output_pred.append(value)
    

# per_sample_RMSE = UTILS.per_sample_RMSE(dataset_true_final, dataset_pred_final, unknown)
# _, per_node_RMSE = UTILS.per_node_RMSE(dataset_true_final, dataset_pred_final, unknown)
# per_sample_R2 = UTILS.per_sample_R2(dataset_true_final, dataset_pred_final, unknown)
# _, per_node_R2 = UTILS.per_node_R2(dataset_true_final, dataset_pred_final, unknown)
# print("Per Sample RMSE:{:.4f}\tPer Node RMSE:{:.4f}\tPer Sample R2:{:.4f}\tPer Node R2:{:.4f}".format(per_sample_RMSE, per_node_RMSE, per_sample_R2, per_node_R2))
# with open(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "ameriflux_true_{}".format(fold))), 'wb') as f:
#     pickle.dump(final_output_true, f)
# with open(os.path.join(RESULT_DIR, "{}_{}_{}_{}".format(file, index, "ameriflux_pred", model_name)), 'wb') as f:
#     pickle.dump(final_output_pred, f)
    
with open(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "true_{}".format(fold))), 'wb') as f:
    pickle.dump(final_output_true, f)
with open(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, model_name)), 'wb') as f:
    pickle.dump(final_output_pred, f)

# np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "ameriflux_true_{}".format(fold))), dataset_true_final)
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}_{}".format(file, index, "ameriflux_pred", model_name)), dataset_pred_final)
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "metaflux_{}".format(fold))), dataset_metaflux_final)

shape1
Per Sample RMSE:18.9865	Per Node RMSE:18.9865	Per Sample R2:-0.1181	Per Node R2:-0.1181


## OUT DISTRIBUTION

In [None]:
# print("IN\tfold:{}\tinit:{}".format(fold, init))

# BUILD MODEL
model = getattr(MODEL, "ctlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=len(static_channels), hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
model = model.to(device)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print("#Parameters:{}".format(pytorch_total_params))
# print(model)

# LOAD MODEL
model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}".format(model_name))))
model.eval()

# LOAD DATA
file, index = "strided_test", "out_indices"
dataset = load_dataset(file)
#do a for loop for each test site: 1, x, 30, 26 where x is the sliding windows
data = get_data_test(dataset)
# print(data.shape)
dataset_true_final = []
dataset_pred_final = []
# dataset_metaflux_final = []

dataset_true_final_data = []
dataset_pred_final_data = []
# dataset_metaflux_final_data = []
for data_value in data:
#     print(data_value.shape)
    data_value = np.expand_dims(data_value, axis=0)
#     print(data_value.shape)
    nodes, years, window, channels = data_value.shape
#     print(nodes, years, window, channels)

    dataset_true = unknown*np.ones((nodes, years, stride, len(output_channels)), dtype=np.float32)
#     print("this is the dtat true shape")
#     print(dataset_true.shape)
    dataset_pred = unknown*np.ones((nodes, years, stride, len(output_channels)), dtype=np.float32)
#     dataset_metaflux = unknown*np.ones((nodes, years, stride, len(metaflux_channels)), dtype=np.float32)
    for year in range(years):

        #Get instance for each node
        node_data = data_value[np.arange(nodes), year]
        node_data = np.array(node_data, dtype=np.float32)
#         print("this is node data shape")
#         print(node_data.shape)

        for batch in range(math.ceil(nodes/batch_size)):

            # GET BATCH DATA AND LABEL
            batch_data = torch.from_numpy(node_data[batch*batch_size:(batch+1)*batch_size]).float().to(device)
            batch_dynamic_input = batch_data[:, :, dynamic_channels].float().to(device)
            batch_static_input = batch_data[:, :, static_channels].float().to(device)
            batch_label = batch_data[:, :, output_channels].float().to(device)   #this is the true value
#             batch_metaflux = batch_data[:, :, metaflux_channels].float().to(device)
#             print(batch_dynamic_input.shape, batch_static_input.shape, batch_label.shape)
            
        
            # GET OUTPUT
            batch_pred = model(x_dynamic=batch_dynamic_input, x_static=batch_static_input)    #this is the pred value 
            print("this is batch pred shape")
            print(batch_pred.shape)
            
            #Slicing the batch_pred and batch_label into two halves and taking the second half.
            half_idx = stride
            batch_pred_second_half = batch_pred[:, -half_idx:]
            batch_label_second_half = batch_label[:, -half_idx:]
#             batch_metaflux_second_half = batch_metaflux[:, -half_idx:]
#             print("this is the shape of the half one")
#             print(batch_label_second_half.shape)

            # STORE OUTPUT
            dataset_true[batch*batch_size:(batch+1)*batch_size, year] = batch_label_second_half.detach().cpu().numpy()  #this is the true value
            dataset_pred[batch*batch_size:(batch+1)*batch_size, year] = batch_pred_second_half.detach().cpu().numpy()    #this is the predicted value
#             dataset_metaflux[batch*batch_size:(batch+1)*batch_size, year] = batch_metaflux_second_half.detach().cpu().numpy()
#     print("dtaset true")
#     print(dataset_true.shape)
#     print("dtaset pred")
#     print(dataset_pred.shape)
    dataset_true_final.append(dataset_true)
    dataset_pred_final.append(dataset_pred)
    
# print("thos one")
# print(dataset_true_final[0].shape)
# print(dataset_true_final[1].shape)
#dataset_true_final = np.concatenate(dataset_true_final, axis=1)
# print("THIS IS THE SHAPE OF THE DATASET TRUE FINAL")
# print(dataset_true_final.shape)

#dataset_pred_final = np.concatenate(dataset_pred_final, axis=1)
# print("THIS IS THE SHAPE OF THE DATASET PRED FINAL")
# print(dataset_pred_final.shape)

# for true_final in dataset_true_final:
#     dataset_true_final_data.append((true_final*dataset["train_data_stds"][output_channels])+dataset["train_data_means"][output_channels])
    
# for pred_final in dataset_pred_final:
#     dataset_pred_final_data.append((pred_final*dataset["train_data_stds"][output_channels])+dataset["train_data_means"][output_channels])
    
# for metaflux_final in dataset_metaflux_final:
#     dataset_metaflux_final_data.append((metaflux_final*dataset["train_data_stds"][metaflux_channels])+dataset["train_data_means"][metaflux_channels])
# dataset_pred_final_data = np.array(dataset_pred_final_data)
# print("shape1")
# print(dataset_pred_final_data.shape)
# print(dataset_true_final_data.shape)
# print("shape2")
# # print((UTILS.unstride_array(dataset_true_final)).shape)

# print(dataset_true_final.shape)
# print(dataset_pred_final.shape)

# for value in dataset_true_final_data:
#     print("one")
#     print(value.shape)
#     dataset_true_final = unstride_array(value)
#     print("two")
#     print(dataset_true_final.shape)
final_output_true = []
for value in dataset_true_final:
#     print(value.shape)
    value = np.reshape(value, (value.shape[0], value.shape[1] * value.shape[2], value.shape[3]))
    print(value.shape)
    final_output_true.append(value)
    
final_output_pred = []
for value in dataset_pred_final:
#     print(value.shape)
    value = np.reshape(value, (value.shape[0], value.shape[1] * value.shape[2], value.shape[3]))
    print(value.shape)
    final_output_pred.append(value)
    

# per_sample_RMSE = UTILS.per_sample_RMSE(dataset_true_final, dataset_pred_final, unknown)
# _, per_node_RMSE = UTILS.per_node_RMSE(dataset_true_final, dataset_pred_final, unknown)
# per_sample_R2 = UTILS.per_sample_R2(dataset_true_final, dataset_pred_final, unknown)
# _, per_node_R2 = UTILS.per_node_R2(dataset_true_final, dataset_pred_final, unknown)
# print("Per Sample RMSE:{:.4f}\tPer Node RMSE:{:.4f}\tPer Sample R2:{:.4f}\tPer Node R2:{:.4f}".format(per_sample_RMSE, per_node_RMSE, per_sample_R2, per_node_R2))
# with open(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "ameriflux_true_{}".format(fold))), 'wb') as f:
#     pickle.dump(final_output_true, f)
# with open(os.path.join(RESULT_DIR, "{}_{}_{}_{}".format(file, index, "ameriflux_pred", model_name)), 'wb') as f:
#     pickle.dump(final_output_pred, f)
    
with open(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "true_{}".format(fold))), 'wb') as f:
    pickle.dump(final_output_true, f)
with open(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, model_name)), 'wb') as f:
    pickle.dump(final_output_pred, f)


# np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "ameriflux_true_{}".format(fold))), dataset_true_final)
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}_{}".format(file, index,"ameriflux_pred", model_name)), dataset_pred_final)
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "metaflux_{}".format(fold))), dataset_metaflux_final)