In [1]:
import pandas as pd
import re
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
import pickle
import torch.nn.functional as F

### Data processing

In [2]:
# Load genotype data from a NumPy file
geno_data = np.load('merged_geno_data.npy')

# List of phenotype (fitness) data files corresponding to 18 different environments
pheno_files = ['pheno_data_23C.txt', 'pheno_data_25C.txt', 'pheno_data_27C.txt', 'pheno_data_30C.txt', 
               'pheno_data_33C.txt', 'pheno_data_35C.txt', 'pheno_data_37C.txt', 'pheno_data_cu.txt',
               'pheno_data_suloc.txt', 'pheno_data_ynb.txt', 'pheno_data_eth.txt', 'pheno_data_gu.txt', 
               'pheno_data_li.txt', 'pheno_data_mann.txt', 'pheno_data_mol.txt', 'pheno_data_raff.txt', 
               'pheno_data_sds.txt', 'pheno_data_4NQO.txt']

# Shuffle the genotype data to randomize sample order
num_segregants = geno_data.shape[0]  # Total number of segregants (samples)
shuffled_indices = list(range(num_segregants))  # Create index list for shuffling
random.seed(0)  # Set a fixed seed for reproducibility
random.shuffle(shuffled_indices)  # Shuffle the indices
geno_data = geno_data[shuffled_indices]  # Apply shuffled indices to genotype data

# Load the list of independent loci for feature selection
ind_loci_list = np.load('ind_loci_list_3.npy')

# Initialize a list to store fitness values across all 18 environments
fitness_list = []

# Iterate through each phenotype file and extract the corresponding fitness values
for file in pheno_files:    
    df_pheno = pd.read_csv(file, sep="\t")  # Load phenotype data
    data_fitness = df_pheno.iloc[shuffled_indices, 1].to_numpy()  # Shuffle phenotype data to match genotype order
    fitness_list.append(data_fitness)  # Store fitness data for this environment

# Select only the independent loci from genotype data and transform values from [0,1] to [-1,+1]
geno_data = 2.0 * geno_data[:, sorted(ind_loci_list)] - 1.0

# Convert the list of fitness arrays into a NumPy array and transpose for proper shape
fitness = np.array(fitness_list)
fitness = fitness.T  # Final shape is (99950, 18), where rows = samples, columns = environments

### Attention layer class in PyTorch

In [3]:
class ThreeLayerAttention(nn.Module):
    def __init__(self, input_dim, query_dim, key_dim):
        """
        Implements a three-layer attention mechanism.

        Args:
            input_dim (int): Dimension of input features.
            query_dim (int): Dimension of the query matrix.
            key_dim (int): Dimension of the key matrix.
        """
        super(ThreeLayerAttention, self).__init__()
        self.input_dim = input_dim
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.seq_length = seq_length  

        # Learnable weight matrices for the first attention layer
        self.query_matrix_1 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_1 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_1 = nn.Parameter(torch.empty(input_dim, input_dim))

        # Learnable weight matrices for the second attention layer
        self.query_matrix_2 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_2 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_2 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        # Learnable weight matrices for the third attention layer
        self.query_matrix_3 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_3 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_3 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        # Learnable random projection matrix (reduces input dimensionality)
        self.random_matrix = nn.Parameter(torch.empty(num_loci, low_dim))  

        # Learnable coefficients for attended values
        self.coeffs_attended = nn.Parameter(torch.empty(seq_length, input_dim))

        # Learnable scalar offset for output adjustment
        self.offset = nn.Parameter(torch.randn(1))

        # Initialize parameters
        self.init_parameters()

    def init_parameters(self):
        """Initializes model parameters using a normal distribution with a small standard deviation."""
        init_scale = 0.03  # Small scale for initialization to prevent exploding gradients

        for param in [self.query_matrix_1, self.key_matrix_1, self.value_matrix_1,
                      self.query_matrix_2, self.key_matrix_2, self.value_matrix_2,
                      self.query_matrix_3, self.key_matrix_3, self.value_matrix_3,
                      self.random_matrix, self.coeffs_attended, self.offset]:
            init.normal_(param, std=init_scale)

    def forward(self, x):
        """
        Forward pass through three layers of self-attention.

        Args:
            x (Tensor): Input tensor of shape (batch_size, num_loci, input_dim)

        Returns:
            Tensor: Final attended output of shape (batch_size,)
        """
        # Apply a random projection and concatenate it with the last feature, which consists entirely of ones
        y = torch.cat((torch.matmul(x[:, :, :seq_length], self.random_matrix), x[:, :, -1:]), dim=2)

        # First self-attention layer
        query_1 = torch.matmul(y, self.query_matrix_1)
        key_1 = torch.matmul(y, self.key_matrix_1)
        value_1 = torch.matmul(y, self.value_matrix_1)
        scores_1 = torch.matmul(query_1, key_1.transpose(1, 2))
        scores_1 = torch.softmax(scores_1, dim=-1)  # Softmax for attention weighting
        attended_values_1 = torch.matmul(scores_1, value_1)

        # Second self-attention layer
        query_2 = torch.matmul(attended_values_1, self.query_matrix_2)
        key_2 = torch.matmul(attended_values_1, self.key_matrix_2)
        value_2 = torch.matmul(attended_values_1, self.value_matrix_2)
        scores_2 = torch.matmul(query_2, key_2.transpose(1, 2))
        scores_2 = torch.softmax(scores_2, dim=-1)
        attended_values_2 = torch.matmul(scores_2, value_2)

        # Third self-attention layer
        query_3 = torch.matmul(attended_values_2, self.query_matrix_3)
        key_3 = torch.matmul(attended_values_2, self.key_matrix_3)
        value_3 = torch.matmul(attended_values_2, self.value_matrix_3)
        scores_3 = torch.matmul(query_3, key_3.transpose(1, 2))
        scores_3 = torch.softmax(scores_3, dim=-1)
        attended_values_3 = torch.matmul(scores_3, value_3)

        # Compute final weighted sum using learned coefficients
        attended_values_3 = torch.einsum("bij,ij->b", attended_values_3, self.coeffs_attended)

        # Add offset term to adjust output scale
        output = attended_values_3 + self.offset

        return output


### Training each environment individually

In [None]:
def write_to_file(filename, a, b):
    """
    Appends validation R² score for each epoch to a file.

    Args:
        filename (str): Path to the output file.
        a (int): Epoch number.
        b (float): R² score on validation data.
    """
    with open(filename, 'a') as file:
        file.write(f"{a} {b}\n")

# Set computation device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Split dataset into training (85%) and testing (15%) sets
X_train, X_test, y_train, y_test = train_test_split(geno_data, fitness, test_size=0.15, random_state=42)

# Further split training set into training (85%) and validation (15%) sets
X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train, test_size=0.15, random_state=42)

# Normalize target (fitness) data by computing mean and standard deviation from y_train2
mean_values = np.nanmean(y_train2, axis=0)
std_values = np.nanstd(y_train2, axis=0)

# Standardize the target data
y_train2 = (y_train2 - mean_values) / std_values
y_val = (y_val - mean_values) / std_values
y_test = (y_test - mean_values) / std_values

# Define dimensions for the attention model
low_dim = 12
num_loci = len(ind_loci_list)
seq_length = num_loci
input_dim = low_dim + 1
query_dim = low_dim + 1
key_dim = low_dim + 1

# Convert training genotype data to a PyTorch tensor
X_train_tens = torch.tensor(X_train2).float()

# Iterate over 18 environments for separate training runs
for env in range(0, 18):
    
    # Extract filename and sheet name from phenotype file
    file = pheno_files[env]
    sheet = file.split('_')[2].split('.')[0]
    
    # Define save directory for model checkpoints
    save_dir = f"./single_env_attention_QTL_yeast_data/{sheet}"
    
    # Create the directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Define path for validation R² scores and remove existing file
    filename = f"{save_dir}/validation_r2.txt"
    if os.path.exists(filename):
        os.remove(filename)

    # Initialize attention model and move it to GPU if available
    attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim).to(device)

    # Extract target data for the current environment
    y_train2_env = y_train2[:, env]
    y_val_env = y_val[:, env]
    y_test_env = y_test[:, env]

    # Convert target data to a PyTorch tensor
    y_train_tens = torch.tensor(np.array(y_train2_env)).float()

    # Define loss function and optimizer
    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(attention_layer.parameters(), lr=0.001)

    # Define batch size and chunking parameters
    batch_size = 64
    chunk_size = 50
    num_epochs = 200
    num_batches = X_train_tens.size(0) // batch_size

    # TRAINING LOOP
    for epoch in range(num_epochs):

        # Shuffle training data before each epoch
        indices = torch.randperm(X_train_tens.size(0))
        train_input_shuffled = X_train_tens[indices]
        train_target_shuffled = y_train_tens[indices]

        # Mini-batch training using stochastic gradient descent
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size

            # Extract mini-batch
            mini_batch_input = train_input_shuffled[start_idx:end_idx]
            mini_batch_target = train_target_shuffled[start_idx:end_idx]  

            # Identify NaN values in target data    
            nan_mask = np.isnan(mini_batch_target)
            nan_indices = np.where(nan_mask)[0]

            # Remove NaN entries from mini-batch
            mini_batch_input = np.delete(mini_batch_input, nan_indices, axis=0).to(device)
            mini_batch_target = np.delete(mini_batch_target, nan_indices).to(device)

            # Create one-hot vector embedding for genotype data
            one_hot_mini_batch_input = torch.zeros((mini_batch_input.shape[0], num_loci, num_loci), device=device)
            indices = torch.arange(num_loci, device=device)
            one_hot_mini_batch_input[:, indices, indices] = mini_batch_input.squeeze()
            one_hot_mini_batch_input = torch.cat((one_hot_mini_batch_input, torch.ones((mini_batch_input.shape[0], seq_length, 1)).to(device)), dim=2)

            # Forward pass
            train_output = attention_layer(one_hot_mini_batch_input).to(device)

            # Compute loss
            train_loss = loss_function(train_output, mini_batch_target)

            # Backpropagation and optimization
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        # Save model state after each epoch
        model_state_path = os.path.join(save_dir, f"epoch_{epoch}.pt")
        torch.save(attention_layer.state_dict(), model_state_path)

        # VALIDATION STEP
        y_pred = torch.tensor([]).to(device)

        # Identify NaN values in validation fitness data
        nan_mask = np.isnan(y_val_env)
        nan_indices = np.where(nan_mask)[0]

        # Remove NaN values from validation data
        X_val_env = np.delete(X_val, nan_indices, axis=0)
        y_val_env2 = np.delete(y_val_env, nan_indices, axis=0)

        # Convert validation genotype data to PyTorch tensor
        X_val_tens = torch.tensor(np.array(X_val_env)).float().to(device)

        # Avoid memory issues by processing validation data in chunks
        for i in range(0, len(X_val_tens), chunk_size):
            chunk = X_val_tens[i:i + chunk_size].to(device)

            if i + chunk_size > len(X_val_tens):
                chunk_size_actual = len(X_val_tens) - i
            else:
                chunk_size_actual = chunk_size

            # Create one-hot vector embedding for validation data
            one_hot_val_input = torch.zeros((chunk_size_actual, num_loci, num_loci + 18), device=device)
            indices = torch.arange(num_loci, device=device)
            one_hot_val_input[:, indices, indices] = chunk.squeeze(dim=1)
            one_hot_val_input = torch.cat((one_hot_val_input, torch.ones((chunk_size_actual, seq_length, 1)).to(device)), dim=2)

            # Compute predictions
            with torch.no_grad():
                i_pred = attention_layer(one_hot_val_input)

            y_pred = torch.cat((y_pred, i_pred), dim=0)

        # Compute R² score for validation data
        val_r_squared = r2_score(y_val_env2, y_pred.cpu())

        # Save validation R² score for current epoch
        write_to_file(filename, epoch, val_r_squared)

        # Clear GPU cache
        torch.cuda.empty_cache()

### Model prediction performance on test data

In [None]:
def write_to_file(filename, a, b):
    """
    Appends test R² score for each environment to a file.

    Args:
        filename (str): Path to the output file.
        a (str): Environment name.
        b (float): R² score on test data.
    """
    with open(filename, 'a') as file:
        file.write(f"{a} {b}\n")

# Define the file path for saving test R² scores
filename2 = "./single_env_attention_QTL_yeast_data/test_r2_scores.txt"

# Remove existing test R² score file if it exists
if os.path.exists(filename2):
    os.remove(filename2)

# Create directories for saving test predictions and true values
if not os.path.exists("./single_env_attention_QTL_yeast_data/test_predictions"):
    os.makedirs("./single_env_attention_QTL_yeast_data/test_predictions")

if not os.path.exists("./single_env_attention_QTL_yeast_data/test_true"):
    os.makedirs("./single_env_attention_QTL_yeast_data/test_true")

# Iterate over all 18 environments
for env in range(18):
    
    file = pheno_files[env]
    sheet = file.split('_')[2].split('.')[0]  # Extract environment name

    # Retrieve the epoch with the highest R² score from validation set
    filename = f"./single_env_attention_QTL_yeast_data/{sheet}/validation_r2.txt"
    data = pd.read_csv(filename, sep='\s+', header=None)
    max_row_index = data[1].idxmax()  # Find the index of the maximum R² score
    max_row = data.loc[max_row_index]  # Get the corresponding row
    max_first_column_value = max_row[0]  # Extract the epoch number with the best performance

    # Set computation device (GPU if available)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Split dataset into training (85%) and testing (15%) sets for the current environment
    X_train, X_test, y_train, y_test = train_test_split(geno_data, fitness[:, env], test_size=0.15, random_state=42)
    X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train, test_size=0.15, random_state=42)

    # Compute mean and standard deviation from y_train2 for normalization
    mean_values = np.nanmean(y_train2, axis=0)
    std_values = np.nanstd(y_train2, axis=0)

    # Normalize target fitness data
    y_train2 = (y_train2 - mean_values) / std_values
    y_val = (y_val - mean_values) / std_values
    y_test = (y_test - mean_values) / std_values

    # Define model input dimensions
    low_dim = 12
    num_loci = len(ind_loci_list)
    seq_length = num_loci
    input_dim = low_dim + 1
    query_dim = low_dim + 1
    key_dim = low_dim + 1
    chunk_size = 50

    # Initialize the attention model and move it to the selected device
    attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim).to(device)

    # Load the trained model checkpoint with the best validation performance
    epoch = int(max_first_column_value)
    model_path = f"./single_env_attention_QTL_yeast_data/{sheet}/epoch_{epoch}.pt"
    state_dict = torch.load(model_path, map_location=device)
    attention_layer.load_state_dict(state_dict)
    attention_layer.to(device)
    attention_layer.eval()  # Set the model to evaluation mode

    # Identify NaN values in test fitness data
    nan_mask = np.isnan(y_test)
    nan_indices = np.where(nan_mask)[0]
    y_pred_test_env = torch.tensor([]).to(device)

    # Remove NaN values from test data
    X_test_env = np.delete(X_test, nan_indices, axis=0)
    y_test_env = np.delete(y_test, nan_indices, axis=0)

    # Convert test genotype data to PyTorch tensor
    X_test_tens = torch.tensor(np.array(X_test_env)).float().to(device)

    # PROCESS TEST DATA IN CHUNKS TO AVOID MEMORY ISSUES
    for i in range(0, len(X_test_tens), chunk_size):
        chunk = X_test_tens[i:i + chunk_size].to(device)

        if i + chunk_size > len(X_test_tens):
            chunk_size_actual = len(X_test_tens) - i
        else:
            chunk_size_actual = chunk_size

        # Create one-hot vector encoding for test data
        one_hot_test_input = torch.zeros((chunk_size_actual, num_loci, num_loci + 18), device=device)
        indices = torch.arange(num_loci, device=device)
        one_hot_test_input[:, indices, indices] = chunk.squeeze(dim=1)
        one_hot_test_input = torch.cat((one_hot_test_input, torch.ones((chunk_size_actual, seq_length, 1)).to(device)), dim=2)

        # Get model predictions
        with torch.no_grad():
            i_pred = attention_layer(one_hot_test_input)

        y_pred_test_env = torch.cat((y_pred_test_env, i_pred), dim=0)

    # Compute R² score for test set
    test_r2 = r2_score(y_test_env, y_pred_test_env.cpu())
    write_to_file(filename2, sheet, test_r2)
    print(sheet, test_r2)

    # Unnormalize the predictions to the original scale
    y_test_env_unnorm = y_test_env * std_values + mean_values
    y_pred_test_env_unnorm = y_pred_test_env.cpu().numpy() * std_values + mean_values

    # Save unnormalized predictions and true values
    np.save(f"./single_env_attention_QTL_yeast_data/test_predictions/{sheet}.npy", y_pred_test_env_unnorm)
    np.save(f"./single_env_attention_QTL_yeast_data/test_true/{sheet}.npy", y_test_env_unnorm)