Acknowledgements: https://github.com/swook/GazeML

The notebook was adapted from this work.

## Imports

In [None]:
!pip install sio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sio
  Downloading sio-0.1.0-py3-none-any.whl (2.5 kB)
Collecting shortio>=0.1.0
  Downloading shortio-0.1.0-py3-none-any.whl (4.9 kB)
Installing collected packages: shortio, sio
Successfully installed shortio-0.1.0 sio-0.1.0


In [None]:
from __future__ import print_function, division
import os

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import numpy as np
import cv2
import argparse
from scipy.spatial.transform import Rotation as R
import cv2 as cv
from typing import Optional

import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import glob
import json
import sio

import torch.nn as nn

## Prepare the dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd drive/My Drive

/content/drive/My Drive


In [None]:
# !gdown 13KJi8Ohbxp5q2hGoaS6tzrA_Ejdshafs
# !gdown 1wqTA4gutC-L4h8TcMMQO_3jJYBL15-h3

In [None]:
# !unzip imgs.zip

## Utility Functions

In [None]:
"""Utility methods for gaze angle and error calculations."""

def pitchyaw_to_vector(pitchyaws):
    n = pitchyaws.shape[0]
    sin = np.sin(pitchyaws)
    cos = np.cos(pitchyaws)
    out = np.empty((n, 3))
    out[:, 0] = np.multiply(cos[:, 0], sin[:, 1])
    out[:, 1] = sin[:, 0]
    out[:, 2] = np.multiply(cos[:, 0], cos[:, 1])
    return out


def vector_to_pitchyaw(vectors):
    n = vectors.shape[0]
    out = np.empty((n, 2))
    vectors = np.divide(vectors, np.linalg.norm(vectors, axis=1).reshape(n, 1))
    out[:, 0] = np.arcsin(vectors[:, 1])  # theta
    out[:, 1] = np.arctan2(vectors[:, 0], vectors[:, 2])  # phi
    return out

radians_to_degrees = 180.0 / np.pi


def angular_error(a, b):
    """Calculate angular error (via cosine similarity)."""
    a = pitchyaw_to_vector(a) if a.shape[1] == 2 else a
    b = pitchyaw_to_vector(b) if b.shape[1] == 2 else b

    ab = np.sum(np.multiply(a, b), axis=1)
    a_norm = np.linalg.norm(a, axis=1)
    b_norm = np.linalg.norm(b, axis=1)

    # Avoid zero-values (to avoid NaNs)
    a_norm = np.clip(a_norm, a_min=1e-7, a_max=None)
    b_norm = np.clip(b_norm, a_min=1e-7, a_max=None)

    similarity = np.divide(ab, np.multiply(a_norm, b_norm))

    return np.arccos(similarity) * radians_to_degrees


def mean_angular_error(a, b):
    """Calculate mean angular error (via cosine similarity)."""
    return np.mean(angular_error(a, b))


def draw_gaze(image_in, eye_pos, pitchyaw, length=40.0, thickness=2, color=(0, 0, 255)):
    """Draw gaze angle on given image with a given eye positions."""
    image_out = image_in
    if len(image_out.shape) == 2 or image_out.shape[2] == 1:
        image_out = cv.cvtColor(image_out, cv.COLOR_GRAY2BGR)
    dx = -length * np.sin(pitchyaw[1])
    dy = length * np.sin(pitchyaw[0])
    cv.arrowedLine(image_out, tuple(np.round(eye_pos).astype(np.int32)),
                   tuple(np.round([eye_pos[0] + dx, eye_pos[1] + dy]).astype(int)), color,
                   thickness, cv.LINE_AA, tipLength=0.2)
    return image_out

In [None]:
def preprocess_unityeyes_image(img, json_data):
    ow = 160
    oh = 96
    # Prepare to segment eye image
    ih, iw = img.shape[:2]
    ih_2, iw_2 = ih/2.0, iw/2.0

    heatmap_w = int(ow/2)
    heatmap_h = int(oh/2)

    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    def process_coords(coords_list):
        coords = [eval(l) for l in coords_list]
        return np.array([(x, ih-y, z) for (x, y, z) in coords])
    
    interior_landmarks = process_coords(json_data['interior_margin_2d'])
    caruncle_landmarks = process_coords(json_data['caruncle_2d'])
    iris_landmarks = process_coords(json_data['iris_2d'])

    left_corner = np.mean(caruncle_landmarks[:, :2], axis=0)
    right_corner = interior_landmarks[8, :2]
    eye_width = 1.5 * abs(left_corner[0] - right_corner[0])
    eye_middle = np.mean([np.amin(interior_landmarks[:, :2], axis=0),
                          np.amax(interior_landmarks[:, :2], axis=0)], axis=0)

    # Normalize to eye width.
    scale = ow/eye_width

    translate = np.asmatrix(np.eye(3))
    translate[0, 2] = -eye_middle[0] * scale
    translate[1, 2] = -eye_middle[1] * scale

    rand_x = np.random.uniform(low=-10, high=10)
    rand_y = np.random.uniform(low=-10, high=10)
    recenter = np.asmatrix(np.eye(3))
    recenter[0, 2] = ow/2 + rand_x
    recenter[1, 2] = oh/2 + rand_y

    scale_mat = np.asmatrix(np.eye(3))
    scale_mat[0, 0] = scale
    scale_mat[1, 1] = scale

    angle = 0 #np.random.normal(0, 1) * 20 * np.pi/180
    rotation = R.from_rotvec([0, 0, angle]).as_matrix()

    transform = recenter * rotation * translate * scale_mat
    transform_inv = np.linalg.inv(transform)
    
    # Apply transforms
    eye = cv2.warpAffine(img, transform[:2], (ow, oh))

    rand_blur = np.random.uniform(low=0, high=20)
    eye = cv2.GaussianBlur(eye, (5, 5), rand_blur)

    # Normalize eye image
    eye = cv2.equalizeHist(eye)
    eye = eye.astype(np.float32)
    eye = eye / 255.0

    # Gaze
    # Convert look vector to gaze direction in polar angles
    look_vec = np.array(eval(json_data['eye_details']['look_vec']))[:3].reshape((1, 3))
    #look_vec = np.matmul(look_vec, rotation.T)

    gaze = vector_to_pitchyaw(-look_vec).flatten()
    gaze = gaze.astype(np.float32)

    iris_center = np.mean(iris_landmarks[:, :2], axis=0)

    landmarks = np.concatenate([interior_landmarks[:, :2],  # 8
                                iris_landmarks[::2, :2],  # 8
                                iris_center.reshape((1, 2)),
                                [[iw_2, ih_2]],  # Eyeball center
                                ])  # 18 in total

    landmarks = np.asmatrix(np.pad(landmarks, ((0, 0), (0, 1)), 'constant', constant_values=1))
    landmarks = np.asarray(landmarks * transform[:2].T) * np.array([heatmap_w/ow, heatmap_h/oh])
    landmarks = landmarks.astype(np.float32)

    # Swap columns so that landmarks are in (y, x), not (x, y)
    # This is because the network outputs landmarks as (y, x) values.
    temp = np.zeros((34, 2), dtype=np.float32)
    temp[:, 0] = landmarks[:, 1]
    temp[:, 1] = landmarks[:, 0]
    landmarks = temp

    heatmaps = get_heatmaps(w=heatmap_w, h=heatmap_h, landmarks=landmarks)

    assert heatmaps.shape == (34, heatmap_h, heatmap_w)

    return {
        'img': eye,
        'transform': np.asarray(transform),
        'transform_inv': np.asarray(transform_inv),
        'eye_middle': np.asarray(eye_middle),
        'heatmaps': np.asarray(heatmaps),
        'landmarks': np.asarray(landmarks),
        'gaze': np.asarray(gaze)
    }


def gaussian_2d(w, h, cx, cy, sigma=1.0):
    """Generate heatmap with single 2D gaussian."""
    xs, ys = np.meshgrid(
        np.linspace(0, w - 1, w, dtype=np.float32),
        np.linspace(0, h - 1, h, dtype=np.float32)
    )

    assert xs.shape == (h, w)
    alpha = -0.5 / (sigma ** 2)
    heatmap = np.exp(alpha * ((xs - cx) ** 2 + (ys - cy) ** 2))
    return heatmap


def get_heatmaps(w, h, landmarks):
    heatmaps = []
    for (y, x) in landmarks:
        heatmaps.append(gaussian_2d(w, h, cx=x, cy=y, sigma=2.0))
    return np.array(heatmaps)

In [None]:
class UnityEyesDataset(Dataset):

    def __init__(self, img_dir: Optional[str] = None):

        if img_dir is None:
            img_dir = os.path.join(os.getcwd(), '/imgs')

        self.img_paths = glob.glob(os.path.join(img_dir, '*.jpg'))
        self.img_paths = sorted(self.img_paths, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
        self.json_paths = []
        for img_path in self.img_paths:
            idx = os.path.splitext(os.path.basename(img_path))[0]
            self.json_paths.append(os.path.join(img_dir, f'{idx}.json'))

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        full_img = cv2.imread(self.img_paths[idx])
        with open(self.json_paths[idx]) as f:
            json_data = json.load(f)

        eye_sample = preprocess_unityeyes_image(full_img, json_data)
        sample = {'full_img': full_img, 'json_data': json_data }
        sample.update(eye_sample)
        return sample

In [None]:
class MPIIGaze(Dataset):

    def __init__(self, mpii_dir: str = 'datasets/MPIIGaze'):

        self.mpii_dir = mpii_dir

        eval_files = glob.glob(f'{mpii_dir}/Evaluation Subset/sample list for eye image/*.txt')

        self.eval_entries = []
        for ef in eval_files:
            person = os.path.splitext(os.path.basename(ef))[0]
            with open(ef) as f:
                lines = f.readlines()
                for line in lines:
                    line = line.strip()
                    if line != '':
                        img_path, side = [x.strip() for x in line.split()]
                        day, img = img_path.split('/')
                        self.eval_entries.append({
                            'day': day,
                            'img_name': img,
                            'person': person,
                            'side': side
                        })

    def __len__(self):
        return len(self.eval_entries)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self._load_sample(idx)

    def _load_sample(self, i):
        entry = self.eval_entries[i]
        mat_path = os.path.join(self.mpii_dir, 'Data/Normalized', entry['person'], entry['day'] + '.mat')
        mat = sio.loadmat(mat_path)

        filenames = mat['filenames']
        row = np.argwhere(filenames == entry['img_name'])[0][0]
        side = entry['side']

        img = mat['data'][side][0, 0]['image'][0, 0][row]
        img = cv2.resize(img, (160, 96))
        img = cv2.equalizeHist(img)
        img = img / 255.
        img = img.astype(np.float32)
        if side == 'right':
            img = np.fliplr(img)

        (x, y, z) = mat['data'][side][0, 0]['gaze'][0, 0][row]

        theta = np.arcsin(-y)
        phi = np.arctan2(-x, -z)
        gaze = np.array([-theta, phi])

        return {
            'img': img,
            'gaze': gaze,
            'side': side
        }

In [None]:
def softargmax2d(input, beta=100, dtype=torch.float32):
    *_, h, w = input.shape

    input = input.reshape(*_, h * w)
    input = nn.functional.softmax(beta * input, dim=-1)

    indices_c, indices_r = np.meshgrid(
        np.linspace(0, 1, w),
        np.linspace(0, 1, h),
        indexing='xy'
    )

    indices_r = torch.tensor(np.reshape(indices_r, (-1, h * w)))
    indices_c = torch.tensor(np.reshape(indices_c, (-1, h * w)))

    device = input.get_device()
    if device >= 0:
        indices_r = indices_r.to(device)
        indices_c = indices_c.to(device)

    result_r = torch.sum((h - 1) * input * indices_r, dim=-1)
    result_c = torch.sum((w - 1) * input * indices_c, dim=-1)

    result = torch.stack([result_r, result_c], dim=-1)

    return result.type(dtype)


def softargmax1d(input, beta=100):
    *_, n = input.shape
    input = nn.functional.softmax(beta * input, dim=-1)
    indices = torch.linspace(0, 1, n)
    result = torch.sum((n - 1) * input * indices, dim=-1)
    return result

In [None]:
class EyeSample:
    def __init__(self, orig_img, img, is_left, transform_inv, estimated_radius):
        self._orig_img = orig_img.copy()
        self._img = img.copy()
        self._is_left = is_left
        self._transform_inv = transform_inv
        self._estimated_radius = estimated_radius
    @property
    def orig_img(self):
        return self._orig_img

    @property
    def img(self):
        return self._img

    @property
    def is_left(self):
        return self._is_left

    @property
    def transform_inv(self):
        return self._transform_inv

    @property
    def estimated_radius(self):
        return self._estimated_radius

In [None]:
class EyePrediction():
    def __init__(self, eye_sample: EyeSample, landmarks, gaze):
        self._eye_sample = eye_sample
        self._landmarks = landmarks
        self._gaze = gaze

    @property
    def eye_sample(self):
        return self._eye_sample

    @property
    def landmarks(self):
        return self._landmarks

    @property
    def gaze(self):
        return self._gaze

## Model

In [None]:
class HeatmapLoss(torch.nn.Module):
    def __init__(self):
        super(HeatmapLoss, self).__init__()

    def forward(self, pred, gt):
        loss = ((pred - gt)**2)
        loss = torch.mean(loss, dim=(1, 2, 3))
        return loss


class AngularError(torch.nn.Module):
    def __init__(self):
        super(AngularError, self).__init__()

    def forward(self, gaze_pred, gaze):
        loss = ((gaze_pred - gaze)**2)
        loss = torch.mean(loss, dim=(1, 2, 3))
        return loss

In [None]:
Pool = nn.MaxPool2d


def batchnorm(x):
    return nn.BatchNorm2d(x.size()[1])(x)


class Conv(nn.Module):
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
        super(Conv, self).__init__()
        self.inp_dim = inp_dim
        self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU()
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class Residual(nn.Module):
    def __init__(self, inp_dim, out_dim):
        super(Residual, self).__init__()
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(inp_dim)
        self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False)
        self.bn2 = nn.BatchNorm2d(int(out_dim/2))
        self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False)
        self.bn3 = nn.BatchNorm2d(int(out_dim/2))
        self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False)
        self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
        if inp_dim == out_dim:
            self.need_skip = False
        else:
            self.need_skip = True
        
    def forward(self, x):
        if self.need_skip:
            residual = self.skip_layer(x)
        else:
            residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        out += residual
        return out 


class Hourglass(nn.Module):
    def __init__(self, n, f, bn=None, increase=0):
        super(Hourglass, self).__init__()
        nf = f + increase
        self.up1 = Residual(f, f)
        # Lower branch
        self.pool1 = Pool(2, 2)
        self.low1 = Residual(f, nf)
        self.n = n
        # Recursive hourglass
        if self.n > 1:
            self.low2 = Hourglass(n-1, nf, bn=bn)
        else:
            self.low2 = Residual(nf, nf)
        self.low3 = Residual(nf, f)

    def forward(self, x):
        up1 = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2 = nn.functional.interpolate(low3, x.shape[2:], mode='bilinear')
        return up1 + up2

In [None]:
class Merge(nn.Module):
    def __init__(self, x_dim, y_dim):
        super(Merge, self).__init__()
        self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)

    def forward(self, x):
        return self.conv(x)


class EyeNet(nn.Module):
    def __init__(self, nstack, nfeatures, nlandmarks, bn=False, increase=0, **kwargs):
        super(EyeNet, self).__init__()

        self.img_w = 160
        self.img_h = 96
        self.nstack = nstack
        self.nfeatures = nfeatures
        self.nlandmarks = nlandmarks

        self.heatmap_w = self.img_w / 2
        self.heatmap_h = self.img_h / 2

        self.nstack = nstack
        self.pre = nn.Sequential(
            Conv(1, 64, 7, 1, bn=True, relu=True),
            Residual(64, 128),
            Pool(2, 2),
            Residual(128, 128),
            Residual(128, nfeatures)
        )

        self.pre2 = nn.Sequential(
            Conv(nfeatures, 64, 7, 2, bn=True, relu=True),
            Residual(64, 128),
            Pool(2, 2),
            Residual(128, 128),
            Residual(128, nfeatures)
        )

        self.hgs = nn.ModuleList([
            nn.Sequential(
                Hourglass(4, nfeatures, bn, increase),
            ) for i in range(nstack)])

        self.features = nn.ModuleList([
            nn.Sequential(
                Residual(nfeatures, nfeatures),
                Conv(nfeatures, nfeatures, 1, bn=True, relu=True)
            ) for i in range(nstack)])

        self.outs = nn.ModuleList([Conv(nfeatures, nlandmarks, 1, relu=False, bn=False) for i in range(nstack)])
        self.merge_features = nn.ModuleList([Merge(nfeatures, nfeatures) for i in range(nstack - 1)])
        self.merge_preds = nn.ModuleList([Merge(nlandmarks, nfeatures) for i in range(nstack - 1)])

        self.gaze_fc1 = nn.Linear(in_features=int(nfeatures * self.img_w * self.img_h / 64 + nlandmarks*2), out_features=256)
        self.gaze_fc2 = nn.Linear(in_features=256, out_features=2)

        self.nstack = nstack
        self.heatmapLoss = HeatmapLoss()
        self.landmarks_loss = nn.MSELoss()
        self.gaze_loss = nn.MSELoss()

    def forward(self, imgs):
        # imgs of size 1,ih,iw
        x = imgs.unsqueeze(1)
        x = self.pre(x)

        gaze_x = self.pre2(x)
        gaze_x = gaze_x.flatten(start_dim=1)

        combined_hm_preds = []
        for i in torch.arange(self.nstack):
            hg = self.hgs[i](x)
            feature = self.features[i](hg)
            preds = self.outs[i](feature)
            combined_hm_preds.append(preds)
            if i < self.nstack - 1:
                x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)

        heatmaps_out = torch.stack(combined_hm_preds, 1)

        landmarks_out = softargmax2d(preds)  # N x nlandmarks x 2

        # Gaze
        gaze = torch.cat((gaze_x, landmarks_out.flatten(start_dim=1)), dim=1)
        gaze = self.gaze_fc1(gaze)
        gaze = nn.functional.relu(gaze)
        gaze = self.gaze_fc2(gaze)

        print(gaze.shape)

        return heatmaps_out, landmarks_out, gaze

    def calc_loss(self, combined_hm_preds, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze):
        combined_loss = []
        for i in range(self.nstack):
            combined_loss.append(self.heatmapLoss(combined_hm_preds[:, i, :], heatmaps))

        heatmap_loss = torch.stack(combined_loss, dim=1)
        landmarks_loss = self.landmarks_loss(landmarks_pred, landmarks)
        gaze_loss = self.gaze_loss(gaze_pred, gaze)

        return torch.sum(heatmap_loss), landmarks_loss, 1000 * gaze_loss

## Training

In [None]:
# Set up pytorch
torch.backends.cudnn.enabled = False
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device', device)

Device cuda:0


In [None]:
def validate(eyenet: EyeNet, val_loader: DataLoader) -> float:
    with torch.no_grad():
        val_losses = []
        for val_batch in val_loader:
            val_imgs = val_batch['img'].float().to(device)
            heatmaps = val_batch['heatmaps'].to(device)
            landmarks = val_batch['landmarks'].to(device)
            gaze = val_batch['gaze'].float().to(device)
            heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(val_imgs)
            heatmaps_loss, landmarks_loss, gaze_loss = eyenet.calc_loss(
                heatmaps_pred, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze)
            loss = 1000 * heatmaps_loss + landmarks_loss + gaze_loss
            val_losses.append(loss.item())
        val_loss = np.mean(val_losses)
        val_losses_list.append(val_loss)
        return val_loss

In [None]:
def train_epoch(epoch: int,
                eyenet: EyeNet,
                optimizer,
                train_loader : DataLoader,
                val_loader: DataLoader,
                best_val_loss: float,
                checkpoint_fn: str,
                writer: SummaryWriter):

    N = len(train_loader)
    for i_batch, sample_batched in enumerate(train_loader):
        i_batch += N * epoch
        imgs = sample_batched['img'].float().to(device)
        heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(imgs)

        heatmaps = sample_batched['heatmaps'].to(device)
        landmarks = sample_batched['landmarks'].float().to(device)
        gaze = sample_batched['gaze'].float().to(device)

        heatmaps_loss, landmarks_loss, gaze_loss = eyenet.calc_loss(
            heatmaps_pred, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze)

        loss = 1000 * heatmaps_loss + landmarks_loss + gaze_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        hm = np.mean(heatmaps[-1, 8:16].cpu().detach().numpy(), axis=0)
        hm_pred = np.mean(heatmaps_pred[-1, -1, 8:16].cpu().detach().numpy(), axis=0)
        norm_hm = cv2.normalize(hm, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        norm_hm_pred = cv2.normalize(hm_pred, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

        if i_batch % 20 == 0:
            cv2.imwrite('true.jpg', norm_hm * 255)
            cv2.imwrite('pred.jpg', norm_hm_pred * 255)
            cv2.imwrite('eye.jpg', sample_batched['img'].numpy()[-1] * 255)

        writer.add_scalar("Training heatmaps loss", heatmaps_loss.item(), i_batch)
        writer.add_scalar("Training landmarks loss", landmarks_loss.item(), i_batch)
        writer.add_scalar("Training gaze loss", gaze_loss.item(), i_batch)
        writer.add_scalar("Training loss", loss.item(), i_batch)

        train_losses_list.append(loss.item())

        if i_batch > 0 and i_batch % 20 == 0:
            val_loss = validate(eyenet=eyenet, val_loader=val_loader)
            writer.add_scalar("validation loss", val_loss, i_batch)
            print('Epoch', epoch, 'Validation loss', val_loss)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'nstack': eyenet.nstack,
                    'nfeatures': eyenet.nfeatures,
                    'nlandmarks': eyenet.nlandmarks,
                    'best_val_loss': best_val_loss,
                    'model_state_dict': eyenet.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, checkpoint_fn)

    return best_val_loss

In [None]:
def train(eyenet: EyeNet, optimizer, nepochs: int, best_val_loss: float, checkpoint_fn: str):
    timestr = datetime.now().strftime("%m%d%Y-%H%M%S")
    writer = SummaryWriter(f'runs/eyenet-{timestr}')
    dataset = UnityEyesDataset('/content/drive/MyDrive/imgs/')
    N = len(dataset)
    print(N)
    VN = 160
    TN = N - VN
    train_set, val_set = torch.utils.data.random_split(dataset, (TN, VN))

    train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=True)

    for i in range(nepochs):
        best_val_loss = train_epoch(epoch=i,
                                    eyenet=eyenet,
                                    optimizer=optimizer,
                                    train_loader=train_loader,
                                    val_loader=val_loader,
                                    best_val_loss=best_val_loss,
                                    checkpoint_fn=checkpoint_fn,
                                    writer=writer)


In [None]:
def main():
    learning_rate = 4 * 1e-4
    nstack = 3
    nfeatures = 32
    nlandmarks = 34
    nepochs = 10
    best_val_loss = float('inf')
    eyenet = EyeNet(nstack=nstack, nfeatures=nfeatures, nlandmarks=nlandmarks).to(device)
    optimizer = torch.optim.Adam(eyenet.parameters(), lr=learning_rate)
    out = 'checkpoint.pt'

    train(
        eyenet=eyenet,
        optimizer=optimizer,
        nepochs=nepochs,
        best_val_loss=best_val_loss,
        checkpoint_fn=out
    )


In [None]:
train_losses_list = []
val_losses_list = []

In [None]:
main()

In [None]:
len(train_losses_list)

In [None]:
len(val_losses_list)

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(8, 8))

# axs.plot(np.arange(1, 21), train_losses_list, color='blue', label='Training Loss')
axs.plot(np.arange(1, 72), val_losses_list, color='red', label='Validation Loss')
axs.set_xlabel('# of epochs')
axs.set_ylabel('Loss')

plt.legend()
plt.savefig('loss_graph.png')