In [None]:
import torch
from tqdm import tqdm
from argparse import Namespace


config = Namespace(
    data_folder='./wm_bench_data', 
    max_seq_len=20, 
    rs_img_size=32, 
    batch_size=10, 
    num_workers=4, 
    use_cnn=1, 
    model_path='./model.pt'
)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [None]:
from src.utils.data_utils import get_test_multitask_dataloader

test_loader = get_test_multitask_dataloader(config)

In [None]:
from src.model import WM_Model

model_data = torch.load(config.model_path)
model = WM_Model(Namespace(**model_data['config']), device).to(device)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

In [None]:
model_dict = {
    'model_1': model,
}

In [None]:
import numpy as np

rnn_out_all = {}

resp_batch_all = []

dataloader = zip(*test_loader.values())

with torch.no_grad():
    for batch_index, multi_task_batch in tqdm(enumerate(dataloader)):
        stim_batch, resp_batch, seq_len = multi_task_batch[4]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        resp_batch_all.append(resp_batch.cpu().numpy())

        for model_name, model in model_dict.items():
            out, rnn_out, hn, proj_out, _ = model(stim_batch, 'STSC_Task', seq_len)
            rnn_out = rnn_out.cpu().numpy()

            if model_name not in rnn_out_all:
                rnn_out_all[model_name] = [rnn_out]
            else:
                rnn_out_all[model_name].append(rnn_out)

In [None]:
resp_batch_all = np.concatenate(resp_batch_all, axis=0)
resp_batch_all = resp_batch_all.reshape(-1)

for model_name, rnn_out in rnn_out_all.items():
    rnn_out_all[model_name] = np.concatenate(rnn_out, axis=0)

In [None]:
import json

test_data = json.load(open('./wm_bench_data/Spatial_Task_Switching_Cued/data_rand_trials.json', 
                           'r'))

tasks_gt_map = {'Up_Down': 0, 'Left_Right': 1, 'Cue_Up_Down': 2, 'Cue_Left_Right': 3}

tasks_gt = []
for trial in test_data["test"]:
    for task in trial['task_gt']:
        tasks_gt.append(tasks_gt_map[task])

tasks_gt = np.array(tasks_gt)

In [None]:
hidden_states = {}

hidden_states['lstm_1024'] = rnn_out_all['lstm_1024'].reshape(-1, 1024)

In [None]:
tasksss_gt = []
respsss_gt = []
hidden_statesss = []

for index, val in enumerate(resp_batch_all):
    if val != 2:
        tasksss_gt.append(tasks_gt[index])
        respsss_gt.append(val)
        hidden_statesss.append(hidden_states['lstm_1024'][index])

tasksss_gt = np.array(tasksss_gt)
respsss_gt = np.array(respsss_gt)
hidden_statesss = np.array(hidden_statesss)

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2, random_state=69)
pca.fit(hidden_statesss)

hidden_statesss_pca = pca.transform(hidden_statesss)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap


sns.set_style("ticks")

# Change font style
plt.rcParams["font.family"] = "Serif"
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1})

fig, ax = plt.subplots(figsize=(5, 5))

scatterplot = sns.scatterplot(x=hidden_statesss_pca[:, 0],
                y=hidden_statesss_pca[:, 1], 
                hue=respsss_gt, hue_order=[0, 1], edgecolor='none', size=0.2, palette='tab10')

sns.despine(left=False, bottom=False, right=True, top=True)

handles, labels = scatterplot.get_legend_handles_labels()

# Define custom labels for the legend
legend_labels = ['Left / Top GT', 'Right / Bottom GT']

# Create a new legend with custom labels
scatterplot.legend(handles=handles, frameon=False, bbox_to_anchor=(0.09, 0.98), 
                   labels=legend_labels, fontsize=20, markerscale=2)


plt.xlabel('PC1', fontsize=25)
plt.ylabel('PC2', fontsize=25)

plt.xticks([])
plt.yticks([])

plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Plot the hidden states in scatter plot using seaborn

sns.set_style("ticks")

# Change font style
plt.rcParams["font.family"] = "Serif"
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1})

fig, ax = plt.subplots(figsize=(5, 5))

scatterplot = sns.scatterplot(x=hidden_statesss_pca[:, 0],
                y=hidden_statesss_pca[:, 1], 
                hue=tasksss_gt, hue_order=[0, 1], edgecolor='none', size=0.2, 
                palette=['tab:green','tab:red'])


sns.despine(left=False, bottom=False, right=True, top=True)

handles, labels = scatterplot.get_legend_handles_labels()

# Define custom labels for the legend
legend_labels = ['Top vs Bottom Task', 'Left vs Right Task']

# Create a new legend with custom labels
scatterplot.legend(handles=handles, frameon=False, bbox_to_anchor=(0.09, 0.98), 
                   labels=legend_labels, fontsize=20, markerscale=2)


plt.xlabel('PC1', fontsize=25)
plt.ylabel('PC2', fontsize=25)

plt.xticks([])
plt.yticks([])

plt.show()
