In [38]:
## Standard libraries
import os
import numpy as np
import pandas as pd
import random
import math
import json
from functools import partial, reduce
import argparse
import copy

from typing import TextIO, Callable, Collection, Dict, Iterator, List, Tuple, Type, TypeVar
T = TypeVar("T", bound="EyetrackingClassifier")

## PyTorch
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import Adam

# PyTorch Lightning
import pytorch_lightning as pl

# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "checkpoints/transformer/"
if not os.path.isdir(CHECKPOINT_PATH):
    os.makedirs(CHECKPOINT_PATH)

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Seed set to 42


Device: cuda:0


### Setup

In [2]:
# Global variables
NUM_FEATURES = 14
NUM_FIX = 30 
BATCH_SIZE = 64
NUM_CLASSES = 31
BATCH_SUBJECTS = False

use_pretrain_dataset = True

if use_pretrain_dataset:
    file = "data/30fixations_RSC_and_children.csv" # combined: children and adults, N = 407
else:
    file = 'data/30fixations_no_padding_sentence_word_pos.csv' # 293 participants in the fine-tunind dataset, 

In [3]:
# Helper functions

def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

def mse_loss(target, input, mask):
    out = (input[mask]-target[mask])**2
    return out.mean()

def mask_with_tokens_3D(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    reduced = torch.any(mask, dim=-1, keepdim=True)
    expanded = reduced.expand_as(mask)
    return expanded

def get_mask_subset_with_prob_3D(mask, prob):
    batch, num_fix, num_features, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * num_fix)

    num_tokens = mask.sum(dim=-2, keepdim=True)
    mask_excess = (mask.cumsum(dim=-2)[:,:,0] > (num_tokens[:,:,0] * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, num_fix, num_features), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-2)
    sampled_indices = (sampled_indices[:,:,0] + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, num_fix + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    new_mask = new_mask[:, 1:].bool()
    
    return new_mask.unsqueeze_(2).expand(-1,-1, num_features)
    

def prob_mask_like_3D(t, prob):
    temp = torch.zeros_like(t[:,:,0]).float().uniform_(0, 1) < prob
    return temp.unsqueeze_(2).expand(-1,-1, NUM_FEATURES)
    
    
def pad_group_with_zeros(group, target_rows):
    # Calculate the number of rows to add
    num_missing_rows = target_rows - len(group)
    if num_missing_rows > 0:
        # Create a DataFrame with the required number of padding rows
        # input padding
        zero_rows = pd.DataFrame(0.3333, index=range(num_missing_rows), columns=group.columns)
        # Label padding
        # zero_rows.iloc[:, 0] = 31
        # Concatenate the group with the zero rows
        group = pd.concat([group, zero_rows], ignore_index=True)
    return group

class ToTensor(object):
    """Convert Series in sample to Tensors."""

    def __call__(self, sample):
        trial, label = sample['trial'], sample['label']
        trial = torch.from_numpy(trial).float()
        label = torch.from_numpy(label).float()
        return trial, label

In [4]:
class AnnasPositionalEncoding(nn.Module):
    """Learnable positional encoding for both features and fixations
    Assumes input `x` is of shape [batch_size, fixations, embed_dim]"""
    def __init__(self, fixations = int, features = int):
        super(AnnasPositionalEncoding, self).__init__()
        
        # Initialize a learnable positional encoding matrix for fixations
        self.fix_encoding = nn.Parameter(torch.zeros(fixations, 1)).to(device)
        nn.init.xavier_uniform_(self.fix_encoding)  # Xavier initialization for better training stability
        self.fix_encoding = self.fix_encoding.expand(-1, features)
        
        # Initialize a learnable positional encoding matrix for features
        self.feat_encoding = nn.Parameter(torch.zeros(1, features)).to(device)
        nn.init.xavier_uniform_(self.feat_encoding)  # Xavier initialization for better training stability
        self.feat_encoding = self.feat_encoding.expand(fixations, -1)
        
        self.encoding = self.fix_encoding + self.feat_encoding
        
    def forward(self, x, mask = None):
        if mask is not None:
            # Apply the mask to ignore padded positions
            pos_encoding = self.encoding  * mask
        else:
            pos_encoding = self.encoding
        return x + pos_encoding

### Data preparation

In [5]:
#LinguisticFeature = Callable[[Tuple[str]], Tuple[torch.Tensor]]

# def apply_standardization(x, m, sd):
#     nonzero_sd = sd.clone()
#     nonzero_sd[sd == 0] = 1
#     x = torch.from_numpy(x).float()
#     res = (x - m.unsqueeze(0)) / nonzero_sd.unsqueeze(0)
#     return res

def apply_standardization(x, m, sd):
    nonzero_sd = sd.clone()
    nonzero_sd[sd == 0] = 1
    x = torch.from_numpy(x).float()
    x_zeros = x[x.sum(dim=(1)) == 0]
    x_zeros[x_zeros==0] = -5
    x_non_zeros = x[x.sum(dim=(1)) != 0]
    x_non_zeros = (x_non_zeros - m.unsqueeze(0)) / nonzero_sd.unsqueeze(0)
    res = torch.cat((x_non_zeros, x_zeros), axis =0)
    return res


def aggregate_per_subject(subjs, y_preds, y_preds_class, y_trues):
    y_preds = np.array(y_preds)
    y_preds_class = np.array(y_preds_class)
    y_trues = np.array(y_trues)
    subjs = np.array(subjs).flatten()
    y_preds_subj = []
    y_preds_class_subj = []
    y_trues_subj = []
    subjs_subj = np.unique(subjs)
    for subj in subjs_subj:
        subj = subj.item()
        y_pred_class_subj = y_preds_class[subjs == subj]
        y_pred_subj = y_preds[subjs == subj]
        y_true_subj = y_trues[subjs == subj]
        assert len(np.unique(y_true_subj)) == 1, f"No unique label: subj={subj}"
        y_trues_subj.append(np.unique(y_true_subj).item())
        y_preds_subj.append(np.mean(y_pred_subj).item())
        if sum(y_pred_class_subj) >= (len(y_pred_class_subj) / 2):
            y_preds_class_subj.append(1)
        else:
            y_preds_class_subj.append(0)
    return subjs_subj, y_preds_subj, y_preds_class_subj, y_trues_subj

def getmeansd(dataset, batch: bool = False):  # removing rows of 0s
    if batch:
        # Anna added preprocessing from ndarray to torch
        tensors = [X for X, _, _, _ in dataset]  #torch.from_numpy(X).float()
        tensors = torch.cat(tensors, axis=0)
        # remove padded tensors
        tensors = tensors[tensors.sum(dim=(1,2)) != 0]   #tensors[tensors.sum(dim=(1, 2)) != 0]
        # remove rows of 0s from the computation
        sentences, timesteps, features = tensors.size()
        subset = tensors.sum(dim=(2)) != 0
        subset = subset.view(sentences, timesteps, 1)
        subset = subset.expand(sentences, timesteps, features)
        result = tensors[subset].view(-1, features) 
        
        means = torch.mean(result, dim=(0))
        sd = torch.std(result, dim=(0))
        return means, sd
    else:
        tensors = [torch.from_numpy(X).float() for X, _, _, _ in dataset] # Anna added , was [X for X, _, _ in dataset]
        tensors = torch.cat(tensors, axis=0)
        # remove padded tensors
        tensors = tensors[tensors.sum(dim=1) != 0]
        means = torch.mean(tensors, 0)
        sd = torch.std(tensors, 0)
        return means, sd
    
    
def get_params(paramdict) -> dict:
    selected_pars = dict()
    for k in paramdict:
        selected_pars[k] = random.sample(list(paramdict[k]), 1)[0]
    return selected_pars

In [6]:
def aggregate_speed_per_subject(subjs, y_preds, y_trues):
    y_preds = np.array(y_preds)
    #y_preds_class = np.array(y_preds_class)
    y_trues = np.array(y_trues)
    subjs = np.array(subjs).flatten()
    y_preds_subj = []
    y_trues_subj = []
    subjs_subj = np.unique(subjs)
    for subj in subjs_subj:
        subj = subj.item()
        y_pred_subj = y_preds[subjs == subj]
        y_true_subj = y_trues[subjs == subj]
        assert len(np.unique(y_true_subj)) == 1, f"No unique label: subj={subj}"
        y_trues_subj.append(np.unique(y_true_subj).item())
        y_preds_subj.append(np.mean(y_pred_subj).item())

    return subjs_subj, y_preds_subj, y_trues_subj

In [7]:
class EyetrackingDataPreprocessor(Dataset):
    """Dataset with the long-format sequence of fixations made during reading by dyslexic 
    and normally-developing Russian-speaking monolingual children."""

    def __init__(
        self, 
        csv_file, 
        transform=None, 
        target_transform=None,  
        num_folds: float = 10,
        ):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            target_transform (callable, optional): Optional transform to be applied
                on a label.
        """
        data = pd.read_csv(csv_file)
        
        # changing dyslexia labels to 0 and 1
        if {'group'}.issubset(data.columns):   # not the case for pretrain dataset
            data['group'] = data['group'] + 0.5
        
        # log-transforming frequency
        to_transform = ['frequency', 'predictability', 'fix_dur'] #
        for column in to_transform:
            data[column] = data[column].apply(lambda x: np.log(x) if x > 0 else 0) 
        
        # drop columns we don't use
        data = data.drop(columns = ['fix_x', 'fix_y', 'fix_index'])  
        
        # center reading sopeed in case we need to predict it
        if {'Reading_speed'}.issubset(data.columns):
            data['Reading_speed'] = (data['Reading_speed'] - data['Reading_speed'].mean())/data['Reading_speed'].std(ddof=0)
        
        if {'sex', 'Grade'}.issubset(data.columns):
            data = data.drop(columns = ['sex', 'Grade'])
            
        convert_columns = ['direction']
        
        if {'IQ', 'Sound_detection', 'Sound_change'}.issubset(data.columns):
            data = data.drop(columns = ['IQ', 'Sound_detection', 'Sound_change'])
        
        for column in convert_columns:
            prefix = column + '_dummy'
            data = pd.concat([data, pd.get_dummies(data[column], 
                                    prefix=prefix)], axis=1)
            data = data.drop(columns = column)

        data.dropna(axis = 0, how = 'any', inplace = True)

            
        # rearrange columns (I need demogrpahic information to come last)
#         cols = ['item', 'subj', 'group', 'Reading_speed', 'fix_dur', 'landing', 'word_length',
#                  'predictability', 'frequency', 'number.morphemes', 'next_fix_dist',
#                  'sac_ampl', 'sac_angle', 'sac_vel', 'rel.position', 'direction_dummy_DOWN',
#                  'direction_dummy_LEFT', 'direction_dummy_RIGHT', 'direction_dummy_UP',
#                  'sex', 'Age', 'Grade_dummy_1', 'Grade_dummy_2', 'Grade_dummy_3', 'Grade_dummy_4',
#                  'Grade_dummy_5', 'Grade_dummy_6']
        if {'Reading_speed'}.issubset(data.columns):
            cols = ['item', 'subj', 'group', 'Reading_speed', 'fix_dur',
                   'landing', 'word_length', 'predictability', 'frequency', 
                    'number.morphemes', 'next_fix_dist', 'sac_ampl', 'sac_angle', 
                    'sac_vel', 'rel.position', 'direction_dummy_LEFT', 
                    'direction_dummy_RIGHT', 'direction_dummy_DOWN'] # temporary
        else:
            cols = ['item', 'subj', 'fix_dur',
                   'landing', 'word_length', 'predictability', 'frequency', 
                    'number.morphemes', 'next_fix_dist', 'sac_ampl', 'sac_angle', 
                    'sac_vel', 'rel.position', 'direction_dummy_LEFT', 
                    'direction_dummy_RIGHT', 'direction_dummy_DOWN'] # temporary
        data = data[cols]
        
        # Record features that are used for prediction
        if {'Reading_speed'}.issubset(data.columns):
            self._features = [i for i in data.columns if i not in ['group', 'item', 'subj', 'Reading_speed']]
        else:
            self._features = [i for i in data.columns if i not in ['item', 'subj']]
        self._data = pd.DataFrame()
        # Add sentence IDs and subject IDs
        self._data["sn"] = data["item"]
        self._data["subj"] = data["subj"]
        # Add labels
        if {'Reading_speed'}.issubset(data.columns):
            self._data["group"] = data["group"]
            self._data["reading_speed"] = data["Reading_speed"]
        else:
            self._data["group"] = -1
            self._data["reading_speed"] = -1
        
        # Add features used for prediction
        for feature in self._features:
            self._data[feature] = data[feature]

#         # Distribute subjects across stratified folds
        self._num_folds = num_folds
        self._folds = [[] for _ in range(num_folds)]
        just_subjects = self._data["subj"].unique()
        random.shuffle(just_subjects)
        for i, subj in enumerate(just_subjects):
            self._folds[i % num_folds].append(subj)
#         dyslexic_subjects = self._data[self._data["group"] == 1]["subj"].unique()
#         control_subjects = self._data[self._data["group"] == 0]["subj"].unique()
#         random.shuffle(dyslexic_subjects)
#         random.shuffle(control_subjects)
#         for i, subj in enumerate(dyslexic_subjects):
#             self._folds[i % num_folds].append(subj)
#         for i, subj in enumerate(control_subjects):
#             self._folds[num_folds - 1 - i % num_folds].append(subj)
        for fold in self._folds:
            random.shuffle(fold)

    def _iter_trials(self, folds: Collection[int]) -> Iterator[pd.DataFrame]:
        # Iterate over all folds
        for fold in folds:
            # Iterate over all subjects in the fold
            for subj in self._folds[fold]:       # Anna: subj in fold?
                subj_data = self._data[self._data["subj"] == subj]
                # Iterate over all sentences this subject read
                for sn in subj_data["sn"].unique():
                    trial_data = subj_data[subj_data["sn"] == sn]
                    yield trial_data
                    
                    
    def iter_folds(
        self, folds: Collection[int]) -> Iterator[Tuple[torch.Tensor, torch.Tensor, int]]:
        for trial_data in self._iter_trials(folds):
            predictors = trial_data[self._features].to_numpy()
            #predictors = np.reshape(predictors, (int(len(predictors)/278), 278, predictors.shape[1]))
            label = trial_data["group"].unique().item()
            subj = trial_data["subj"].unique().item()
            reading_speed = trial_data["reading_speed"].unique().item()
            #  X = (time_steps, features)
            X = predictors
            y = torch.tensor(label, dtype=torch.float)
            rs = torch.tensor(reading_speed , dtype=torch.float)
            yield X, y, subj, rs
                    

    @property
    def num_features(self) -> int:
        """Number of features per word (excluding word vector dimensions)."""
        return len(self._features)
    

    @property
    def max_number_of_sentences(self):
        data_copy = self._data.copy()
        max_s_count = data_copy.groupby(by="subj").sn.unique()
        return max([len(x) for x in max_s_count])

In [8]:
class EyetrackingDataset(Dataset):
    def __init__(
        self,
        preprocessor: EyetrackingDataPreprocessor,
       # word_vector_model: WordVectorModel,
        folds: Collection[int],
        batch_subjects: bool = False,
    ):
        self.sentences = list(preprocessor.iter_folds(folds))
        self._subjects = list(np.unique([subj for _, _, subj, _ in self.sentences]))
        self.num_features = preprocessor.num_features# + word_vector_model.dimensions()
        self.batch_subjects = batch_subjects
        #self.max_sentence_length = preprocessor.max_sentence_length
        self.max_number_of_sentences = preprocessor.max_number_of_sentences

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        if self.batch_subjects:
            subject = self._subjects[index]
            subject_sentences = [
                (X, y, subj, rs) for X, y, subj, rs in self.sentences if subj == subject
            ]
            X = torch.stack([torch.FloatTensor(X) for X, _, _, _ in subject_sentences]) #[X for X, _, _ in subject_sentences] #torch.FloatTensor([X for X, _, _ in subject_sentences])
            y = torch.stack([y for _, y, _, _ in subject_sentences]).unique().squeeze() 
            rs = torch.stack([rs for _, _, _, rs in subject_sentences]).unique().squeeze()
            return X, y, subject, rs

        else:
            X, y, subj, rs = self.sentences[index]
            #X = torch.from_numpy(X).float()   
            return X, y, subj, rs

    def __len__(self) -> int:
        if self.batch_subjects:
            return len(self._subjects)
        else:
            return len(self.sentences)

    def standardize(self, mean: torch.Tensor, sd: torch.Tensor):
        self.sentences = [
            (apply_standardization(X, mean, sd), y, subj, rs)
            for X, y, subj, rs in self.sentences
        ]

In [9]:
class EncoderLayer(nn.Module):  
    def __init__(self,
                dim_upscale = int,
                inner_dim_upscale = int,
                num_heads = int, 
                num_layers = int, 
                dropout = 0
                ):

        super().__init__()
        self.num_heads = num_heads
        self.dropout = dropout
        self.dim_upscale = dim_upscale
        self.inner_dim_upscale = inner_dim_upscale
        
        # layer norm for multi-head attention
        self.attn_layer_norm = nn.LayerNorm(self.dim_upscale)
        # layer norm for feedforward network
        self.ffn_layer_norm = nn.LayerNorm(self.dim_upscale)
        
        self.attention = nn.MultiheadAttention(embed_dim = self.dim_upscale,  
                                                 num_heads = self.num_heads, 
                                                 bias = True,
                                                 batch_first = True)
        # feed forward
        self.ff = nn.Sequential(
            nn.Linear(self.dim_upscale, self.inner_dim_upscale, bias = True),
            nn.LayerNorm(self.inner_dim_upscale),
            nn.ReLU(inplace=True),
            nn.Dropout(self.dropout),
            nn.Linear(self.inner_dim_upscale, self.dim_upscale, bias = True)
        )
        

    def forward(self, src: torch.Tensor, src_mask: torch.Tensor):
        # pass embeddings through multi-head attention
        x, attn_probs = self.attention(src, src, src, src_mask)

        # residual add and norm
        first_out = self.attn_layer_norm(x + src)

        # position-wise feed-forward network
        x2 = self.ff(first_out)

        # residual add and norm
        second_out = self.ffn_layer_norm(x2 + first_out)  # first_out + x2

        return second_out, attn_probs


In [10]:
class Encoder(nn.Module):
    def __init__(self, 
                dim_upscale = int,
                inner_dim_upscale = int,
                num_heads = int, 
                num_layers = int, 
                dropout = 0):

        super().__init__()
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.dropout = dropout
        self.dim_upscale = dim_upscale
        self.inner_dim_upscale = inner_dim_upscale
        
        # create n_layers encoders 
        self.layers = nn.ModuleList([EncoderLayer(
                                    dim_upscale = self.dim_upscale,
                                    num_heads = self.num_heads, 
                                    inner_dim_upscale = self.inner_dim_upscale,
                                    dropout = self.dropout)
                                     for layer in range(self.num_layers)])


    def forward(self, src: torch.Tensor, src_mask: torch.Tensor):

        # pass the sequences through each encoder
        for layer in self.layers:
            src, attn_probs = layer(src, src_mask)

        self.attn_probs = attn_probs

        return src

In [11]:
# Annas changes
class TransformerWithCustomPositionalEncoding(nn.Module):
    def __init__(
        self, 
        embed_dim = 30,
        d_model = 14,
        dim_upscale = 128,
        inner_dim_upscale = 4*128,
        num_heads = 1, 
        num_layers = 1, 
        dropout = 0,
        mask_prob = 0.2,
        replace_prob = 0, # 0.9
        mask_token_id = 2,
        pad_token_id = -5,
        mask_ignore_token_ids = []
        ):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.d_model = d_model
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.dropout = dropout
        self.dim_upscale = dim_upscale
        self.inner_dim_upscale = inner_dim_upscale

        # token ids
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, self.pad_token_id])
        
        self.positional_encoding = AnnasPositionalEncoding(fixations = self.embed_dim, 
                                                           features = self.d_model)
        
        self.upscale = nn.Linear(self.d_model, self.dim_upscale, bias = True)
        self.downscale = nn.Linear(self.dim_upscale, self.d_model, bias = True)
        
        self.encoder = Encoder(dim_upscale = self.dim_upscale, 
                               num_heads = self.num_heads, 
                               num_layers = self.num_layers,
                               inner_dim_upscale = self.inner_dim_upscale, 
                               dropout = self.dropout)
        
        
        
    # TransformerEncoder
    def forward(self, seq, identity = False):
        # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
        # also do not include these special tokens in the tokens chosen at random
        no_mask = mask_with_tokens_3D(seq, self.mask_ignore_token_ids) 
        mask = get_mask_subset_with_prob_3D(~no_mask, self.mask_prob)
        hidden = no_mask + mask # all elements that the model will not attend to

        # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
        masked_seq = seq.clone().detach().to(device)
        
        #   Static positional encoding
        masked_seq_pos = self.positional_encoding(masked_seq, mask = None) #for hidden fixations ~ no_mask

        #[mask] input = This does not change the input if replace_prob = 0
        # masked_replace_prob = prob_mask_like_3D(seq, self.replace_prob) # Anna: select 90% of all values  (ignore all masking for now)
        # masked_seq = masked_seq_pos.masked_fill(mask * masked_replace_prob, self.mask_token_id) # Anna: select 90% only of those selected for masking
        
        # derive labels to predict
        labels = seq.masked_fill(~mask, self.pad_token_id)
        
        #import pdb; 
        #pdb.set_trace()
        
        if identity:
            attn_mask = no_mask[:,:,0]
        else:
            attn_mask = hidden[:,:,0]
        
        # Upscaling
        masked_seq_upscaled = self.upscale(masked_seq_pos)
       # labels_upscaled = self.upscale(seq)
       # mask_upscaled = mask[:,:,0].unsqueeze(2).expand(64, 30, 128)
       # labels_upscaled = labels_upscaled.masked_fill(~mask_upscaled, self.pad_token_id)
        
        # Encoder
        out = self.encoder(masked_seq_upscaled, 
                           attn_mask)  
        out = self.downscale(out)

        return out, labels, mask

In [12]:
NUM_FOLDS = 10
test_fold = 2
dev_fold = 3
train_folds = [
                fold
                for fold in range(NUM_FOLDS)
                if fold != test_fold and fold != dev_fold
            ]

preprocessor = EyetrackingDataPreprocessor(
    csv_file = file, 
   num_folds = NUM_FOLDS
)

train_dataset = EyetrackingDataset(
                preprocessor,
                folds=train_folds,
                batch_subjects=BATCH_SUBJECTS,
            )
mean, sd = getmeansd(train_dataset, batch=BATCH_SUBJECTS)
train_dataset.standardize(mean, sd)

dev_dataset = EyetrackingDataset(
    preprocessor,
    folds=[dev_fold],
    batch_subjects=BATCH_SUBJECTS,
)
dev_dataset.standardize(mean, sd)

test_dataset = EyetrackingDataset(
    preprocessor,
    folds=[test_fold],
    batch_subjects=BATCH_SUBJECTS,
)
test_dataset.standardize(mean, sd)

In [39]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE)

train_loader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size=BATCH_SIZE,
                                               shuffle = True)

trainer = TransformerWithCustomPositionalEncoding(embed_dim = 30,
                                                  num_heads = 64, 
                                                  d_model = NUM_FEATURES,
                                                  num_layers = 1,
                                                  dropout = 0.1,
                                                  dim_upscale = 128).to(device)

# # Setup optimizer to optimize model's parameters
optimizer = Adam(trainer.parameters(), lr=1e-3)
patience = 10

epochs = 201
min_epochs = 15
best_losses = [float("inf")] * patience

epoch_count = 0
for epoch in range(epochs):
    epoch_count += 1
    epoch_loss = 0.0
    train_loss = 0
    trainer.train()
    ### Training
    for X, _, _, _ in train_loader:

        # 1. Forward pass 
        X = X.to(device)
        train_preds, labels, mask = trainer(X, identity = False)
            
        # 2. Calculate loss/accuracy
        loss = mse_loss(
            labels,
            train_preds,
            mask
        )

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backwards
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        epoch_loss += loss.item()

    train_loss = epoch_loss/math.ceil(len(train_loader.dataset)/BATCH_SIZE)
    print(f"Epoch {epoch} done. Loss: {train_loss}")

    if dev_dataset is not None:
        trainer.eval()
        dev_loss = 0
        dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE)
        for X_dev, _, _, _ in dev_loader:
            X_dev = X_dev.to(device)
            dev_preds, labels, mask = trainer(X_dev, identity = False)
            dloss = mse_loss(
                labels,
                dev_preds,
                mask
            )
            dev_loss += dloss.item() 
        eval_loss = dev_loss /math.ceil(len(dev_loader.dataset)/BATCH_SIZE)   
        trainer.train()
        if epoch > min_epochs and all(eval_loss > i for i in best_losses):
            epoch_count -= patience - best_losses.index(min(best_losses))
            break
        else:
            best_losses.pop(0)
            best_losses.append(eval_loss)

#torch.save(trainer.state_dict(), CHECKPOINT_PATH)

torch.save({'epoch': epoch_count,
            'model_state_dict': trainer.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss,
            'dev_loss': min(best_losses)
            }, CHECKPOINT_PATH + "model.pth")

Epoch 0 done. Loss: 0.0414509875207982
Epoch 1 done. Loss: 0.005692043323660086
Epoch 2 done. Loss: 0.005565003407086102
Epoch 3 done. Loss: 0.0026908729436092445
Epoch 4 done. Loss: 0.0021667877080174766
Epoch 5 done. Loss: 0.0018108091476983325
Epoch 6 done. Loss: 0.0012899188592892246
Epoch 7 done. Loss: 0.00105100241335991
Epoch 8 done. Loss: 0.0027093871698090086
Epoch 9 done. Loss: 0.0007505658102954712
Epoch 10 done. Loss: 0.0006718648890480989
Epoch 11 done. Loss: 0.0019070828856853604
Epoch 12 done. Loss: 0.0016997671171364499
Epoch 13 done. Loss: 0.00048469220651690785
Epoch 14 done. Loss: 0.001706456957474305
Epoch 15 done. Loss: 0.0004046807703225952
Epoch 16 done. Loss: 0.0014526391673761808
Epoch 17 done. Loss: 0.0003322193759129117
Epoch 18 done. Loss: 0.00039245590450661094
Epoch 19 done. Loss: 0.00023888126498335829
Epoch 20 done. Loss: 0.00028911324174464743
Epoch 21 done. Loss: 0.000323692189058548
Epoch 22 done. Loss: 0.0002517591784312636
Epoch 23 done. Loss: 0.000

Save the whole model

In [48]:
torch.save(trainer, CHECKPOINT_PATH+"trainer.pth")

### Loading model weights

In [41]:
model = TransformerWithCustomPositionalEncoding(embed_dim = 30,
                                                  num_heads = 64, 
                                                  d_model = NUM_FEATURES,
                                                  num_layers = 1,
                                                  dropout = 0.1,
                                                  dim_upscale = 128).to(device)
optimizer = Adam(trainer.parameters(), lr=1e-3)

checkpoint = torch.load(CHECKPOINT_PATH+"model.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()

TransformerWithCustomPositionalEncoding(
  (positional_encoding): AnnasPositionalEncoding()
  (upscale): Linear(in_features=14, out_features=128, bias=True)
  (downscale): Linear(in_features=128, out_features=14, bias=True)
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (attn_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (ffn_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (ff): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): ReLU(inplace=True)
          (3): Dropout(p=0.1, inplace=False)
          (4): Linear(in_features=512, out_features=128, bias=True)
        )
      )
    )
  )
)

#### Loss is way too high

In [44]:
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE)
dev_loss = 0

for X_dev, _, _, _ in dev_loader:
    X_dev = X_dev.to(device)
    dev_preds, labels, mask = model(X_dev, identity = False)
    dloss = mse_loss(
        labels,
        dev_preds,
        mask
    )
    dev_loss += dloss.item() 
    
eval_loss = dev_loss /math.ceil(len(dev_loader.dataset)/BATCH_SIZE)   
eval_loss

0.3453269522441061

In [45]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)
test_loss = 0

for X_test, _, _, _ in test_loader:
    X_test = X_test.to(device)
    test_preds, labels, mask = model(X_test, identity = False)
    tloss = mse_loss(
        labels,
        test_preds,
        mask
    )
    test_loss += tloss.item() 

test_loss = test_loss /math.ceil(len(test_loader.dataset)/BATCH_SIZE)
test_loss

0.3445869858066241

#### If I use the trained model (not load the weights)

In [46]:
test_loss = 0
for X_test, _, _, _ in test_loader:
    X_test = X_test.to(device)
    test_preds, labels, mask = trainer(X_test, identity = False)
    tloss = mse_loss(
        labels,
        test_preds,
        mask
    )
    test_loss += tloss.item() 

test_loss = test_loss /math.ceil(len(test_loader.dataset)/BATCH_SIZE)
test_loss

0.0007113359695520356

### Alternative (loading the whole model, not recommended)

In [50]:
model = torch.load(CHECKPOINT_PATH+"trainer.pth")

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)
test_loss = 0

for X_test, _, _, _ in test_loader:
    X_test = X_test.to(device)
    test_preds, labels, mask = model(X_test, identity = False)
    tloss = mse_loss(
        labels,
        test_preds,
        mask
    )
    test_loss += tloss.item() 

test_loss = test_loss /math.ceil(len(test_loader.dataset)/BATCH_SIZE)
test_loss

0.0007811346691192335