In [1]:
## Standard libraries
import os
import numpy as np
import pandas as pd
import random
import math
import json
from functools import partial

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
#from IPython.display import set_matplotlib_formats
#matplotlib_inline.backend_inline.set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

# ## tqdm for loading bars
# from tqdm.notebook import tqdm

# for one-hot encoding
from keras.utils.np_utils import to_categorical   

from typing import TextIO, Callable, Collection, Dict, Iterator, List, Tuple, Type, TypeVar

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import default_collate
import torch.optim as optim
from torch.optim import Adam


# Transformer wrapper
import tensorflow as tf
from x_transformers import TransformerWrapper, Encoder
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# Positional encoding in two dimensions
from positional_encodings.torch_encodings import PositionalEncodingPermute1D, PositionalEncoding1D

from functools import reduce
import math


# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "saved_models/simple_transformer"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)


2025-02-05 13:41:37.526024: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-02-05 13:41:37.591949: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-05 13:41:37.610718: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Seed set to 42


Device: cuda:0


### Setup

In [2]:
# Global variables
NUM_FEATURES = 14
NUM_FIX = 30 
BATCH_SIZE = 64
NUM_CLASSES = 31
BATCH_SUBJECTS = False

num_fix = NUM_FIX
num_features = NUM_FEATURES
ablation = False

In [3]:
# Helper functions

def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

def mse_loss(target, input, mask):
    out = (input[mask]-target[mask])**2
    return out.mean()

def mask_with_tokens_3D(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    reduced = torch.any(mask, dim=-1, keepdim=True)
    expanded = reduced.expand_as(mask)
    return expanded

def get_mask_subset_with_prob_3D(mask, prob):
    batch, num_fix, num_features, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * num_fix)

    num_tokens = mask.sum(dim=-2, keepdim=True)
    mask_excess = (mask.cumsum(dim=-2)[:,:,0] > (num_tokens[:,:,0] * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, num_fix, num_features), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-2)
    sampled_indices = (sampled_indices[:,:,0] + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, num_fix + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    new_mask = new_mask[:, 1:].bool()
    
    return new_mask.unsqueeze_(2).expand(-1,-1, num_features)
    

def prob_mask_like_3D(t, prob):
    temp = torch.zeros_like(t[:,:,0]).float().uniform_(0, 1) < prob
    return temp.unsqueeze_(2).expand(-1,-1, num_features)
    
    
def pad_group_with_zeros(group, target_rows):
    # Calculate the number of rows to add
    num_missing_rows = target_rows - len(group)
    if num_missing_rows > 0:
        # Create a DataFrame with the required number of padding rows
        # input padding
        zero_rows = pd.DataFrame(0.3333, index=range(num_missing_rows), columns=group.columns)
        # Label padding
        # zero_rows.iloc[:, 0] = 31
        # Concatenate the group with the zero rows
        group = pd.concat([group, zero_rows], ignore_index=True)
    return group

class ToTensor(object):
    """Convert Series in sample to Tensors."""

    def __call__(self, sample):
        trial, label = sample['trial'], sample['label']
        trial = torch.from_numpy(trial).float()
        label = torch.from_numpy(label).float()
        return trial, label

In [4]:
class CustomPositionalEncoding(nn.Module):
    """Learnable positional encoding for both features and fixations
    Assumes input `x` is of shape [batch_size, fixations, embed_dim]"""
    def __init__(self, fixations = 30, features = None):
        super(CustomPositionalEncoding, self).__init__()
        
        # Initialize a learnable positional encoding matrix
        self.encoding = nn.Parameter(torch.zeros(fixations, features)).to(device)
        nn.init.xavier_uniform_(self.encoding)  # Xavier initialization for better training stability
        
    def forward(self, x, mask = None):
        if mask is not None:
            # Apply the mask to ignore padded positions
            pos_encoding = self.encoding  * mask
        else:
            pos_encoding = self.encoding
        return x + pos_encoding

In [5]:
class AbsolutePositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=278, l2norm_embed = False):
        super().__init__()
        self.scale = dim ** -0.5 if not l2norm_embed else 1.
        self.max_seq_len = max_seq_len
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(max_seq_len, dim)

    def forward(self, x, pos = None, seq_start_pos = None, mask = None):
        seq_len, device = x.shape[1], x.device
        assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'

        #if not exists(pos):
        pos = torch.arange(seq_len, device = device)

        #if exists(seq_start_pos):
        #    pos = (pos - seq_start_pos[..., None]).clamp(min = 0)

        pos_emb = self.emb(pos)
        pos_emb = pos_emb * self.scale
        
        if mask is not None:
            # Apply the mask to ignore padded positions
            pos_emb = pos_emb * mask
            
        return l2norm(pos_emb) if self.l2norm_embed else pos_emb


In [6]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=278):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model%2 != 0:
            pe[:, 1::2] = torch.cos(position * div_term)[:,0:-1]
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).to(device)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x, mask = None):
        pos_enc = self.pe[:, :x.size(1)]
        if not(mask == None):
            mask = mask.to(device)
            pos_enc = pos_enc*mask
        return x + pos_enc

### Data preparation

In [7]:
#LinguisticFeature = Callable[[Tuple[str]], Tuple[torch.Tensor]]

# def apply_standardization(x, m, sd):
#     nonzero_sd = sd.clone()
#     nonzero_sd[sd == 0] = 1
#     x = torch.from_numpy(x).float()
#     res = (x - m.unsqueeze(0)) / nonzero_sd.unsqueeze(0)
#     return res

def apply_standardization(x, m, sd):
    nonzero_sd = sd.clone()
    nonzero_sd[sd == 0] = 1
    x = torch.from_numpy(x).float()
    x_zeros = x[x.sum(dim=(1)) == 0]
    x_zeros[x_zeros==0] = -5
    x_non_zeros = x[x.sum(dim=(1)) != 0]
    x_non_zeros = (x_non_zeros - m.unsqueeze(0)) / nonzero_sd.unsqueeze(0)
    res = torch.cat((x_non_zeros, x_zeros), axis =0)
    return res


def aggregate_per_subject(subjs, y_preds, y_preds_class, y_trues):
    y_preds = np.array(y_preds)
    y_preds_class = np.array(y_preds_class)
    y_trues = np.array(y_trues)
    subjs = np.array(subjs).flatten()
    y_preds_subj = []
    y_preds_class_subj = []
    y_trues_subj = []
    subjs_subj = np.unique(subjs)
    for subj in subjs_subj:
        subj = subj.item()
        y_pred_class_subj = y_preds_class[subjs == subj]
        y_pred_subj = y_preds[subjs == subj]
        y_true_subj = y_trues[subjs == subj]
        assert len(np.unique(y_true_subj)) == 1, f"No unique label: subj={subj}"
        y_trues_subj.append(np.unique(y_true_subj).item())
        y_preds_subj.append(np.mean(y_pred_subj).item())
        if sum(y_pred_class_subj) >= (len(y_pred_class_subj) / 2):
            y_preds_class_subj.append(1)
        else:
            y_preds_class_subj.append(0)
    return subjs_subj, y_preds_subj, y_preds_class_subj, y_trues_subj

def getmeansd(dataset, batch: bool = False):  # removing rows of 0s
    if batch:
        # Anna added preprocessing from ndarray to torch
        tensors = [X for X, _, _, _ in dataset]  #torch.from_numpy(X).float()
        tensors = torch.cat(tensors, axis=0)
        # remove padded tensors
        tensors = tensors[tensors.sum(dim=(1,2)) != 0]   #tensors[tensors.sum(dim=(1, 2)) != 0]
        # remove rows of 0s from the computation
        sentences, timesteps, features = tensors.size()
        subset = tensors.sum(dim=(2)) != 0
        subset = subset.view(sentences, timesteps, 1)
        subset = subset.expand(sentences, timesteps, features)
        result = tensors[subset].view(-1, features) 
        
        means = torch.mean(result, dim=(0))
        sd = torch.std(result, dim=(0))
        return means, sd
    else:
        tensors = [torch.from_numpy(X).float() for X, _, _, _ in dataset] # Anna added , was [X for X, _, _ in dataset]
        tensors = torch.cat(tensors, axis=0)
        # remove padded tensors
        tensors = tensors[tensors.sum(dim=1) != 0]
        means = torch.mean(tensors, 0)
        sd = torch.std(tensors, 0)
        return means, sd
    
    
def get_params(paramdict) -> dict:
    selected_pars = dict()
    for k in paramdict:
        selected_pars[k] = random.sample(list(paramdict[k]), 1)[0]
    return selected_pars

In [8]:
def aggregate_speed_per_subject(subjs, y_preds, y_trues):
    y_preds = np.array(y_preds)
    #y_preds_class = np.array(y_preds_class)
    y_trues = np.array(y_trues)
    subjs = np.array(subjs).flatten()
    y_preds_subj = []
    y_trues_subj = []
    subjs_subj = np.unique(subjs)
    for subj in subjs_subj:
        subj = subj.item()
        y_pred_subj = y_preds[subjs == subj]
        y_true_subj = y_trues[subjs == subj]
        assert len(np.unique(y_true_subj)) == 1, f"No unique label: subj={subj}"
        y_trues_subj.append(np.unique(y_true_subj).item())
        y_preds_subj.append(np.mean(y_pred_subj).item())

    return subjs_subj, y_preds_subj, y_trues_subj

In [9]:
class EyetrackingDataPreprocessor(Dataset):
    """Dataset with the long-format sequence of fixations made during reading by dyslexic 
    and normally-developing Russian-speaking monolingual children."""

    def __init__(
        self, 
        csv_file, 
        transform=None, 
        target_transform=None, 
        dropPhonologyFeatures = True, 
        dropPhonologySubjects = True,     
        num_folds: float = 10,
        ):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            target_transform (callable, optional): Optional transform to be applied
                on a label.
        """
        data = pd.read_csv(csv_file)
        
        # changing dyslexia labels to 0 and 1
        data['group'] = data['group'] + 0.5
        
        # log-transforming frequency
        to_transform = ['frequency', 'predictability', 'fix_dur'] #
        for column in to_transform:
            data[column] = data[column].apply(lambda x: np.log(x) if x > 0 else 0) 
        
        # drop columns we don't use
        data = data.drop(columns = ['fix_x', 'fix_y', 'fix_index'])  
        
        # center reading sopeed in case we need to predict it
        data['Reading_speed'] = (data['Reading_speed'] - data['Reading_speed'].mean())/data['Reading_speed'].std(ddof=0)
        
        if ablation == True:
            data = data.drop(columns = ['sex', 'Grade'])
            convert_columns = ['direction']
        else:
            convert_columns = ['Grade', 'direction'] 
        
        if dropPhonologyFeatures == True:
            data = data.drop(columns = ['IQ', 'Sound_detection', 'Sound_change'])
        
        for column in convert_columns:
            prefix = column + '_dummy'
            data = pd.concat([data, pd.get_dummies(data[column], 
                                    prefix=prefix)], axis=1)
            data = data.drop(columns = column)
        #data = data.drop(columns = ['Grade_dummy_0', 'direction_dummy_0'])
            
        if dropPhonologySubjects == True:
            # Drop subjects
            data.dropna(axis = 0, how = 'any', inplace = True)
        else:
            # Drop columns
            data.dropna(axis = 1, how = 'any', inplace = True)
            
        # rearrange columns (I need demogrpahic information to come last)
#         cols = ['item', 'subj', 'group', 'Reading_speed', 'fix_dur', 'landing', 'word_length',
#                  'predictability', 'frequency', 'number.morphemes', 'next_fix_dist',
#                  'sac_ampl', 'sac_angle', 'sac_vel', 'rel.position', 'direction_dummy_DOWN',
#                  'direction_dummy_LEFT', 'direction_dummy_RIGHT', 'direction_dummy_UP',
#                  'sex', 'Age', 'Grade_dummy_1', 'Grade_dummy_2', 'Grade_dummy_3', 'Grade_dummy_4',
#                  'Grade_dummy_5', 'Grade_dummy_6']
        cols = ['item', 'subj', 'group', 'Reading_speed', 'fix_dur',
               'landing', 'word_length', 'predictability', 'frequency', 
                'number.morphemes', 'next_fix_dist', 'sac_ampl', 'sac_angle', 
                'sac_vel', 'rel.position', 'direction_dummy_LEFT', 
                'direction_dummy_RIGHT', 'direction_dummy_DOWN'] # temporary
        data = data[cols]
        
        # Record features that are used for prediction
        self._features = [i for i in data.columns if i not in ['group', 'item', 'subj', 'Reading_speed']]
        self._data = pd.DataFrame()
        # Add sentence IDs and subject IDs
        self._data["sn"] = data["item"]
        self._data["subj"] = data["subj"]
        # Add labels
        self._data["group"] = data["group"]
        self._data["reading_speed"] = data["Reading_speed"]
        
        # Add features used for prediction
        for feature in self._features:
            self._data[feature] = data[feature]

        # Distribute subjects across stratified folds
        self._num_folds = num_folds
        self._folds = [[] for _ in range(num_folds)]
        dyslexic_subjects = self._data[self._data["group"] == 1]["subj"].unique()
        control_subjects = self._data[self._data["group"] == 0]["subj"].unique()
        random.shuffle(dyslexic_subjects)
        random.shuffle(control_subjects)
        for i, subj in enumerate(dyslexic_subjects):
            self._folds[i % num_folds].append(subj)
        for i, subj in enumerate(control_subjects):
            self._folds[num_folds - 1 - i % num_folds].append(subj)
        for fold in self._folds:
            random.shuffle(fold)

    def _iter_trials(self, folds: Collection[int]) -> Iterator[pd.DataFrame]:
        # Iterate over all folds
        for fold in folds:
            # Iterate over all subjects in the fold
            for subj in self._folds[fold]:       # Anna: subj in fold?
                subj_data = self._data[self._data["subj"] == subj]
                # Iterate over all sentences this subject read
                for sn in subj_data["sn"].unique():
                    trial_data = subj_data[subj_data["sn"] == sn]
                    yield trial_data
                    
                    
    def iter_folds(
        self, folds: Collection[int]) -> Iterator[Tuple[torch.Tensor, torch.Tensor, int]]:
        for trial_data in self._iter_trials(folds):
            predictors = trial_data[self._features].to_numpy()
            #predictors = np.reshape(predictors, (int(len(predictors)/278), 278, predictors.shape[1]))
            label = trial_data["group"].unique().item()
            subj = trial_data["subj"].unique().item()
            reading_speed = trial_data["reading_speed"].unique().item()
            #  X = (time_steps, features)
            X = predictors
            y = torch.tensor(label, dtype=torch.float)
            rs = torch.tensor(reading_speed , dtype=torch.float)
            yield X, y, subj, rs
                    

    @property
    def num_features(self) -> int:
        """Number of features per word (excluding word vector dimensions)."""
        return len(self._features)
    

    @property
    def max_number_of_sentences(self):
        data_copy = self._data.copy()
        max_s_count = data_copy.groupby(by="subj").sn.unique()
        return max([len(x) for x in max_s_count])

In [10]:
class EyetrackingDataset(Dataset):
    def __init__(
        self,
        preprocessor: EyetrackingDataPreprocessor,
       # word_vector_model: WordVectorModel,
        folds: Collection[int],
        batch_subjects: bool = False,
    ):
        self.sentences = list(preprocessor.iter_folds(folds))
        self._subjects = list(np.unique([subj for _, _, subj, _ in self.sentences]))
        self.num_features = preprocessor.num_features# + word_vector_model.dimensions()
        self.batch_subjects = batch_subjects
        #self.max_sentence_length = preprocessor.max_sentence_length
        self.max_number_of_sentences = preprocessor.max_number_of_sentences

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        if self.batch_subjects:
            subject = self._subjects[index]
            subject_sentences = [
                (X, y, subj, rs) for X, y, subj, rs in self.sentences if subj == subject
            ]
            X = torch.stack([torch.FloatTensor(X) for X, _, _, _ in subject_sentences]) #[X for X, _, _ in subject_sentences] #torch.FloatTensor([X for X, _, _ in subject_sentences])
            y = torch.stack([y for _, y, _, _ in subject_sentences]).unique().squeeze() 
            rs = torch.stack([rs for _, _, _, rs in subject_sentences]).unique().squeeze()
            return X, y, subject, rs

        else:
            X, y, subj, rs = self.sentences[index]
            #X = torch.from_numpy(X).float()   
            return X, y, subj, rs

    def __len__(self) -> int:
        if self.batch_subjects:
            return len(self._subjects)
        else:
            return len(self.sentences)

    def standardize(self, mean: torch.Tensor, sd: torch.Tensor):
        self.sentences = [
            (apply_standardization(X, mean, sd), y, subj, rs)
            for X, y, subj, rs in self.sentences
        ]

In [11]:
# Annas changes
class TransformerWithCustomPositionalEncoding(nn.Module):
    def __init__(
        self, 
        embed_dim = 30,
        d_model = 2,
        inner_dim = 8,
        num_heads = 1, 
        num_layers = 1, 
        dropout = 0,
        mask_prob = 0.2,
        replace_prob = 0, # 0.9
        mask_token_id = 2,
        pad_token_id = -5,
        mask_ignore_token_ids = []
        ):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.d_model = d_model
        self.inner_dim = inner_dim
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.dropout = dropout

        # token ids
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])
        
        self.positional_encoding = PositionalEncoding(d_model=self.d_model, 
                                                      max_len=self.embed_dim)
      #  self.positional_encoding = AbsolutePositionalEmbedding(dim = self.d_model)
        self.positional_encoding = CustomPositionalEncoding(fixations = self.embed_dim, 
                                                            features = self.d_model)
        self.transformer = nn.MultiheadAttention(embed_dim = self.d_model,  
                                                 num_heads = self.num_heads, 
                                                 bias = True,
                                                 batch_first = True)
        # Output classifier per sequence lement
        self.output_net = nn.Sequential(
            nn.Linear(self.d_model, self.inner_dim),
            nn.LayerNorm(self.inner_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(self.dropout),
            nn.Linear(self.inner_dim, self.d_model)
        )
        
    # TransformerEncoder
    def forward(self, seq):
        # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
        # also do not include these special tokens in the tokens chosen at random
        no_mask = mask_with_tokens_3D(seq, self.mask_ignore_token_ids) 
        mask = get_mask_subset_with_prob_3D(~no_mask, self.mask_prob)
        hidden = no_mask + mask # all elements that the model will not attend to

        # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
        masked_seq = seq.clone().detach().to(device)
        
        #   Static positional encoding
        masked_seq_pos = self.positional_encoding(masked_seq, mask = None) #for hidden fixations ~ no_mask

        #[mask] input
        masked_replace_prob = prob_mask_like_3D(seq, self.replace_prob) # Anna: select 90% of all values  (ignore all masking for now)
        masked_seq = masked_seq_pos.masked_fill(mask * masked_replace_prob, self.mask_token_id) # Anna: select 90% only of those selected for masking
        
        # derive labels to predict
        labels = seq.masked_fill(~mask, self.pad_token_id)
        
        #import pdb; 
        #pdb.set_trace()
        
        # Pass through the transformer
        x, weights = self.transformer(masked_seq, 
                                      key = masked_seq, 
                                      value = masked_seq, 
                                      key_padding_mask = hidden[:,:,0])  #no_mask// for masking out fixations: hidden[:,:,0]
        
        preds = self.output_net(x)
        
        # may add loss

        return x, labels, mask

In [13]:
NUM_FOLDS = 10
test_fold = 3
dev_fold = 2
train_folds = [
                fold
                for fold in range(NUM_FOLDS)
                if fold != test_fold and fold != dev_fold
            ]

preprocessor = EyetrackingDataPreprocessor(
    csv_file = 'data/30fixations_no_padding_sentence_word_pos.csv', # 'data/fixation_dataset_30fix_padded.csv',  
    num_folds = NUM_FOLDS
)

train_dataset = EyetrackingDataset(
                preprocessor,
                folds=train_folds,
                batch_subjects=BATCH_SUBJECTS,
            )
mean, sd = getmeansd(train_dataset, batch=BATCH_SUBJECTS)
train_dataset.standardize(mean, sd)
dev_dataset = EyetrackingDataset(
    preprocessor,
    folds=[dev_fold],
    batch_subjects=BATCH_SUBJECTS,
)
dev_dataset.standardize(mean, sd)

In [None]:
#train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2)
test_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE)

trainer = TransformerWithCustomPositionalEncoding(embed_dim = 30,
                                                  num_heads = 14, 
                                                  d_model = NUM_FEATURES,
                                                  num_layers= 1).to(device)

# # Setup optimizer to optimize model's parameters
optimizer = Adam(trainer.parameters(), lr=3e-3)

epochs = 700
epoch_count = 0
for epoch in range(epochs):
    epoch_count += 1
    epoch_loss = 0.0
    
    train_loader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size=BATCH_SIZE,
                                               shuffle = True)
    train_loss = 0
    trainer.train()
    ### Training
    for X, _, _, _ in train_loader:

        # 1. Forward pass 
        X = X.to(device)

        train_preds, labels, mask = trainer(X)
            
        # 2. Calculate loss/accuracy

        loss = mse_loss(
            labels,
            train_preds,
            mask
        )

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backwards
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        epoch_loss += loss.item()

    train_loss = epoch_loss/(len(train_loader.dataset)/BATCH_SIZE)

    #### Evaluation
    test_loss = 0.0
    
    trainer.eval()
    with torch.no_grad():
        for X_test, _, _, _ in test_loader:
            X_test = X_test.to(device)
            test_preds, labels, mask = trainer(X_test)
            tloss = mse_loss(
                labels,
                test_preds,
                mask
            )
            test_loss += tloss.item() 

    test_loss = test_loss /(len(test_loader.dataset)/BATCH_SIZE)


    # Print out what's happening every 10 epochs
    if epoch % 50 == 0:
        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

#print('Input: ', X[0, 0:15])
#print('Predictions: ', train_preds[0, 0:15])

Epoch 1: Train Loss: 0.6550, Test Loss: 0.5569
Epoch 51: Train Loss: 0.0840, Test Loss: 0.1024
Epoch 101: Train Loss: 0.0812, Test Loss: 0.1124
Epoch 151: Train Loss: 0.0782, Test Loss: 0.0903
Epoch 201: Train Loss: 0.0791, Test Loss: 0.0850
Epoch 251: Train Loss: 0.0794, Test Loss: 0.1052
Epoch 301: Train Loss: 0.0785, Test Loss: 0.0882
Epoch 351: Train Loss: 0.0779, Test Loss: 0.0833


In [None]:
test_loss

In [None]:
loss = mse_loss(
    labels,
    test_preds,
    mask
)
loss

In [None]:
mask[3,:,0]

In [None]:
test_preds[7][0:4]