In [None]:
import torch
from tqdm import tqdm
from argparse import Namespace


config = Namespace(
    data_folder='./wm_bench_data', 
    max_seq_len=20, 
    rs_img_size=32, 
    batch_size=1, 
    num_workers=4, 
    use_cnn=1, 
    model_path='./model.pt', 
)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [None]:
from src.utils.data_utils import get_test_multitask_dataloader

test_loader = get_test_multitask_dataloader(config)

In [None]:
from src.model import WM_Model

model_data = torch.load(config.model_path)
model = WM_Model(Namespace(**model_data['config']), device).to(device)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

In [None]:
model_dict = {
    'model_1': model, 
}

In [None]:
import numpy as np

rnn_out_dict = {'SC_Task': [], 
                'SFR_Task': [], 
                'SI_Task': [], 
                'SMU_Task': [], 
                'STSC_Task': [], 
                'VIRec_2C_Task': [], 
                'VSR_Task': [], 
                'VSRec_Task': [], 
                'CD_Color_Task': [], 
                'CD_Orientation_Task': [], 
                'CD_Size_Task': [], 
                'CD_Gap_Task': [], 
                'CD_Conj_Task': [], 
                'Complex_WM_Task': []}

dataloader = zip(*test_loader.values())

model = model_dict['model_1']

with torch.no_grad():
    for batch_index, multi_task_batch in tqdm(enumerate(dataloader)):

        # SC Task
        stim_batch, resp_batch, seq_len, _ = multi_task_batch[0]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'SC_Task', seq_len)

        rnn_out_dict['SC_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # SFR Task
        stim_batch, resp_batch, seq_len, _ = multi_task_batch[1]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'SFR_Task', seq_len)

        rnn_out_dict['SFR_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # SI Task
        stim_batch, resp_batch, seq_len, _ = multi_task_batch[2]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'SI_Task', seq_len)

        rnn_out_dict['SI_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # SMU Task
        stim_batch, resp_batch, seq_len, _ = multi_task_batch[3]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'SMU_Task', seq_len)

        rnn_out_dict['SMU_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # STSC Task
        stim_batch, resp_batch, seq_len = multi_task_batch[4]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'STSC_Task', seq_len)

        rnn_out_dict['STSC_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # VIRec 2C Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[5]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'VIRec_2C_Task', seq_len)

        rnn_out_dict['VIRec_2C_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # VSR Task
        stim_batch, resp_batch, seq_len, list_length = multi_task_batch[6]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'VSR_Task', seq_len)

        rnn_out_dict['VSR_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # VSRec Task
        stim_batch, resp_batch, seq_len, list_length, _ = multi_task_batch[7]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'VSRec_Task', seq_len)

        rnn_out_dict['VSRec_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # CD Color Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[8]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Color_Task', seq_len)

        rnn_out_dict['CD_Color_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # CD Orientation Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[9]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Orientation_Task', seq_len)

        rnn_out_dict['CD_Orientation_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # CD Size Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[10]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Size_Task', seq_len)

        rnn_out_dict['CD_Size_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # CD Gap Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[11]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Gap_Task', seq_len)

        rnn_out_dict['CD_Gap_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # CD Conj Task
        stim_batch, resp_batch, seq_len, _, _, _ = multi_task_batch[12]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Conj_Task', seq_len)

        rnn_out_dict['CD_Conj_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # Complex_WM Task
        stim_batch, resp_batch, seq_len, _, variation = multi_task_batch[13]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'Complex_WM_Task', seq_len)

        rnn_out_dict['Complex_WM_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

In [None]:
import numpy as np

rnn_out_dict = {'SC_Task': [], 
                'SFR_Task': [], 
                'SI_Task': [], 
                'SMU_Task': [], 
                'STSC_Task': [], 
                'VIRec_2C_Task': [], 
                'VSR_Task': [], 
                'VSRec_Task': [], 
                'CD_Color_Task': [], 
                'CD_Orientation_Task': [], 
                'CD_Size_Task': [], 
                'CD_Gap_Task': [], 
                'CD_Conj_Task': [], 
                'Complex_WM_Task': []}

dataloader = zip(*test_loader.values())

model = model_dict['lstm_256']

with torch.no_grad():
    for batch_index, multi_task_batch in tqdm(enumerate(dataloader)):

        # SC Task
        stim_batch, resp_batch, seq_len, _ = multi_task_batch[0]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'SC_Task', seq_len)

        rnn_out_dict['SC_Task'].append(rnn_out[:, :seq_len[0]-1, :][0].cpu().numpy())

        # SFR Task
        stim_batch, resp_batch, seq_len, _ = multi_task_batch[1]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'SFR_Task', seq_len)

        rnn_out_dict['SFR_Task'].append(rnn_out[:, :seq_len[0]-1, :][0].cpu().numpy())

        # SI Task
        stim_batch, resp_batch, seq_len, _ = multi_task_batch[2]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'SI_Task', seq_len)

        rnn_out_dict['SI_Task'].append(rnn_out[:, :seq_len[0]-1, :][0].cpu().numpy())

        # SMU Task
        stim_batch, resp_batch, seq_len, set_size = multi_task_batch[3]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'SMU_Task', seq_len)

        rnn_out_dict['SMU_Task'].append(rnn_out[:, :seq_len[0]-set_size[0], :][0].cpu().numpy())

        # STSC Task
        stim_batch, resp_batch, seq_len = multi_task_batch[4]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'STSC_Task', seq_len)

        rnn_out_dict['STSC_Task'].append(rnn_out[:, :seq_len[0], :][0].cpu().numpy())

        # VIRec 2C Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[5]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'VIRec_2C_Task', seq_len)

        rnn_out_dict['VIRec_2C_Task'].append(rnn_out[:, :seq_len[0]-1, :][0].cpu().numpy())

        # VSR Task
        stim_batch, resp_batch, seq_len, list_length = multi_task_batch[6]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'VSR_Task', seq_len)

        rnn_out_dict['VSR_Task'].append(rnn_out[:, :seq_len[0]-list_length[0], :][0].cpu().numpy())

        # VSRec Task
        stim_batch, resp_batch, seq_len, list_length, _ = multi_task_batch[7]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'VSRec_Task', seq_len)

        rnn_out_dict['VSRec_Task'].append(rnn_out[:, :seq_len[0]-list_length[0], :][0].cpu().numpy())

        # CD Color Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[8]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Color_Task', seq_len)

        rnn_out_dict['CD_Color_Task'].append(rnn_out[:, :seq_len[0]-1, :][0].cpu().numpy())

        # CD Orientation Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[9]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Orientation_Task', seq_len)

        rnn_out_dict['CD_Orientation_Task'].append(rnn_out[:, :seq_len[0]-1, :][0].cpu().numpy())

        # CD Size Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[10]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Size_Task', seq_len)

        rnn_out_dict['CD_Size_Task'].append(rnn_out[:, :seq_len[0]-1, :][0].cpu().numpy())

        # CD Gap Task
        stim_batch, resp_batch, seq_len, _, _ = multi_task_batch[11]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Gap_Task', seq_len)

        rnn_out_dict['CD_Gap_Task'].append(rnn_out[:, :seq_len[0]-1, :][0].cpu().numpy())

        # CD Conj Task
        stim_batch, resp_batch, seq_len, _, _, _ = multi_task_batch[12]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'CD_Conj_Task', seq_len)

        rnn_out_dict['CD_Conj_Task'].append(rnn_out[:, :seq_len[0]-1, :][0].cpu().numpy())

        # Complex_WM Task
        stim_batch, resp_batch, seq_len, _, variation = multi_task_batch[13]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        _, rnn_out, _, _, _ = model(stim_batch, 'Complex_WM_Task', seq_len)

        rnn_out_dict['Complex_WM_Task'].append(rnn_out[:, :seq_len[0]-2, :][0].cpu().numpy())

In [None]:
rnn_out_ts_dict = {'SC_Task': {},
                   'SFR_Task': {}, 
                   'SI_Task': {}, 
                   'SMU_Task': {}, 
                   'STSC_Task': {}, 
                   'VIRec_2C_Task': {}, 
                   'VSR_Task': {}, 
                   'VSRec_Task': {}, 
                   'CD_Task': {}, 
                   'Complex_WM_Task': {}
                   }

for trial in range(len(rnn_out_dict['SC_Task'])):
    for task in list(rnn_out_dict.keys()):
        if task in ['CD_Color_Task', 'CD_Orientation_Task', 'CD_Size_Task', 'CD_Gap_Task', 'CD_Conj_Task']:
            tasksss = 'CD_Task'
        else:
            tasksss = task
        for ts in range(len(rnn_out_dict[task][trial])):
            if ts not in rnn_out_ts_dict[tasksss].keys():
                rnn_out_ts_dict[tasksss][ts] = [rnn_out_dict[task][trial][ts]]
            else:
                rnn_out_ts_dict[tasksss][ts].append(rnn_out_dict[task][trial][ts])

In [None]:
import numpy as np

rnn_out_ts_dict_var = {'SC_Task': [], 
                       'SFR_Task': [], 
                       'SI_Task': [], 
                       'SMU_Task': [], 
                       'STSC_Task': [], 
                       'VIRec_2C_Task': [], 
                       'VSR_Task': [], 
                       'VSRec_Task': [], 
                       'CD_Task': [], 
                       'Complex_WM_Task': []
                       }

for task in rnn_out_ts_dict.keys():
    for ts in range(len(rnn_out_ts_dict[task])):
        rnn_out_ts_dict_var[task].append(np.var(rnn_out_ts_dict[task][ts], axis=0))

In [None]:
sc_task_var = np.array(rnn_out_ts_dict_var['SC_Task'])
sfr_task_var = np.array(rnn_out_ts_dict_var['SFR_Task'])
si_task_var = np.array(rnn_out_ts_dict_var['SI_Task'])
smu_task_var = np.array(rnn_out_ts_dict_var['SMU_Task'])
stsc_task_var = np.array(rnn_out_ts_dict_var['STSC_Task'])
virec_2c_task_var = np.array(rnn_out_ts_dict_var['VIRec_2C_Task'])
vsr_task_var = np.array(rnn_out_ts_dict_var['VSR_Task'])
vsrec_task_var = np.array(rnn_out_ts_dict_var['VSRec_Task'])
cd_task_var = np.array(rnn_out_ts_dict_var['CD_Task'])
complex_wm_task_var = np.array(rnn_out_ts_dict_var['Complex_WM_Task'])

In [None]:
sc_task_var = np.mean(sc_task_var, axis=0)
sfr_task_var = np.mean(sfr_task_var, axis=0)
si_task_var = np.mean(si_task_var, axis=0)
smu_task_var = np.mean(smu_task_var, axis=0)
stsc_task_var = np.mean(stsc_task_var, axis=0)
virec_2c_task_var = np.mean(virec_2c_task_var, axis=0)
vsr_task_var = np.mean(vsr_task_var, axis=0)
vsrec_task_var = np.mean(vsrec_task_var, axis=0)
cd_task_var = np.mean(cd_task_var, axis=0)
complex_wm_task_var = np.mean(complex_wm_task_var, axis=0)

In [None]:
task_var_array = np.array([sc_task_var, sfr_task_var, si_task_var, smu_task_var, 
                           stsc_task_var, virec_2c_task_var, vsr_task_var, vsrec_task_var, 
                           cd_task_var, complex_wm_task_var])

In [None]:
task_var_array_normalized = task_var_array / np.max(task_var_array, axis=0)

In [None]:
task_var_array_normalized = task_var_array_normalized.T

In [None]:
from sklearn.cluster import KMeans

num_clust = []
silhouette_scores = []

for num_clusters in range(2, 100):
    kmeans = KMeans(n_clusters=num_clusters, random_state=45).fit(task_var_array_normalized)

    # Calculate the silhouette score
    from sklearn.metrics import silhouette_score

    num_clust.append(num_clusters)
    silhouette_scores.append(silhouette_score(task_var_array_normalized, kmeans.labels_))

In [None]:
import matplotlib.pyplot as plt

plt.plot(num_clust, silhouette_scores)
plt.xlabel('Number of Clusters')
plt.ylabel('Silhouette Score')
plt.show()

In [None]:
max_silhoette_score = max(silhouette_scores)
apt_num_clusters = num_clust[silhouette_scores.index(max_silhoette_score)]

In [None]:
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=apt_num_clusters, random_state=45).fit(task_var_array_normalized)

from sklearn.metrics import silhouette_score

silhouette_score(task_var_array_normalized, kmeans.labels_)

In [None]:
clustered_task_var_array_normalized = [[] for _ in range(apt_num_clusters)]

for i in range(len(kmeans.labels_)):
    clustered_task_var_array_normalized[kmeans.labels_[i]].append(task_var_array_normalized[i])

In [None]:
concatenated_clusters = np.concatenate(clustered_task_var_array_normalized, axis=0)
concatenated_clusters = concatenated_clusters.T
concatenated_clusters.shape

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Serif"
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1})

plt.figure(figsize=(15, 5))
plt.imshow(concatenated_clusters, cmap='hot', aspect='auto')

plt.ylabel('Task', fontsize=28)

plt.xticks([])
plt.yticks(np.arange(0, 10, 1), ['SC', 'SFR', 'SI', 'SMU', 'STSC', 'VIR', 'VSR', 'VSRec', 
                                'CD', 'CS'], fontsize=20)

cb = plt.colorbar()
cb.ax.set_yticks([0, 1])
cb.ax.set_ylabel('Normalized Task Variance', fontsize=18)
cb.ax.tick_params(labelsize=20)

plt.show()