In [1]:
import numpy as np
from scipy.io import loadmat 
import os
import glob
from tqdm.notebook import tqdm

In [2]:
def compute_mds_dist(matfile_path, training_methods=["dino", "supervised"], 
                     depth_list=[4, 8, 12], num_models=6, with_mean=False):
    # load matlab file
    matfile = loadmat(matfile_path, simplify_cells=True)['mds_res_all']
    
    res_dict = {}
    dict_keys = ["mds_centered", "mds_dist", "mds_dist_min", "mds_dist_flatten"]
    for key in dict_keys:
        res_dict[key] = {}
        for tm in training_methods:
            res_dict[key][tm] = {}
    
    # compute median position
    mds_res = matfile["mds_res"]
    mds_res_center = np.median(mds_res, 0)
    
    # subjects
    subj_mds_centered = mds_res - mds_res_center
    subj_mds_dist = np.linalg.norm(subj_mds_centered, axis=-1)
    res_dict["mds_centered"]["subj"] = subj_mds_centered
    res_dict["mds_dist"]["subj"] = subj_mds_dist
    
    # gbvs
    gbvs_mds_centered = matfile["gbvs"] - mds_res_center
    gbvs_mds_dist = np.linalg.norm(gbvs_mds_centered, axis=-1)
    res_dict["mds_centered"]["gbvs"] = gbvs_mds_centered
    res_dict["mds_dist"]["gbvs"] = gbvs_mds_dist
    
    # vit
    for tm in training_methods:
        # vit model
        for depth in depth_list:
            model_mds = matfile[f"{tm}_{depth}"]
            model_mds_centered = model_mds - mds_res_center
            model_mds_dist = np.linalg.norm(model_mds_centered, axis=-1)
            if with_mean:       
                res_dict["mds_centered"][tm][str(depth)] = model_mds_centered
                res_dict["mds_dist"][tm][str(depth)] = model_mds_dist
            else:
                res_dict["mds_centered"][tm][str(depth)] = model_mds_centered[:,:,:6]
                res_dict["mds_dist"][tm][str(depth)] = model_mds_dist[:,:,:6]
        # vit offical model
        model_mds = matfile[f"{tm}_deit_small16"]
        model_mds_centered = model_mds - mds_res_center
        model_mds_dist = np.linalg.norm(model_mds_centered, axis=-1)
        if with_mean:    
            res_dict["mds_centered"][f"{tm}_deit_small16"] = model_mds_centered
            res_dict["mds_dist"][f"{tm}_deit_small16"] = model_mds_dist
        else:
            res_dict["mds_centered"][f"{tm}_deit_small16"] = model_mds_centered[:,:6]
            res_dict["mds_dist"][f"{tm}_deit_small16"] = model_mds_dist[:,:6]
    
    for tm in training_methods:
        for depth in depth_list:
            model_mds_dist = res_dict["mds_dist"][tm][str(depth)]
            num_models = len(model_mds_dist)
            res_dict["mds_dist_min"][tm][str(depth)] = np.min(model_mds_dist.reshape(num_models, -1), axis=1)
            res_dict["mds_dist_flatten"][tm][str(depth)] = model_mds_dist.transpose(0, 2, 1).reshape(-1, depth)
        res_dict["mds_dist_min"][f"{tm}_deit_small16"] = np.min(res_dict["mds_dist"][f"{tm}_deit_small16"])
    return res_dict

In [3]:
matfile_path_list = sorted(glob.glob(f"../results/mds_results_*.mat"))

In [4]:
for path in tqdm(matfile_path_list):
    res_dict = compute_mds_dist(path)
    save_name = "../results/mds_dist_" + path.split("/")[-1][12:].split(".")[0]
    np.savez_compressed(save_name, **res_dict)

  0%|          | 0/40 [00:00<?, ?it/s]