# IMPORT LIBRARIES

In [None]:
import sys
sys.path.append("../")

import config
import MODEL

import os
import math
import numpy as np
import pandas as pd
from collections import Counter
import random

import torch
from torch.utils.data.dataset import Dataset
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

torch.manual_seed(0)

# HYPERPARAMETERS

In [None]:
# TIME SERIES INFO
window = config.window

# CHANNELS INFO
dynamic_channels = config.dynamic_channels
static_channels = config.static_channels
output_channels = config.output_channels

# LABELS INFO
unknown = config.unknown

# MODEL INFO
model_name = "kgssl"
code_dim = config.code_dim
# device = torch.device("cpu")
device = torch.device(config.device)
recon_weight = config.recon_weight
contrastive_weight = config.contrastive_weight
static_weight = config.static_weight
sum_weight = recon_weight+contrastive_weight+static_weight

# TRAIN INFO
train = config.train
batch_size = config.batch_size
epochs = config.epochs
learning_rate = config.learning_rate

print("Hyperparameters:{}".format(model_name))
print("window : {}".format(window))
print("dynamic_channels : {}".format(dynamic_channels))
print("static_channels : {}".format(static_channels))
print("output_channels : {}".format(output_channels))
print("unknown : {}".format(unknown))
print("model_name : {}".format(model_name))
print("code_dim : {}".format(code_dim))
print("device : {}".format(device))
print("recon_weight : {}".format(recon_weight))
print("contrastive_weight : {}".format(contrastive_weight))
print("static_weight : {}".format(static_weight))
print("train : {}".format(train))
print("batch_size : {}".format(batch_size))
print("epochs : {}".format(epochs))
print("learning_rate : {}".format(learning_rate))

# DEFINE DIRECTORIES

In [None]:
PREPROCESSED_DIR = config.PREPROCESSED_DIR
RESULT_DIR = config.RESULT_DIR
MODEL_DIR = config.MODEL_DIR

# LOAD DATA

In [None]:
def load_dataset(file):
	dataset = np.load(os.path.join(PREPROCESSED_DIR, "{}.npz".format(file)), allow_pickle=True)
	return dataset

def get_data(dataset, index, preprocessed=True):
	data = dataset["data"]
	if preprocessed:
		data = (data-dataset["train_data_means"])/dataset["train_data_stds"]
	data = np.nan_to_num(data, nan=unknown)
	data = data[dataset[index]]
	return data

# BUILD MODEL

In [None]:
model = getattr(MODEL, "ae")(input_channels=len(dynamic_channels)+len(output_channels), code_dim=code_dim, output_channels=len(static_channels), device=device)
model = model.to(device)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("#Parameters:{}".format(pytorch_total_params))
print(model)
mse_criterion = torch.nn.MSELoss(reduction="none")
contrastive_criterion = MODEL.SimCLRLoss(temperature=5e-01)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# TRAIN MODEL

In [None]:
if train:

	train_loss = []
	valid_loss = []
	min_loss = 10000

	for epoch in range(1,epochs+1):

		# LOSS ON TRAIN SET
		model.train()

		# LOAD DATA
		file, index = "strided_train", "train_index"
		dataset = load_dataset(file)
		data = get_data(dataset, index)
		nodes, years, window, channels = data.shape
		# print(nodes, years, window, channels)

		# GET ANCHOR AND POSITIVE YEARS
		anchor_years = np.zeros((nodes, years))
		for node in range(nodes):
			anchor_years[node] = random.sample(range(years), years)
		anchor_years = anchor_years.astype(np.int64)
		positive_years = np.zeros((nodes, years))
		for node in range(nodes):
			positive_years[node] = random.sample(range(years), years)
		positive_years = positive_years.astype(np.int64)
		# print(anchor_years.shape, positive_years.shape)

		# LOSS
		model.train()
		epoch_loss = 0
		epoch_recon_loss = 0
		epoch_contrastive_loss = 0
		epoch_static_loss = 0
		for year in range(anchor_years.shape[1]):

			#Get (anchor,positive) Instances for each node
			anchor_data = data[np.arange(nodes), anchor_years[:, year]]
			positive_data = data[np.arange(nodes), positive_years[:, year]]
			# print(anchor_data.shape, positive_data.shape)

			# Remove pairs where (anchor,positive) years are same
			keep_idx = anchor_years[:, year] != positive_years[:, year]
			anchor_data = anchor_data[keep_idx]
			positive_data = positive_data[keep_idx]
			# print(anchor_data.shape, positive_data.shape)

			# Remove pairs where (anchor,positive) basins have unknown in streamflow
			keep_idx = np.zeros((anchor_data.shape[0], 2)).astype(bool)
			keep_idx[:,0] = (anchor_data[:,:,-1]!=unknown).all(axis=1)
			keep_idx[:,1] = (positive_data[:,:,-1]!=unknown).all(axis=1)
			keep_idx = keep_idx.all(axis=1)
			anchor_data = anchor_data[keep_idx]
			positive_data = positive_data[keep_idx]
			# print(anchor_data.shape, positive_data.shape)

			random_batches = random.sample(range(anchor_data.shape[0]),anchor_data.shape[0])
			for batch in range(math.ceil(anchor_data.shape[0]/batch_size)):

				optimizer.zero_grad()

				# GET BATCH DATA
				random_batch = random_batches[batch*batch_size:(batch+1)*batch_size]
				batch_anchor_data = torch.from_numpy(anchor_data[random_batch]).to(device)
				batch_positive_data = torch.from_numpy(positive_data[random_batch]).to(device)
				batch_input = torch.cat((batch_anchor_data[:,:,dynamic_channels+output_channels], batch_positive_data[:,:,dynamic_channels+output_channels]), dim=0)
				batch_static = torch.cat((batch_anchor_data[:,0,static_channels], batch_positive_data[:,0,static_channels]), axis=0)
				# print(batch_input.shape, batch_static.shape)

				# GET OUTPUT
				batch_code_vec, batch_static_pred, batch_input_pred = model(x=batch_input)
				# print(batch_code_vec.shape, batch_static_pred.shape, batch_input_pred.shape)

				# CALCULATE LOSS
				batch_recon_loss = mse_criterion(batch_input_pred, batch_input)								# RECON LOSS
				batch_recon_loss = torch.sum(batch_recon_loss, axis=2)										# RECON LOSS
				batch_recon_loss = torch.mean(batch_recon_loss)												# RECON LOSS
				batch_contrastive_loss = contrastive_criterion(batch_code_vec)								# CONTRASTIVE LOSS
				batch_static_loss = torch.mean(mse_criterion(batch_static_pred, batch_static), axis=1)		# INVERSE LOSS
				batch_static_loss = torch.mean(batch_static_loss)											# INVERSE LOSS
				batch_loss = (recon_weight*batch_recon_loss + contrastive_weight*batch_contrastive_loss + static_weight*batch_static_loss)/sum_weight
				# print(batch_loss.shape, batch_loss)

				# LOSS BACKPROPOGATE
				batch_loss.backward()
				optimizer.step()

				# AGGREGATE LOSS
				epoch_loss += batch_loss.item()
				epoch_recon_loss += batch_recon_loss.item()
				epoch_contrastive_loss += batch_contrastive_loss.item()
				epoch_static_loss += batch_static_loss.item()

		epoch_loss /= ((batch+1)*(year+1))
		epoch_recon_loss /= ((batch+1)*(year+1))
		epoch_contrastive_loss /= ((batch+1)*(year+1))
		epoch_static_loss /= ((batch+1)*(year+1))
		print('Epoch:{}\tTrain Loss:{:.4f}\tRecon Loss:{:.4f}\tCont Loss:{:.4f}\tStatic Loss:{:.4f}'.format(epoch, epoch_loss, epoch_recon_loss, epoch_contrastive_loss, epoch_static_loss), end="\t")
		train_loss.append(epoch_loss)

		# LOSS ON VALIDATION SET
		model.eval()

		# LOAD DATA
		file, index = "strided_valid", "train_index"
		dataset = load_dataset(file)
		data = get_data(dataset, index)
		nodes, years, window, channels = data.shape
		# print(nodes, years, window, channels)

		# GET ANCHOR AND POSITIVE YEARS
		anchor_years = np.zeros((nodes, years))
		for node in range(nodes):
			anchor_years[node] = random.sample(range(years), years)
		anchor_years = anchor_years.astype(np.int64)
		positive_years = np.zeros((nodes, years))
		for node in range(nodes):
			positive_years[node] = random.sample(range(years), years)
		positive_years = positive_years.astype(np.int64)
		# print(anchor_years.shape, positive_years.shape)

		# LOSS
		epoch_loss = 0
		epoch_recon_loss = 0
		epoch_contrastive_loss = 0
		epoch_static_loss = 0
		for year in range(anchor_years.shape[1]):

			#Get (anchor,positive) Instances for each node
			anchor_data = data[np.arange(nodes), anchor_years[:, year]]
			positive_data = data[np.arange(nodes), positive_years[:, year]]
			# print(anchor_data.shape, positive_data.shape)

			# Remove pairs where (anchor,positive) years are same
			keep_idx = anchor_years[:, year] != positive_years[:, year]
			anchor_data = anchor_data[keep_idx]
			positive_data = positive_data[keep_idx]
			# print(anchor_data.shape, positive_data.shape)

			# Remove pairs where (anchor,positive) basins have unknown in streamflow
			keep_idx = np.zeros((anchor_data.shape[0], 2)).astype(bool)
			keep_idx[:,0] = (anchor_data[:,:,-1]!=unknown).all(axis=1)
			keep_idx[:,1] = (positive_data[:,:,-1]!=unknown).all(axis=1)
			keep_idx = keep_idx.all(axis=1)
			anchor_data = anchor_data[keep_idx]
			positive_data = positive_data[keep_idx]
			# print(anchor_data.shape, positive_data.shape)

			random_batches = random.sample(range(anchor_data.shape[0]),anchor_data.shape[0])
			for batch in range(math.ceil(anchor_data.shape[0]/batch_size)):

				# GET BATCH DATA
				random_batch = random_batches[batch*batch_size:(batch+1)*batch_size]
				batch_anchor_data = torch.from_numpy(anchor_data[random_batch]).to(device)
				batch_positive_data = torch.from_numpy(positive_data[random_batch]).to(device)
				batch_input = torch.cat((batch_anchor_data[:,:,dynamic_channels+output_channels], batch_positive_data[:,:,dynamic_channels+output_channels]), dim=0)
				batch_static = torch.cat((batch_anchor_data[:,0,static_channels], batch_positive_data[:,0,static_channels]), axis=0)
				# print(batch_input.shape, batch_static.shape)

				# GET OUTPUT
				batch_code_vec, batch_static_pred, batch_input_pred = model(x=batch_input)
				# print(batch_code_vec.shape, batch_static_pred.shape, batch_input_pred.shape)

				# CALCULATE LOSS
				batch_recon_loss = mse_criterion(batch_input_pred, batch_input)								# RECON LOSS
				batch_recon_loss = torch.sum(batch_recon_loss, axis=2)										# RECON LOSS
				batch_recon_loss = torch.mean(batch_recon_loss)												# RECON LOSS
				batch_contrastive_loss = contrastive_criterion(batch_code_vec)								# CONTRASTIVE LOSS
				batch_static_loss = torch.mean(mse_criterion(batch_static_pred, batch_static), axis=1)		# INVERSE LOSS
				batch_static_loss = torch.mean(batch_static_loss)											# INVERSE LOSS
				batch_loss = (recon_weight*batch_recon_loss + contrastive_weight*batch_contrastive_loss + static_weight*batch_static_loss)/sum_weight
				# print(batch_loss.shape, batch_loss)

				# AGGREGATE LOSS
				epoch_loss += batch_loss.item()
				epoch_recon_loss += batch_recon_loss.item()
				epoch_contrastive_loss += batch_contrastive_loss.item()
				epoch_static_loss += batch_static_loss.item()

		epoch_loss /= ((batch+1)*(year+1))
		epoch_recon_loss /= ((batch+1)*(year+1))
		epoch_contrastive_loss /= ((batch+1)*(year+1))
		epoch_static_loss /= ((batch+1)*(year+1))
		print('Valid Loss:{:.4f}\tRecon Loss:{:.4f}\tCont Loss:{:.4f}\tStatic Loss:{:.4f}\tMin Loss:{:.4f}'.format(epoch_loss, epoch_recon_loss, epoch_contrastive_loss, epoch_static_loss, min_loss))
		valid_loss.append(epoch_loss)
		if min_loss>epoch_loss:
			min_loss = epoch_loss
			torch.save(model.state_dict(), os.path.join(MODEL_DIR, model_name))

	# PLOT LOSS
	fig = plt.figure(figsize=(10,10))
	ax1 = fig.add_subplot(111)
	ax1.set_xlabel("#Epoch", fontsize=50)

	# PLOT TRAIN LOSS
	lns1 = ax1.plot(train_loss, color='red', marker='o', linewidth=4, label="TRAIN LOSS")

	# PLOT VALIDATION LOSS
	ax2 = ax1.twinx()
	lns2 = ax2.plot(valid_loss, color='blue', marker='o', linewidth=4, label="VAL LOSS")

	# added these three lines
	lns = lns1+lns2
	labs = [l.get_label() for l in lns]
	ax1.legend(lns, labs, loc="upper right", fontsize=40, frameon=False)

	plt.tight_layout(pad=0.0,h_pad=0.0,w_pad=0.0)
	plt.savefig(os.path.join(RESULT_DIR, "{}_SCORE.pdf".format(model_name)), format = "pdf")
	plt.close()