# Training Predictive Model

### Import libraries

In [None]:
import datetime, json, random, IPython, pandas as pd, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.image as mpimg
import torch, pytorch_lightning as pl
from ray import tune
from pytorch_lightning.loggers import TensorBoardLogger

from models import GazeDataModule, SingleModel, EyesModel, FullModel
from utils  import (
    get_config,
    tune_asha,
    get_best_results,
    save_model,
    plot_asha_param_grid,
    plot_parallel_param_loss,
    latest_tune_dir,
    _build_datamodule,
    _build_model,
    predict_screen_errors,
)

# project settings
SETTINGS, COLOURS, EYETRACKER, TF = get_config("config.ini")

%load_ext autoreload
%autoreload 2


### Dataset information

In [None]:
df = pd.read_csv("data/positions.csv")
region_map = np.load("data/region_map.npy").T

print(f"# of samples: {len(df)}")
coverage = np.count_nonzero(region_map > 0) / region_map.size * 100
print(f"Coverage: {coverage:.2f}% of screen surface")
print(f"Crop size: {SETTINGS['image_size']} x {SETTINGS['image_size']} px")

### Data visualization

In [None]:
x = np.arange(region_map.shape[1])
y = np.arange(region_map.shape[0])
X, Y = np.meshgrid(x,y)

fig = plt.figure(figsize=(15,10))

ax = fig.add_subplot(221)
ax.hist(df['x'], bins=20, edgecolor='k')
ax.set_title('Screen X')
ax.margins(x=0)

ax = fig.add_subplot(222, projection='3d')
ax.dist = 9
ax.plot_surface(X, Y, region_map, cmap="inferno")
ax.set_xlabel('Screen X', labelpad=-10)
ax.set_ylabel('Screen Y', labelpad=-10)

ax = fig.add_subplot(223)
ax.imshow(region_map, interpolation='bilinear', cmap="inferno")

ax = fig.add_subplot(224)
ax.hist(df['y'], bins=15, edgecolor='k', orientation='horizontal')
ax.set_title('Screen Y')

plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
plt.tight_layout()
plt.savefig('media/images/0_data_distribution.png')
plt.show()

In [None]:
low_bright  = plt.imread("data/r_eye/12.jpg")
high_bright = plt.imread("data/r_eye/31.jpg")

fig, axs = plt.subplots(1, 2, figsize=(6, 3))
axs[0].imshow(low_bright);  axs[0].set_title("darker")
axs[1].imshow(high_bright); axs[1].set_title("brighter")
for ax in axs: ax.axis("off")
plt.tight_layout()
plt.show()

In [None]:
def plot_calibration(eye: str):
    idx = list(range(1, 10))
    imgs = [plt.imread(f"data/{eye}/{i}.jpg") for i in idx]

    fig, axs = plt.subplots(3, 3, figsize=(4, 4))
    order = [3, 6, 8, 2, 4, 7, 1, 0, 5]  # keep your original layout
    for ax, img_i in zip(axs.flat, order):
        ax.imshow(imgs[img_i])
        ax.axis("off")
    fig.suptitle(f"9-point calibration: {eye}")
    plt.tight_layout()
    plt.savefig(f"media/images/0_calibration_{eye}.png", dpi=120)
    plt.show()

plot_calibration("l_eye")
plot_calibration("r_eye")


### Fine tuning

In [None]:
search_space = {
    "seed":  tune.randint(0, 10000),
    "bs":    tune.choice([256, 512, 1024]),
    "lr":    tune.loguniform(1e-4, 3e-3),
    "channels": tune.choice([(32, 64, 128), (48, 96, 192), (64, 128, 256)]),
    "hidden":   tune.choice([128, 256, 512]),
}

analysis = tune_asha(
    search_space   = search_space,
    train_func     = "single",  # utils infers correct trainer wrapper
    name           = "face/tune",
    img_types      = ["face"],
    num_samples    = 20,
    num_epochs     = 15,
    data_dir     = Path.cwd() / "data",
    seed           = 87,
)

In [None]:
plot_asha_param_grid(analysis, save_path="media/images/final_face_explore_scatter.png")

In [None]:
plot_parallel_param_loss(analysis, save_path="media/images/final_face_explore_parallel.png")

### Training

In [None]:
start_time = datetime.datetime.now().strftime("%Y-%b-%d %H-%M-%S")

tune_dir = Path.cwd() / "logs" / "face"
best_cfg = get_best_results(latest_tune_dir(tune_dir))
pl.seed_everything(best_cfg["seed"])

dm = GazeDataModule(
    data_dir = Path.cwd() / "data",
    batch_size = best_cfg["bs"],
    img_types = ["face"],
    seed = best_cfg["seed"],
)

model = _build_model(best_cfg, ["face"])

trainer = pl.Trainer(
    max_epochs = 100,
    accelerator = "auto",
    devices = "auto",
    precision = "bf16-mixed",           # or "16-mixed" if your GPU supports it
    logger = TensorBoardLogger(
        save_dir = Path.cwd() / "logs",
        name     = f"face/final/{start_time}",
        log_graph = True,
    ),
    callbacks = [
        pl.callbacks.ModelCheckpoint(
            filename = "best",
            monitor  = "val_loss",
            mode     = "min",
            save_last = True,
            save_top_k = 1,
        )
    ],
)

trainer.fit(model, datamodule=dm)
best_path = trainer.checkpoint_callback.best_model_path
state = torch.load(best_path, map_location="cpu", weights_only=False)  # ←
model.load_state_dict(state["state_dict"])

# save weights + config next to lightning checkpoint
out_dir = Path.cwd() / "logs" / "face" / "final" / start_time
out_dir.mkdir(parents=True, exist_ok=True)

save_model(
    model.cpu(),                   # save on CPU to shrink file size
    best_cfg,
    out_dir / "eyetracking_model.pt",
    out_dir / "eyetracking_config.json",
)

### Model Evaluation

In [None]:
test_results = trainer.test(ckpt_path="best", datamodule=dm)[0]

loss = test_results["test_loss_epoch"]
mae  = test_results["test_mae_epoch"]

mse  = test_results.get("test_mse_epoch",  test_results.get("test_mse"))
rmse = test_results.get("test_rmse_epoch", test_results.get("test_rmse"))

print("────────  Test set  ────────")
print(f"MSE   : {mse:8.2f}  px²")
print(f"RMSE  : {rmse:8.2f}  px")
print(f"MAE   : {mae:8.2f}  px")
print(f"Loss  : {loss:8.2f}  (Smooth-L1)")

In [None]:
predict_screen_errors(
    "face",
    path_model  = out_dir/"eyetracking_model.pt",
    path_config = out_dir/"eyetracking_config.json",
    path_plot   = out_dir/"error_heatmap_face.png",
    path_errors = out_dir/"errors.npy",
    steps       = 10,
)