In [1]:
import pandas as pd
import re
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
import pickle
import torch.nn.functional as F

### Data processing

In [2]:
# Load genotype data
geno_data = np.load('merged_geno_data.npy')

# List of phenotype data files corresponding to different temperature environments
pheno_files = ['pheno_data_23C.txt', 'pheno_data_25C.txt', 'pheno_data_27C.txt', 
               'pheno_data_30C.txt', 'pheno_data_33C.txt', 'pheno_data_35C.txt', 'pheno_data_37C.txt']

# Shuffle the genotype data to randomize the order of segregants
num_segregants = geno_data.shape[0]
shuffled_indices = list(range(num_segregants))
random.seed(0)  # Use a fixed seed for reproducibility
random.shuffle(shuffled_indices)
geno_data = geno_data[shuffled_indices]

# Load causal loci indices
ind_loci_list = np.load('ind_loci_list_3.npy')

# Initialize an empty list to store fitness values for different environments
fitness_list = []

# Load phenotype (fitness) data for each environment and apply the same shuffled indices
for file in pheno_files:
    df_pheno = pd.read_csv(file, sep="\t")  # Load phenotype data from a tab-separated file
    data_fitness = df_pheno.iloc[shuffled_indices, 1].to_numpy()  # Extract fitness values
    fitness_list.append(data_fitness)

# Select only the causal loci from genotype data and convert [0,1] to [-1,+1]
geno_data = 2.0 * geno_data[:, sorted(ind_loci_list)] - 1.0

# Reshape the fitness array to match (num_segregants, num_environments)
fitness = np.array(fitness_list).T  # Shape becomes (99950, 7)


### Attention layer class in PyTorch

In [3]:
class ThreeLayerAttention(nn.Module):
    def __init__(self, input_dim, query_dim, key_dim, seq_length):
        """
        Implements a three-layer attention model

        Args:
            input_dim (int): Dimension of the input features.
            query_dim (int): Dimension of the query matrix.
            key_dim (int): Dimension of the key matrix.
            seq_length (int): Length of the input sequence.
        """
        super(ThreeLayerAttention, self).__init__()
        self.input_dim = input_dim
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.seq_length = seq_length

        # Learnable weight matrices for the first attention layer
        self.query_matrix_1 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_1 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_1 = nn.Parameter(torch.empty(input_dim, input_dim))

        # Learnable weight matrices for the second attention layer
        self.query_matrix_2 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_2 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_2 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        # Learnable weight matrices for the third attention layer
        self.query_matrix_3 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_3 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_3 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        # Learnable random projection matrix (reduces input dimensionality)
        self.random_matrix = nn.Parameter(torch.empty(seq_length - 1, low_dim))
        
        # Learnable coefficients for attended values
        self.coeffs_attended = nn.Parameter(torch.empty(seq_length, input_dim))

        # Learnable scalar offset for output adjustment
        self.offset = nn.Parameter(torch.randn(1))
        
        # Initialize model parameters
        self.init_parameters()

    def init_parameters(self):
        """
        Initializes model parameters using a normal distribution with a small standard deviation.
        """
        init_scale = 0.04  # Standard deviation for parameter initialization

        # Initialize all weight matrices and learnable parameters
        for param in [self.query_matrix_1, self.key_matrix_1, self.value_matrix_1,
                      self.query_matrix_2, self.key_matrix_2, self.value_matrix_2,
                      self.query_matrix_3, self.key_matrix_3, self.value_matrix_3,
                      self.random_matrix, self.coeffs_attended, self.offset]:
            init.normal_(param, std=init_scale)

    def forward(self, x, envs_one_hot):
        """
        Forward pass through the three-layer self-attention mechanism.

        Args:
            x (Tensor): Genotype input tensor of shape (batch_size, seq_length-1, input_dim).
            envs_one_hot (Tensor): One-hot encoded environmental condition of shape (batch_size, num_envs).

        Returns:
            Tensor: Final attended output of shape (batch_size,).
        """
        # Apply random projection to the genotype data
        y = torch.matmul(x, self.random_matrix)

        # Concatenate environmental encoding
        z = torch.cat((y, envs_one_hot.unsqueeze(1)), dim=1)

        # Add a bias term (constant ones)
        z = torch.cat((z, torch.ones(z.shape[0], z.shape[1], 1).to(device)), dim=2)

        # First self-attention layer
        query_1 = torch.matmul(z, self.query_matrix_1)
        key_1 = torch.matmul(z, self.key_matrix_1)
        value_1 = torch.matmul(z, self.value_matrix_1)
        scores_1 = torch.matmul(query_1, key_1.transpose(1, 2)).softmax(dim=-1)  # Compute attention scores
        attended_values_1 = torch.matmul(scores_1, value_1)  # Apply attention weights

        # Second self-attention layer
        query_2 = torch.matmul(attended_values_1, self.query_matrix_2)
        key_2 = torch.matmul(attended_values_1, self.key_matrix_2)
        value_2 = torch.matmul(attended_values_1, self.value_matrix_2)
        scores_2 = torch.matmul(query_2, key_2.transpose(1, 2)).softmax(dim=-1)
        attended_values_2 = torch.matmul(scores_2, value_2)

        # Third self-attention layer
        query_3 = torch.matmul(attended_values_2, self.query_matrix_3)
        key_3 = torch.matmul(attended_values_2, self.key_matrix_3)
        value_3 = torch.matmul(attended_values_2, self.value_matrix_3)
        scores_3 = torch.matmul(query_3, key_3.transpose(1, 2)).softmax(dim=-1)
        attended_values_3 = torch.matmul(scores_3, value_3)
        
        # Compute final weighted sum using learned coefficients
        output = torch.einsum("bij,ij->b", attended_values_3, self.coeffs_attended) + self.offset

        return output

In [4]:
def one_hot_temp_embeddings(temps):
    """
    Generate one-hot embeddings for temperatures with random shuffling and pad them with zeros.

    Args:
    - temps (Tensor): A 1D tensor of temperature values.
    - low_dim (int): The target dimensionality (default is 30).
    - seed (int): The random seed for reproducibility.

    Returns:
    - Tensor: A 2D tensor of shape [len(temps), low_dim] containing padded one-hot embeddings.
    """
    # Define the list of unique temperatures
    unique_temps = [23.0, 25.0, 27.0, 30.0, 33.0, 35.0, 37.0]

    # Set the random seed for reproducibility
    random.seed(42)
    shuffled_temps = random.sample(unique_temps, len(unique_temps))  # Shuffle the temperatures

    # Create a mapping from temperature to one-hot index
    temp_to_index = {temp: idx for idx, temp in enumerate(shuffled_temps)}

    # Initialize a tensor to hold the one-hot vectors
    batch_size = len(temps)
    one_hot_vectors = torch.zeros((batch_size, len(unique_temps)))

    # Set the appropriate positions to 1 based on the shuffled temperature values
    for i, temp in enumerate(temps):
        if temp.item() in temp_to_index:
            one_hot_vectors[i, temp_to_index[temp.item()]] = 1.0

    # Pad with zeros to match low_dim
    padding = torch.zeros((batch_size, low_dim - len(unique_temps)))
    one_hot_padded = torch.cat((one_hot_vectors, padding), dim=1)

    return one_hot_padded

### Training for different amounts of training data for $T=23$

In [None]:
def write_to_file(filename, a, b, c):
    """
    Writes training progress metrics to a file.

    Args:
        filename (str): Path to the file where data will be saved.
        a (int/float): First value (e.g., low_dim).
        b (int): Second value (e.g., epoch number).
        c (float): Third value (e.g., validation R² score).
    """
    with open(filename, 'a') as file:
        file.write(f"{a} {b} {c}\n")

# Set computation device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Define model hyperparameters
low_dim = 12
num_loci = len(ind_loci_list)
seq_length = num_loci + 1
input_dim = low_dim + 1
query_dim = low_dim + 1
key_dim = low_dim + 1
hidden_dim = 32

# Split dataset into training (85%) and testing (15%)
X_train, X_test, y_train, y_test = train_test_split(geno_data, fitness, test_size=0.15, random_state=42)

# Define temperature values (used for environment embeddings)
temps = torch.tensor([23, 25, 27, 30, 33, 35, 37])

# Define a range for training data size `m`.  Here, 1390 represents the number of NaN values in the fitness training data.
m_values = np.logspace(np.log10(2), np.log10(len(X_train) - 1390), num=6, dtype=int)

# Set random seed for reproducibility
np.random.seed(42)

# Iterate over different training data sizes 
for m in m_values:
    
    save_dir = f"./transfer_learning_QTL_yeast_temps/num_data_{m}"

    # Create the directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    filename = f"{save_dir}/validation_r2.txt"
    if os.path.exists(filename):
        os.remove(filename)

    # Copy `y_train` to avoid modifying the original data
    y_train_modified = y_train.copy()

    # Reduce available training data by setting excess values to NaN
    for idx in [0]:  # Processing only the first column (adjustable)
        total_entries = y_train_modified.shape[0]
        current_non_nan_count = total_entries - 1390  # Total - predefined NaN count
        additional_nan_needed = current_non_nan_count - m  # Calculate how many to remove

        not_nan_indices = np.where(~np.isnan(y_train_modified[:, idx]))[0]
        random_indices = np.random.choice(not_nan_indices, size=additional_nan_needed, replace=False)
        y_train_modified[random_indices, idx] = np.nan  # Assign NaN

    # Split into training and validation sets
    X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train_modified, test_size=0.15, random_state=42)

    # Compute mean and standard deviation for normalization
    mean_values = np.nanmean(y_train2, axis=0)
    std_values = np.nanstd(y_train2, axis=0)

    # Normalize target fitness data
    y_train2 = (y_train2 - mean_values) / std_values
    y_val = (y_val - mean_values) / std_values
    y_test = (y_test - mean_values) / std_values

    # Convert dataset to PyTorch tensors
    X_train_tens = torch.tensor(X_train2).float()
    y_train_tens = torch.tensor(np.array(y_train2)).float()

    # Define batch sizes
    num_elements = 9
    batch_size = 7 * num_elements
    chunk_size = 50
    num_epochs = 200
    num_batches = X_train_tens.size(0) // batch_size
    
    # Initialize the attention model and move it to CUDA
    attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim, seq_length).to(device)

    # Define loss function and optimizer
    loss_function = nn.MSELoss()
    optimizer = optim.Adam(attention_layer.parameters(), lr=0.001)

    # TRAINING LOOP
    for epoch in range(num_epochs):

        # Shuffle training data at the beginning of each epoch
        indices = torch.randperm(X_train_tens.size(0))
        train_input_shuffled = X_train_tens[indices]
        train_target_shuffled = y_train_tens[indices].T

        # Mini-batch training
        for i in range(num_batches):

            start_idx = i * batch_size
            end_idx = start_idx + batch_size

            fitness_env = []

            # Collect num_elements number of fitness data from each environment
            for i, row in enumerate(train_target_shuffled):
                start_col = start_idx + i * num_elements
                end_col = start_col + num_elements
                fitness_env.extend(row[start_col:end_col])

            mini_batch_input = train_input_shuffled[start_idx:end_idx]
            mini_batch_target = torch.tensor(fitness_env)

            # Identify and remove NaN values in fitness data    
            nan_mask = np.isnan(mini_batch_target)
            nan_indices = np.where(nan_mask)[0]

            # Generate one-hot encoded temperature embeddings
            envs = temps.repeat(num_elements)     
            envs = [envs[i] for i in range(len(envs)) if i not in nan_indices]
            envs = torch.tensor(envs)
            envs_one_hot = one_hot_temp_embeddings(envs).to(device)        

            # Remove corresponding NaN values from input and target data
            mini_batch_input = np.delete(mini_batch_input, nan_indices, axis=0).to(device)
            mini_batch_target = np.delete(mini_batch_target, nan_indices).to(device)

            # Create one-hot genotype representation
            one_hot_mini_batch_input = torch.zeros((mini_batch_input.shape[0], num_loci, num_loci), device=device)
            indices = torch.arange(num_loci, device=device)
            one_hot_mini_batch_input[:, indices, indices] = mini_batch_input.squeeze()

            # Forward pass
            train_output = attention_layer(one_hot_mini_batch_input, envs_one_hot).to(device)
            train_loss = loss_function(train_output, mini_batch_target)

            # Backpropagation
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        # Save model parameters after each epoch
        model_state_path = os.path.join(save_dir, f"epoch_{epoch}.pt")
        torch.save(attention_layer.state_dict(), model_state_path)

        # VALIDATION EVALUATION
        y_pred = torch.tensor([]).to(device)
        y_val_all = np.array([])

        for env in range(7):
            y_val_env = y_val.T[env]

            # Identify and remove NaN values in validation data
            nan_mask = np.isnan(y_val_env)
            nan_indices = np.where(nan_mask)[0]

            X_val_env = np.delete(X_val, nan_indices, axis=0)
            y_val_env = np.delete(y_val_env, nan_indices, axis=0)

            # Convert validation dataset to PyTorch tensors
            X_val_tens = torch.tensor(np.array(X_val_env)).float().to(device)

            # Process validation data in chunks
            for i in range(0, len(X_val_tens), chunk_size):
                chunk = X_val_tens[i:i + chunk_size].to(device)
                chunk_size_actual = min(chunk_size, len(X_val_tens) - i)

                envs = [temps[env] for _ in range(chunk_size_actual)]
                envs = torch.tensor(envs)
                envs_one_hot = one_hot_temp_embeddings(envs).to(device)

                one_hot_val_input = torch.zeros((chunk_size_actual, num_loci, num_loci), device=device)
                indices = torch.arange(num_loci, device=device)
                one_hot_val_input[:, indices, indices] = chunk.squeeze(dim=1)

                with torch.no_grad():
                    i_pred = attention_layer(one_hot_val_input, envs_one_hot)

                y_pred = torch.cat((y_pred, i_pred), dim=0)

            y_val_all = np.concatenate((y_val_all, y_val_env))

        # Compute validation R² score
        val_r_squared = r2_score(y_val_all, y_pred.cpu())

        # Save validation performance
        write_to_file(filename, low_dim, epoch, val_r_squared)
        torch.cuda.empty_cache()  # Clear PyTorch cache

### Model prediction performance on test data

In [None]:
def write_to_file(filename, a, b, c):
    """
    Writes data to a text file in space-separated format.

    Args:
        filename (str): Path to the file where data will be saved.
        a (int/float): First value (e.g., number of training data points).
        b (int/float): Second value (e.g., temperature condition).
        c (float): Third value (e.g., test R² score).
    """
    with open(filename, 'a') as file:
        file.write(f"{a} {b} {c}\n")

# Define the output file for storing number of training samples vs. R² scores
filename2 = "./transfer_learning_QTL_yeast_temps/num_data_vs_r2_score.txt"

# Remove the file if it exists (to start fresh)
if os.path.exists(filename2):
    os.remove(filename2)

# Set computation device (use GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define temperature values used for environmental embeddings
temps = torch.tensor([23, 25, 27, 30, 33, 35, 37])

# Split the data into training (85%) and testing (15%) datasets
X_train, X_test, y_train, y_test = train_test_split(geno_data, fitness, test_size=0.15, random_state=42)

# Define a range for the training data size m. Here, 1390 represents the number of NaN values in the fitness training data.
m_values = np.logspace(np.log10(2), np.log10(len(X_train) - 1390), num=6, dtype=int)

# Set random seed for reproducibility
np.random.seed(42)

# Iterate over different training data sizes 
for m in m_values:
    
    # Resplit the dataset to ensure independent splits for each `m`
    X_train, X_test, y_train, y_test = train_test_split(geno_data, fitness, test_size=0.15, random_state=42)
        
    # Create a copy of `y_train` to avoid modifying the original data
    y_train_modified = y_train.copy()

    # Reduce available training data by setting excess values to NaN
    for idx in [0]:  # Processing only the first column (adjustable)
        total_entries = y_train_modified.shape[0]
        current_non_nan_count = total_entries - 1390  # Total - predefined NaN count
        additional_nan_needed = current_non_nan_count - m  # Calculate how many to remove

        not_nan_indices = np.where(~np.isnan(y_train_modified[:, idx]))[0]
        random_indices = np.random.choice(not_nan_indices, size=additional_nan_needed, replace=False)
        y_train_modified[random_indices, idx] = np.nan  # Assign NaN

    # Split into training and validation sets
    X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train_modified, test_size=0.15, random_state=42)

    # Compute mean and standard deviation for normalization
    mean_values = np.nanmean(y_train2, axis=0)
    std_values = np.nanstd(y_train2, axis=0)
    print(m, mean_values[2], std_values[2])  # Print statistics for debugging

    # Normalize target fitness data
    y_train2 = (y_train2 - mean_values) / std_values
    y_val = (y_val - mean_values) / std_values
    y_test = (y_test - mean_values) / std_values

    # Define model parameters
    low_dim = 12
    chunk_size = 100
    num_loci = len(ind_loci_list)
    seq_length = num_loci + 1
    input_dim = low_dim + 1
    query_dim = low_dim + 1
    key_dim = low_dim + 1
    hidden_dim = 32

    # Find the epoch with the highest validation R² score
    filename = f"./transfer_learning_QTL_yeast_temps/num_data_{m}/validation_r2.txt"
    data = pd.read_csv(filename, sep='\s+', header=None)
    max_row_index = data[2].idxmax()  # Find index of max value in the second column
    max_row = data.loc[max_row_index]  # Retrieve the row with the highest R² score
    max_second_column_value = max_row[1]  # Extract best epoch number

    # Create an instance of the trained ThreeLayerAttention model
    attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim, seq_length).to(device)

    # Load the saved model parameters from the best epoch
    epoch = int(max_second_column_value)
    model_path = f"./transfer_learning_QTL_yeast_temps/num_data_{m}/epoch_{epoch}.pt"
    state_dict = torch.load(model_path, map_location=device)
    attention_layer.load_state_dict(state_dict)
    attention_layer.to(device)  # Move model to GPU if available
    attention_layer.eval()  # Set model to evaluation mode

    # Iterate over all environmental conditions (temperature values)
    for env in range(7):
        y_pred_env = torch.tensor([]).to(device)
        y_test_env = y_test.T[env]  # Extract test fitness values for the given environment

        # Identify and remove NaN values in the test data
        nan_mask = np.isnan(y_test_env)
        nan_indices = np.where(nan_mask)[0]

        # Remove corresponding segregants with NaN values
        X_test_env = np.delete(X_test, nan_indices, axis=0)
        y_test_env = np.delete(y_test_env, nan_indices, axis=0)

        # Convert test genotype data to PyTorch tensors
        X_test_tens = torch.tensor(np.array(X_test_env)).float().to(device)

        # Process test data in chunks to avoid memory issues
        for i in range(0, len(X_test_tens), chunk_size):
            chunk = X_test_tens[i:i + chunk_size].to(device)
            chunk_size_actual = min(chunk_size, len(X_test_tens) - i)

            # Generate one-hot encoded temperature embeddings
            envs = torch.tensor([temps[env]] * chunk_size_actual)
            envs_one_hot = one_hot_temp_embeddings(envs).to(device)

            # Create one-hot encoded genotype representation
            one_hot_test_input = torch.zeros((chunk_size_actual, num_loci, num_loci), device=device)
            indices = torch.arange(num_loci, device=device)
            one_hot_test_input[:, indices, indices] = chunk.squeeze(dim=1)

            # Forward pass through the model
            with torch.no_grad():
                i_pred_env = attention_layer(one_hot_test_input, envs_one_hot)

            # Store predictions
            y_pred_env = torch.cat((y_pred_env, i_pred_env), dim=0)

        # Compute test R² score
        test_r_squared = r2_score(y_test_env, y_pred_env.cpu())

        # Save results to file
        write_to_file(filename2, m, temps[env], test_r_squared)
        print(m, temps[env], test_r_squared)  # Print evaluation results