In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat, savemat
import os
import glob
from scipy import interpolate
import pandas as pd
from tqdm.notebook import tqdm

In [2]:
# load subject gaze pos
subj_gaze_pos = np.load("../preprocessed_data/goodsubj_gaze_pos.npz", allow_pickle=True)["gaze_data_goodsubj"]
num_subj = len(subj_gaze_pos)

In [3]:
"""
arch_names = ["deitsmall", "vitbase"]
patch_sizes = [8, 16]
dist_mat_dict = {}

# load vit gaze pos
vit_gaze_pos_data = {}
for arch_name in arch_names:
    vit_gaze_pos_data |= np.load(f"../preprocessed_data/vit_official_{arch_name}_gaze_pos.npz", allow_pickle=True)
"""

'\narch_names = ["deitsmall", "vitbase"]\npatch_sizes = [8, 16]\ndist_mat_dict = {}\n\n# load vit gaze pos\nvit_gaze_pos_data = {}\nfor arch_name in arch_names:\n    vit_gaze_pos_data |= np.load(f"../preprocessed_data/vit_official_{arch_name}_gaze_pos.npz", allow_pickle=True)\n'

In [4]:
vit_gaze_pos_data = np.load(f"../preprocessed_data/vit_official_gaze_pos.npz", allow_pickle=True)

In [5]:
vit_gaze_pos_data.files

['info', 'dino_deit_small16', 'supervised_deit_small16']

In [6]:
blanks = np.hstack(
    [np.arange(140, 155), np.arange(311, 326), np.arange(538, 553), 
     np.arange(740, 772), np.arange(911, 926), np.arange(1094, 1123),
     np.arange(1319, 1334), np.arange(1651, 1666), np.arange(1835, 1850), 
     np.arange(1988, 2003), np.arange(2167, 2182), 2326])

In [7]:
len(blanks)

197

In [8]:
num_frames = 2327
num_sampling = 3883

In [9]:
time = np.linspace(0, 1.0, num_frames)
time_upsample = np.linspace(0, 1.0, num_sampling)

In [10]:
model_keys = ['dino_deit_small16', 'supervised_deit_small16']

In [11]:
vit_gaze_pos_data_upsample = {}
for key in model_keys:
    vit_gaze_pos_data_tmp = vit_gaze_pos_data[key]
    depth, num_heads, _, _ = vit_gaze_pos_data_tmp.shape
    vit_gaze_pos_data_tmp[:, :, blanks, :] = np.nan
    gaze_pos_upsample = np.zeros((depth, num_heads, num_sampling, 2))
    for d in range(depth):
        for h in range(num_heads):
            for i in range(2): # x, y
                gp = vit_gaze_pos_data_tmp[d, h, :, i]
                f = interpolate.interp1d(time, gp, kind="nearest")
                gp_upsample = f(time_upsample) 
                gaze_pos_upsample[d, h, :, i] = gp_upsample
    vit_gaze_pos_data_upsample[key] = gaze_pos_upsample

In [12]:
save_dir = "../preprocessed_data/"
np.savez_compressed(f"{save_dir}/vit_official_gaze_pos_upsample.npz", **vit_gaze_pos_data_upsample)

In [13]:
"""
vit_gaze_pos_data_upsample = {}
for arch_name in arch_names:
    for patch_size in patch_sizes:
        vit_gaze_pos_data_tmp = vit_gaze_pos_data[f"{arch_name}{patch_size}"]
        depth, num_heads, _, _ = vit_gaze_pos_data_tmp.shape
        vit_gaze_pos_data_tmp[:, :, blanks, :] = np.nan
        gaze_pos_upsample = np.zeros((depth, num_heads, num_sampling, 2))
        for d in range(depth):
            for h in range(num_heads):
                for i in range(2): # x, y
                    gp = vit_gaze_pos_data_tmp[d, h, :, i]
                    f = interpolate.interp1d(time, gp, kind="nearest")
                    gp_upsample = f(time_upsample) 
                    gaze_pos_upsample[d, h, :, i] = gp_upsample
        vit_gaze_pos_data_upsample[f"{arch_name}{patch_size}"] = gaze_pos_upsample
"""

'\nvit_gaze_pos_data_upsample = {}\nfor arch_name in arch_names:\n    for patch_size in patch_sizes:\n        vit_gaze_pos_data_tmp = vit_gaze_pos_data[f"{arch_name}{patch_size}"]\n        depth, num_heads, _, _ = vit_gaze_pos_data_tmp.shape\n        vit_gaze_pos_data_tmp[:, :, blanks, :] = np.nan\n        gaze_pos_upsample = np.zeros((depth, num_heads, num_sampling, 2))\n        for d in range(depth):\n            for h in range(num_heads):\n                for i in range(2): # x, y\n                    gp = vit_gaze_pos_data_tmp[d, h, :, i]\n                    f = interpolate.interp1d(time, gp, kind="nearest")\n                    gp_upsample = f(time_upsample) \n                    gaze_pos_upsample[d, h, :, i] = gp_upsample\n        vit_gaze_pos_data_upsample[f"{arch_name}{patch_size}"] = gaze_pos_upsample\n'

In [14]:
vit_gaze_pos_data_upsample.keys()

dict_keys(['dino_deit_small16', 'supervised_deit_small16'])

In [15]:
dist_mat_dict = {}
for key in model_keys:
    vit_gaze_pos_data_tmp = vit_gaze_pos_data_upsample[key]
    depth, num_heads, _, _ = vit_gaze_pos_data_tmp.shape
    dist_mat = np.zeros((num_subj, depth, num_heads))
    for i in tqdm(range(num_subj)):
        gaze_diff = vit_gaze_pos_data_tmp - subj_gaze_pos[i]
        diff_norm = np.linalg.norm(gaze_diff, axis=-1)
        dist_mat[i] = np.nanmedian(diff_norm, axis=-1)
    dist_mat_dict[key] = dist_mat

  0%|          | 0/104 [00:00<?, ?it/s]

  0%|          | 0/104 [00:00<?, ?it/s]

In [16]:
"""
for arch_name in arch_names:
    for patch_size in patch_sizes:
        vit_gaze_pos_data_tmp = vit_gaze_pos_data_upsample[f"{arch_name}{patch_size}"]
        depth, num_heads, _, _ = vit_gaze_pos_data_tmp.shape
        dist_mat = np.zeros((num_subj, depth, num_heads))
        for i in tqdm(range(num_subj)):
            gaze_diff = vit_gaze_pos_data_tmp - subj_gaze_pos[i]
            diff_norm = np.linalg.norm(gaze_diff, axis=-1)
            dist_mat[i] = np.nanmedian(diff_norm, axis=-1)
        dist_mat_dict[f"{arch_name}{patch_size}"] = dist_mat
"""

'\nfor arch_name in arch_names:\n    for patch_size in patch_sizes:\n        vit_gaze_pos_data_tmp = vit_gaze_pos_data_upsample[f"{arch_name}{patch_size}"]\n        depth, num_heads, _, _ = vit_gaze_pos_data_tmp.shape\n        dist_mat = np.zeros((num_subj, depth, num_heads))\n        for i in tqdm(range(num_subj)):\n            gaze_diff = vit_gaze_pos_data_tmp - subj_gaze_pos[i]\n            diff_norm = np.linalg.norm(gaze_diff, axis=-1)\n            dist_mat[i] = np.nanmedian(diff_norm, axis=-1)\n        dist_mat_dict[f"{arch_name}{patch_size}"] = dist_mat\n'

In [17]:
np.savez_compressed("../preprocessed_data/subj2vit_official_dist.npz", **dist_mat_dict)