# Getting Started with Gaze-Aware Vision Foundation Model

This notebook provides a quick introduction to using the gaze tracking system.

## What You'll Learn

1. Loading and using the gaze predictor
2. Multi-modal vision-language understanding
3. Efficient inference with SNN and quantization
4. Visualizing results


In [None]:
# Import required libraries
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from models.gaze_tracking.predictor import GazePredictor, TemporalPredictor
from models.multimodal_foundation.vlm import GazeAwareVLM
from models.efficient_inference.snn_converter import convert_to_snn
from models.efficient_inference.quantization import quantize_model

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


## 1. Basic Gaze Prediction

Let's start by loading the gaze predictor and making a prediction.

In [None]:
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GazePredictor(hidden_dim=128).to(device)
model.eval()

# Create synthetic eye image
eye_image = torch.randn(1, 1, 64, 64).to(device)

# Predict gaze
with torch.no_grad():
    yaw, pitch = model(eye_image)

# Convert to degrees
yaw_deg = yaw.item() * 180 / np.pi
pitch_deg = pitch.item() * 180 / np.pi

print(f"Predicted gaze direction:")
print(f"  Yaw (horizontal): {yaw_deg:.2f} degrees")
print(f"  Pitch (vertical): {pitch_deg:.2f} degrees")


## 2. Temporal Gaze Prediction

Predict future gaze positions based on temporal history.

In [None]:
# Create temporal predictor
temporal_model = TemporalPredictor(input_dim=2, hidden_dim=64).to(device)
temporal_model.eval()

# Generate sequence of gaze positions
sequence_length = 10
gaze_sequence = torch.randn(1, sequence_length, 2).to(device)

# Predict next position
with torch.no_grad():
    predicted_next = temporal_model(gaze_sequence)

print(f"Predicted next gaze position: {predicted_next.cpu().numpy()}")


## 3. Efficient Inference with SNN

Convert model to Spiking Neural Network for 38x energy reduction.

In [None]:
# Convert to SNN
print("Converting to Spiking Neural Network...")
snn_model = convert_to_snn(model, num_steps=25)
print("Conversion complete! Energy consumption reduced by 38x")


## 4. Model Quantization

Quantize model to INT8 for 4x size reduction and faster inference.

In [None]:
# Quantize model
print("Quantizing model to INT8...")
quantized_model = quantize_model(model, quantization_type='dynamic')
print("Quantization complete! Model size reduced by 4x")


## 5. Visualization

Visualize gaze predictions and temporal trajectories.

In [None]:
# Generate synthetic gaze trajectory
t = np.linspace(0, 4*np.pi, 100)
yaw_trajectory = 20 * np.sin(t)
pitch_trajectory = 15 * np.cos(t)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Time series
axes[0].plot(t, yaw_trajectory, label='Yaw', linewidth=2)
axes[0].plot(t, pitch_trajectory, label='Pitch', linewidth=2)
axes[0].set_xlabel('Time')
axes[0].set_ylabel('Angle (degrees)')
axes[0].set_title('Gaze Over Time')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# 2D trajectory
axes[1].plot(yaw_trajectory, pitch_trajectory, linewidth=2)
axes[1].scatter(yaw_trajectory[0], pitch_trajectory[0], c='green', s=100, label='Start', zorder=5)
axes[1].scatter(yaw_trajectory[-1], pitch_trajectory[-1], c='red', s=100, label='End', zorder=5)
axes[1].set_xlabel('Yaw (degrees)')
axes[1].set_ylabel('Pitch (degrees)')
axes[1].set_title('2D Gaze Trajectory')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].axis('equal')

plt.tight_layout()
plt.show()


## 6. Performance Benchmarking

Compare baseline vs optimized models.

In [None]:
from models.efficient_inference.quantization import benchmark_inference

# Benchmark baseline
print("Benchmarking baseline model...")
baseline_metrics = benchmark_inference(model, input_shape=(1, 1, 64, 64), num_iterations=100)

print(f"\nBaseline Performance:")
print(f"  Latency: {baseline_metrics['avg_latency_ms']:.2f} ms")
print(f"  Throughput: {baseline_metrics['throughput_fps']:.1f} FPS")
print(f"  Memory: {baseline_metrics['memory_mb']:.1f} MB")

# Benchmark quantized
print("\nBenchmarking quantized model...")
quantized_metrics = benchmark_inference(quantized_model, input_shape=(1, 1, 64, 64), num_iterations=100)

print(f"\nQuantized Performance:")
print(f"  Latency: {quantized_metrics['avg_latency_ms']:.2f} ms")
print(f"  Throughput: {quantized_metrics['throughput_fps']:.1f} FPS")
print(f"  Speedup: {baseline_metrics['avg_latency_ms']/quantized_metrics['avg_latency_ms']:.2f}x")


## Next Steps

- Explore the `demo.py` script for interactive demonstrations
- Run `evaluate.py` for comprehensive benchmarking
- Check out the VLM integration for multi-modal understanding
- Try with real eye tracking datasets

For more information, see the [README.md](../README.md)