In [None]:
import os
import sys
import json
import torch
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from model_dev.dataloader import data_provider
from model_dev.stock_picker import StockPicker
from model_dev.visualize.visualize_single import Visualize, craete_heatmap
from model_dev.visualize.scatter_plot import Scatter
from model_dev.utills import read_default_args, load_model, get_stock_meta, get_stock_heatmap_matrix

# import mse loss from torch
from torch.nn import MSELoss
# import lregularized mse loss from torch


%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings('ignore')

from model_dev.expirement import ExpMain

In [None]:
default_args = read_default_args()

In [None]:
config_file = '../../configs/config_akash.json'
data_dir_overrides = ["/Users/akashanand/repo/data/ltsf/iteration1/data_maestro",
                      "/Users/akashanand/repo/data/ltsf/iteration2/data_maestro"]


In [None]:
all_matrices = []
for iteration_idx, data_dir_override in enumerate(data_dir_overrides):
    config = json.load(open(config_file))
    if data_dir_override != "":
        config['data_dir'] = data_dir_override
    data_dir = config['data_dir']
    raw_dir = "{}/{}".format(data_dir, config['raw_data_dir'])
    csv_dir = "{}/{}".format(data_dir, config['raw_data_csv'])
    ltsf = "{}/ltsf".format(data_dir)
    print("data_dir: {}".format(data_dir))
    target_wise_attention = []
    for i in range(397):
        args = {
        'root_path': ltsf,
        'checkpoints': '{}/checkpoints/'.format(data_dir),
        'data_path': '03_23.csv',
        'seq_len': 120,
        'pred_len': 30,
        'batch_size': 1,
        'learning_rate': 0.025,
        'train_only': False,
        'train_epochs': 20,
        'data_segment': None,
        'model': 'nlinear_attention',
        'enc_in': 397,
        'patience': 5,
        'target': i,
        'stocks': None
        }

        for key, value in args.items():
            default_args[key] = value

        args = argparse.Namespace(**default_args)
        setting = 'mod_{}_sl{}_pl{}_ds_{}_tg_{}_ch_{}'.format(args.model, args.seq_len, args.pred_len, args.data_path.split('.')[0], args.target, args.enc_in)
        if iteration_idx != 0:
            setting = setting + "_" + str(iteration_idx + 1)
        weights = os.listdir("{}/{}".format(args.checkpoints, setting))
        sorted_weights = sorted(weights, key=lambda x: float(x.replace('checkpoint_','').replace('.pth','')), reverse=True)
        model = load_model(args)
        model.load_state_dict(torch.load("{}/{}/{}".format(args.checkpoints, setting, sorted_weights[-1])))
        target_wise_attention.append(model.Attention.weight.cpu().detach().numpy().flatten().tolist())
    matrix = np.array(target_wise_attention)
    all_matrices.append(matrix)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Create a figure and a 1x2 grid of subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

# Plot heatmap on the first subplot
sns.heatmap(all_matrices[0], cmap='RdBu',vmin=-1, vmax=1, ax=axs[0])
axs[0].set_title('Heatmap 0')

# Plot heatmap on the second subplot
sns.heatmap(all_matrices[1], cmap='RdBu',vmin=-1, vmax=1, ax=axs[1])
axs[1].set_title('Heatmap 1')

# Display the figure with its two subplots
plt.tight_layout() # Ensures a bit of spacing between plots
plt.show()

In [None]:
a = get_stock_meta("{}/instruments.json".format(ltsf), "{}/03_23.csv".format(ltsf))
names = []
for k, v in a.items():
    names.append(v['name'])

In [None]:
intersection = {}
num_intersection1 = {}
num_intersection2 = {}
threshold = 0.4
for i in range(len(all_matrices[0])):
    vals1 = [abs(x) for x in all_matrices[0][i]]
    vals2 = [abs(x) for x in all_matrices[1][i]]
    # Calculate intersection of indexes based on threshold
    num_intersection1[names[i]] = len([x for x in vals1 if x > threshold])
    num_intersection2[names[i]] = len([x for x in vals2 if x > threshold])
    intersection[names[i]] = [names[x] for x in range(len(vals1)) if vals1[x] > threshold and vals2[x] > threshold]
# Sort based on number of intersections
sorted_intersection = sorted(intersection.items(), key=lambda x: len(x[1]), reverse=True)
out = [(x[0], len(x[1])) for x in sorted_intersection]
for x in out:
    print(x[0], x[1], num_intersection1[x[0]], num_intersection2[x[0]])

In [None]:
intersection = {}
top_k = 20
for i in range(len(all_matrices[0])):
    vals1 = [abs(x) for x in all_matrices[0][i]]
    vals2 = [abs(x) for x in all_matrices[1][i]]
    # Calculate intersection of indexes based on threshold
    top_k_indices1 = np.argsort(vals1)[-top_k:]
    top_k_indices2 = np.argsort(vals2)[-top_k:]
    intersection[names[i]] = list(set([names[x] for x in top_k_indices1]).intersection(set([names[x] for x in top_k_indices2])))
# Sort based on number of intersections
sorted_intersection = sorted(intersection.items(), key=lambda x: len(x[1]), reverse=True)
out = [(x[0], len(x[1])) for x in sorted_intersection]
for x in out:
    print(x[0], x[1])