In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from tqdm.notebook import tqdm
import glob
import numpy as np
from utils_analysis import *
from PIL import Image
#import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader

In [2]:
dataset_dir = "../dataset/Nakano_etal_2010/video_stimuli/frames"
#dataset_dir = "../dataset/Nakano_etal_2010/video_stimuli/frames_cropped"
training_methods = ["dino", "supervised"] 
depth_list = [4, 8, 12]
num_models = 6
num_heads = 6
num_frames = 2327
remove_border = False # for publication, we used images w/ borders.

patch_size = 16
blur_size = patch_size * 2
batch_size = 32

transform = pth_transforms.Compose([
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [3]:
save_dir = "../dataset/Nakano_etal_2010/preprocessed_data/"

In [4]:
res_dict = {}
res_dict["info"] = ["num_models", "depth", "num_head+mean", "num_frames", "xy"]
if remove_border:
    border_info = np.load("../dataset/Nakano_etal_2010/video_stimuli/border_info.npz", allow_pickle=True)
    crop_range = border_info['crop_range'] 
    datah = border_info['datah']
    dataw = border_info['dataw']
    num_crops = len(crop_range)
    for tm in training_methods:
        res_dict[tm] = {}
        for depth in depth_list:
            print(tm, depth)
            gaze_pos_model_all = np.nan * np.ones((num_models, depth, num_heads+1, num_frames, 2))
            for trial_idx in tqdm(range(num_models)):
                model, device = model_load(tm, trial_idx+1, depth, patch_size)
                for crop_idx in tqdm(range(num_crops)):
                    # load images
                    s, e = crop_range[crop_idx] # start & end
                    imgs = [Image.open(dataset_dir + "/frame{0:04d}.png".format(frame_idx)) for frame_idx in range(s, e+1)]
                    images = torch.stack([transform(img) for img in imgs])
                    images_split = torch.split(images, batch_size)
    
                    gaze_pos_model = []
                    for images_split_part in images_split:
                        gaze_pos_model_part = get_gaze_pos_model(model, device, images_split_part, patch_size, blur_size)
                        gaze_pos_model.append(gaze_pos_model_part)
                    gaze_pos_model = np.concatenate(gaze_pos_model, axis=2)
    
                    # modification of offset
                    gaze_pos_model[:, :, :, 0] += dataw[crop_idx][0]
                    gaze_pos_model[:, :, :, 1] += datah[crop_idx][0]
                    gaze_pos_model_all[trial_idx, :, :, (s-1):e] = gaze_pos_model
                del model
                torch.cuda.empty_cache()
            res_dict[tm][str(depth)] = gaze_pos_model_all
    np.savez_compressed(f"{save_dir}/vit_gaze_pos_removed_border.npz", **res_dict)
else:
    image_path_list = sorted(glob.glob(f"{dataset_dir}/*.png"))
    dataset = ImageDataset(image_path_list, transform)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
    for tm in training_methods:
        res_dict[tm] = {}
        for depth in depth_list:
            print(tm, depth)
            gaze_pos_model_all = np.zeros((num_models, depth, num_heads+1, num_frames, 2))
            for trial_idx in tqdm(range(num_models)):
                model, device = model_load(tm, trial_idx+1, depth, patch_size)
                gaze_pos_model_all[trial_idx] = get_gaze_pos_model_dataset(model, device, dataloader, patch_size, blur_size)
                del model
                torch.cuda.empty_cache()
            res_dict[tm][str(depth)] = gaze_pos_model_all
    np.savez_compressed(f"{save_dir}/vit_gaze_pos.npz", **res_dict)

dino 4


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

dino 8


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

dino 12


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

supervised 4


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

supervised 8


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

supervised 12


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]