In [1]:
## Standard libraries
import os
import numpy as np
import pandas as pd
import random
import math
import json
from functools import partial

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
#from IPython.display import set_matplotlib_formats
#matplotlib_inline.backend_inline.set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

# ## tqdm for loading bars
# from tqdm.notebook import tqdm

# for one-hot encoding
from keras.utils.np_utils import to_categorical   

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import default_collate
import torch.optim as optim
from torch.optim import Adam
from mlm_pytorch import MLM

# Transformer wrapper
import tensorflow as tf
from x_transformers import TransformerWrapper, Encoder
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# Positional encoding in two dimensions
from positional_encodings.torch_encodings import PositionalEncodingPermute1D, PositionalEncoding1D

from functools import reduce
import math


# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "saved_models/simple_transformer"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)



2024-12-16 14:40:03.627254: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-16 14:40:03.691277: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-16 14:40:03.709463: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Seed set to 42


Device: cuda:0


In [2]:
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.5
    import pytorch_lightning as pl

### Data preparation

In [3]:
# Global variables
NUM_FEATURES = 1
NUM_FIX = 30 
BATCH_SIZE = 64
NUM_CLASSES = 31


num_fix = NUM_FIX
seq_len = NUM_FEATURES

In [4]:
class CustomPositionalEncoding(nn.Module):
    """Learnable positional encoding for features """
    def __init__(self, embed_dim, max_len=1):
        super(CustomPositionalEncoding, self).__init__()

    # Initialize a learnable positional encoding matrix
        self.encoding = nn.Parameter(torch.zeros(max_len, embed_dim)).to(device)
        nn.init.xavier_uniform_(self.encoding)  # Xavier initialization for better training stability

    def forward(self, x, mask = None):
        if mask is not None:
            # Apply the mask to ignore padded positions
            pos_encoding = self.encoding[:x.size(1), :] * mask
    # Add the learnable positional encoding to the input tensor
        return x + pos_encoding

In [27]:
# embed_dim = 30
# encoding = nn.Parameter(torch.zeros(1, embed_dim)).to(device)
# nn.init.xavier_uniform_(encoding)
# mask = ~no_mask

# pos_encoding = encoding[:seq.size(1), :] * mask

In [6]:
# Helper functions

def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask
    
# def mse_loss(target, input, ignored_index):
#     mask = target == ignored_index
#     out = (input[~mask]-target[~mask])**2
#     return out.mean()

def mse_loss(target, input, mask):
    out = (input[mask]-target[mask])**2
    return out.mean()

def mask_with_tokens_3D(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    reduced = torch.any(mask, dim=-1, keepdim=True)
    expanded = reduced.expand_as(mask)
    return expanded

def get_mask_subset_with_prob_3D(mask, prob):
    batch, num_fix, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * num_fix)

    num_tokens = mask.sum(dim=-2, keepdim=True)
    mask_excess = (mask.cumsum(dim=-2)[:,:,0] > (num_tokens[:,:,0] * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, num_fix, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-2)
    sampled_indices = (sampled_indices[:,:,0] + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, num_fix + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    new_mask = new_mask[:, 1:].bool()
    
    return new_mask.unsqueeze_(2).expand(-1,-1, seq_len)
    
def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)

    num_tokens = mask.sum(dim=-1, keepdim=True)
    mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-1)
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, seq_len + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    return new_mask[:, 1:].bool()

def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

def prob_mask_like_3D(t, prob):
    temp = torch.zeros_like(t[:,:,0]).float().uniform_(0, 1) < prob
    return temp.unsqueeze_(2).expand(-1,-1, seq_len)

    
# class CustomPositionalEncoding(nn.Module):
#     """Learnable positional encoding for both features and fixations"""
#     def __init__(self, fixations, embed_dim, max_len=5000):
#         super(CustomPositionalEncoding, self).__init__()
        
#         # Initialize a learnable positional encoding matrix
#         self.encoding = nn.Parameter(torch.zeros(fixations, embed_dim))
#         nn.init.xavier_uniform_(self.encoding)  # Xavier initialization for better training stability

#     def forward(self, x, mask = None):
#         if mask is not None:
#             # Apply the mask to ignore padded positions
#             pos_encoding = self.encoding * mask
#         # Assumes input `x` is of shape [batch_size, fixations, embed_dim]
#         return x + pos_encoding#.unsqueeze(0)

    
    
def pad_group_with_zeros(group, target_rows):
    # Calculate the number of rows to add
    num_missing_rows = target_rows - len(group)
    if num_missing_rows > 0:
        # Create a DataFrame with the required number of padding rows
        # input padding
        zero_rows = pd.DataFrame(0.3333, index=range(num_missing_rows), columns=group.columns)
        # Label padding
        # zero_rows.iloc[:, 0] = 31
        # Concatenate the group with the zero rows
        group = pd.concat([group, zero_rows], ignore_index=True)
    return group

class ToTensor(object):
    """Convert Series in sample to Tensors."""

    def __call__(self, sample):
        trial, label = sample['trial'], sample['label']
        trial = torch.from_numpy(trial).float()
        label = torch.from_numpy(label).float()
        return trial, label

In [7]:
class RSCFixationsOrder(Dataset):
    """Dataset with the long-format sequence of fixations made during reading by dyslexic 
    and normally-developing Russian-speaking monolingual children."""

    def __init__(self, csv_file, transform=None, target_transform = None, 
                 n_fix = NUM_FIX, 
                 dropPhonologyFeatures = True, dropPhonologySubjects = True):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            target_transform (callable, optional): Optional transform to be applied
                on a label.
        """
        self.fixations_frame = pd.read_csv(csv_file)
        
        # remove demography and identification
        self.fixations_frame = self.fixations_frame.drop(columns = ['fix_x', 
                                                                    'fix_y', 'full_text' 
                                                                    ]) 
        
        #self.fixations_frame = self.fixations_frame[order_cols]
        
        # Log-transforming appropriate measures
        to_transform = ['frequency', 'fix_dur'] 
        for column in to_transform:
            self.fixations_frame[column] = self.fixations_frame[column].apply(lambda x: np.log(x) if x > 0 else 0) 

        # Center 
        cols = ['fix_dur', 'landing', 'predictability',
                'frequency', 'word_length', 'number.morphemes', 
                'next_fix_dist', 'sac_ampl', 'sac_angle', 
                'sac_vel']
        for col in cols:
            self.fixations_frame[col] = np.where(self.fixations_frame[col] == 0, -4,
                (self.fixations_frame[col] - self.fixations_frame[col].mean())/self.fixations_frame[col].std(ddof=0)) 
        

        # Drop padding 
        self.fixations_frame = self.fixations_frame[self.fixations_frame['fix_dur'] != -4]
        
        
        # Convert direction to a dummy-coded variable
        self.fixations_frame['direction'] = np.where(self.fixations_frame['direction'].isnull(), 0,
                                                     self.fixations_frame['direction'])
        self.fixations_frame = pd.concat([self.fixations_frame, 
                                          pd.get_dummies(self.fixations_frame['direction'], 
                                                         prefix='dummy')], axis=1)
        
        # Center exceptions: 
        # columns where 0 is meaningful and should not be counted as NA
#         cols = ['rel.position', 'dummy_DOWN', 'dummy_LEFT','dummy_RIGHT', 'dummy_UP']
#         for col in cols:
#             self.fixations_frame[col] = (self.fixations_frame[col] - \
#                                              self.fixations_frame[col].mean())/self.fixations_frame[col].std(ddof=0)
            
        if dropPhonologySubjects == True:
            # Drop subjects
            self.fixations_frame.dropna(axis = 0, how = 'any', inplace = True)
        else:
            # Drop columns
            self.fixations_frame.dropna(axis = 1, how = 'any', inplace = True)
        
        
        self.fixations_frame['subj'] = self.fixations_frame['subj'].astype(str)
        self.fixations_frame['item'] = self.fixations_frame['sn'].astype(str)
        self.fixations_frame['Combined'] = self.fixations_frame[['subj', 'item']].agg('_'.join, axis=1)
        
        # cleaning up
        self.fixations_frame.drop(columns = ['subj', 'item', 'direction', 'dummy_0', 'sn',\
                                            'dummy_DOWN', 'dummy_LEFT', 'dummy_UP'\
                                            ], inplace = True)
        self.fixations_frame.drop(columns = ['word_length', 'predictability', 
                                             'frequency', 'number.morphemes', 'fix_index'], 
                                             inplace = True)
        self.fixations_frame.drop(columns = ['landing', 
                'next_fix_dist', 'sac_ampl', 'sac_angle', 
                'sac_vel', 'dummy_RIGHT'], inplace = True)
        

        padded = self.fixations_frame.groupby('Combined', group_keys=False).apply(lambda x: 
                                                                           pad_group_with_zeros(x, n_fix))
        padded.drop(columns = "Combined", inplace = True)
    
        
        #### Fixation index is the label
        self.fixations_frame = padded.to_numpy()
        dataReshaped = np.reshape(self.fixations_frame, (int(len(self.fixations_frame)/n_fix), 
                                                         n_fix, self.fixations_frame.shape[1]))

        self.predictors = dataReshaped[:,:,1:]
        self.labels = dataReshaped[:,:,0]
        self.labels = to_categorical(self.labels-1, num_classes=NUM_CLASSES)   # one-hot encoding

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        trial = self.predictors[idx]
        label = self.labels[idx]
        
        sample = {'trial': trial, 'label': label}

        if self.transform:
            sample = self.transform(sample)
            
        if self.target_transform:
            sample = self.target_transform(sample)
            
        return sample

In [8]:
transformed_dataset = RSCFixationsOrder(csv_file='data/RSC_long_no_word_padded_word_pos.csv', 
                                    transform=ToTensor(),
                                    dropPhonologySubjects = True)

train_size = int(0.8 * len(transformed_dataset))
val_size = int(0.1 * len(transformed_dataset))
test_size = len(transformed_dataset) - val_size - train_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(transformed_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

In [9]:
trials, labels = next(iter(train_loader))

train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([64, 30, 1])
Labels batch shape: torch.Size([64, 30, 31])


In [10]:
trials[1, 0:15,:]

tensor([[0.0000],
        [0.1000],
        [0.2000],
        [0.1000],
        [0.3000],
        [0.5000],
        [0.5000],
        [0.6000],
        [0.7000],
        [0.9000],
        [0.8000],
        [0.3333],
        [0.3333],
        [0.3333],
        [0.3333]], device='cuda:0')

### Embedding

-- Instead of reconstructing the vocabulary, we would use MSE and try to reconstruct *x, y, fix dur* and we can look into how to reconstruct the rest of the features (i.e. all the additional that you used).

-- and an adjustable pytorch implementation here:
https://github.com/lucidrains/mlm-pytorch/tree/master

-- https://github.com/lucidrains/mlm-pytorch

-- Embedding for position - DONE

In [11]:
# #pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)

# emb_dim = 7
# max_seq_len = 30

# pos_emb = AbsolutePositionalEmbedding(seq_len, num_fix, l2norm_embed = l2norm_embed)
# pos_emb(X_train)

# # #X_train + pos_emb(X_train)

In [12]:
X_train = trials

In [13]:
trials.squeeze(2).shape

torch.Size([64, 30])

In [14]:
# Annas changes
class TransformerWithCustomPositionalEncoding(nn.Module):
    def __init__(
        self, 
        embed_dim, 
        num_heads, 
        num_layers, 
        max_len=5000,
        mask_prob = 0.15,
        replace_prob = 1, # 0.9
        mask_token_id = 2,
        pad_token_id = 0.3333,
        mask_ignore_token_ids = []
        ):
        super(TransformerWithCustomPositionalEncoding, self).__init__()
        
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob

        # token ids
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])
        
        # positional encoding
        #self.positional_encoding = PositionalEncoding(d_model = embed_dim)
        self.positional_encoding = CustomPositionalEncoding(num_fix).to(device)
        #self.positional_encoding = AbsolutePositionalEmbedding(seq_len, num_fix, l2norm_embed = l2norm_embed)
        
        # transformer
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
                                                        nhead=num_heads, 
                                                        batch_first = True)
        self.transformer = TransformerEncoder(self.encoder_layer, num_layers=num_layers).to(device)

    def forward(self, seq):
        # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
        # also do not include these special tokens in the tokens chosen at random
        seq = seq.squeeze(2)
        
        no_mask = mask_with_tokens(seq, self.mask_ignore_token_ids) 
        mask = get_mask_subset_with_prob(~no_mask, self.mask_prob)

        # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
        masked_seq = seq.clone().detach()
        masked_seq_pos = self.positional_encoding(masked_seq, mask = ~no_mask)

        # [mask] input
        masked_replace_prob = prob_mask_like(seq, self.replace_prob) # Anna: select 90% of all values  (ignore all masking for now)
        masked_seq = masked_seq_pos.masked_fill(mask * masked_replace_prob, self.mask_token_id) # Anna: select 90% only of those selected for masking
        
        # derive labels to predict
        labels = seq 
        labels = labels.masked_fill(~mask, self.pad_token_id)
        
        # Pass through the transformer
        preds = self.transformer(masked_seq)
        
        my_loss = mse_loss(
            labels,
            preds,
            #ignored_index = self.pad_token_id
            mask = mask
        )

        return preds, my_loss

In [31]:
X_train = trials.squeeze(2)
seq = X_train

mask_prob = 0.15
replace_prob = 1
mask_token_id = 2
pad_token_id = 0.3333
mask_ignore_token_ids = set([pad_token_id])
embed_dim = num_fix


positional_encoding = CustomPositionalEncoding(num_fix).to(device)
#positional_encoding = AbsolutePositionalEmbedding(seq_len, num_fix, l2norm_embed = False)

# transformer
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
                                            nhead=1, 
                                            batch_first = True)
transformer = TransformerEncoder(encoder_layer, num_layers=2).to(device)

no_mask = mask_with_tokens(seq, mask_ignore_token_ids) # works in 3D
mask = get_mask_subset_with_prob(~no_mask, mask_prob)

# mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
masked_seq = seq.clone().detach()
masked_seq_pos = positional_encoding(masked_seq, mask = ~no_mask)

# [mask] input
masked_replace_prob = prob_mask_like(seq, replace_prob) # Anna: select 90% of all values  (ignore all masking for now)
masked_seq1 = masked_seq_pos.masked_fill(mask * masked_replace_prob, mask_token_id) # Anna: select 90% only of those selected for masking

labels = seq 
labels = labels.masked_fill(~mask, pad_token_id)

preds = transformer(masked_seq1)

In [35]:
masked_seq1

tensor([[ 2.0000,  0.4057, -0.1516,  ...,  0.3333,  0.3333,  0.3333],
        [-0.3825,  0.4057, -0.1516,  ...,  0.3333,  0.3333,  0.3333],
        [-0.2397,  2.0000,  0.0769,  ...,  0.3333,  0.3333,  0.3333],
        ...,
        [ 2.0000,  0.4875, -0.0789,  ...,  0.3333,  0.3333,  0.3333],
        [ 2.0000,  0.4168, -0.2405,  ...,  0.3333,  0.3333,  0.3333],
        [-0.3825,  0.4307, -0.1016,  ...,  0.3333,  0.3333,  0.3333]],
       device='cuda:0', grad_fn=<MaskedFillBackward0>)

In [16]:
out = (preds[mask]-seq[mask])**2
out.mean()

tensor(4.6095, device='cuda:0', grad_fn=<MeanBackward0>)

In [17]:
(preds[mask]-seq[mask])**2


tensor([ 5.2760,  5.9953,  3.8571,  5.9305,  6.1298,  1.8734,  4.6777,  5.6820,
         1.9164,  1.5182,  6.0262,  4.8444,  8.8233,  5.1869,  6.2024,  2.4236,
         1.9974,  6.4584,  3.1318,  4.0985,  1.6622, 12.9924,  2.1515,  4.6691,
         0.8951,  6.0129,  2.5412,  5.8102,  5.3217,  3.2777,  4.2109,  5.6686,
         4.4057,  3.8594,  3.0983,  1.4901,  1.6957,  5.5950,  2.5687, 11.3325,
         2.8783,  2.3831,  4.2814,  4.3297,  5.3064,  2.6955,  3.4869,  3.6654,
         3.3722,  5.5973,  4.7559,  4.7739,  0.5254,  3.5579,  3.1393,  3.6581,
         1.6184,  7.2295,  4.4963,  3.7964,  5.5778,  3.2406,  7.6695,  5.4511,
         5.0518,  2.3712,  2.1043,  0.5691,  2.8765,  2.4981,  4.3434,  7.6151,
         5.1446,  9.0712,  2.5140,  5.4814,  6.6554,  3.2315,  7.5800,  5.2623,
         7.6871,  6.2179,  2.9578,  1.8041,  7.0106,  2.2643,  5.5639,  7.4577,
         1.3620,  6.4833,  2.6712, 10.9557,  1.7333,  2.0845,  7.9348,  9.6457,
         1.3798,  2.2662,  7.1542,  8.37

In [18]:
masked_seq1[4]

tensor([ 0.0994, -0.1811,  2.0000, -0.0295,  0.3226,  0.6096,  1.0829,  0.5423,
         2.0000, -0.0627,  2.0000,  0.7470,  0.9153,  0.0901,  0.5264,  0.5076,
         0.3333,  0.3333,  0.3333,  0.3333,  0.3333,  0.3333,  0.3333,  0.3333,
         0.3333,  0.3333,  0.3333,  0.3333,  0.3333,  0.3333], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [19]:
out = mask * masked_replace_prob
out[0]

tensor([ True, False, False, False, False, False,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False],
       device='cuda:0')

In [20]:
masked_seq_pos.masked_fill(mask * masked_replace_prob, mask_token_id)

tensor([[ 2.0000, -0.3311,  0.1106,  ...,  0.3333,  0.3333,  0.3333],
        [ 0.0994, -0.3311,  2.0000,  ...,  0.3333,  0.3333,  0.3333],
        [ 0.2422, -0.1454,  2.0000,  ...,  0.3333,  0.3333,  0.3333],
        ...,
        [ 0.0994, -0.2493,  0.1833,  ...,  0.3333,  0.3333,  0.3333],
        [ 0.0994, -0.3200,  0.0217,  ...,  0.3333,  0.3333,  0.3333],
        [ 0.0994, -0.3061,  0.1606,  ...,  0.3333,  0.3333,  0.3333]],
       device='cuda:0', grad_fn=<MaskedFillBackward0>)

In [21]:
# derive labels to predict
labels = seq # self.positional_encoding(seq) # add positional encoding before masking, so that encoding does not affect mask
labels = labels.masked_fill(~mask, pad_token_id)

labels[4]

tensor([0.3333, 0.3333, 0.2500, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.5000,
        0.3333, 0.3750, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333,
        0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333,
        0.3333, 0.3333, 0.3333], device='cuda:0')

In [22]:
# labels[0,1] == seq[0,1]

In [23]:
# Ok, let's see:
trainer = TransformerWithCustomPositionalEncoding(num_heads = 1, 
                                                  num_layers= 1, 
                                                  embed_dim = num_fix).cuda()

print(trainer)

TransformerWithCustomPositionalEncoding(
  (positional_encoding): CustomPositionalEncoding()
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=30, out_features=30, bias=True)
    )
    (linear1): Linear(in_features=30, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=30, bias=True)
    (norm1): LayerNorm((30,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((30,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=30, out_features=30, bias=True)
        )
        (linear1): Linear(in_features=30, out_features=2048, bias=Tru

In [24]:
# # Setup optimizer to optimize model's parameters
optimizer = Adam(trainer.parameters(), lr=3e-4)

epochs = 41

for epoch in range(epochs):
    train_loss = 0.0
    ### Training
    trainer.train()
    
    for X_train, y_train in train_loader:
        # 1. Forward pass 
        # 2. Calculate loss/accuracy
        train_preds, loss = trainer(X_train)
        
        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backwards
        loss.backward()
 
        # 5. Optimizer step
        optimizer.step()
        
        train_loss += loss.item() * X_train.size(0)
        
    train_loss /= len(train_loader.dataset)

    #### Evaluation
    test_loss = 0.0
    trainer.eval()
    with torch.no_grad():
        for X_test, y_test in test_loader:
            test_preds, tloss = trainer(X_test)
            test_loss += tloss.item() * X_test.size(0)

    test_loss /= len(test_loader.dataset)


    # Print out what's happening every 10 epochs
    if epoch % 10 == 0:
        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')


Epoch 1: Train Loss: 0.2666, Test Loss: 0.0570
Epoch 11: Train Loss: 0.0313, Test Loss: 0.0258
Epoch 21: Train Loss: 0.0251, Test Loss: 0.0213
Epoch 31: Train Loss: 0.0233, Test Loss: 0.0198
Epoch 41: Train Loss: 0.0216, Test Loss: 0.0201


In [25]:
X_test[1, 0:15]

tensor([[0.0000],
        [0.1111],
        [0.2222],
        [0.3333],
        [0.3333],
        [0.4444],
        [0.5556],
        [0.7778],
        [1.0000],
        [1.0000],
        [1.0000],
        [0.8889],
        [0.3333],
        [0.3333],
        [0.3333]], device='cuda:0')

In [26]:
test_preds[1, 0:15]

tensor([ 0.0475, -0.2503, -0.4501, -0.4139, -0.2933,  0.3569,  0.4882,  0.6947,
         0.4225,  0.8231,  0.7815,  0.5046,  0.3444,  0.3982, -0.0254],
       device='cuda:0')