In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"

from tqdm.notebook import tqdm
import glob
import numpy as np
from utils_analysis import *
from PIL import Image
#import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
import utils
import vision_transformer as vits
from vision_transformer import DINOHead

In [2]:
def model_dino_official_load(arch_name, patch_size, 
                             models_dir="../trained_model_weights/dino_official/backbone/"):
    # getting data file path
    weight_path = os.path.join(models_dir, f"dino_{arch_name}{patch_size}_pretrain.pth")
    checkpoint = torch.load(weight_path, map_location="cpu")

    if arch_name == "deitsmall":
        arch = "vit_small"
    elif arch_name == "vitbase":
        arch = "vit_base"

    model = vits.__dict__[arch](patch_size=patch_size)
    model.load_state_dict(checkpoint, strict=True)
        
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    for p in model.parameters():
        p.requires_grad = False
    model.eval()
    model.to(device)
    return model, device

In [3]:
def model_deit_official_load(arch="vit_small", patch_size=16):
    weight_path = "../trained_model_weights/deit_official/deit_small_patch16_224.pth"
    checkpoint = torch.load(weight_path, map_location="cpu")
    model = vits.__dict__[arch](patch_size=patch_size)
    model.load_state_dict(checkpoint["model"], strict=False) # ignore head weights
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    for p in model.parameters():
        p.requires_grad = False
    model.eval()
    model.to(device)
    return model, device

In [4]:
dataset_dir = "../dataset/Nakano_etal_2010/video_stimuli/frames"
#dataset_dir = "../dataset/Nakano_etal_2010/video_stimuli/frames_cropped"
border_info = np.load("../dataset/Nakano_etal_2010/video_stimuli/border_info.npz", allow_pickle=True)

In [5]:
crop_range = border_info['crop_range'] 
datah = border_info['datah']
dataw = border_info['dataw']

num_crops = len(crop_range)
num_frames = 2327

In [6]:
def get_gazepos_official_model_remove_border(model, device, transform, patch_size=16, batch_size=16):
    blur_size = patch_size * 2
    num_heads = model.num_heads
    depth = model.depth
    gaze_pos_model_all = np.nan * np.ones((depth, num_heads+1, num_frames, 2))
    for crop_idx in tqdm(range(num_crops)):
        # load images
        s, e = crop_range[crop_idx] # start & end
        imgs = [Image.open(dataset_dir + "/frame{0:04d}.png".format(frame_idx)) for frame_idx in range(s, e+1)]
        images = torch.stack([transform(img) for img in imgs])
        images_split = torch.split(images, batch_size)

        gaze_pos_model = []
        for images_split_part in images_split:
            gaze_pos_model_part = get_gaze_pos_model(model, device, images_split_part, patch_size, blur_size)
            gaze_pos_model.append(gaze_pos_model_part)
        gaze_pos_model = np.concatenate(gaze_pos_model, axis=2)

        # modification of offset
        gaze_pos_model[:, :, :, 0] += dataw[crop_idx][0]
        gaze_pos_model[:, :, :, 1] += datah[crop_idx][0]
        gaze_pos_model_all[:, :, (s-1):e] = gaze_pos_model
    del model
    torch.cuda.empty_cache()
    return gaze_pos_model_all

In [7]:
def get_gazepos_official_model(model, device, transform, patch_size=16, batch_size=32):
    blur_size = patch_size * 2
    num_heads = model.num_heads
    depth = model.depth

    image_path_list = sorted(glob.glob(f"{dataset_dir}/*.png"))
    #assert len(image_path_list) == num_frame, "num frame is not matched."
    dataset = ImageDataset(image_path_list, transform)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
    gaze_pos_model = get_gaze_pos_model_dataset(model, device, dataloader, patch_size, blur_size)
    return gaze_pos_model

In [8]:
#arch_names = ["vitbase"]#["deitsmall", "vitbase"]
#arch_name = "vitbase" #"deitsmall", "vitbase"
#patch_sizes = [8, 16]

transform = pth_transforms.Compose([
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [9]:
save_dir = "../dataset/Nakano_etal_2010/preprocessed_data/"

In [10]:
remove_border = False

In [11]:
res_dict = {}
res_dict["info"] = ["depth", "num_head+mean", "num_frames", "xy"]
model, device = model_dino_official_load(arch_name="deitsmall", patch_size=16)
if remove_border:
    res_dict[f"dino_deit_small16"] = get_gazepos_official_model_remove_border(model, device, transform)
else:
    res_dict[f"dino_deit_small16"] = get_gazepos_official_model(model, device, transform)
del model
torch.cuda.empty_cache()

  0%|          | 0/73 [00:00<?, ?it/s]

In [None]:
model, device = model_deit_official_load()
if remove_border:
    res_dict[f"supervised_deit_small16"] = get_gazepos_official_model_remove_border(model, device, transform)
else:
    res_dict[f"supervised_deit_small16"] = get_gazepos_official_model(model, device, transform)
del model
torch.cuda.empty_cache()

  0%|          | 0/73 [00:00<?, ?it/s]

In [None]:
np.savez_compressed(f"{save_dir}/vit_official_gaze_pos.npz", **res_dict)

In [None]:
"""
batch_size = 2
res_dict = {}
res_dict["info"] = ["depth", "num_head+mean", "num_frames", "xy"]
#for arch_name in arch_names:
for patch_size in patch_sizes:
    blur_size = patch_size * 2
    print(arch_name, patch_size)
    model, device = model_official_load(arch_name, patch_size)
    num_heads = model.num_heads
    depth = model.depth
    gaze_pos_model_all = np.nan * np.ones((depth, num_heads+1, num_frames, 2))
    for crop_idx in tqdm(range(num_crops)):
        # load images
        s, e = crop_range[crop_idx] # start & end
        imgs = [Image.open(dataset_dir + "/frame{0:04d}.png".format(frame_idx)) for frame_idx in range(s, e+1)]
        images = torch.stack([transform(img) for img in imgs])
        images_split = torch.split(images, batch_size)

        gaze_pos_model = []
        for images_split_part in images_split:
            gaze_pos_model_part = get_gaze_pos_model(model, device, images_split_part, patch_size, blur_size)
            gaze_pos_model.append(gaze_pos_model_part)
        gaze_pos_model = np.concatenate(gaze_pos_model, axis=2)

        # modification of offset
        gaze_pos_model[:, :, :, 0] += dataw[crop_idx][0]
        gaze_pos_model[:, :, :, 1] += datah[crop_idx][0]
        gaze_pos_model_all[:, :, (s-1):e] = gaze_pos_model
    del model
    torch.cuda.empty_cache()
    res_dict[f"{arch_name}{patch_size}"] = gaze_pos_model_all

np.savez_compressed(f"{save_dir}/vit_official_{arch_name}_gaze_pos.npz", **res_dict)
res_dict.keys()
"""