In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import seaborn as sns

In [2]:
vit_gaze_pos = np.load("../dataset/animal_parts_dataset/preprocessed_data/vit_gaze_pos_val.npz", allow_pickle=True)

In [3]:
training_methods = ["dino"]#, "supervised"] 
depth_list = [8, 12]
num_models = 6
num_heads = 6

In [4]:
eye_pos_data = np.load("../dataset/animal_parts_dataset/preprocessed_data/shifted_eye_pos_val.npz", allow_pickle=True)
mouth_pos_data = np.load("../dataset/animal_parts_dataset/preprocessed_data/shifted_mouth_pos_val.npz", allow_pickle=True)

In [5]:
mouth_pos = mouth_pos_data["mouth_pos"]

In [6]:
img_indices = mouth_pos_data["img_indices"].astype(int)

In [7]:
eye_pos = eye_pos_data["eye_pos"]
img_id = eye_pos_data["img_id"]

In [8]:
sigma = 30
num_images = len(img_indices)

In [9]:
num_images

875

In [10]:
"""
def get_gaze_weight_animal(tm, depth):
    gaze_weight = np.zeros((num_models, depth, num_heads, num_images))
    for m_idx in tqdm(range(num_models)):
        for d_idx in tqdm(range(depth)):
            for h_idx in tqdm(range(num_heads)):
                vit_gaze_pos_head = vit_gaze_pos_depth[m_idx, d_idx, h_idx]
                for im_id in range(num_images):
                    ep = eye_pos[img_id == im_id]
                    num_kp = len(ep)
                    gp = vit_gaze_pos_head[im_id]
                    d = np.exp(- ((ep[:, 0] - gp[0])**2 + (ep[:, 1] - gp[1])**2) / (2*(sigma**2)))
                    gaze_weight[m_idx, d_idx, h_idx, im_id] = np.clip(np.sum(d), 0, 1)
    return gaze_weight
"""

def get_gaze_weight_animal(tm, depth):
    gaze_weight_eye = np.zeros((num_models, depth, num_heads, num_images))
    gaze_weight_mouth = np.zeros((num_models, depth, num_heads, num_images))
    for m_idx in tqdm(range(num_models)):
        for d_idx in tqdm(range(depth)):
            for h_idx in tqdm(range(num_heads)):
                vit_gaze_pos_head = vit_gaze_pos_depth[m_idx, d_idx, h_idx]
                for i, im_id in enumerate(img_indices):
                    gp = vit_gaze_pos_head[im_id]
                    # eye
                    ep = eye_pos[img_id == im_id]
                    ed = np.exp(- ((ep[:, 0] - gp[0])**2 + (ep[:, 1] - gp[1])**2) / (2*(sigma**2)))
                    sum_ed = np.sum(ed)

                    # mouth
                    mp = mouth_pos[i]
                    md = np.exp(- ((mp[0] - gp[0])**2 + (mp[1] - gp[1])**2) / (2*(sigma**2)))
    
                    sum_d = sum_ed + md
                    if sum_d > 1:
                        gaze_weight_eye[m_idx, d_idx, h_idx, i] = sum_ed / sum_d
                        gaze_weight_mouth[m_idx, d_idx, h_idx, i] = md / sum_d
                    elif sum_d <= 1 and sum_d > 0:
                        gaze_weight_eye[m_idx, d_idx, h_idx, i] = sum_ed
                        gaze_weight_mouth[m_idx, d_idx, h_idx, i] = md
    return gaze_weight_eye, gaze_weight_mouth

In [11]:
res_dict = {}
for tm in training_methods:
    res_dict[tm] = {}
    for depth in depth_list:
        res_dict[tm][str(depth)] = {}
        vit_gaze_pos_depth = vit_gaze_pos[tm].item()[str(depth)]
        gaze_weight_eye, gaze_weight_mouth = get_gaze_weight_animal(tm, depth)
        res_dict[tm][str(depth)]["eye"] = gaze_weight_eye
        res_dict[tm][str(depth)]["mouth"] = gaze_weight_mouth

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

In [12]:
np.savez_compressed(f"../results/gazew_animals_val.npz", **res_dict)