In [None]:
import sys
sys.path.insert(0, "../../projects/code-2023-deephyptrails/")  # you might need to insert the absolute path here

In [None]:
from torch.utils.data import DataLoader
from Code.Models.LightningGPT import GPT, GPTConfig
from Code.Dataset.ReviewDataset import ReviewsDataset, AMZWalkDataset
import torch
import umap
import umap.plot
import numpy as np
import torch
import hdbscan
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# Some settings
dataset_path = "code-2023-deephyptrails/data/all-data/dataset-subtrails.jsonl"
model_path = "code-2023-deephyptrails/data/all-data/model_data-subtrails-annotated-walks_final_200a7298.ckpt"

In [None]:
def change_categories(walk: list, new_category_dict: dict):
    new_walk = []
    for node in walk:
        new_walk.append([node[0], node[1], node[2], new_category_dict])
    return new_walk

In [None]:
dataset = ReviewsDataset(dataset_path)
all_walkies, _ = dataset.get_walks()
all_walkies = all_walkies['annotated_walks']

In [None]:
all_category_combinations = [x[0][3] for x in all_walkies]
category_set = set([tuple(x.values()) for x in all_category_combinations])
all_category_combinations = [{k: v for k, v in zip(all_category_combinations[0].keys(), curr_set)} for curr_set in category_set]
all_category_combinations = sorted(all_category_combinations, key=lambda x: (x['cat1'], x['cat2'], x['cat3'], x['cat4']), reverse=True)
feature_index = {k: v for k, v in enumerate(all_category_combinations)}

In [None]:
walkies = []
walk_idx = 0
walk_index = {}
for idx in [1, 2, 3, 4]:
    current_walkies = [x for x in all_walkies if x[0][3][f'cat{idx}'] == 1][:10]
    for current_walky in current_walkies:
        walk_index[walk_idx] = [0] + [x[2] + 2 for x in current_walky] + [1] # +2 for BOS and EOS
        walk_idx += 1
        for current_dict in all_category_combinations:
            walkies.append(change_categories(current_walky, current_dict))

In [None]:
config = GPTConfig(
    block_size=20 + 1,  # to be able to predict EOS in the end
    vocab_size=100 + 2,  # 0 for BOS and 1 for EOS, all other tokens are thus shifted by 2
    n_layer=4,
    n_head=4,
    n_embd=16,
    feature_embd_dim=12,
    bias=False,
)
model = GPT(config)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))["state_dict"])

In [None]:
model.eval()
eval_dict = {}
dataset = AMZWalkDataset(
    walkies,
    walk_type=f"subtrails-test",
    args=dataset.args
)
dataloader = DataLoader(
    dataset,
    num_workers=0,
    batch_size=1,
    shuffle=True,
)
losses = []
for i, batch in enumerate(dataloader):
    targets = batch[0][:, 1:].contiguous()
    input = batch[0][:, :-1].contiguous()
    features = batch[1]
    last_hidden_state = model(idx=input, targets=targets, features=features, return_last_hidden_state=True)
    curr = {"idx:": i, "loss": last_hidden_state['loss'].item(), "features": features, "walk": batch[0]}
    losses.append(curr) 

In [None]:
def get_idx_by_walk_or_feature(idx_dict: dict, walk_or_feature):
    for k, v in idx_dict.items():
        if walk_or_feature == v:
            return k

In [None]:
amazing_matrix = np.zeros((len(feature_index), len(walk_index)))
for i_want_to_go_home in losses:
    idx = get_idx_by_walk_or_feature(feature_index, i_want_to_go_home['features'])
    idy = get_idx_by_walk_or_feature(walk_index, i_want_to_go_home['walk'].tolist()[0])
    amazing_matrix[idx, idy] = i_want_to_go_home['loss']

In [None]:
import seaborn as sns
# plot heatmap of probabilities per feature and walk
plt.figure(figsize=(8, 8))
sns.heatmap(
    amazing_matrix,
    # xticklabels=range(len(dataset.annotated_walks)),
    # yticklabels=list_of_available_features,
    cmap="Blues",
    # vmin=0,
    # vmax=1,
)
plt.xlabel("Walk")
plt.ylabel("Feature combination")
# plt.title("Probability of sequence given feature")
plt.savefig("code-2023-deephyptrails/data/potential-paper-figures/subtrails_heatmap.pdf", dpi=300)

In [None]:
cluster_labes = np.array(["Cat1/Even"] * 4 + ["Cat2/Odd"] * 4 + ["Cat3/First Even"] * 4 + ["Cat4/First Odd"] * 4)
mapper = umap.UMAP(n_neighbors=4)
mapper.fit_transform(amazing_matrix)
umap.plot.points(mapper, labels=cluster_labes)
plt.title("feature combinations")
plt.savefig("code-2023-deephyptrails/data/potential-paper-figures/subtrails_synthetic_feature.pdf", dpi=300)

In [None]:
cluster_labes = np.array(["Even"] * 10 + ["Odd"] * 10 + ["First Even"] * 10 + ["First Odd"] * 10)
mapper = umap.UMAP()
mapper.fit_transform(amazing_matrix.T)
umap.plot.points(mapper, labels=cluster_labes)
plt.title("Walks")
plt.savefig("code-2023-deephyptrails/data/potential-paper-figures/subtrails_synthetic_walks.pdf", dpi=300)