# IMPORT LIBRARIES

In [7]:
from tqdm import *
import sys
sys.path.append("../")

import config2 as config
import MODEL
import UTILS

import os
import time
import argparse
import math
import numpy as np
import pandas as pd
from collections import Counter
import random

import torch
from torch.utils.data.dataset import Dataset
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

torch.manual_seed(0)

<torch._C.Generator at 0x2b24132754f0>

# HYPERPARAMETERS

In [8]:
# parser = argparse.ArgumentParser(
#     description='arguments')
# parser.add_argument('--init', type=int, default=0, help='init number')
# parser.add_argument('--fold', type=int, default=0, help='fold number')
# parser.add_argument('--few_shot_years', type=int, default=7, help='few_shot_years')
# parser.add_argument('--model_name', type=str, default='ctlstm', help='model_name')
# parser.add_argument('--date', type=str, default='20230420', help='date')

# args = parser.parse_args()

parser = argparse.ArgumentParser(
    description='arguments')

# Define your command-line arguments
parser.add_argument('--init', type=int, default=0, help='init number')
parser.add_argument('--fold', type=int, default=0, help='fold number')
parser.add_argument('--few_shot_years', type=int, default=7, help='few_shot_years')
parser.add_argument('--model_name', type=str, default='ctlstm', help='model_name')
parser.add_argument('--date', type=str, default='20230420', help='date')

# Parse the arguments, ignoring the first one (which is the connection file passed by Jupyter)
args = parser.parse_args(['--init', '0', '--fold', '0', '--few_shot_years', '7', '--model_name', 'ctlstm', '--date', '20230420'])

# Now you can access the parsed arguments using args.init, args.fold, etc.
print(args)


Namespace(init=0, fold=0, few_shot_years=7, model_name='ctlstm', date='20230420')


In [9]:
# TIME SERIES INFO
window = config.window
stride = config.stride

# CHANNELS INFO
dynamic_channels = config.dynamic_channels
static_channels = config.static_channels
output_channels = config.output_channels

# LABELS INFO
unknown = config.unknown

# MODEL INFO
model_name = args.model_name
forward_code_dim = config.forward_code_dim
latent_code_dim = config.latent_code_dim
device = torch.device(config.device)
dropout = config.dropout

# TRAIN INFO
train = config.train
batch_size = config.batch_size
epochs = config.epochs
learning_rate = config.learning_rate
meta_learning_rate = config.meta_learning_rate
num_inner_steps = 5
init = args.init 
fold = args.fold
max_patience = 10
few_shot_years = args.few_shot_years

print("Hyperparameters:{}".format(model_name))
print("window : {}".format(window))
print("stride : {}".format(stride))
print("dynamic_channels : {}".format(dynamic_channels))
print("static_channels : {}".format(static_channels))
print("output_channels : {}".format(output_channels))
print("unknown : {}".format(unknown))
print("model_name : {}".format(model_name))
print("forward_code_dim : {}".format(forward_code_dim))
print("latent_code_dim : {}".format(latent_code_dim))
print("device : {}".format(device))
print("dropout : {}".format(dropout))
print("train : {}".format(train))
print("batch_size : {}".format(batch_size))
print("epochs : {}".format(epochs))
print("learning_rate : {}".format(learning_rate))
print("meta_learning_rate : {}".format(meta_learning_rate))
print("num_inner_steps : {}".format(num_inner_steps))
print("init : {}".format(init))
print("fold : {}".format(fold))
print("max_patience : {}".format(max_patience))
print("few_shot_years : {}".format(few_shot_years))

AttributeError: module 'config2' has no attribute 'static_channels'

# DEFINE DIRECTORIES

In [None]:
DATE = args.date
PREPROCESSED_DIR = config.PREPROCESSED_DIR
RESULT_DIR = os.path.join(config.RESULT_DIR, DATE)
MODEL_DIR = os.path.join(config.MODEL_DIR, DATE)

# LOAD DATA

In [None]:
def load_dataset(file):
    dataset = np.load(os.path.join(PREPROCESSED_DIR, "{}.npz".format(file)), allow_pickle=True)
    return dataset

def get_data(dataset, index, preprocessed=True,fold=0):
    data = dataset["data"]
    if preprocessed:
        data_mean = dataset["train_data_means"][fold]
        data_std =  dataset["train_data_stds"][fold]

        normalized_data = np.zeros_like(data)
        for feature in range(data_mean.shape[0]):
            if data_std[feature]!=0:
                normalized_data[:,:,:,feature] = (data[:,:,:,feature] - data_mean[feature])/data_std[feature]
            else:
                normalized_data[:,:,:,feature] = data[:,:,:,feature]
        data = normalized_data
    data = np.nan_to_num(data, nan=unknown)
    data = data[dataset[index][fold]]
    return data



# TRAIN MODEL

In [None]:
if train:
    print("fold:{}\tinit:{}".format(fold, init))

    # BUILD MODEL
    inverse_model = getattr(MODEL, "ae")(input_channels=len(dynamic_channels)+len(output_channels), code_dim=latent_code_dim, hidden_dim=latent_code_dim, output_channels=len(static_channels), device=device)
    inverse_model = inverse_model.to(device)
    pytorch_total_params = sum(p.numel() for p in inverse_model.parameters() if p.requires_grad)
    print(inverse_model)
    forward_model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
    forward_model = forward_model.to(device)
    pytorch_total_params += sum(p.numel() for p in forward_model.parameters() if p.requires_grad)
    print(forward_model)
    print("#Parameters:{}".format(pytorch_total_params))
    criterion = torch.nn.MSELoss(reduction="none")
    optimizer_embedding = torch.optim.Adam(list(inverse_model.parameters()), lr=meta_learning_rate)
    optimizer_forward = torch.optim.Adam(list(forward_model.parameters()), lr=meta_learning_rate)

    train_loss = []
    valid_loss = []
    min_loss = 10000
    patience = 0 

    for epoch in range(1,epochs+1):
        

        start = time.time()

        # LOSS ON TRAIN SET
        inverse_model.train()
        forward_model.train()

        # LOAD DATA
        file, index = "strided_train", "in_indices"
        dataset = load_dataset(file)
        data = get_data(dataset, index,fold=fold)
        nodes, years, window, channels = data.shape
        
        #Create a random vector and repeat it across years and windows for CTLSTM
        np.random.seed(0)
        # random_static_data = np.random.normal(0, 1, size = (nodes, latent_code_dim))
        # random_static_data = np.repeat(random_static_data[:, np.newaxis,np.newaxis],window,axis=2)
        # print(nodes, years, window, channels)

        # LOSS
        epoch_loss = 0

        #Get instance for each node
        node_data = data[np.arange(nodes)]
        # random_static_data = random_static_data[np.arange(nodes)]
        # print(node_data.shape)

        random_batches = random.sample(range(node_data.shape[0]),node_data.shape[0])
        for batch in range(math.ceil(nodes/batch_size)):
            batch_loss = []


            
            random_batch = random_batches[batch*batch_size:(batch+1)*batch_size]
            batch_data = torch.from_numpy(node_data[random_batch]).to(device)
            # batch_random_static_data = torch.from_numpy(random_static_data[random_batch]).to(device)
            

            # GET BATCH DATA AND LABEL
            batch_support_data, batch_query_data = UTILS.datsetSupportQuerry(batch_data)
            
            
            for i in range(batch_support_data.shape[0]):
                
                # GET Basin support and query data
                basin_support_data, basin_query_data = batch_support_data[i], batch_query_data[i]
                # basin_random_static_data = batch_random_static_data[i]
                
                basin_dynamic_support_input = basin_support_data[:, :, dynamic_channels].to(device)
                basin_dynamic_support_input_output = basin_support_data[:, :, dynamic_channels+output_channels].to(device)
                basin_static_data,_,_,_ = inverse_model(x=basin_dynamic_support_input_output)               
                basin_static_data = torch.mean(basin_static_data,axis=0)
                basin_static_support_input = torch.repeat_interleave(basin_static_data.unsqueeze(0),basin_dynamic_support_input.shape[0],axis=0).to(device)
                basin_static_support_input = torch.repeat_interleave(basin_static_support_input.unsqueeze(1),window,axis=1).to(device)
                # basin_static_support_input = basin_support_data[:, :, static_channels].to(device)
                basin_support_label = basin_support_data[:, :, output_channels].to(device)
                
                basin_dynamic_query_input = basin_query_data[:, :, dynamic_channels].to(device)
                basin_dynamic_query_input_output = basin_query_data[:, :, dynamic_channels+output_channels].to(device)
                basin_static_query_input = torch.repeat_interleave(basin_static_data.unsqueeze(0),basin_dynamic_query_input.shape[0],axis=0).to(device)
                basin_static_query_input = torch.repeat_interleave(basin_static_query_input.unsqueeze(1),window,axis=1).to(device)
                basin_query_label = basin_query_data[:, :, output_channels].to(device)
                
                basin_model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
                basin_model = basin_model.to(device)
                basin_model.load_state_dict(forward_model.state_dict())
                
                for step in range(num_inner_steps):
                    
                    # GET OUTPUT
                    # print(basin_dynamic_support_input.shape,basin_static_support_input.shape)
                    batch_pred = forward_model(x_dynamic=basin_dynamic_support_input, x_static=basin_static_support_input.float() )
                    # print(batch_pred.shape)

                    # CALCULATE LOSS
                    basin_loss = criterion(basin_support_label, batch_pred)											# PER CHANNEL LOSS
                    mask = (basin_support_label!=unknown).float()													# CREATE MASK
                    basin_loss = basin_loss * mask															# MULTIPLY MASK
                    basin_loss, mask = torch.sum(basin_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER INSTANCE LOSS
                    basin_loss = torch.sum(basin_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS
                    
                    # LOSS BACKPROPOGATE
                    grad = torch.autograd.grad(basin_loss, forward_model.parameters(),create_graph=True, allow_unused=True)
                    fast_weights = list(map(lambda p: p[1] - learning_rate * p[0], zip(grad, forward_model.parameters())))
                    for param, fast_params in zip(forward_model.parameters(), fast_weights):
                        param.data = fast_params
            

                
                # GET OUTPUT
                batch_pred = forward_model(x_dynamic=basin_dynamic_query_input, x_static=basin_static_query_input.float() )
                # print(batch_pred.shape)

                # CALCULATE LOSS
                basin_loss = criterion(basin_query_label, batch_pred)											# PER CHANNEL LOSS
                mask = (basin_query_label!=unknown).float()													# CREATE MASK
                basin_loss = basin_loss * mask															# MULTIPLY MASK
                basin_loss, mask = torch.sum(basin_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER INSTANCE LOSS
                basin_loss = torch.sum(basin_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS
                
                batch_loss.append(basin_loss)

                # RESET MODEL FOR NEXT TASK
                forward_model.load_state_dict(basin_model.state_dict())
                
            # LOSS BACKPROPOGATE
            batch_loss = torch.stack(batch_loss).mean(0) 
            optimizer_forward.zero_grad()
            optimizer_embedding.zero_grad()        
            batch_loss.backward()
            optimizer_embedding.step()
            optimizer_forward.step()

            # AGGREGATE LOSS
            epoch_loss += batch_loss.item()

        epoch_loss /= ((batch+1))
        print('Epoch:{}\tTrain Loss:{:.4f}'.format(epoch, epoch_loss), end="\t")
        train_loss.append(epoch_loss)

        # SCORE ON VALIDATION SET
        
        inverse_model.eval()
        forward_model.eval()

        # LOAD DATA
        file, index = "strided_valid", "in_indices"
        dataset = load_dataset(file)
        data = get_data(dataset, index,fold=fold)
        nodes, years, window, channels = data.shape
        # print(nodes, years, window, channels)

       # LOSS
        epoch_loss = 0

        #Get instance for each node
        node_data = data[np.arange(nodes)]
        # print(node_data.shape)

        random_batches = random.sample(range(node_data.shape[0]),node_data.shape[0])
        for batch in range(math.ceil(nodes/batch_size)):
            batch_loss = []

            
            random_batch = random_batches[batch*batch_size:(batch+1)*batch_size]
            batch_data = torch.from_numpy(node_data[random_batch]).to(device)
            batch_random_static_data = torch.from_numpy(random_static_data[random_batch]).to(device)
            

            # GET BATCH DATA AND LABEL
            batch_support_data, batch_query_data = UTILS.datsetSupportQuerry(batch_data)
            
            
            for i in range(batch_support_data.shape[0]):
                # print(f"VAL_epoch_{epoch}\tbasin_{i}")
                
                # GET Basin support and query data
                basin_support_data, basin_query_data = batch_support_data[i], batch_query_data[i]
                basin_random_static_data = batch_random_static_data[i]
                
                basin_dynamic_support_input = basin_support_data[:, :, dynamic_channels].to(device)
                basin_dynamic_support_input_output = basin_support_data[:, :, dynamic_channels+output_channels].to(device)
                basin_static_data,_,_,_ = inverse_model(x=basin_dynamic_support_input_output)               
                basin_static_data = torch.mean(basin_static_data,axis=0)
                basin_static_support_input = torch.repeat_interleave(basin_static_data.unsqueeze(0),basin_dynamic_support_input.shape[0],axis=0).to(device)
                basin_static_support_input = torch.repeat_interleave(basin_static_support_input.unsqueeze(1),window,axis=1).to(device)
                # basin_static_support_input = basin_support_data[:, :, static_channels].to(device)
                basin_support_label = basin_support_data[:, :, output_channels].to(device)
                
                basin_dynamic_query_input = basin_query_data[:, :, dynamic_channels].to(device)
                basin_dynamic_query_input_output = basin_query_data[:, :, dynamic_channels+output_channels].to(device)
                basin_static_query_input = torch.repeat_interleave(basin_static_data.unsqueeze(0),basin_dynamic_query_input.shape[0],axis=0).to(device)
                basin_static_query_input = torch.repeat_interleave(basin_static_query_input.unsqueeze(1),window,axis=1).to(device)
                basin_query_label = basin_query_data[:, :, output_channels].to(device)
                
                basin_model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
                basin_model = basin_model.to(device)
                basin_model.load_state_dict(forward_model.state_dict())
                
                for step in range(num_inner_steps):
                    
                    # GET OUTPUT
                    # print(basin_dynamic_support_input.shape,basin_static_support_input)
                    batch_pred = forward_model(x_dynamic=basin_dynamic_support_input, x_static=basin_static_support_input.float() )
                    # print(batch_pred.shape)

                    # CALCULATE LOSS
                    basin_loss = criterion(basin_support_label, batch_pred)											# PER CHANNEL LOSS
                    mask = (basin_support_label!=unknown).float()													# CREATE MASK
                    basin_loss = basin_loss * mask															# MULTIPLY MASK
                    basin_loss, mask = torch.sum(basin_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER INSTANCE LOSS
                    basin_loss = torch.sum(basin_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS
                    
                    # LOSS BACKPROPOGATE
                    grad = torch.autograd.grad(basin_loss, forward_model.parameters(),create_graph=True, allow_unused=True)
                    fast_weights = list(map(lambda p: p[1] - learning_rate * p[0], zip(grad, forward_model.parameters())))
                    for param, fast_params in zip(forward_model.parameters(), fast_weights):
                        param.data = fast_params
            

                
                # GET OUTPUT
                batch_pred = forward_model(x_dynamic=basin_dynamic_query_input, x_static=basin_static_query_input.float() )
                # print(batch_pred.shape)

                # CALCULATE LOSS
                basin_loss = criterion(basin_query_label, batch_pred)											# PER CHANNEL LOSS
                mask = (basin_query_label!=unknown).float()													# CREATE MASK
                basin_loss = basin_loss * mask															# MULTIPLY MASK
                basin_loss, mask = torch.sum(basin_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER INSTANCE LOSS
                basin_loss = torch.sum(basin_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS
                
                batch_loss.append(basin_loss)

                # RESET MODEL FOR NEXT TASK
                forward_model.load_state_dict(basin_model.state_dict())
                
            batch_loss = torch.stack(batch_loss).mean(0)  
            # AGGREGATE LOSS
            epoch_loss += batch_loss.item()

        epoch_loss /= ((batch+1))
        print("Val Loss:{:.4f}\tMin Loss:{:.4f}\tPatience:{}".format(epoch_loss, min_loss, patience), end="\t")
        valid_loss.append(epoch_loss)
        if min_loss>epoch_loss:
            min_loss = epoch_loss
            torch.save(inverse_model.state_dict(), os.path.join(MODEL_DIR, "{}_inverse".format(model_name)))
            torch.save(forward_model.state_dict(), os.path.join(MODEL_DIR, "{}_forward".format(model_name)))
            
            patience = 0
        else:
            patience+=1
        if patience>max_patience:
            break
        end = time.time()
        print("Time:{:.4f}".format(end-start))

    # PLOT LOSS
    fig = plt.figure(figsize=(10,10))
    ax1 = fig.add_subplot(111)
    ax1.set_xlabel("#Epoch", fontsize=50)

    # PLOT TRAIN LOSS
    lns1 = ax1.plot(train_loss, color='red', marker='o', linewidth=4, label="TRAIN LOSS")

    # PLOT VALIDATION SCORE
    ax2 = ax1.twinx()
    lns2 = ax2.plot(valid_loss, color='blue', marker='o', linewidth=4, label="VAL LOSS")

    # added these three lines
    lns = lns1+lns2
    labs = [l.get_label() for l in lns]
    ax1.legend(lns, labs, loc="upper right", fontsize=40, frameon=False)

    plt.tight_layout(pad=0.0,h_pad=0.0,w_pad=0.0)
    plt.savefig(os.path.join(RESULT_DIR, "{}_SCORE.pdf".format(model_name)), format = "pdf")
    plt.close()

# TEST MODEL

## IN DISTRIBUTION

In [None]:
print("IN\tfold:{}\tinit:{}".format(fold, init))
start = time.time()

# BUILD MODEL
inverse_model = getattr(MODEL, "ae")(input_channels=len(dynamic_channels)+len(output_channels), code_dim=latent_code_dim, hidden_dim=latent_code_dim, output_channels=len(static_channels), device=device)
inverse_model = inverse_model.to(device)
pytorch_total_params = sum(p.numel() for p in inverse_model.parameters() if p.requires_grad)
print(inverse_model)
forward_model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
forward_model = forward_model.to(device)
pytorch_total_params += sum(p.numel() for p in forward_model.parameters() if p.requires_grad)
print(forward_model)
print("#Parameters:{}".format(pytorch_total_params))
criterion = torch.nn.MSELoss(reduction="none")
optimizer_embedding = torch.optim.Adam(list(inverse_model.parameters()), lr=meta_learning_rate)
optimizer_forward = torch.optim.Adam(list(forward_model.parameters()), lr=meta_learning_rate)
# print("#Parameters:{}".format(pytorch_total_params))
# print(model)

# LOAD MODEL
inverse_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}_forward".format(model_name))))
inverse_model.eval()
forward_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}_inverse".format(model_name))))
forward_model.eval()

# LOAD Few shot DATA
file, index = "strided_train", "in_indices"
dataset = load_dataset(file)
data = get_data(dataset, index,fold=fold)[:,-(few_shot_years*2-1):]
nodes, years, window, channels = data.shape
# print(nodes, years, window, channels)

# LOAD DATA
file, index = "strided_test", "in_indices"
dataset = load_dataset(file)
data_test = get_data(dataset, index,fold=fold)
nodes, years, window, channels = data_test.shape

np.random.seed(0)
# static_data = np.random.normal(0, 1, size = (nodes, latent_code_dim))
# static_data = np.repeat(static_data[:, np.newaxis,np.newaxis],window,axis=2)

dataset_true = unknown*np.ones((nodes, years, window, len(output_channels)), dtype=np.float32)
dataset_pred = unknown*np.ones((nodes, years, window, len(output_channels)), dtype=np.float32)


#Get instance for each node
node_data_train = data[np.arange(nodes)]
node_data_test = data_test[np.arange(nodes)]
# print(node_data.shape)

node_data_train = torch.from_numpy(node_data_train).to(device)
node_data_test = torch.from_numpy(node_data_test).to(device)
# static_data = torch.from_numpy(static_data).to(device)


support_data, query_data = node_data_train,node_data_test

# print(support_data.shape)



for i in range(nodes):
    print(i)

    # GET Basin support and query data
    basin_support_data, basin_query_data = support_data[i], query_data[i]
    
    # print(basin_support_data.shape)
    # print(basin_query_data.shape)
    
    basin_dynamic_support_input = basin_support_data[:, :, dynamic_channels].to(device)
    basin_dynamic_support_input_output = basin_support_data[:, :, dynamic_channels+output_channels].to(device)
    basin_static_data,_,_,_ = inverse_model(x=basin_dynamic_support_input_output)               
    basin_static_data = torch.mean(basin_static_data,axis=0)
    basin_static_support_input = torch.repeat_interleave(basin_static_data.unsqueeze(0),basin_dynamic_support_input.shape[0],axis=0).to(device)
    basin_static_support_input = torch.repeat_interleave(basin_static_support_input.unsqueeze(1),window,axis=1).to(device)
    # basin_static_support_input = basin_support_data[:, :, static_channels].to(device)
    basin_support_label = basin_support_data[:, :, output_channels].to(device)

    basin_dynamic_query_input = basin_query_data[:, :, dynamic_channels].to(device)
    basin_dynamic_query_input_output = basin_query_data[:, :, dynamic_channels+output_channels].to(device)
    basin_static_query_input = torch.repeat_interleave(basin_static_data.unsqueeze(0),basin_dynamic_query_input.shape[0],axis=0).to(device)
    basin_static_query_input = torch.repeat_interleave(basin_static_query_input.unsqueeze(1),window,axis=1).to(device)
    basin_query_label = basin_query_data[:, :, output_channels].to(device)

    basin_model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
    basin_model = basin_model.to(device)
    basin_model.load_state_dict(forward_model.state_dict())

    for step in range(num_inner_steps):

        # GET OUTPUT
        # print(basin_dynamic_support_input.shape,basin_static_support_input)
        batch_pred = forward_model(x_dynamic=basin_dynamic_support_input, x_static=basin_static_support_input.float() )
        # print(batch_pred.shape)

        # CALCULATE LOSS
        basin_loss = criterion(basin_support_label, batch_pred)											# PER CHANNEL LOSS
        mask = (basin_support_label!=unknown).float()													# CREATE MASK
        basin_loss = basin_loss * mask															# MULTIPLY MASK
        basin_loss, mask = torch.sum(basin_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER INSTANCE LOSS
        basin_loss = torch.sum(basin_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS

        # LOSS BACKPROPOGATE
        grad = torch.autograd.grad(basin_loss, forward_model.parameters(),create_graph=True, allow_unused=True)
        fast_weights = list(map(lambda p: p[1] - learning_rate * p[0], zip(grad, forward_model.parameters())))
        for param, fast_params in zip(forward_model.parameters(), fast_weights):
            param.data = fast_params



    # GET Basin support and query data

    batch_pred = forward_model(x_dynamic=basin_dynamic_query_input,x_static=basin_static_query_input.float())
    batch_label = basin_query_label

    # print(batch_pred.shape)

    # STORE OUTPUT
    dataset_true[i] = batch_label.detach().cpu().numpy()
    dataset_pred[i] = batch_pred.detach().cpu().numpy()

dataset_true = (dataset_true*dataset["train_data_stds"][fold][output_channels])+dataset["train_data_means"][fold][output_channels]
dataset_pred = (dataset_pred*dataset["train_data_stds"][fold][output_channels])+dataset["train_data_means"][fold][output_channels]
dataset_true = UTILS.unstride_array(dataset_true)
dataset_pred = UTILS.unstride_array(dataset_pred)
dataset_true = dataset_true[:, stride:]
dataset_pred = dataset_pred[:, stride:]

per_sample_RMSE = UTILS.per_sample_RMSE(dataset_true, dataset_pred, unknown)
_, per_node_RMSE = UTILS.per_node_RMSE(dataset_true, dataset_pred, unknown)
per_sample_R2 = UTILS.per_sample_R2(dataset_true, dataset_pred, unknown)
_, per_node_R2 = UTILS.per_node_R2(dataset_true, dataset_pred, unknown)
print("Per Sample RMSE:{:.4f}\tPer Node RMSE:{:.4f}\tPer Sample R2:{:.4f}\tPer Node R2:{:.4f}".format(per_sample_RMSE, per_node_RMSE, per_sample_R2, per_node_R2))
np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "true_{}".format(fold))), dataset_true)
np.save(os.path.join(RESULT_DIR, "{}_{}_{}_{}".format(file, index, few_shot_years, model_name)), dataset_pred)


end = time.time()
print("Time:{:.4f}".format(end-start))

In [None]:
print("OUT\tfold:{}\tinit:{}".format(fold, init))
start = time.time()

# BUILD MODEL
inverse_model = getattr(MODEL, "ae")(input_channels=len(dynamic_channels)+len(output_channels), code_dim=latent_code_dim, hidden_dim=latent_code_dim, output_channels=len(static_channels), device=device)
inverse_model = inverse_model.to(device)
pytorch_total_params = sum(p.numel() for p in inverse_model.parameters() if p.requires_grad)
print(inverse_model)
forward_model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
forward_model = forward_model.to(device)
pytorch_total_params += sum(p.numel() for p in forward_model.parameters() if p.requires_grad)
print(forward_model)
print("#Parameters:{}".format(pytorch_total_params))
criterion = torch.nn.MSELoss(reduction="none")
optimizer_embedding = torch.optim.Adam(list(inverse_model.parameters()), lr=meta_learning_rate)
optimizer_forward = torch.optim.Adam(list(forward_model.parameters()), lr=meta_learning_rate)
# print("#Parameters:{}".format(pytorch_total_params))
# print(model)

# LOAD MODEL
inverse_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}_forward".format(model_name))))
inverse_model.eval()
forward_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}_inverse".format(model_name))))
forward_model.eval()

# LOAD Few shot DATA
file, index = "strided_train", "out_indices"
dataset = load_dataset(file)
data = get_data(dataset, index,fold=fold)[:,-(few_shot_years*2-1):]
nodes, years, window, channels = data.shape
# print(nodes, years, window, channels)

# LOAD DATA
file, index = "strided_test", "out_indices"
dataset = load_dataset(file)
data_test = get_data(dataset, index,fold=fold)
nodes, years, window, channels = data_test.shape

np.random.seed(0)
# static_data = np.random.normal(0, 1, size = (nodes, latent_code_dim))
# static_data = np.repeat(static_data[:, np.newaxis,np.newaxis],window,axis=2)

dataset_true = unknown*np.ones((nodes, years, window, len(output_channels)), dtype=np.float32)
dataset_pred = unknown*np.ones((nodes, years, window, len(output_channels)), dtype=np.float32)


#Get instance for each node
node_data_train = data[np.arange(nodes)]
node_data_test = data_test[np.arange(nodes)]
# print(node_data.shape)

node_data_train = torch.from_numpy(node_data_train).to(device)
node_data_test = torch.from_numpy(node_data_test).to(device)
# static_data = torch.from_numpy(static_data).to(device)


support_data, query_data = node_data_train,node_data_test

# print(support_data.shape)



for i in range(nodes):
    print(i)

    # GET Basin support and query data
    basin_support_data, basin_query_data = support_data[i], query_data[i]
    
    # print(basin_support_data.shape)
    # print(basin_query_data.shape)
    
    basin_dynamic_support_input = basin_support_data[:, :, dynamic_channels].to(device)
    basin_dynamic_support_input_output = basin_support_data[:, :, dynamic_channels+output_channels].to(device)
    basin_static_data,_,_,_ = inverse_model(x=basin_dynamic_support_input_output)               
    basin_static_data = torch.mean(basin_static_data,axis=0)
    basin_static_support_input = torch.repeat_interleave(basin_static_data.unsqueeze(0),basin_dynamic_support_input.shape[0],axis=0).to(device)
    basin_static_support_input = torch.repeat_interleave(basin_static_support_input.unsqueeze(1),window,axis=1).to(device)
    # basin_static_support_input = basin_support_data[:, :, static_channels].to(device)
    basin_support_label = basin_support_data[:, :, output_channels].to(device)

    basin_dynamic_query_input = basin_query_data[:, :, dynamic_channels].to(device)
    basin_dynamic_query_input_output = basin_query_data[:, :, dynamic_channels+output_channels].to(device)
    basin_static_query_input = torch.repeat_interleave(basin_static_data.unsqueeze(0),basin_dynamic_query_input.shape[0],axis=0).to(device)
    basin_static_query_input = torch.repeat_interleave(basin_static_query_input.unsqueeze(1),window,axis=1).to(device)
    basin_query_label = basin_query_data[:, :, output_channels].to(device)

    basin_model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
    basin_model = basin_model.to(device)
    basin_model.load_state_dict(forward_model.state_dict())

    for step in range(num_inner_steps):

        # GET OUTPUT
        # print(basin_dynamic_support_input.shape,basin_static_support_input)
        batch_pred = forward_model(x_dynamic=basin_dynamic_support_input, x_static=basin_static_support_input.float() )
        # print(batch_pred.shape)

        # CALCULATE LOSS
        basin_loss = criterion(basin_support_label, batch_pred)											# PER CHANNEL LOSS
        mask = (basin_support_label!=unknown).float()													# CREATE MASK
        basin_loss = basin_loss * mask															# MULTIPLY MASK
        basin_loss, mask = torch.sum(basin_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER INSTANCE LOSS
        basin_loss = torch.sum(basin_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS

        # LOSS BACKPROPOGATE
        grad = torch.autograd.grad(basin_loss, forward_model.parameters(),create_graph=True, allow_unused=True)
        fast_weights = list(map(lambda p: p[1] - learning_rate * p[0], zip(grad, forward_model.parameters())))
        for param, fast_params in zip(forward_model.parameters(), fast_weights):
            param.data = fast_params



    # GET Basin support and query data

    batch_pred = forward_model(x_dynamic=basin_dynamic_query_input,x_static=basin_static_query_input.float())
    batch_label = basin_query_label

    # print(batch_pred.shape)

    # STORE OUTPUT
    dataset_true[i] = batch_label.detach().cpu().numpy()
    dataset_pred[i] = batch_pred.detach().cpu().numpy()

dataset_true = (dataset_true*dataset["train_data_stds"][fold][output_channels])+dataset["train_data_means"][fold][output_channels]
dataset_pred = (dataset_pred*dataset["train_data_stds"][fold][output_channels])+dataset["train_data_means"][fold][output_channels]
dataset_true = UTILS.unstride_array(dataset_true)
dataset_pred = UTILS.unstride_array(dataset_pred)
dataset_true = dataset_true[:, stride:]
dataset_pred = dataset_pred[:, stride:]

per_sample_RMSE = UTILS.per_sample_RMSE(dataset_true, dataset_pred, unknown)
_, per_node_RMSE = UTILS.per_node_RMSE(dataset_true, dataset_pred, unknown)
per_sample_R2 = UTILS.per_sample_R2(dataset_true, dataset_pred, unknown)
_, per_node_R2 = UTILS.per_node_R2(dataset_true, dataset_pred, unknown)
print("Per Sample RMSE:{:.4f}\tPer Node RMSE:{:.4f}\tPer Sample R2:{:.4f}\tPer Node R2:{:.4f}".format(per_sample_RMSE, per_node_RMSE, per_sample_R2, per_node_R2))
np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "true_{}".format(fold))), dataset_true)
np.save(os.path.join(RESULT_DIR, "{}_{}_{}_{}".format(file, index, few_shot_years, model_name)), dataset_pred)


end = time.time()
print("Time:{:.4f}".format(end-start))

In [None]:
# file, index = "strided_train", "test_index"
# dataset = load_dataset(file)
# dataset_processed_data = get_data(dataset, index, preprocessed=False,dataset_fold=dataset_fold)
# dataset_processed_data = data_normalize(dataset_processed_data,norm_region)[:,:,:,:-1]     
# dataset_processed_data = dataset_processed_data[:10]  
# lakes, years, window, channels = dataset_processed_data.shape
# dataset_obs = get_data(dataset, index, preprocessed=False,dataset_fold=dataset_fold)
# dataset_obs = data_normalize(dataset_obs,norm_region)[:,:,:,-1]   
# dataset_obs = dataset_obs[:10] 
# batch_processed_data_support, batch_obs_support =  dataset_processed_data, dataset_obs 

# file, index = "strided_test", "test_index"
# dataset = load_dataset(file)
# dataset_processed_data = get_data(dataset, index, preprocessed=False,dataset_fold=dataset_fold)
# dataset_processed_data = data_normalize(dataset_processed_data,norm_region)[:,:,:,:-1]     
# dataset_processed_data = dataset_processed_data[:10]  
# lakes, years, window, channels = dataset_processed_data.shape
# dataset_obs = get_data(dataset, index, preprocessed=False,dataset_fold=dataset_fold)
# dataset_obs = data_normalize(dataset_obs,norm_region)[:,:,:,-1]   
# dataset_obs = dataset_obs[:10]

# total_runs=1
# dataset_true = unknown*np.ones((lakes, years, window), dtype=np.float32)
# dataset_pred = unknown*np.ones((lakes, years, window), dtype=np.float32)


# batch_processed_data_querry, batch_obs_querry = dataset_processed_data, dataset_obs 
# batch_loss = []
# for i,lake in enumerate(batch_lake):
#     # GET PARAMETER COPY FOR TASK  
#     inverse_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}_inverse".format(model_name))))
#     inverse_model.eval()
#     forward_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}_forward".format(model_name))))
#     # GET LAKE DATA
#     lake_processed_data_support, lake_obs_support, lake_processed_data_querry, lake_obs_querry = batch_processed_data_support[i], batch_obs_support[i], batch_processed_data_querry[i], batch_obs_querry[i]
#     # GET LAKE DATA SUPPORT
#     lake_processed_data = torch.from_numpy(lake_processed_data_support).unsqueeze(0).to(device)
#     lake_obs = torch.from_numpy(lake_obs_support).unsqueeze(0).to(device)
#     # lake_code_vec =  torch.from_numpy(mean_lake_code[lake]).unsqueeze(0).to(device)               

#     # GET OUTPUT
#     lake_processed_data_dynamic = lake_processed_data[:,  :, :, dynamic_channels]           
#     lake_processed_data_static = lake_processed_data[:,  :, :, static_channels] 
#     lake_processed_data_inputs_embedding = torch.cat((lake_processed_data_dynamic,lake_obs.unsqueeze(-1)),axis = 3)         
#     lake_processed_data_inputs_embedding = UTILS.unstride_array(lake_processed_data_inputs_embedding.cpu())             
#     lake_processed_data_inputs_embedding = torch.from_numpy(lake_processed_data_inputs_embedding).to(device).float()
#     lake_code_vec,_,_ = inverse_model(x=lake_processed_data_inputs_embedding)               
#     lake_code_vec_processed =  lake_code_vec.unsqueeze(1).repeat(1, years, 1)   
#     lake_code_vec_processed =  lake_code_vec_processed.unsqueeze(2).repeat(1, 1, window, 1) 
#     for step in range(num_inner_steps):                  
#         out = forward_model(x_dynamic=lake_processed_data_dynamic, x_static=lake_code_vec_processed)

#         # CALCULATE LOSS
#         lake_loss = UTILS.per_lake_loss(y_true=lake_obs, y_pred=out, criterion=criterion, unknown=unknown)

#         # LOSS BACKPROPOGATE
#         grad = torch.autograd.grad(lake_loss, forward_model.parameters(),create_graph=True, allow_unused=True)
# # 				for (name, param), grad in zip(params.items(), grads):
# # 				if grad_clip > 0 and grad is not None:
# # 					grad = grad.clamp(min=-grad_clip,max=grad_clip)                
#         fast_weights = list(map(lambda p: p[1] - learning_rate * p[0], zip(grad, forward_model.parameters())))
#         for param, fast_params in zip(forward_model.parameters(), fast_weights):
#             param.data = fast_params

#     # GET LAKE DATA QUERRY
    
#     lake_processed_data = torch.from_numpy(lake_processed_data_querry).unsqueeze(0).to(device)
#     lake_obs = torch.from_numpy(lake_obs_querry).unsqueeze(0).to(device)


#     # GET OUTPUT
#     lake_processed_data_dynamic = lake_processed_data[:,  :, :, dynamic_channels]
#     lake_processed_data_static = lake_processed_data[:,  :, :, static_channels]
#     lake_processed_data_inputs_embedding = torch.cat((lake_processed_data_dynamic,lake_obs.unsqueeze(-1)),axis = 3)          
#     lake_processed_data_inputs_embedding = UTILS.unstride_array(lake_processed_data_inputs_embedding.cpu())             
#     lake_processed_data_inputs_embedding = torch.from_numpy(lake_processed_data_inputs_embedding).to(device).float()
#     lake_code_vec,_,_ = inverse_model(x=lake_processed_data_inputs_embedding) 
#     lake_code_vec_processed =  lake_code_vec.unsqueeze(1).repeat(1, years, 1)   
#     lake_code_vec_processed =  lake_code_vec_processed.unsqueeze(2).repeat(1, 1, window, 1)     
#     out = forward_model(x_dynamic=lake_processed_data_dynamic, x_static=lake_code_vec_processed)
#     # print(out.shape)
#     # print(lake_obs.shape)
#     lake_loss = UTILS.per_lake_loss(y_true=lake_obs, y_pred=out, criterion=criterion, unknown=unknown)
#     batch_loss.append(lake_loss)
    
    
#         # STORE OUTPUT
#     dataset_true[i, :,:] = lake_obs.detach().cpu().numpy()
#     dataset_pred[i, :,:] = out.detach().cpu().numpy()
    
    
    
# # dataset_true_ensemble =   np.mean(dataset_true,axis=-1)   
# # dataset_pred_ensemble =   np.mean(dataset_pred,axis=-1) 
# # print(dataset_true_ensemble.shape)
# print("TRUE_OBS",dataset_true.shape)
# dataset_true = (dataset_true*dataset["train_data_stds"][output_channels])+dataset["train_data_means"][output_channels]
# dataset_pred = (dataset_pred*dataset["train_data_stds"][output_channels])+dataset["train_data_means"][output_channels]

# dataset_true = np.expand_dims(dataset_true, axis=3)
# dataset_pred = np.expand_dims(dataset_pred, axis=3)

# dataset_true = UTILS.unstride_array(dataset_true)
# dataset_pred = UTILS.unstride_array(dataset_pred)
# dataset_true = dataset_true[:, stride:]
# dataset_pred = dataset_pred[:, stride:]

# per_sample_RMSE = UTILS.per_sample_RMSE(dataset_true, dataset_pred, unknown)
# all_node_RMSE, per_node_RMSE = UTILS.per_node_RMSE(dataset_true, dataset_pred, unknown)
# per_sample_R2 = UTILS.per_sample_R2(dataset_true, dataset_pred, unknown)
# all_node_r2, per_node_R2 = UTILS.per_node_R2(dataset_true, dataset_pred, unknown)
# print("Per Sample RMSE:{:.4f}\tPer Node RMSE:{:.4f}\tPer Sample R2:{:.4f}\tPer Node R2:{:.4f}".format(per_sample_RMSE, per_node_RMSE, per_sample_R2, per_node_R2))
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "true")), dataset_true)
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, model_name)), dataset_pred)    
    

In [None]:
# print("IN\tfold:{}\tinit:{}".format(fold, init))
# start = time.time()

# # BUILD MODEL
# model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
# model = model.to(device)
# pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# criterion = torch.nn.MSELoss(reduction="none")
# # print("#Parameters:{}".format(pytorch_total_params))
# # print(model)

# # LOAD MODEL
# model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}".format(model_name))))

# # LOAD Few shot DATA
# file, index = "strided_train", "in_indices"
# dataset = load_dataset(file)
# data = get_data(dataset, index,fold=fold)[:,-(few_shot_years*2-1):]
# nodes, years, window, channels = data.shape
# # print(nodes, years, window, channels)

# # LOAD DATA
# file, index = "strided_test", "in_indices"
# dataset = load_dataset(file)
# data_test = get_data(dataset, index,fold=fold)
# nodes, years, window, channels = data_test.shape

# np.random.seed(0)
# random_static_data = np.random.normal(0, 1, size = (nodes, latent_code_dim))
# random_static_data = np.repeat(random_static_data[:, np.newaxis,np.newaxis],window,axis=2)

# dataset_true = unknown*np.ones((nodes, years, window, len(output_channels)), dtype=np.float32)
# dataset_pred = unknown*np.ones((nodes, years, window, len(output_channels)), dtype=np.float32)


# #Get instance for each node
# node_data_train = data[np.arange(nodes)]
# node_data_test = data_test[np.arange(nodes)]
# # print(node_data.shape)

# node_data_train = torch.from_numpy(node_data_train).to(device)
# node_data_test = torch.from_numpy(node_data_test).to(device)
# random_static_data = torch.from_numpy(random_static_data).to(device)
# support_data, query_data = node_data_train,node_data_test

# # print(support_data.shape)



# for i in range(nodes):
#     print(i)

#     # GET Basin support and query data
#     basin_support_data, basin_query_data = support_data[i], query_data[i]
#     basin_random_static_data = random_static_data[i]
    
#     # print(basin_support_data.shape)
#     # print(basin_query_data.shape)
    
#     basin_dynamic_support_input = basin_support_data[:, :, dynamic_channels].to(device)
#     basin_static_support_input = torch.repeat_interleave(basin_random_static_data,basin_dynamic_support_input.shape[0],axis=0).to(device)
#     basin_support_label = basin_support_data[:, :, output_channels].to(device)

#     basin_dynamic_query_input = basin_query_data[:, :, dynamic_channels].to(device)
#     basin_static_query_input = torch.repeat_interleave(basin_random_static_data,basin_dynamic_query_input.shape[0],axis=0).to(device)
#     basin_query_label = basin_query_data[:, :, output_channels].to(device)

#     basin_model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
#     basin_model = basin_model.to(device)
#     basin_model.load_state_dict(model.state_dict())

#     for step in range(num_inner_steps):
#         # GET OUTPUT
#         # print("basin_dynamic_support_input",basin_dynamic_support_input.shape)
#         batch_pred = model(x_dynamic=basin_dynamic_support_input,x_static=basin_static_support_input.float())
#         # print(batch_pred.shape)

#         # CALCULATE LOSS
#         basin_loss = criterion(basin_support_label, batch_pred)											# PER CHANNEL LOSS
#         mask = (basin_support_label!=unknown).float()													# CREATE MASK
#         basin_loss = basin_loss * mask															# MULTIPLY MASK
#         basin_loss, mask = torch.sum(basin_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER INSTANCE LOSS
#         basin_loss = torch.sum(basin_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS

#         # LOSS BACKPROPOGATE
#         grad = torch.autograd.grad(basin_loss, model.parameters(),create_graph=True, allow_unused=True)
#         fast_weights = list(map(lambda p: p[1] - learning_rate * p[0], zip(grad, model.parameters())))
#         for param, fast_params in zip(model.parameters(), fast_weights):
#             param.data = fast_params



#     model.eval()
#     # GET Basin support and query data

#     batch_pred = model(x_dynamic=basin_dynamic_query_input,x_static=basin_static_query_input.float())
#     batch_label = basin_query_label

#     # print(batch_pred.shape)

#     # STORE OUTPUT
#     dataset_true[i] = batch_label.detach().cpu().numpy()
#     dataset_pred[i] = batch_pred.detach().cpu().numpy()

# dataset_true = (dataset_true*dataset["train_data_stds"][fold][output_channels])+dataset["train_data_means"][fold][output_channels]
# dataset_pred = (dataset_pred*dataset["train_data_stds"][fold][output_channels])+dataset["train_data_means"][fold][output_channels]
# dataset_true = UTILS.unstride_array(dataset_true)
# dataset_pred = UTILS.unstride_array(dataset_pred)
# dataset_true = dataset_true[:, stride:]
# dataset_pred = dataset_pred[:, stride:]

# per_sample_RMSE = UTILS.per_sample_RMSE(dataset_true, dataset_pred, unknown)
# _, per_node_RMSE = UTILS.per_node_RMSE(dataset_true, dataset_pred, unknown)
# per_sample_R2 = UTILS.per_sample_R2(dataset_true, dataset_pred, unknown)
# _, per_node_R2 = UTILS.per_node_R2(dataset_true, dataset_pred, unknown)
# print("Per Sample RMSE:{:.4f}\tPer Node RMSE:{:.4f}\tPer Sample R2:{:.4f}\tPer Node R2:{:.4f}".format(per_sample_RMSE, per_node_RMSE, per_sample_R2, per_node_R2))
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "true_{}".format(fold))), dataset_true)
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}_{}".format(file, index, few_shot_years, model_name)), dataset_pred)


# end = time.time()
# print("Time:{:.4f}".format(end-start))

## OUT DISTRIBUTION

In [None]:
# print("Out\tfold:{}\tinit:{}".format(fold, init))
# start = time.time()

# # BUILD MODEL
# model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
# model = model.to(device)
# pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# # print("#Parameters:{}".format(pytorch_total_params))
# # print(model)

# # LOAD MODEL
# model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "{}".format(model_name))))

# # LOAD Few shot DATA
# file, index = "strided_train", "out_indices"
# dataset = load_dataset(file)
# data = get_data(dataset, index,fold=fold)[:,-(few_shot_years*2-1):]
# nodes, years, window, channels = data.shape
# # print(nodes, years, window, channels)

# # LOAD DATA
# file, index = "strided_test", "out_indices"
# dataset = load_dataset(file)
# data_test = get_data(dataset, index,fold=fold)
# nodes, years, window, channels = data_test.shape

# dataset_true = unknown*np.ones((nodes, years, window, len(output_channels)), dtype=np.float32)
# dataset_pred = unknown*np.ones((nodes, years, window, len(output_channels)), dtype=np.float32)


# #Get instance for each node
# node_data_train = data[np.arange(nodes)]
# node_data_test = data_test[np.arange(nodes)]
# # print(node_data.shape)

# node_data_train = torch.from_numpy(node_data_train).to(device)
# node_data_test = torch.from_numpy(node_data_test).to(device)
# support_data, query_data = node_data_train,node_data_test

# # print(support_data.shape)



# for i in range(nodes):
#     print(i)

#     # GET Basin support and query data
#     basin_support_data, basin_query_data = support_data[i], query_data[i]
#     # print(basin_support_data.shape)
#     # print(basin_query_data.shape)
    
#     basin_dynamic_support_input = basin_support_data[:, :, dynamic_channels].to(device)
#     basin_static_support_input = basin_support_data[:, :, static_channels].to(device)
#     basin_support_label = basin_support_data[:, :, output_channels].to(device)

#     basin_dynamic_query_input = basin_query_data[:, :, dynamic_channels].to(device)
#     basin_static_query_input = basin_query_data[:, :, static_channels].to(device)
#     basin_query_label = basin_query_data[:, :, output_channels].to(device)

#     basin_model = getattr(MODEL, "tamlstm")(input_dynamic_channels=len(dynamic_channels), input_static_channels=latent_code_dim, hidden_dim=forward_code_dim, output_channels=len(output_channels), dropout=dropout)
#     basin_model = basin_model.to(device)
#     basin_model.load_state_dict(model.state_dict())

#     for step in range(num_inner_steps):
#         # GET OUTPUT
#         # print("basin_dynamic_support_input",basin_dynamic_support_input.shape)
#         batch_pred = model(x_dynamic=basin_dynamic_support_input)
#         # print(batch_pred.shape)

#         # CALCULATE LOSS
#         basin_loss = criterion(basin_support_label, batch_pred)											# PER CHANNEL LOSS
#         mask = (basin_support_label!=unknown).float()													# CREATE MASK
#         basin_loss = basin_loss * mask															# MULTIPLY MASK
#         basin_loss, mask = torch.sum(basin_loss, dim=2), (torch.sum(mask, dim=2)>0).float()		# PER INSTANCE LOSS
#         basin_loss = torch.sum(basin_loss)/torch.sum(mask)										# MEAN SEQUENCE LOSS

#         # LOSS BACKPROPOGATE
#         grad = torch.autograd.grad(basin_loss, model.parameters(),create_graph=True, allow_unused=True)
#         fast_weights = list(map(lambda p: p[1] - learning_rate * p[0], zip(grad, model.parameters())))
#         for param, fast_params in zip(model.parameters(), fast_weights):
#             param.data = fast_params



#     model.eval()
#     # GET Basin support and query data

#     batch_pred = model(x_dynamic=basin_dynamic_query_input)
#     batch_label = basin_query_label

#     # print(batch_pred.shape)

#     # STORE OUTPUT
#     dataset_true[i] = batch_label.detach().cpu().numpy()
#     dataset_pred[i] = batch_pred.detach().cpu().numpy()

# dataset_true = (dataset_true*dataset["train_data_stds"][fold][output_channels])+dataset["train_data_means"][fold][output_channels]
# dataset_pred = (dataset_pred*dataset["train_data_stds"][fold][output_channels])+dataset["train_data_means"][fold][output_channels]
# dataset_true = UTILS.unstride_array(dataset_true)
# dataset_pred = UTILS.unstride_array(dataset_pred)
# dataset_true = dataset_true[:, stride:]
# dataset_pred = dataset_pred[:, stride:]

# per_sample_RMSE = UTILS.per_sample_RMSE(dataset_true, dataset_pred, unknown)
# _, per_node_RMSE = UTILS.per_node_RMSE(dataset_true, dataset_pred, unknown)
# per_sample_R2 = UTILS.per_sample_R2(dataset_true, dataset_pred, unknown)
# _, per_node_R2 = UTILS.per_node_R2(dataset_true, dataset_pred, unknown)
# print("Per Sample RMSE:{:.4f}\tPer Node RMSE:{:.4f}\tPer Sample R2:{:.4f}\tPer Node R2:{:.4f}".format(per_sample_RMSE, per_node_RMSE, per_sample_R2, per_node_R2))
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}".format(file, index, "true_{}".format(fold))), dataset_true)
# np.save(os.path.join(RESULT_DIR, "{}_{}_{}_{}".format(file, index, few_shot_years, model_name)), dataset_pred)


# end = time.time()
# print("Time:{:.4f}".format(end-start))