In [16]:
import torch
import numpy as np
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import clip
from models import Transformer
from CLIPGaze import CLIPGaze
from feature_extractor import visual_forward

In [17]:
def run_model(model, src, task, device = "cuda:0", im_h=20, im_w=32, project_num = 16, num_samples = 1):
    task = torch.tensor(task.astype(np.float32)).to(device).unsqueeze(0).repeat(num_samples, 1)
    firstfix = torch.tensor([(im_h//2)*project_num, (im_w//2)*project_num]).unsqueeze(0).repeat(num_samples, 1)
    with torch.no_grad():
        token_prob, ys, xs, ts = model(src = src, tgt = firstfix, task = task)
    token_prob = token_prob.detach().cpu().numpy()
    ys = ys.cpu().detach().numpy()
    xs = xs.cpu().detach().numpy()
    ts = ts.cpu().detach().numpy()
    scanpaths = []
    for i in range(num_samples):
        ys_i = [(im_h//2) * project_num] + list(ys[:, i, 0])[1:]
        xs_i = [(im_w//2) * project_num] + list(xs[:, i, 0])[1:]
        ts_i = list(ts[:, i, 0])
        token_type = [0] + list(np.argmax(token_prob[:, i, :], axis=-1))[1:]
        scanpath = []
        for tok, y, x, t in zip(token_type, ys_i, xs_i, ts_i):
            if tok == 0:
                scanpath.append([min(im_h * project_num - 2, y),min(im_w * project_num - 2, x), t])
            else:
                break
        scanpaths.append(np.array(scanpath))
    return scanpaths

In [18]:
def postprocessScanpaths(trajs):
    # convert actions to scanpaths
    scanpaths = []
    for traj in trajs:
        task_name, img_name, condition, subject, fixs = traj
        scanpaths.append({
            'X': fixs[:, 1],
            'Y': fixs[:, 0],
            'T': fixs[:, 2],
            'subject':subject,
            'name': img_name,
            'task': task_name,
            'condition': condition
        })
    return scanpaths

In [19]:
def load_clipgaze_model(checkpoint_path, device, im_hw=(20, 32), max_len=7, hidden_dim=1024, nhead=8, num_decoder=6):
    """Instantiate CLIPGaze and restore pretrained weights."""
    transformer = Transformer(
        nhead=nhead,
        d_model=hidden_dim,
        num_decoder_layers=num_decoder,
        dim_feedforward=hidden_dim,
        device=device,
        im_h=im_hw[0],
        im_w=im_hw[1],
    ).to(device)
    model = CLIPGaze(transformer, spatial_dim=im_hw, max_len=max_len, device=device).to(device)
    state = torch.load(checkpoint_path, map_location=device)
    if isinstance(state, dict) and "model" in state:
        state = state["model"]
    model.load_state_dict(state, strict=False)
    model.eval()
    return model


def extract_visual_features(image_path, clip_visual, device, token_shape, extract_layers=(6, 12, 18)):
    """Run CLIP visual encoder and format activations the way CLIPGaze expects."""
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.Resize((280, 448)),
    ])
    image = Image.open(image_path).convert("RGB")
    img_dtype = clip_visual.conv1.weight.dtype
    image_tensor = preprocess(image).unsqueeze(0).to(device=device, dtype=img_dtype)
    with torch.no_grad():
        _, activations, _ = visual_forward(
            clip_visual,
            image_tensor,
            extract_layers=extract_layers,
            token_shape=token_shape,
        )
    activations = [x.permute(1, 0, 2).to(device=device, dtype=torch.float32) for x in activations]
    return [activations]


def encode_target_prompt(prompt, clip_model, device):
    text_tokens = clip.tokenize(prompt).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens).squeeze(0)
    return text_features.detach().cpu().numpy().astype(np.float32)


def plot_scanpaths(image_path, scanpaths, im_hw=(20, 32), project_num=16, duration_radius=(5, 18)):
    image = Image.open(image_path).convert("RGB")
    img_w, img_h = image.size
    scale_x = img_w / (im_hw[1] * project_num)
    scale_y = img_h / (im_hw[0] * project_num)
    colors = ["#ff6b6b", "#4ecdc4", "#ffa62b", "#5d5fef", "#2ec4b6"]
    num_samples = len(scanpaths)
    fig, axes = plt.subplots(1, num_samples, figsize=(5 * num_samples, 5))
    if num_samples == 1:
        axes = [axes]
    for idx, (ax, scanpath) in enumerate(zip(axes, scanpaths)):
        ax.imshow(image)
        if len(scanpath) > 0:
            xs = scanpath[:, 1] * scale_x
            ys = scanpath[:, 0] * scale_y
            durations = scanpath[:, 2]
            color = colors[idx % len(colors)]
            ax.plot(xs, ys, "-", color=color, linewidth=2)
            dur_min, dur_max = durations.min(), durations.max()
            if dur_max == dur_min:
                radii = np.full_like(durations, duration_radius[0], dtype=float)
            else:
                radii = np.interp(durations, (dur_min, dur_max), duration_radius)
            # Use scatter so marker radius grows with fixation duration
            ax.scatter(
                xs,
                ys,
                s=np.square(radii),
                color=color,
                edgecolors="white",
                linewidths=0.8,
                zorder=3,
            )
            for step, (x, y) in enumerate(zip(xs, ys), start=1):
                ax.text(x + 2, y + 2, str(step), color="white", fontsize=8, weight="bold")
        ax.set_title(f"Sample {idx + 1} ({len(scanpath)} fix)")
        ax.axis("off")
    plt.tight_layout()
    plt.show()

    for idx, (ax, scanpath) in enumerate(zip(axes, scanpaths)):
        ax.imshow(image)
        if len(scanpath) > 0:
            xs = scanpath[:, 1] * scale_x
            ys = scanpath[:, 0] * scale_y
            durations = scanpath[:, 2]
            color = colors[idx % len(colors)]
            ax.plot(xs, ys, "-", color=color, linewidth=2)
            dur_min, dur_max = durations.min(), durations.max()
            if dur_max == dur_min:
                radii = np.full_like(durations, duration_radius[0], dtype=float)
            else:
                radii = np.interp(durations, (dur_min, dur_max), duration_radius)

            ax.scatter(xs, ys, s=np.square(radii), color=color, edgecolors="white", linewidths=0.8, zorder=3)

            for step, (x, y) in enumerate(zip(xs, ys), start=1):
                ax.text(x + 2, y + 2, str(step), color="white", fontsize=8, weight="bold")

        ax.set_title(f"Sample {idx + 1}")
        ax.axis("off")

    plt.tight_layout()
    # Sauvegarde sur le disque
    plt.savefig(output_path, dpi=150)
    # Ferme la figure pour libérer la mémoire (très important dans une boucle)
    plt.close(fig)

def save_scanpath_result(image_path, scanpaths, output_path, im_hw=(20, 32), project_num=16, duration_radius=(5, 18)):
    # Similaire à plot_scanpaths mais sauvegarde le fichier
    image = Image.open(image_path).convert("RGB")
    img_w, img_h = image.size
    scale_x = img_w / (im_hw[1] * project_num)
    scale_y = img_h / (im_hw[0] * project_num)
    colors = ["#ff6b6b", "#4ecdc4", "#ffa62b", "#5d5fef", "#2ec4b6"]

    num_samples = len(scanpaths)
    # On crée la figure sans l'afficher
    fig, axes = plt.subplots(1, num_samples, figsize=(5 * num_samples, 5))
    if num_samples == 1:
        axes = [axes]

    for idx, (ax, scanpath) in enumerate(zip(axes, scanpaths)):
        ax.imshow(image)
        if len(scanpath) > 0:
            xs = scanpath[:, 1] * scale_x
            ys = scanpath[:, 0] * scale_y
            durations = scanpath[:, 2]
            color = colors[idx % len(colors)]
            ax.plot(xs, ys, "-", color=color, linewidth=2)
            dur_min, dur_max = durations.min(), durations.max()
            if dur_max == dur_min:
                radii = np.full_like(durations, duration_radius[0], dtype=float)
            else:
                radii = np.interp(durations, (dur_min, dur_max), duration_radius)

            ax.scatter(xs, ys, s=np.square(radii), color=color, edgecolors="white", linewidths=0.8, zorder=3)

            for step, (x, y) in enumerate(zip(xs, ys), start=1):
                ax.text(x + 2, y + 2, str(step), color="white", fontsize=8, weight="bold")

        ax.set_title(f"Sample {idx + 1}")
        ax.axis("off")

    plt.tight_layout()
    # Sauvegarde sur le disque
    plt.savefig(output_path, dpi=150)
    # Ferme la figure pour libérer la mémoire (très important dans une boucle)
    plt.close(fig)

In [23]:
# --- Inference configuration ---
import os
import glob

# --- CONFIGURATION ---
# Dossier avec les images
input_folder = "./images_renders"
# Dossier où les résultats seront sauvegardés
output_folder = "./results/renders"
# Extensions d'images à chercher
extensions = ["*.jpg", "*.jpeg", "*.png"]

# Paramètres du modèle (inchangés)
checkpoint_path = Path("CLIPGaze_TP.pkg")
im_hw = (20, 32)
project_num = 16
num_samples = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_version = "ViT-L/14@336px"
token_shape_lookup = {"ViT-B/32": (7, 7), "ViT-B/16": (14, 14), "ViT-L/14@336px": (24, 24)}

# Création du dossier de sortie s'il n'existe pas
os.makedirs(output_folder, exist_ok=True)

# 1. Chargement du modèle (une seule fois pour gagner du temps)
print("Chargement du modèle...")
token_shape = token_shape_lookup[clip_version]
clip_model, _ = clip.load(clip_version, device=device, jit=False)
clip_model.eval()
model = load_clipgaze_model(checkpoint_path, device, im_hw=im_hw)
print("Modèle chargé.")

# 2. Récupération de la liste des images
image_files = []
for ext in extensions:
    image_files.extend(list(Path(input_folder).glob(ext)))
image_files = sorted(image_files)

print(f"{len(image_files)} images trouvées à traiter.")

# 3. Boucle de traitement
for i, img_path in enumerate(image_files):
    try:
        filename = img_path.name
        print(f"[{i+1}/{len(image_files)}] Traitement de : {filename} ...", end=" ")

        # Déduction du mot cible depuis le nom de fichier (ex: 'oven_01.jpg' -> 'oven')
        target_word = img_path.stem.split("_")[0]
        target_prompt = target_word.replace("-", " ")

        # Encodage du texte (Prompt)
        task_embedding = encode_target_prompt(target_prompt, clip_model, device)

        # Extraction des features de l'image
        image_features = extract_visual_features(str(img_path), clip_model.visual, device, token_shape)

        # Exécution du modèle (Génération des scanpaths)
        scanpaths = run_model(
            model=model,
            src=image_features,
            task=task_embedding,
            device=device,
            im_h=im_hw[0],
            im_w=im_hw[1],
            project_num=project_num,
            num_samples=num_samples,
        )

        # Sauvegarde du résultat visuel
        save_path = Path(output_folder) / f"result_{img_path.stem}.png"
        save_scanpath_result(
            str(img_path),
            scanpaths,
            str(save_path),
            im_hw=im_hw,
            project_num=project_num
        )

        print(f"Terminé -> {target_word}")

    except Exception as e:
        print(f"\nERREUR sur {img_path}: {e}")
        continue

print("Traitement terminé ! Voir le dossier:", output_folder)

Chargement du modèle...
Modèle chargé.
8 images trouvées à traiter.
[1/8] Traitement de : appel.png ... Terminé -> appel
[2/8] Traitement de : apple.png ... Terminé -> apple
[3/8] Traitement de : camera.png ... Terminé -> camera
[4/8] Traitement de : chair.png ... Terminé -> chair
[5/8] Traitement de : duck.png ... Terminé -> duck
[6/8] Traitement de : green apple.png ... Terminé -> green apple
[7/8] Traitement de : red chair.png ... Terminé -> red chair
[8/8] Traitement de : table.png ... Terminé -> table
Traitement terminé ! Voir le dossier: ./results/renders


In [21]:
# target_prompt = target_word.replace("-", " ")
# task_embedding = encode_target_prompt(target_prompt, clip_model, device)
#
# image_features = extract_visual_features(image_path, clip_model.visual, device, token_shape)
#
# scanpaths = run_model(
#     model=model,
#     src=image_features,
#     task=task_embedding,
#     device=device,
#     im_h=im_hw[0],
#     im_w=im_hw[1],
#     project_num=project_num,
#     num_samples=num_samples,
#  )
#
# print(f"Target prompt: '{target_prompt}' | Num samples: {num_samples}")
# plot_scanpaths(image_path, scanpaths, im_hw=im_hw, project_num=project_num)