In [1]:
import pandas as pd
import re
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
import pickle
import torch.nn.functional as F

### Data processing

In [2]:
# Load genotype data from a NumPy file
geno_data = np.load('merged_geno_data.npy')

# List of phenotype (fitness) data files corresponding to 18 different environments
pheno_files = ['pheno_data_23C.txt', 'pheno_data_25C.txt', 'pheno_data_27C.txt', 'pheno_data_30C.txt', 
               'pheno_data_33C.txt', 'pheno_data_35C.txt', 'pheno_data_37C.txt', 'pheno_data_cu.txt',
               'pheno_data_suloc.txt', 'pheno_data_ynb.txt', 'pheno_data_eth.txt', 'pheno_data_gu.txt', 
               'pheno_data_li.txt', 'pheno_data_mann.txt', 'pheno_data_mol.txt', 'pheno_data_raff.txt', 
               'pheno_data_sds.txt', 'pheno_data_4NQO.txt']

# Shuffle the genotype data to randomize sample order
num_segregants = geno_data.shape[0]  # Total number of segregants (samples)
shuffled_indices = list(range(num_segregants))  # Create index list for shuffling
random.seed(0)  # Set a fixed seed for reproducibility
random.shuffle(shuffled_indices)  # Shuffle the indices
geno_data = geno_data[shuffled_indices]  # Apply shuffled indices to genotype data

# Load the list of independent loci for feature selection
ind_loci_list = np.load('ind_loci_list_3.npy')

# Initialize a list to store fitness values across all 18 environments
fitness_list = []

# Iterate through each phenotype file and extract the corresponding fitness values
for file in pheno_files:    
    df_pheno = pd.read_csv(file, sep="\t")  # Load phenotype data
    data_fitness = df_pheno.iloc[shuffled_indices, 1].to_numpy()  # Shuffle phenotype data to match genotype order
    fitness_list.append(data_fitness)  # Store fitness data for this environment

# Select only the independent loci from genotype data and transform values from [0,1] to [-1,+1]
geno_data = 2.0 * geno_data[:, sorted(ind_loci_list)] - 1.0

# Convert the list of fitness arrays into a NumPy array and transpose for proper shape
fitness = np.array(fitness_list)
fitness = fitness.T  # Final shape is (99950, 18), where rows = samples, columns = environments

### Attention layer class in PyTorch

In [3]:
class ThreeLayerAttention(nn.Module):
    def __init__(self, input_dim, query_dim, key_dim, seq_length):
        """
        Initializes a three-layer attention model.

        Args:
            input_dim (int): Dimension of input features.
            query_dim (int): Dimension of the query matrix.
            key_dim (int): Dimension of the key matrix.
            seq_length (int): Length of input sequence.
        """
        super(ThreeLayerAttention, self).__init__()
        self.input_dim = input_dim
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.seq_length = seq_length

        # Learnable parameters for the first attention layer
        self.query_matrix_1 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_1 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_1 = nn.Parameter(torch.empty(input_dim, input_dim))

        # Learnable parameters for the second attention layer
        self.query_matrix_2 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_2 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_2 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        # Learnable parameters for the third attention layer
        self.query_matrix_3 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_3 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_3 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        # Learnable random projection matrix for dimensionality reduction
        self.random_matrix = nn.Parameter(torch.empty(num_loci, low_dim))

        # Additional learnable coefficients for final transformation
        self.coeffs_attended = nn.Parameter(torch.empty(seq_length, input_dim))
        self.offset = nn.Parameter(torch.randn(1))  # Learnable bias term

        # Initialize model parameters
        self.init_parameters()

    def init_parameters(self):
        """Initializes model parameters with a small normal distribution."""
        init_scale = 0.03  # Standard deviation for parameter initialization

        # Initialize all weight matrices and learnable parameters
        for param in [self.query_matrix_1, self.key_matrix_1, self.value_matrix_1,
                      self.query_matrix_2, self.key_matrix_2, self.value_matrix_2,
                      self.query_matrix_3, self.key_matrix_3, self.value_matrix_3,
                      self.random_matrix, self.coeffs_attended, self.offset]:
            init.normal_(param, std=init_scale)

    def forward(self, x):
        """
        Forward pass through the three-layer self-attention mechanism.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_length, input_dim).

        Returns:
            Tensor: Final attended output of shape (batch_size,).
        """
        # Apply random projection to the sequence data
        y = torch.cat((torch.matmul(x[:, :seq_length-1, :seq_length-1], self.random_matrix), 
                       x[:, :-1, -19:]), dim=2)

        # Concatenate the transformed sequence with the last element 
        x = torch.cat((y, x[:, -1, -low_dim-19:].unsqueeze(1)), dim=1)

        # First self-attention layer
        query_1 = torch.matmul(x, self.query_matrix_1)
        key_1 = torch.matmul(x, self.key_matrix_1)
        value_1 = torch.matmul(x, self.value_matrix_1)
        scores_1 = torch.matmul(query_1, key_1.transpose(1, 2))  # Compute attention scores
        scores_1 = torch.softmax(scores_1, dim=-1)  # Normalize scores with softmax
        attended_values_1 = torch.matmul(scores_1, value_1)  # Apply attention weights

        # Second self-attention layer
        query_2 = torch.matmul(attended_values_1, self.query_matrix_2)
        key_2 = torch.matmul(attended_values_1, self.key_matrix_2)
        value_2 = torch.matmul(attended_values_1, self.value_matrix_2)
        scores_2 = torch.matmul(query_2, key_2.transpose(1, 2))
        scores_2 = torch.softmax(scores_2, dim=-1)
        attended_values_2 = torch.matmul(scores_2, value_2)

        # Third self-attention layer
        query_3 = torch.matmul(attended_values_2, self.query_matrix_3)
        key_3 = torch.matmul(attended_values_2, self.key_matrix_3)
        value_3 = torch.matmul(attended_values_2, self.value_matrix_3)
        scores_3 = torch.matmul(query_3, key_3.transpose(1, 2))
        scores_3 = torch.softmax(scores_3, dim=-1)
        attended_values_3 = torch.matmul(scores_3, value_3)

        # Compute final weighted sum using learned coefficients
        attended_values_3 = torch.einsum("bij,ij->b", attended_values_3, self.coeffs_attended)

        # Add offset term to adjust output scale
        output = attended_values_3 + self.offset

        return output


### Create one-hot environment embedding

In [4]:
def env_embed(nan_indices):
    """
    Creates an environment embedding tensor with one-hot encoding for 18 environments.

    Args:
        nan_indices (list): Indices of data points that should be removed from the output.

    Returns:
        Tensor: A tensor of shape (filtered_batch_size, 1, num_loci + 18),
                where the last 18 elements encode the environment identity.
    """
    # Initialize a zero tensor with shape (batch_size, 1, num_loci + 18)
    tensor = torch.zeros((batch_size, 1, num_loci + 18))

    # Assign one-hot encoding for environments
    for i in range(batch_size):
        position = i // num_elements  # Determine the environment index
        tensor[i, 0, num_loci + position] = 1  # Set the corresponding environment bit to 1

    # Remove indices corresponding to NaN values
    valid_indices = [i for i in range(batch_size) if i not in nan_indices]
    tensor = tensor[valid_indices, :, :]  # Select only valid rows

    return tensor

### Training all the environments combined

In [None]:
# Define dimensionality of the low-dimensional space
low_dim = 12

# Directory for saving model parameters and results
save_dir = "./multi_env_attention_QTL_yeast_data/"

# Create the directory if it doesn't exist
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# File for storing R² validation scores
filename = "./multi_env_attention_QTL_yeast_data/val_r2.txt"

def write_to_file(filename, a, b, c):
    """
    Writes validation R² scores to a file after each epoch.

    Args:
        filename (str): Path to the output file.
        a (int): Low-dimensional embedding size.
        b (int): Epoch number.
        c (float): R² score on validation data.
    """
    with open(filename, 'a') as file:
        file.write(f"{a} {b} {c}\n")

# Move model and tensors to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Split data into training (85%) and testing (15%) sets
X_train, X_test, y_train, y_test = train_test_split(geno_data, fitness, test_size=0.15, random_state=42)

# Further split training set into training (85%) and validation (15%) sets
X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train, test_size=0.15, random_state=42)

# Compute mean and standard deviation for normalization
mean_values = np.nanmean(y_train2, axis=0)
std_values = np.nanstd(y_train2, axis=0)

# Normalize target (fitness) data
y_train2 = (y_train2 - mean_values) / std_values
y_val = (y_val - mean_values) / std_values
y_test = (y_test - mean_values) / std_values

# Convert genotype and fitness data to PyTorch tensors
X_train_tens = torch.tensor(X_train2).float()
y_train_tens = torch.tensor(np.array(y_train2)).float()

# Define model input dimensions
num_loci = len(ind_loci_list)
seq_length = num_loci + 1
input_dim = low_dim + 18 + 1
query_dim = low_dim + 18 + 1
key_dim = low_dim + 18 + 1

# Initialize the attention model and move it to device
attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim, seq_length).to(device)

# Define loss function and optimizer
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(attention_layer.parameters(), lr=0.001)

# Define batch and training parameters
num_elements = 4  # Number of segregants per environment
batch_size = 18 * num_elements
chunk_size = 50
num_epochs = 3000
num_batches = X_train_tens.size(0) // batch_size

# TRAINING LOOP
for epoch in range(num_epochs):
    
    # Shuffle training data before each epoch
    indices = torch.randperm(X_train_tens.size(0))
    train_input_shuffled = X_train_tens[indices]
    train_target_shuffled = y_train_tens[indices].T

    # Mini-batch training using stochastic gradient descent
    for i in range(num_batches):
        
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
                
        fitness_18_env = []

        # Collect num_elements fitness values per environment
        for i, row in enumerate(train_target_shuffled):
            start_col = start_idx + i * num_elements
            end_col = start_col + num_elements
            fitness_18_env.extend(row[start_col:end_col])

        # Create mini-batches
        mini_batch_input = train_input_shuffled[start_idx:end_idx]
        mini_batch_target = torch.tensor(fitness_18_env)

        # Remove NaN values from target
        nan_mask = np.isnan(mini_batch_target)
        nan_indices = np.where(nan_mask)[0]
        mini_batch_input = np.delete(mini_batch_input, nan_indices, axis=0).to(device)
        mini_batch_target = np.delete(mini_batch_target, nan_indices).to(device)

        # Create environment embeddings
        mini_batch_env = torch.tensor(np.delete(np.repeat(np.arange(18), 4), nan_indices)).float().to(device)
        mini_batch_input_env = torch.cat((mini_batch_input, mini_batch_env.view(-1, 1)), dim=1)

        # Create one-hot vector embedding for loci
        one_hot_mini_batch_input = torch.zeros((mini_batch_input.shape[0], num_loci, num_loci + 18), device=device)
        indices = torch.arange(num_loci, device=device)
        one_hot_mini_batch_input[:, indices, indices] = mini_batch_input.squeeze()
        
        # Generate environment embeddings and concatenate with input
        env_emb = env_embed(nan_indices).float().to(device)
        one_hot_mini_batch_input = torch.cat((one_hot_mini_batch_input, env_emb), dim=1)
        one_hot_mini_batch_input = torch.cat((one_hot_mini_batch_input, torch.ones((mini_batch_input.shape[0], seq_length, 1)).to(device)), dim=2)

        # Forward pass
        train_output = attention_layer(one_hot_mini_batch_input).to(device)
        train_loss = loss_function(train_output, mini_batch_target)

        # Backpropagation
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

    # Save model after each epoch
    model_state_path = os.path.join(save_dir, f"epoch_{epoch}.pt")
    torch.save(attention_layer.state_dict(), model_state_path)

    # VALIDATION STEP
    y_pred = torch.tensor([]).to(device)
    y_val_all = np.array([])

    for env in range(18):
        env_emb = np.zeros(num_loci + 18)
        env_emb[int(num_loci + env)] = 1
        env_emb = torch.tensor(env_emb).float().to(device)

        y_val_env = y_val.T[env]

        # Remove NaN values from validation fitness data
        nan_mask = np.isnan(y_val_env)
        nan_indices = np.where(nan_mask)[0]
        X_val_env = np.delete(X_val, nan_indices, axis=0)
        y_val_env = np.delete(y_val_env, nan_indices, axis=0)

        # Convert validation data to PyTorch tensor
        X_val_tens = torch.tensor(np.array(X_val_env)).float().to(device)

        # Process validation data in chunks
        for i in range(0, len(X_val_tens), chunk_size):
            chunk = X_val_tens[i:i + chunk_size].to(device)
            chunk_size_actual = chunk_size if i + chunk_size <= len(X_val_tens) else len(X_val_tens) - i

            # Create environment embeddings for validation chunk
            chunk_env = torch.tensor(np.repeat(env, chunk_size_actual)).float().to(device)
            chunk_input_env = torch.cat((chunk, chunk_env.view(-1, 1)), dim=1)

            one_hot_val_input = torch.zeros((chunk_size_actual, num_loci, num_loci + 18), device=device)
            indices = torch.arange(num_loci, device=device)
            one_hot_val_input[:, indices, indices] = chunk.squeeze(dim=1)
            env_emb2 = env_emb.unsqueeze(0).unsqueeze(1).repeat(one_hot_val_input.shape[0], 1, 1)
            one_hot_val_input = torch.cat((one_hot_val_input, env_emb2), dim=1)
            one_hot_val_input = torch.cat((one_hot_val_input, torch.ones((chunk_size_actual, seq_length, 1)).to(device)), dim=2)

            with torch.no_grad():
                i_pred = attention_layer(one_hot_val_input)

            y_pred = torch.cat((y_pred, i_pred), dim=0)

        y_val_all = np.concatenate((y_val_all, y_val_env))

    val_r_squared = r2_score(y_val_all, y_pred.cpu())

    # Save validation performance
    write_to_file(filename, low_dim, epoch, val_r_squared)
    torch.cuda.empty_cache()


### Model prediction performance on test data

In [None]:
# Define low-dimensional space size
low_dim = 12

# Load validation R² scores and retrieve the best epoch
filename = "./multi_env_attention_QTL_yeast_data/val_r2.txt"
data = pd.read_csv(filename, sep='\s+', header=None)
max_row_index = data[2].idxmax()  # Index of max validation R² score
max_row = data.loc[max_row_index]  # Retrieve the corresponding row
max_second_column_value = max_row[1]  # Get epoch with highest R² score

# Move model and tensors to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Split dataset into training, validation, and testing sets
X_train, X_test, y_train, y_test = train_test_split(geno_data, fitness, test_size=0.15, random_state=42)
X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train, test_size=0.15, random_state=42)

# Compute mean and standard deviation for normalization
mean_values = np.nanmean(y_train2, axis=0)
std_values = np.nanstd(y_train2, axis=0)

# Normalize target (fitness) data
y_train2 = (y_train2 - mean_values) / std_values
y_val = (y_val - mean_values) / std_values
y_test = (y_test - mean_values) / std_values

# Define model input dimensions
chunk_size = 100
num_loci = len(ind_loci_list)
seq_length = num_loci + 1
input_dim = low_dim + 18 + 1
query_dim = low_dim + 18 + 1
key_dim = low_dim + 18 + 1

# Initialize the ThreeLayerAttention model and move it to device
attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim, seq_length).to(device)

# Load the best-performing model checkpoint
epoch = int(max_second_column_value)
model_path = f"./multi_env_attention_QTL_yeast_data/epoch_{epoch}.pt"
state_dict = torch.load(model_path, map_location=device)
attention_layer.load_state_dict(state_dict)
attention_layer.to(device)  # Move model to the appropriate device
attention_layer.eval()  # Set model to evaluation mode

# Lists to store results for later R² calculations
y_test_all = []
y_pred_test_all = []

# File to store test R² scores
filename2 = "./multi_env_attention_QTL_yeast_data/test_r2_scores.txt"

# Remove existing test R² score file if it exists
if os.path.exists(filename2):
    os.remove(filename2)

# Directory for saving test predictions
save_dir = "./multi_env_attention_QTL_yeast_data/predictions_test"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

def write_to_file(filename, a, b):
    """
    Writes test R² scores to a file after evaluation.

    Args:
        filename (str): Path to the output file.
        a (str): Environment name.
        b (float): R² score on test data.
    """
    with open(filename, 'a') as file:
        file.write(f"{a} {b}\n")

# Iterate over 18 environments and evaluate test performance
for env in range(18):
    file = pheno_files[env]
    sheet = file.split('_')[2].split('.')[0]  # Extract environment name

    # Create one-hot encoding for the current environment
    env_emb = np.zeros(num_loci + 18)
    env_emb[num_loci + env] = 1
    env_emb = torch.tensor(env_emb).float().to(device)

    # Extract test fitness data for the current environment
    y_test_env = y_test.T[env]

    # Identify and remove NaN values
    nan_mask = np.isnan(y_test_env)
    nan_indices = np.where(nan_mask)[0]
    y_pred_test_env = torch.tensor([]).to(device)

    X_test_env = np.delete(X_test, nan_indices, axis=0)
    y_test_env = np.delete(y_test_env, nan_indices, axis=0)

    # Convert test data to PyTorch tensors
    X_test_tens = torch.tensor(np.array(X_test_env)).float().to(device)

    # Process test data in chunks to prevent memory issues
    for i in range(0, len(X_test_tens), chunk_size):
        chunk = X_test_tens[i:i + chunk_size].to(device)

        if i + chunk_size > len(X_test_tens):
            chunk_size_actual = len(X_test_tens) - i
        else:
            chunk_size_actual = chunk_size

        # Create environment embedding for the chunk
        chunk_env = torch.tensor(np.repeat(env, chunk_size_actual)).float().to(device)
        chunk_input_env = torch.cat((chunk, chunk_env.view(-1, 1)), dim=1)

        # Create one-hot vector encoding for test data
        one_hot_test_input = torch.zeros((chunk_size_actual, num_loci, num_loci + 18), device=device)
        indices = torch.arange(num_loci, device=device)
        one_hot_test_input[:, indices, indices] = chunk.squeeze(dim=1)

        # Concatenate environment embedding
        env_emb2 = env_emb.unsqueeze(0).unsqueeze(1).repeat(one_hot_test_input.shape[0], 1, 1)
        one_hot_test_input = torch.cat((one_hot_test_input, env_emb2), dim=1)
        one_hot_test_input = torch.cat((one_hot_test_input, torch.ones((chunk_size_actual, seq_length, 1)).to(device)), dim=2)

        # Perform inference with the model
        with torch.no_grad():
            i_pred = attention_layer(one_hot_test_input)

        # Store predictions
        y_pred_test_env = torch.cat((y_pred_test_env, i_pred), dim=0)

    # Unnormalize predictions
    y_test_env_unnorm = y_test_env * std_values[env] + mean_values[env]
    y_pred_test_env_unnorm = y_pred_test_env.cpu().numpy() * std_values[env] + mean_values[env]

    # Store predictions for overall R² calculation
    y_test_all.extend(y_test_env)
    y_pred_test_all.extend(y_pred_test_env.cpu().numpy())

    # Save predictions for the current environment
    np.save(f"{save_dir}/y_test_{sheet}.npy", y_test_env_unnorm)
    np.save(f"{save_dir}/y_pred_test_{sheet}.npy", y_pred_test_env_unnorm)

    # Compute and save R² score for the current environment
    test_r2 = r2_score(y_test_env, y_pred_test_env.cpu())
    write_to_file(filename2, sheet, test_r2)
    print(sheet, test_r2)

# Calculate and print overall R² score
overall_r2 = r2_score(y_test_all, y_pred_test_all)
print("Overall R²:", overall_r2)