# Scanpath Predictive Model --- EyeFormer
### CODE USED FOR ASSIGNMENT A3a
---

<div class="alert alert-block alert-info">
    
# Table of contents
* <a href='#1.'> 1. Model overview </a>
* <a href='#2.'> 2. Preparations </a>
* <a href='#3.'> 3. Inference </a>
* <a href='#4.'> 4. Visualisation of results </a>
* <a href='#5.'> 5. Evaluation methods </a>


<a href='#T1'><b>Student Task 1.</b> Scanpath Prediction </a>

    
<a href='#T2'><b>Student Task 2.</b> Evaluation metrics </a>


## 1. Model overview <a id='1.'></a>


<img src="imgs/img1.png">

**EyeFormer**  is a model that combines Transformer and Reinforcement Learning to predict scanpaths in Free-Viewing Tasks. The goal of the model is to simulate and predict the sequence of human gaze points while viewing an image or video, and to simulate and predict human scanpaths in Free-Viewing Tasks. The model is able to capture the global information in the image and continuously optimise the prediction strategy through reinforcement learning to achieve highly accurate line-of-sight prediction and saliency prediction.


<img src="imgs/img2.png">



### 1. Generic scanpath prediction
Generic scanpath prediction means that the model predicts common eye movement patterns when confronted with a particular image or scene based on the visual behaviour of most observers. This prediction method is critical to understanding which areas attract more attention and how to design more attractive interfaces.

In EyeFormer, generic prediction is done through the Transformer model. This model captures global information from the image and continuously optimises its prediction strategy through reinforcement learning. The result is an accurate prediction that shows the areas that most people are most likely to focus on and their order of gaze.

### 2. Personalised scanpath prediction

Unlike generic prediction, personalised scanpath prediction focuses on predicting an individual's unique eye movement patterns based on their visual behavioural characteristics. Each individual may have a different focus and order of attention when looking at an image, and EyeFormer is able to generate predictions that match the behaviour of a specific individual by learning a small amount of scan path data about that individual.

<img src="imgs/img3.png">

Personalised predictions are important in many real-world applications. For example, in custom user interface design, we need to ensure that the interface layout adapts to the visual habits of a particular user to enhance their experience.

## 2. Preparations <a id='2.'></a>




### Step 1.  Environment preparation
First, make sure you have installed the necessary Python packages and environment. The following code will help you create and configure the environment needed to run the EyeFormer model.

In [1]:
!pip install -qq ruamel.yaml==0.17.21

In [2]:
import argparse
import os
import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.distributed as dist

In [3]:
%%capture

!pip install -q pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit==11.0
!pip install -q opencv-python==4.5.3.56 Pillow einops multimatch-gaze
!pip install -q transformers==4.8.1
!pip install -q timm==0.4.9

%cd EyeFormer-UIST2024

# Download the weights file (Generic scanpath prediction weights file)
!pip install gdown
!gdown https://drive.google.com/uc?id=1n2l7leXJqAM16TZnlpMiw1N-jeY7Bi4K -O ./weights/

### Step 2. Loading and configuring the model

**NOTE:**
If you have your own eval images. Please change the `eval_image_root` string in `configs/Tracking.yaml` to your eval image dir

In [4]:
from models.model_tracking import TrackingTransformer
from models.vit import interpolate_pos_embed

import utils
from dataset import create_dataset, create_sampler, create_loader
import csv

import argparse
import os
import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.distributed as dist
%cd EyeFormer-UIST2024
from models.model_tracking import TrackingTransformer
from models.vit import interpolate_pos_embed

import utils
from dataset import create_dataset, create_sampler, create_loader
import csv

[Errno 2] No such file or directory: 'EyeFormer-UIST2024'
/notebooks/compdesign2024/Visual_Saliency/EyeFormer-UIST2024


This code defines some of the parameters needed for the model to run and loads the specific configuration from the YAML file. Make sure that the checkpoint path points correctly to the pre-trained model file.

Create a configuration class **ARGS** and load the configuration file to initialise the model parameters:

In [5]:
class ARGS:
    def __init__(self, config):
        self.config = './configs/Tracking.yaml'
        self.checkpoint = './weights/checkpoint_19.pth'
        self.resume = False
        self.output_dir = 'output/tracking_eval'
        self.text_encoder = 'bert-base-uncased'
        self.device ='cpu'
        self.seed = 42
        self.world_size = 1
        self.dist_url = 'env://'
        self.distributed = True
        

from ruamel.yaml import YAML        
        
args = ARGS(config = './configs/Tracking.yaml')

yaml = YAML(typ='rt')
config = yaml.load(open(args.config, 'rt'))

Path(args.output_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))

## 3. Model inference <a id='3.'></a>
Use the EyeFormer model for inference to generate predictions.

Define a **test** function for loading data and making inferences. This function will step through each image during the test, predict the scan path, and save the results to a CSV file.

In [6]:
def test(model, data_loader, tokenizer, device, output_dir, config):
    # train
    model.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Testing:'

    image_names = []
    results = []
    widths = []
    heights = []
    # user_ids = []
    for i, (image, image_name, width, height) in enumerate(metric_logger.log_every(data_loader, 1, header)):
        image = image.to(device, non_blocking=True)
        # user_id = user_id.to(device, non_blocking=True)

        coord = model.inference(image)
        coord = coord.detach().cpu().numpy().tolist()

        width = width.numpy().tolist()
        height = height.numpy().tolist()

        # user_id = user_id.cpu().numpy().tolist()

        image_names.extend(image_name)
        results.extend(coord)
        widths.extend(width)
        heights.extend(height)
        # user_ids.extend(user_id)

    with open(os.path.join(output_dir, 'predicted_result.csv'), 'w') as wfile:
        writer = csv.writer(wfile)
        writer.writerow(["image", "width", "height", "x", "y", "timestamp"])

        for image, width, height, coord in zip(image_names, widths, heights, results):

            for row in coord:
                x = row[0] * width
                y = row[1] * height
                t = row[2]
                # username = data_loader.dataset.id2user[user_id]
                writer.writerow([image, width, height,
                                x, y, t])

    return

Model loading and data inference in the **main** function. The main function will load the model, dataset and call the test function for inference. The inference results will be saved to the specified output directory.

In [7]:
def main(args, config):
    # utils.init_distributed_mode(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    #### Dataset ####
    print("Creating dataset")
    datasets = [create_dataset('inference', config)]

    if False:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        samplers = create_sampler(datasets, [True], num_tasks, global_rank)
    else:
        samplers = [None]

    data_loader = create_loader(datasets,
                                samplers,
                                batch_size=[config['batch_size_test']],
                                num_workers=[16],
                                is_trains=[False],
                                collate_fns=[None])[0]

    # tokenizer = BertTokenizer.from_pretrained(args.text_encoder)
    tokenizer = None

    #### Model ####
    print("Creating model")
    model = TrackingTransformer(config=config, init_deit=False)

    model = model.to(device)

    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint, map_location='cpu')
        state_dict = checkpoint['model']

        msg = model.load_state_dict(state_dict)
        print('load checkpoint from %s' % args.checkpoint)
        print(msg)

    model_without_ddp = model

    print("Start testing")
    start_time = time.time()

    test(model, data_loader, tokenizer, device, args.output_dir, config)

    # dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Testing time {}'.format(total_time_str))

In [8]:
## run main function to get prediction file.

main(args, config)

Creating dataset
Creating model
Model will generate 16 points
load checkpoint from ./weights/checkpoint_19.pth
<All keys matched successfully>
Start testing
Testing:  [0/7]  eta: 0:00:41    time: 5.9185  data: 0.8545
Testing:  [1/7]  eta: 0:00:19    time: 3.2791  data: 0.4298
Testing:  [2/7]  eta: 0:00:11    time: 2.3823  data: 0.2883
Testing:  [3/7]  eta: 0:00:07    time: 1.9310  data: 0.2162
Testing:  [4/7]  eta: 0:00:04    time: 1.6621  data: 0.1734
Testing:  [5/7]  eta: 0:00:02    time: 1.4813  data: 0.1448
Testing:  [6/7]  eta: 0:00:01    time: 1.3518  data: 0.1242
Testing: Total time: 0:00:09 (1.3828 s / it)
Testing time 0:00:09


## 4. Visualization of results <a id='4.'></a>

The following visualisation code allows you to display the **ground truth and prediction results** on the image separately or simultaneously. You can choose which results to display as desired.

First we run the following cell to prepare the functions and libraries we need.

In [9]:
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os, random, csv
from IPython.display import display


def extract_scanpaths(csvfile):
    res = []
    cur = None
    obj = None

    with open(csvfile) as f:
        reader = csv.DictReader(f, delimiter=',')
        for row in reader:
            uniq = row['image']

            if uniq != cur:
                if 'obj' in locals() and cur is not None:
                    res.append(obj)

                obj = {'image': row['image'], 'width': int(row['width']), 'height': int(row['height']),
                       'scanpath': [], 'duration': []}

            x, y, t = float(row['x']), float(row['y']), float(row['timestamp'])
            obj['scanpath'].append([x, y])
            obj['duration'].append(t)

            cur = uniq

    if len(obj['scanpath']) > 0:
        res.append(obj)

    return res

def collect_info(res, dir, method):
    images = []
    xs_list = []
    ys_list = []
    ts_list = []
    methods = []

    for r in res:
        image_path = os.path.join(dir, r["image"])
        images.append(image_path)
        scan_path = r["scanpath"]
        xs = [p[0] for p in scan_path]
        ys = [p[1] for p in scan_path]
        xs_list.append(xs)
        ys_list.append(ys)
        ts_list.append(r['duration'])
        methods.append(method)

    return images, xs_list, ys_list, ts_list, methods

The following cell is the definition of the *visuliaztion* function:

In [10]:
def visualize(data_type):
    
    
    if data_type == "gt":        
        output_dir = '/notebooks/compdesign2024/Visual_Saliency/scanpath_visualization_output/gt'
        
    elif data_type == "pred":        
        output_dir = '/notebooks/compdesign2024/Visual_Saliency/scanpath_visualization_output/pred'
        
    elif data_type == "comb":
        output_dir = '/notebooks/compdesign2024/Visual_Saliency/scanpath_visualization_output/comb'
        
    else:
        raise ValueError("Invalid data_type. Use 'gt', 'pred' or 'comb'.")

    os.makedirs(output_dir, exist_ok=True)
    

    
    data_dir = "/notebooks/compdesign2024/Visual_Saliency/imgs/test"
    gt_file = "/notebooks/compdesign2024/Visual_Saliency/EyeFormer-UIST2024/evaluation/testing_ground_truth.csv"
    pred_file = "/notebooks/compdesign2024/Visual_Saliency/EyeFormer-UIST2024/output/tracking_eval/predicted_result.csv"

    gt_res = extract_scanpaths(gt_file)
    pred_res = extract_scanpaths(pred_file)

    g_images, g_xs_list, g_ys_list, g_ts_list, g_methods = collect_info(gt_res, data_dir, "gt")
    p_images, p_xs_list, p_ys_list, p_ts_list, p_methods = collect_info(pred_res, data_dir, "pred")    
    
    # Combine ground truth and projected results data
    images = list(set(g_images + p_images))
    image_to_gt = {img: (g_xs_list[g_images.index(img)], g_ys_list[g_images.index(img)], g_ts_list[g_images.index(img)]) for img in g_images if img in g_images}
    image_to_pred = {img: (p_xs_list[p_images.index(img)], p_ys_list[p_images.index(img)], p_ts_list[p_images.index(img)]) for img in p_images if img in p_images}
    # Define the colour mapping
    cm_gt = plt.get_cmap('winter_r')
    cm_pred = plt.get_cmap('autumn_r')
    
  


  
    
    
    for image in images:
        try:
            img = Image.open(image).convert("RGB")
        except FileNotFoundError:
            print(f"File not found: {image}")
            continue

        img.putalpha(int(255 * 0.8))
        img = np.array(img)

        width = img.shape[1]
        height = img.shape[0]
        plt.gray()
        plt.axis('off')
        ax = plt.imshow(img)

        
        
        ########################################################################## Mapping ground truth
                    
        if data_type == "gt":
            
            # Mapping ground truth
            if image in image_to_gt:
                xs_gt, ys_gt, ts_gt = image_to_gt[image]
                cmap_gt = (cm_gt(np.linspace(0, 1, 2 * len(xs_gt) - 1)) * 255).astype(np.uint8)
                for i in range(len(xs_gt)):
                    if i > 0:
                        ax.axes.arrow(
                            xs_gt[i - 1],
                            ys_gt[i - 1],
                            (xs_gt[i] - xs_gt[i - 1]),
                            (ys_gt[i] - ys_gt[i - 1]),
                            width=min(width, height) / 300 * 3,
                            head_width=0.05,
                            head_length=0.01,
                            color=cmap_gt[i * 2 - 1] / 255.,
                            alpha=1,
                        )
                for i in range(len(xs_gt)):
                    edgecolor = 'red' if i == 0 else 'black'
                    circle = plt.Circle(
                        (xs_gt[i], ys_gt[i]),
                        radius=min(width, height) / 35 * ts_gt[i] * 2 * 1.1 * 1.5,
                        edgecolor=edgecolor,
                        facecolor=cmap_gt[i * 2] / 255.,
                        linewidth=1
                    )
                    ax.axes.add_patch(circle)
                      
                    
        ######################################################################### Mapping prediction result          
        
                
        
        if data_type == "pred":    
            
            # Mapping prediction result
            if image in image_to_pred:
                xs_pred, ys_pred, ts_pred = image_to_pred[image]
                cmap_pred = (cm_pred(np.linspace(0, 1, 2 * len(xs_pred) - 1)) * 255).astype(np.uint8)
                for i in range(len(xs_pred)):
                    if i > 0:
                        ax.axes.arrow(
                            xs_pred[i - 1],
                            ys_pred[i - 1],
                            (xs_pred[i] - xs_pred[i - 1]),
                            (ys_pred[i] - ys_pred[i - 1]),
                            width=min(width, height) / 300 * 3,
                            head_width=0.05,
                            head_length=0.01,
                            color=cmap_pred[i * 2 - 1] / 255.,
                            alpha=1,
                        )
                for i in range(len(xs_pred)):
                    edgecolor = 'blue' if i == 0 else 'black'
                    circle = plt.Circle(
                        (xs_pred[i], ys_pred[i]),
                        radius=min(width, height) / 35 * ts_pred[i] * 2 * 1.1 * 1.5,
                        edgecolor=edgecolor,
                        facecolor=cmap_pred[i * 2] / 255.,
                        linewidth=1
                    )
                    ax.axes.add_patch(circle)

                    
        ########################################################################### Mapping both (combination)         
        
        
        if data_type == "comb":
            
            # Mapping ground truth
            if image in image_to_gt:
                xs_gt, ys_gt, ts_gt = image_to_gt[image]
                cmap_gt = (cm_gt(np.linspace(0, 1, 2 * len(xs_gt) - 1)) * 255).astype(np.uint8)
                for i in range(len(xs_gt)):
                    if i > 0:
                        ax.axes.arrow(
                            xs_gt[i - 1],
                            ys_gt[i - 1],
                            (xs_gt[i] - xs_gt[i - 1]),
                            (ys_gt[i] - ys_gt[i - 1]),
                            width=min(width, height) / 300 * 3,
                            head_width=0.05,
                            head_length=0.01,
                            color=cmap_gt[i * 2 - 1] / 255.,
                            alpha=1,
                        )
                for i in range(len(xs_gt)):
                    edgecolor = 'red' if i == 0 else 'black'
                    circle = plt.Circle(
                        (xs_gt[i], ys_gt[i]),
                        radius=min(width, height) / 35 * ts_gt[i] * 2 * 1.1 * 1.5,
                        edgecolor=edgecolor,
                        facecolor=cmap_gt[i * 2] / 255.,
                        linewidth=1
                    )
                    ax.axes.add_patch(circle)
            
            # Mapping prediction result
            if image in image_to_pred:
                xs_pred, ys_pred, ts_pred = image_to_pred[image]
                cmap_pred = (cm_pred(np.linspace(0, 1, 2 * len(xs_pred) - 1)) * 255).astype(np.uint8)
                for i in range(len(xs_pred)):
                    if i > 0:
                        ax.axes.arrow(
                            xs_pred[i - 1],
                            ys_pred[i - 1],
                            (xs_pred[i] - xs_pred[i - 1]),
                            (ys_pred[i] - ys_pred[i - 1]),
                            width=min(width, height) / 300 * 3,
                            head_width=0.05,
                            head_length=0.01,
                            color=cmap_pred[i * 2 - 1] / 255.,
                            alpha=1,
                        )
                for i in range(len(xs_pred)):
                    edgecolor = 'blue' if i == 0 else 'black'
                    circle = plt.Circle(
                        (xs_pred[i], ys_pred[i]),
                        radius=min(width, height) / 35 * ts_pred[i] * 2 * 1.1 * 1.5,
                        edgecolor=edgecolor,
                        facecolor=cmap_pred[i * 2] / 255.,
                        linewidth=1
                    )
                    ax.axes.add_patch(circle)

        #########################################################################################                  
                    

        imagename = os.path.basename(image).split(".")[0]
        ax.figure.savefig(output_dir + "/" + '{}_{}.png'.format(imagename, data_type), dpi=120, bbox_inches="tight")
        plt.close(ax.figure)




Now we create a visualization for our desired analysis.

In [11]:
from PIL import Image

def create_display(background_path, sticker_paths, coordinates, sizes):
    # Open the background image
    background = Image.open(background_path).convert("RGBA")
    bg_width, bg_height = background.size

    # Loop through each sticker image
    for i, sticker_path in enumerate(sticker_paths):
        sticker = Image.open(sticker_path).convert("RGBA")
        max_size = int(sizes[i].item() if isinstance(sizes[i], torch.Tensor) else sizes[i])

        aspect_ratio = sticker.width / sticker.height
        if sticker.width > sticker.height:
            new_width = max_size
            new_height = int(max_size / aspect_ratio)
        else:
            new_height = max_size
            new_width = int(max_size * aspect_ratio)

        sticker = sticker.resize((new_width, new_height), Image.LANCZOS)
        
        # so that the corrdinates refer to the centerpoint of each sticker
        x = int(coordinates[i][0] * bg_width / 100) - new_width // 2
        y = int(coordinates[i][1] * bg_height / 100) - new_height // 2
        
        background.paste(sticker, (x, y), sticker)

    return background

Optimization Code

In [13]:
output_path = "/notebooks/compdesign2024/Visual_Saliency/imgs/A3a/"
output_path_backup = "/notebooks/compdesign2024/Visual_Saliency/imgs/test/"
scanpath_csv_path = "/notebooks/compdesign2024/Visual_Saliency/EyeFormer-UIST2024/output/tracking_eval/predicted_result.csv"

background_path = "/notebooks/compdesign2024/Visual_Saliency/imgs/background.png"
sticker_paths = ["/notebooks/compdesign2024/Visual_Saliency/imgs/button1.png", "/notebooks/compdesign2024/Visual_Saliency/imgs/button2.png", "/notebooks/compdesign2024/Visual_Saliency/imgs/button3.png"]
coordinates = [(52, 50), (70, 80), (30, 80)]    # All in %!
sizes = [100, 100, 100]                           # All in %!


image = create_display(background_path, sticker_paths, coordinates, sizes)
image.save(output_path + "1.png")
image.save(output_path_backup + "1.png")

coordinates = [(52, 50), (70, 70), (30, 80)]    # All in %!
sizes = [100, 100, 100]                           # All in %!
           
image = create_display(background_path, sticker_paths, coordinates, sizes)
image.save(output_path + "2.png")
image.save(output_path_backup + "2.png")

coordinates = [(52, 50), (70, 70), (30, 80)]    # All in %!
sizes = [100, 120, 80]                           # All in %!

image = create_display(background_path, sticker_paths, coordinates, sizes)
image.save(output_path + "3.png")
image.save(output_path_backup + "3.png")

           
main(args, config)
visualize("pred")


Creating dataset
Creating model
Model will generate 16 points
load checkpoint from ./weights/checkpoint_19.pth
<All keys matched successfully>
Start testing
Testing:  [0/7]  eta: 0:00:42    time: 6.1150  data: 4.8567
Testing:  [1/7]  eta: 0:00:20    time: 3.4554  data: 2.4321
Testing:  [2/7]  eta: 0:00:12    time: 2.5229  data: 1.6214
Testing:  [3/7]  eta: 0:00:08    time: 2.0598  data: 1.2168
Testing:  [4/7]  eta: 0:00:05    time: 1.7798  data: 0.9735
Testing:  [5/7]  eta: 0:00:03    time: 1.5929  data: 0.8113
Testing:  [6/7]  eta: 0:00:01    time: 1.4604  data: 0.6957
Testing: Total time: 0:00:10 (1.5004 s / it)
Testing time 0:00:10


Ignore this block. This is for Maunal Grading for your report