# Transformer Model

# Modules

In [1]:
import pandas as pd #1.5.3 
import numpy as np #1.20.3

import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import KFold
import json

### Logging

In [2]:
import logging

#logging.basicConfig(level=logging.INFO)
logging.basicConfig(filename='training.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

logger.info(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    logger.info("GPU:", torch.cuda.get_device_name(0))

# Import datasets

In [3]:
X_train = pd.read_csv("../preprocess/data/X_train_8e-4.csv")
X_test = pd.read_csv("../preprocess/data/X_test_8e-4.csv")
y_train = pd.read_csv("../preprocess/data/y_train_8e-4.csv", header=None)
y_test = pd.read_csv("../preprocess/data/y_test_8e-4.csv", header=None)

### Separate genotype, position inputs
Require due to separate embedding processes

In [4]:
snp_columns = [col for col in X_train.columns if "_pos" not in col]
pos_columns = [col for col in X_train.columns if "_pos" in col]

X_train_snp = X_train[snp_columns]
X_test_snp = X_test[snp_columns]
X_train_pos = X_train[pos_columns]
X_test_pos = X_test[pos_columns]


print(
    "Dataset shapes \n",
    "X_train_snp: ",X_train_snp.shape,"\n",
    "X_test_snp: ",X_test_snp.shape,"\n",
    "X_train_pos: ",X_train_pos.shape,"\n",
    "X_test_pos: ",X_test_pos.shape,"\n",
    "y_train: ",y_train.shape,"\n",
    "y_test: ",y_test.shape,"\n",
)

Dataset shapes 
 X_train_snp:  (292410, 2272) 
 X_test_snp:  (32509, 2272) 
 X_train_pos:  (292410, 2272) 
 X_test_pos:  (32509, 2272) 
 y_train:  (292410, 1) 
 y_test:  (32509, 1) 



# Params

In [5]:
# Device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Embedding
embed_size = 256
snp_encoding_size = 32 # "vocab" size, snp encoding possibilities e.g. AA, GG, AT, none 
seq_len = 2272 # Number of snps/ features
max_pos_length = 100000 # "sentence length", max SNP position number (absolute positions to be scaled down)

# Transformer
heads = 4
num_layers = 4
forward_expansion = 8
dropout = 0.7
agg_fc = "mean"

# Kfold
k_folds = 3
batch_size = 10 # SNVformer uses 10

# Training
src_pad_idx = None # index of the padding token in source vocabulary
lr = 7e-7 # SNVformer uses 1e-6
num_epochs = 50 # SNV former uses 60
weight_decay=1e-1 # L2 regularisation for Optimiser
early_stopping_patience=10

# Linformer k
k=256

# Embedding

### snp embedding

In [6]:
class SnpEmbedding(nn.Module):
    """
    Snp embeddings - use default nn.Embedding. Created class for potential custom functionality / encapsulation
    """
    def __init__(self, snp_encoding_size, embed_size):
        super(SnpEmbedding, self).__init__()
        self.embedding = nn.Embedding(snp_encoding_size, embed_size)  # Create an embedding layer
        
    def forward(self, x):
        return self.embedding(x)  # Forward pass to get embeddings

### position embedding

In [7]:
class PosEmbedding(nn.Module):
    """
    Pos embeddings - sine-cosine encoding of absolute snp positions. Enables positional information to be
    captured and the model to learn positional contexts between SNPs. 
    """
    def __init__(self, max_pos_length, embed_size):
        super(PosEmbedding, self).__init__()
        self.max_pos_length = max_pos_length
        self.embed_size = embed_size        
        
        # Create a positional encoding matrix with shape (max_position, embedding_dim). Sine + cosine values calculated in
        # embedding space. Relative positions and attension can be learned.
        position = torch.arange(max_pos_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * -(math.log(10000.0) / embed_size))
        positional_encoding = torch.zeros(max_pos_length, embed_size)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        
        # Register this matrix as a buffer that is not a model parameter
        self.register_buffer('positional_encoding', positional_encoding)
        
    def forward(self, x):
        """
        Inputs:
            x: A tensor of shape (batch_size, sequence_length) containing the SNP positions.
        Returns:
            A tensor of shape (batch_size, sequence_length, embedding_dim) with added positional encodings.
        """
        # Retrieve the positional encodings based on the SNP positions in x
        # Ensure the positions in x do not exceed max_position and is int (scaled positions are float)
        x = x.clamp(0, self.max_pos_length - 1)
        x = x.round().long()
        return self.positional_encoding[x]

Linear scaling of SNP pos to preserve relative distances

In [None]:
# Example SNP positions
snp_positions = np.array(X_train_pos.iloc[0,:])  # snp positions of first record - same for all records

# Calculate the min and max
min_position = np.min(snp_positions)
max_position = np.max(snp_positions)

# Scale positions to range 0-max_pos_length
scaled_positions = 0 + ((snp_positions - min_position) * (max_pos_length - 0)) / (max_position - min_position)

# Replace positions with scaled positions for all rows/records
X_train_pos = np.tile(scaled_positions,(len(X_train_pos),1)) #This is now ndarray and not df


In [11]:
snp_positions

array([   14725873,   149501894,   164246170,   191089379,   191093014,
         198159558,  1187528659,    22424764,    29453246,   246573210,
        2123646203,  2125009044,  2140759771,  2144304494,  2144535837,
        2155670768,  2156568806,  2188896397,  2207927831,   335591556,
         344523634,   348768680,   382011690,   384757676,   385094800,
         389273019,  3103710513,  3112101315,  3177368489,   416823173,
         437993214,   438653771,   439931534,   443237901,   447490933,
         467709129,   484600736,  4119861650,    57982843,   581145466,
        5125344932,  5131137580,   685882348,   698546547,   699300428,
        6124183635,  6164040267,    71980113,    73535566,   763359252,
        7114940159,  7131542232,  7132519864,  7133453874,  7142427638,
        7150726689,   834751056,   834751056,   834946925,   860196619,
         865258742,   865341428,   866956527,   877592275,  8131204233,
          91521527,   921943952,   931634264,   976157130,   986

In [12]:
scaled_positions

array([0.00000000e+00, 9.56489331e+01, 1.06112770e+02, 1.25163075e+02,
       1.25165654e+02, 1.30180696e+02, 8.32324136e+02, 5.46381103e+00,
       1.04518408e+01, 1.64539287e+02, 1.49667558e+03, 1.49764277e+03,
       1.50882088e+03, 1.51133652e+03, 1.51150070e+03, 1.51940303e+03,
       1.52004036e+03, 1.54298286e+03, 1.55648925e+03, 2.27714545e+02,
       2.34053534e+02, 2.37066192e+02, 2.60658359e+02, 2.62607153e+02,
       2.62846406e+02, 2.65811638e+02, 2.19221552e+03, 2.19817037e+03,
       2.24448971e+03, 2.85363653e+02, 3.00387779e+02, 3.00856568e+02,
       3.01763381e+02, 3.04109870e+02, 3.07128196e+02, 3.21476809e+02,
       3.33464581e+02, 2.91336585e+03, 3.06989552e+01, 4.01981223e+02,
       3.62694533e+03, 3.63105631e+03, 4.76311738e+02, 4.85299368e+02,
       4.85834389e+02, 4.33580924e+03, 4.36409501e+03, 4.06326506e+01,
       4.17365370e+01, 5.31296171e+02, 5.03893731e+03, 5.05071960e+03,
       5.05141342e+03, 5.05207627e+03, 5.05844485e+03, 5.06433459e+03,
      

In [13]:
X_train_pos

array([[   0.        ,   95.64893306,  106.11276971, ...,  134.03435073,
        1442.72711463, 1578.66072119],
       [   0.        ,   95.64893306,  106.11276971, ...,  134.03435073,
        1442.72711463, 1578.66072119],
       [   0.        ,   95.64893306,  106.11276971, ...,  134.03435073,
        1442.72711463, 1578.66072119],
       ...,
       [   0.        ,   95.64893306,  106.11276971, ...,  134.03435073,
        1442.72711463, 1578.66072119],
       [   0.        ,   95.64893306,  106.11276971, ...,  134.03435073,
        1442.72711463, 1578.66072119],
       [   0.        ,   95.64893306,  106.11276971, ...,  134.03435073,
        1442.72711463, 1578.66072119]])

# Linear Multi Head Attention 
Taken from https://github.com/lucidrains/linformer/blob/master/linformer/linformer.py

In [14]:
# helper functions
def default(val, default_val):
    return val if val is not None else default_val

def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor

class LinformerSelfAttention(nn.Module):
    def __init__(self, embed_size, seq_len, k = 256, heads = 8, dim_head = None, one_kv_head = False, share_kv = False, dropout = 0.):
        super().__init__()
        
        dim = embed_size
        
        assert (dim % heads) == 0, 'dimension must be divisible by the number of heads'

        self.seq_len =  seq_len
        self.k = k
        

        self.heads = heads

        dim_head = default(dim_head, dim // heads)
        self.dim_head = dim_head

        self.to_q = nn.Linear(dim, dim_head * heads, bias = False)

        kv_dim = dim_head if one_kv_head else (dim_head * heads)
        self.to_k = nn.Linear(dim, kv_dim, bias = False)
        self.proj_k = nn.Parameter(init_(torch.zeros(seq_len, k)))

        self.share_kv = share_kv
        if not share_kv:
            self.to_v = nn.Linear(dim, kv_dim, bias = False)
            self.proj_v = nn.Parameter(init_(torch.zeros(seq_len, k)))

        self.dropout = nn.Dropout(dropout)
        self.to_out = nn.Linear(dim_head * heads, dim)

    def forward(self, x, context = None, **kwargs):
        # x shape is [batch, seq_len, embed_size]
        b, n, d, d_h, h, k = *x.shape, self.dim_head, self.heads, self.k

        kv_len = n if context is None else context.shape[1]
        assert kv_len <= self.seq_len, f'the sequence length of the key / values must be {self.seq_len} - {kv_len} given'

        queries = self.to_q(x)

        proj_seq_len = lambda args: torch.einsum('bnd,nk->bkd', *args)

        kv_input = x if context is None else context

        keys = self.to_k(kv_input)
        values = self.to_v(kv_input) if not self.share_kv else keys

        kv_projs = (self.proj_k, self.proj_v if not self.share_kv else self.proj_k)

        # allow for variable sequence lengths (less than maximum sequence length) by slicing projections

        if kv_len < self.seq_len:
            kv_projs = map(lambda t: t[:kv_len], kv_projs)

        # project keys and values along the sequence length dimension to k

        keys, values = map(proj_seq_len, zip((keys, values), kv_projs))

        # merge head into batch for queries and key / values

        queries = queries.reshape(b, n, h, -1).transpose(1, 2)

        merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)
        keys, values = map(merge_key_values, (keys, values))
        
        # attention

        dots = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (d_h ** -0.5)
        
        attn = dots.softmax(dim=-1)
        attn = self.dropout(attn)
        out = torch.einsum('bhnk,bhkd->bhnd', attn, values)

        # split heads
        out = out.transpose(1, 2).reshape(b, n, -1)
        return self.to_out(out)

# Encoder

Transformer Layer: 
- Multi-Head Attention
- Add & Norm
- Feed Forward
- Add & Norm again

In [15]:
class TransformerLayer(nn.Module):
    def __init__(self, embed_size, seq_len, heads, dropout, k, forward_expansion=4):
        super(TransformerLayer, self).__init__()
        self.attention = LinformerSelfAttention(embed_size, seq_len, k, heads, 
                                            dim_head = None, one_kv_head = False, share_kv = False, 
                                            dropout=dropout) 
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Values, Keys and Queries have size: (batch_size, query_len, embedding_size)
        attention = self.attention(x) # attention shape: (batch_size, query_len, embedding_size)
        # Add skip connection, run through normalization and finally dropout
        norm_out = self.dropout(self.norm1(attention + x)) # x shape: (batch_size, query_len, embedding_size)
        forward = self.feed_forward(norm_out) # forward shape: (batch_size, query_len, embedding_size)
        out = self.dropout(self.norm2(forward + x)) # out shape: (batch_size, query_len, embedding_size)
        return out

Encoder = Embedding + transformer layer

In [16]:
class Encoder(nn.Module):
    def __init__(self, snp_encoding_size, embed_size, seq_len, num_layers, heads,
        device, forward_expansion, dropout, k, max_pos_length): 
        super(Encoder, self).__init__()
        self.embed_size = embed_size # size of the input embedding
        self.device = device # either "cuda" or "cpu"
        # Lookup table with an embedding for each word in the vocabulary
        self.snp_embedding = SnpEmbedding(snp_encoding_size, embed_size)
        # Lookup table with a positional embedding for each word in the sequence
        self.position_embedding = PosEmbedding(max_pos_length, embed_size)
        
        self.layers = nn.ModuleList(
            [
                TransformerLayer(
                    embed_size,
                    seq_len,
                    heads,
                    dropout,
                    k,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, positions):
        """
        Forward pass.
        :param x: source sequence. Shape: (batch_size, source_sequence_len).
        :param positions: source positions. Shape: (batch_size, source_sequence_len).
        :return output: torch tensor of shape (batch_size, src_sequence_length, embedding_size)
        """
        batch_size, seq_length = x.shape
        
        embed_out = self.dropout(
            (self.snp_embedding(x) # Shape (batch_size, snps_total, embed_size) e.g. (64,5,128)
            + self.position_embedding(positions) # Shape (batch_size, snps_total, embed_size) e.g. (64,5,128)
            ) 
        )
        
        out = embed_out
        # Final shape should be [batch_size, snp_total, embed_size]
        
        # In the Encoder the query, key, value are all the same
        for layer in self.layers:
            out = layer(out)
        
        # TEST: add skip connection
        out = out + embed_out
        return out

# Transformer Model

In [17]:
class Transformer(nn.Module):
    def __init__(self, snp_encoding_size, src_pad_idx, embed_size, seq_len,
                 num_layers, forward_expansion, heads, dropout, k, device, max_pos_length, agg_fc):

        super(Transformer, self).__init__()
        # === Encoder ===
        self.encoder = Encoder(snp_encoding_size, embed_size, seq_len, num_layers, heads,
                               device, forward_expansion, dropout, k, max_pos_length )
        self.src_pad_idx = src_pad_idx
        self.device = device
        
        # === Regression Out ===
        self.fc_out = nn.Linear(embed_size, 1) # Single regression target value


    def forward(self, snp, pos):

        enc_out = self.encoder(snp, pos) 
        
        # Aggregate layers output e.g. mean or max
        aggregated_out=None
        if agg_fc=="max":
            aggregated_out, _ = enc_out.max(dim=1)  # [batch_size, embed_size]
        else:
            aggregated_out = enc_out.mean(dim=1)  # [batch_size, embed_size]
        
        out = self.fc_out(aggregated_out) # [batch_size, 1]
        return out

# DataLoader

In [18]:
snp_tensor = torch.tensor(X_train_snp.values, dtype=torch.long) # torch.long for integers

pos_tensor = torch.tensor(X_train_pos)

y_tensor = torch.tensor(y_train.values)

# KFold CV setup

In [19]:
kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

indices = np.arange(len(snp_tensor)) # List for tensor indexing with k-folds


# Training

In [None]:
training_log = []

for fold, (train_ids, val_ids) in enumerate(kf.split(indices)):
    # Split the data
    snp_train, snp_val = snp_tensor[train_ids], snp_tensor[val_ids]
    pos_train, pos_val = pos_tensor[train_ids], pos_tensor[val_ids]
    y_train, y_val = y_tensor[train_ids], y_tensor[val_ids]
    
    # Create DataLoader for both training and validation sets
    train_dataset = TensorDataset(snp_train, pos_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # batch_size defined in params
    
    val_dataset = TensorDataset(snp_val, pos_val, y_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    # Initialize model, loss function, and optimizer for each fold
    # This ensures that each fold starts with a fresh model
    model = Transformer(snp_encoding_size, src_pad_idx, embed_size, seq_len,
                     num_layers, forward_expansion, heads, dropout, k, device, max_pos_length, agg_fc)

    model = model.to(device)
    
    optimizer = optim.AdamW(model.parameters(), lr=lr,betas=(0.9, 0.999), eps=1e-8, amsgrad=True, weight_decay=weight_decay)
    loss_function = nn.MSELoss()
    #loss_function = nn.L1Loss()

    # Early stopping setup
    early_stopping_patience = early_stopping_patience
    epochs_no_improve = 0
    early_stop = False
    
    # Checkpoint setup - model with lowest avg_val_loss is saved
    best_val_loss = float('inf')
    
    # Logging setup
    loss_log = {
        'train_loss': [],
        'val_loss': []
    }
    
    # Training loop for the current fold
    for epoch in range(num_epochs):
        # Break the loop if early stopping is triggered
        if early_stop:
            print(f"Early stopping triggered after {epoch} epochs.")
            break  

        model.train() # Training mode
        train_loss = 0.0
        train_batches_count = 0
        
        for snp_batch, pos_batch, y_batch in train_loader:
            snp_batch = snp_batch.to(device)
            pos_batch = pos_batch.to(device)
            y_batch = y_batch.to(device)

            # Forward pass
            out = model(snp_batch, pos_batch) #[batch_size,1]  regression output scores

            # Zero the gradients if necessary
            optimizer.zero_grad()

            # Compute loss
            loss = loss_function(out.to(dtype=torch.float64), y_batch) 
            print(np.sqrt(loss.item()))
            train_loss+=loss.item() # Accumulate the training loss
            
            train_batches_count+=1
            
            # Backprop
            loss.backward()

            # Update Weights
            optimizer.step()
            
        # Calculate average training loss for the current epoch
        avg_train_loss = train_loss / train_batches_count
        logger.info(f"Fold {fold}, Epoch {epoch}, Training Loss: {avg_train_loss}") 
            
        # Validation loop for the current fold
        model.eval() # Eval mode
        val_loss = 0.0
        val_batches_count = 0
        
        with torch.no_grad():
            for snp_batch, pos_batch, y_batch in val_loader:
                snp_batch = snp_batch.to(device)
                pos_batch = pos_batch.to(device)
                y_batch = y_batch.to(device)
                
                # Forward pass with no gradient calculation
                out = model(snp_batch, pos_batch) #[batch_size,1]  regression output scores
                
                loss = loss_function(out.to(dtype=torch.float64), y_batch)
                
                val_loss += loss.item()  # Accumulate the validation loss
                
                val_batches_count += 1    
                
        # Calculate average validation loss for the current epoch
        avg_val_loss = val_loss / val_batches_count
        logger.info(f"Fold {fold}, Epoch {epoch}, Validation Loss: {avg_val_loss}") 
        
        # Checkpointing / early stopping
        if avg_val_loss < best_val_loss:
            logger.info(f"Validation loss improved from {best_val_loss} to {avg_val_loss}. Saving model...")
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), f"linformer_best_fold_{fold}_epoch_{epoch}.pth")
        else:
            epochs_no_improve += 1
            if epochs_no_improve == early_stopping_patience:
                early_stop = True
                print(f"Early stopping activated. Validation loss did not decrease for {early_stopping_patience} consecutive epochs.")
            
        # Log per epoch
        loss_log['train_loss'].append(avg_train_loss)
        loss_log['val_loss'].append(avg_val_loss)
        
        # with open(f'linformer_train_loss_log_fold_{fold}_epoch_{epoch}.json', 'w') as f:
        #     json.dump(loss_log['train_loss'], f)
            
        # with open(f'linformer_val_loss_log_fold_{fold}_epoch_{epoch}.json', 'w') as f:
        #     json.dump(loss_log['val_loss'], f)
    
    #Log per fold
    training_log.append(loss_log)
    with open('linformer_loss_log.json', 'w') as f:
        json.dump(training_log, f)



# Save final model

In [None]:
torch.save(model, "template_linformer_model.pth")