In [1]:
## Standard libraries
import os
import numpy as np
import pandas as pd
import random
import math
import json
from functools import partial

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
#from IPython.display import set_matplotlib_formats
#matplotlib_inline.backend_inline.set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

# ## tqdm for loading bars
# from tqdm.notebook import tqdm

# for one-hot encoding
from keras.utils.np_utils import to_categorical   

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import default_collate
import torch.optim as optim
from torch.optim import Adam
from mlm_pytorch import MLM

# Transformer wrapper
import tensorflow as tf
from x_transformers import TransformerWrapper, Encoder
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# Positional encoding in two dimensions
from positional_encodings.torch_encodings import PositionalEncodingPermute1D, PositionalEncoding1D

from functools import reduce
import math


# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "saved_models/simple_transformer"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)


2024-12-19 14:46:49.885277: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-19 14:46:49.954202: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-19 14:46:49.971768: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Seed set to 42


Device: cuda:0


In [2]:
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.5
    import pytorch_lightning as pl

### Setup

In [3]:
# Global variables
NUM_FEATURES = 1
NUM_FIX = 30 
BATCH_SIZE = 64
NUM_CLASSES = 31


num_fix = NUM_FIX
num_features = NUM_FEATURES

In [4]:
# Helper functions

def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

def mse_loss(target, input, mask):
    out = (input[mask]-target[mask])**2
    return out.mean()

def mask_with_tokens_3D(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    reduced = torch.any(mask, dim=-1, keepdim=True)
    expanded = reduced.expand_as(mask)
    return expanded

def get_mask_subset_with_prob_3D(mask, prob):
    batch, num_fix, num_features, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * num_fix)

    num_tokens = mask.sum(dim=-2, keepdim=True)
    mask_excess = (mask.cumsum(dim=-2)[:,:,0] > (num_tokens[:,:,0] * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, num_fix, num_features), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-2)
    sampled_indices = (sampled_indices[:,:,0] + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, num_fix + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    new_mask = new_mask[:, 1:].bool()
    
    return new_mask.unsqueeze_(2).expand(-1,-1, num_features)
    

def prob_mask_like_3D(t, prob):
    temp = torch.zeros_like(t[:,:,0]).float().uniform_(0, 1) < prob
    return temp.unsqueeze_(2).expand(-1,-1, num_features)
    
    
def pad_group_with_zeros(group, target_rows):
    # Calculate the number of rows to add
    num_missing_rows = target_rows - len(group)
    if num_missing_rows > 0:
        # Create a DataFrame with the required number of padding rows
        # input padding
        zero_rows = pd.DataFrame(0.3333, index=range(num_missing_rows), columns=group.columns)
        # Label padding
        # zero_rows.iloc[:, 0] = 31
        # Concatenate the group with the zero rows
        group = pd.concat([group, zero_rows], ignore_index=True)
    return group

class ToTensor(object):
    """Convert Series in sample to Tensors."""

    def __call__(self, sample):
        trial, label = sample['trial'], sample['label']
        trial = torch.from_numpy(trial).float()
        label = torch.from_numpy(label).float()
        return trial, label

In [5]:
class CustomPositionalEncoding(nn.Module):
    """Learnable positional encoding for both features and fixations
    Assumes input `x` is of shape [batch_size, fixations, embed_dim]"""
    def __init__(self, fixations, embed_dim, max_len= num_features):
        super(CustomPositionalEncoding, self).__init__()
        
        # Initialize a learnable positional encoding matrix
        self.encoding = nn.Parameter(torch.zeros(fixations, max_len)).to(device)
        nn.init.xavier_uniform_(self.encoding)  # Xavier initialization for better training stability
        
    def forward(self, x, mask = None):
        if mask is not None:
            # Apply the mask to ignore padded positions
            pos_encoding = self.encoding  * mask
         
        return x + pos_encoding

In [6]:
class AbsolutePositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len, l2norm_embed = False):
        super().__init__()
        self.scale = dim ** -0.5 if not l2norm_embed else 1.
        self.max_seq_len = max_seq_len
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(max_seq_len, dim)

    def forward(self, x, pos = None, seq_start_pos = None, mask = None):
        seq_len, device = x.shape[1], x.device
        assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'

        #if not exists(pos):
        pos = torch.arange(seq_len, device = device)

        #if exists(seq_start_pos):
        #    pos = (pos - seq_start_pos[..., None]).clamp(min = 0)

        pos_emb = self.emb(pos)
        pos_emb = pos_emb * self.scale
        
        if mask is not None:
            # Apply the mask to ignore padded positions
            pos_emb = pos_emb * mask
            
        return l2norm(pos_emb) if self.l2norm_embed else pos_emb


In [7]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=num_fix):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model%2 != 0:
            pe[:, 1::2] = torch.cos(position * div_term)[:,0:-1]
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).to(device)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x, mask = None):
        pos_enc = self.pe[:, :x.size(1)]*mask
        return x + pos_enc

### Data preparation

In [8]:
class RSCFixationsOrder(Dataset):
    """Dataset with the long-format sequence of fixations made during reading by dyslexic 
    and normally-developing Russian-speaking monolingual children."""

    def __init__(self, csv_file, transform=None, target_transform = None, 
                 n_fix = NUM_FIX, 
                 dropPhonologyFeatures = True, dropPhonologySubjects = True):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            target_transform (callable, optional): Optional transform to be applied
                on a label.
        """
        self.fixations_frame = pd.read_csv(csv_file)
        
        # remove demography and identification
        self.fixations_frame = self.fixations_frame.drop(columns = ['fix_x', 
                                                                    'fix_y', 'full_text' 
                                                                    ]) 
        
        #self.fixations_frame = self.fixations_frame[order_cols]
        
        # Log-transforming appropriate measures
        to_transform = ['frequency', 'fix_dur'] 
        for column in to_transform:
            self.fixations_frame[column] = self.fixations_frame[column].apply(lambda x: np.log(x) if x > 0 else 0) 

        # Center 
        cols = ['fix_dur', 'landing', 'predictability',
                'frequency', 'word_length', 'number.morphemes', 
                'next_fix_dist', 'sac_ampl', 'sac_angle', 
                'sac_vel']
        for col in cols:
            self.fixations_frame[col] = np.where(self.fixations_frame[col] == 0, -4,
                (self.fixations_frame[col] - self.fixations_frame[col].mean())/self.fixations_frame[col].std(ddof=0)) 
        

        # Drop padding 
        self.fixations_frame = self.fixations_frame[self.fixations_frame['fix_dur'] != -4]
        
        
        # Convert direction to a dummy-coded variable
        self.fixations_frame['direction'] = np.where(self.fixations_frame['direction'].isnull(), 0,
                                                     self.fixations_frame['direction'])
        self.fixations_frame = pd.concat([self.fixations_frame, 
                                          pd.get_dummies(self.fixations_frame['direction'], 
                                                         prefix='dummy')], axis=1)

        if dropPhonologySubjects == True:
            # Drop subjects
            self.fixations_frame.dropna(axis = 0, how = 'any', inplace = True)
        else:
            # Drop columns
            self.fixations_frame.dropna(axis = 1, how = 'any', inplace = True)
        
        
        self.fixations_frame['subj'] = self.fixations_frame['subj'].astype(str)
        self.fixations_frame['item'] = self.fixations_frame['sn'].astype(str)
        self.fixations_frame['Combined'] = self.fixations_frame[['subj', 'item']].agg('_'.join, axis=1)
        
        # cleaning up
        self.fixations_frame.drop(columns = ['subj', 'item', 'direction', 'dummy_0', 'sn',\
                                            'dummy_DOWN', 'dummy_LEFT', 'dummy_UP'\
                                            ], inplace = True)
        self.fixations_frame.drop(columns = ['word_length', 'predictability', 
                                             'frequency', 'number.morphemes', 'fix_index'], 
                                             inplace = True)
        
        # Leaving just one feature for now
        self.fixations_frame.drop(columns = ['landing', 
                'next_fix_dist', 'sac_ampl', 'sac_angle', 
                'sac_vel', 'dummy_RIGHT'], inplace = True)
        

        padded = self.fixations_frame.groupby('Combined', group_keys=False).apply(lambda x: 
                                                                           pad_group_with_zeros(x, n_fix))
        padded.drop(columns = "Combined", inplace = True)
    
        
        #### Fixation index is the label
        self.fixations_frame = padded.to_numpy()
        dataReshaped = np.reshape(self.fixations_frame, (int(len(self.fixations_frame)/n_fix), 
                                                         n_fix, self.fixations_frame.shape[1]))

        self.predictors = dataReshaped[:,:,1:]
        self.labels = dataReshaped[:,:,0]
        self.labels = to_categorical(self.labels-1, num_classes=NUM_CLASSES)   # one-hot encoding

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        trial = self.predictors[idx]
        label = self.labels[idx]
        
        sample = {'trial': trial, 'label': label}

        if self.transform:
            sample = self.transform(sample)
            
        if self.target_transform:
            sample = self.target_transform(sample)
            
        return sample

In [9]:
transformed_dataset = RSCFixationsOrder(csv_file='data/RSC_long_no_word_padded_word_pos.csv', 
                                    transform=ToTensor(),
                                    dropPhonologySubjects = True)

train_size = int(0.8 * len(transformed_dataset))
val_size = int(0.1 * len(transformed_dataset))
test_size = len(transformed_dataset) - val_size - train_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(transformed_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

In [10]:
trials, labels = next(iter(train_loader))
print(f"Feature batch shape: {trials.size()}")

Feature batch shape: torch.Size([64, 30, 1])


In [11]:
trials[1, 0:15,:]

tensor([[0.0000],
        [0.1000],
        [0.2000],
        [0.1000],
        [0.3000],
        [0.5000],
        [0.5000],
        [0.6000],
        [0.7000],
        [0.9000],
        [0.8000],
        [0.3333],
        [0.3333],
        [0.3333],
        [0.3333]], device='cuda:0')

### Training

In [12]:
# Annas changes
class TransformerWithCustomPositionalEncoding2D(nn.Module):
    def __init__(
        self, 
        embed_dim, 
        num_heads, 
        num_layers, 
        max_len=5000,
        mask_prob = 0.15,
        replace_prob = 1, # 0.9
        mask_token_id = 2,
        pad_token_id = 0.3333,
        mask_ignore_token_ids = []
        ):
        super(TransformerWithCustomPositionalEncoding2D, self).__init__()
        
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob

        # token ids
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])
        
        self.positional_encoding = PositionalEncoding(num_features)
        
        # transformer
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
                                                        nhead=num_heads, 
                                                        batch_first = True)
        self.transformer = TransformerEncoder(self.encoder_layer, num_layers=num_layers).to(device)

    def forward(self, seq, before):
        # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
        # also do not include these special tokens in the tokens chosen at random
        no_mask = mask_with_tokens_3D(seq, self.mask_ignore_token_ids) 
        mask = get_mask_subset_with_prob_3D(~no_mask, self.mask_prob)

        # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
        masked_seq = seq.clone().detach()
        
        #   Static positional encoding
        masked_seq_pos = self.positional_encoding(masked_seq, mask = ~no_mask)

        # [mask] input
        masked_replace_prob = prob_mask_like_3D(seq, self.replace_prob) # Anna: select 90% of all values  (ignore all masking for now)
        masked_seq = masked_seq_pos.masked_fill(mask * masked_replace_prob, self.mask_token_id) # Anna: select 90% only of those selected for masking
        
        # derive labels to predict
        labels = seq.masked_fill(~mask, self.pad_token_id).squeeze(2)
        
        # Pass through the transformer
        if before:
            preds = self.transformer(masked_seq.squeeze(2))
        else:
            preds = self.transformer(masked_seq).squeeze(2)
            
        
        my_loss = mse_loss(
            labels,
            preds,
            mask = mask.squeeze(2)
        )

        return preds, my_loss

In [17]:
# flattening input before feeding it to transformer (vs. after)
settings = [True, False]

for before in settings:
    print("\n",
          "Flattening the input before feeding it to the transformer =", before)
    if before:
        embed_dim = num_fix
    else:
        embed_dim = num_features

    trainer = TransformerWithCustomPositionalEncoding2D(num_heads = 1, 
                                                        num_layers= 2, 
                                                        embed_dim = embed_dim).cuda()

    # # Setup optimizer to optimize model's parameters
    optimizer = Adam(trainer.parameters(), lr=3e-4)

    epochs = 31

    for epoch in range(epochs):
        train_loss = 0.0
        ### Training
        trainer.train()

        for X_train, y_train in train_loader:
            # 1. Forward pass 
            # 2. Calculate loss/accuracy
            train_preds, loss = trainer(X_train, before = before)

            # 3. Optimizer zero grad
            optimizer.zero_grad()

            # 4. Loss backwards
            loss.backward()

            # 5. Optimizer step
            optimizer.step()

            train_loss += loss.item() * X_train.size(0)

        train_loss /= len(train_loader.dataset)

        #### Evaluation
        test_loss = 0.0
        trainer.eval()
        with torch.no_grad():
            for X_test, y_test in test_loader:
                test_preds, tloss = trainer(X_test, before = before)
                test_loss += tloss.item() * X_test.size(0)

        test_loss /= len(test_loader.dataset)


        # Print out what's happening every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

    print('Input: ', X_test[1, 0:15])
    print('Predictions: ', test_preds[1, 0:15])


 Flattening the input before feeding it to the transformer = True
Epoch 1: Train Loss: 0.1368, Test Loss: 0.0483
Epoch 11: Train Loss: 0.0276, Test Loss: 0.0265
Epoch 21: Train Loss: 0.0236, Test Loss: 0.0233
Epoch 31: Train Loss: 0.0217, Test Loss: 0.0220
Input:  tensor([[0.1111],
        [0.1111],
        [0.2222],
        [0.3333],
        [0.3333],
        [0.4444],
        [0.5556],
        [0.5556],
        [0.7778],
        [0.7778],
        [0.7778],
        [0.7778],
        [0.8889],
        [0.3333],
        [0.3333]], device='cuda:0')
Predictions:  tensor([-0.0195,  0.0576,  0.0978,  0.2506, -2.1445, -0.4618,  0.3668,  0.6642,
         0.7162,  0.8439,  0.6074,  0.5196,  0.5091,  0.5393,  0.4169],
       device='cuda:0')

 Flattening the input before feeding it to the transformer = False
Epoch 1: Train Loss: 0.2418, Test Loss: 0.2227
Epoch 11: Train Loss: 0.0853, Test Loss: 0.0863
Epoch 21: Train Loss: 0.0837, Test Loss: 0.0849
Epoch 31: Train Loss: 0.0838, Test Loss: 0.08