In [1]:
import pandas as pd
import re
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
import pickle
import torch.nn.functional as F
from itertools import combinations

In [2]:
# Load the genotype data
geno_data = np.load("merged_geno_data.npy")

### Genotype data processing

In [3]:
# Temperature conditions
temps = np.array([23.0, 25.0, 27.0, 29.0, 31.0, 33.0, 35.0, 37.0])

def pick_loci(num_loci):
    """
    Selects a subset of loci that meet a threshold based on mean allele frequency.

    Args:
        num_loci (int): The number of loci to select.

    Returns:
        ndarray: A genotype matrix of shape (num_segregants, num_loci) 
                 with values mapped from [0,1] to [-1,+1].
    """
    # Get the number of segregants (samples)
    num_segregants = geno_data.shape[0]

    # Shuffle segregant indices for randomization
    shuffled_indices = list(range(num_segregants))
    random.seed(0)  # Ensure reproducibility
    random.shuffle(shuffled_indices)

    # Apply shuffling to genotype data
    geno_data2 = geno_data[shuffled_indices]

    # Load independent loci list
    loci_list = np.load("ind_loci_list_3.npy")
    loci_list = np.sort(loci_list)  # Sort loci indices in ascending order
    loci_list_reduced = []  # List to store selected loci

    # Iterate through loci and filter based on mean allele frequency
    for i in loci_list:
        avg_loci = (2.0 * geno_data2[:, i] - 1.0).mean()  # Convert allele states to [-1, 1] and compute mean
        
        if abs(avg_loci) < 0.05:  # Select loci with mean close to zero (balanced representation)
            loci_list_reduced.append(i)
        
        if len(loci_list_reduced) > num_loci - 1:  # Stop when required number of loci is reached
            break

    # Extract the genotype data for selected loci and map values from [0,1] to [-1,+1]
    genotype = 2.0 * geno_data2[:, np.array(loci_list_reduced)] - 1.0
    
    return genotype

# Define number of loci to be selected
num_loci = 100

### Attention layer class in PyTorch

In [4]:
class ThreeLayerAttention(nn.Module):
    def __init__(self, input_dim, query_dim, key_dim, seq_length):
        """
        Implements a three-layer attention model for processing input sequences with environmental embeddings.

        Args:
            input_dim (int): Dimensionality of input features.
            query_dim (int): Dimensionality of query vectors.
            key_dim (int): Dimensionality of key vectors.
            seq_length (int): Sequence length, including additional environmental embeddings.
        """
        super(ThreeLayerAttention, self).__init__()

        self.input_dim = input_dim  # Input feature dimension
        self.query_dim = query_dim  # Query vector dimension
        self.key_dim = key_dim  # Key vector dimension
        self.seq_length = seq_length  # Length of input sequence

        # Learnable query, key, and value matrices for the first attention layer
        self.query_matrix_1 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_1 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_1 = nn.Parameter(torch.empty(input_dim, input_dim))

        # Learnable query, key, and value matrices for the second attention layer
        self.query_matrix_2 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_2 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_2 = nn.Parameter(torch.empty(input_dim, input_dim))

        # Learnable query, key, and value matrices for the third attention layer
        self.query_matrix_3 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_3 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_3 = nn.Parameter(torch.empty(input_dim, input_dim))

        # Learnable random matrix for projecting input embeddings
        self.random_matrix = nn.Parameter(torch.empty(seq_length - 1, low_dim))

        # Learnable coefficients for final attended values
        self.coeffs_attended = nn.Parameter(torch.empty(seq_length, input_dim))
        
        # Learnable bias offset applied to final output
        self.offset = nn.Parameter(torch.randn(1))

        # Initialize model parameters
        self.init_parameters()

    def init_parameters(self):
        """
        Initializes model parameters using a normal distribution with a small standard deviation.
        """
        init_scale = 0.04

        # Apply normal initialization to all learnable parameters
        for param in [self.query_matrix_1, self.key_matrix_1, self.value_matrix_1,
                      self.query_matrix_2, self.key_matrix_2, self.value_matrix_2,
                      self.query_matrix_3, self.key_matrix_3, self.value_matrix_3,
                      self.random_matrix, self.coeffs_attended, self.offset]:
            init.normal_(param, std=init_scale)

    def forward(self, x, envs_one_hot):
        """
        Forward pass through the three-layer attention model.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_length-1, input_dim).
            envs_one_hot (Tensor): One-hot encoded environmental embeddings, shape (batch_size, low_dim).

        Returns:
            Tensor: Final output prediction of shape (batch_size,).
        """

        # Project input features using the learnable random matrix
        y = torch.matmul(x, self.random_matrix)

        # Concatenate projected input with one-hot environmental embeddings
        z = torch.cat((y, envs_one_hot.unsqueeze(1)), dim=1)

        # Append an additional bias term of ones
        z = torch.cat((z, torch.ones(z.shape[0], z.shape[1], 1).to(device)), dim=2)

        ### First attention layer ###
        query_1 = torch.matmul(z, self.query_matrix_1)  # Compute queries
        key_1 = torch.matmul(z, self.key_matrix_1)  # Compute keys
        value_1 = torch.matmul(z, self.value_matrix_1)  # Compute values

        scores_1 = torch.matmul(query_1, key_1.transpose(1, 2)).softmax(dim=-1)  # Compute attention scores
        attended_values_1 = torch.matmul(scores_1, value_1)  # Apply attention to values

        ### Second attention layer ###
        query_2 = torch.matmul(attended_values_1, self.query_matrix_2)
        key_2 = torch.matmul(attended_values_1, self.key_matrix_2)
        value_2 = torch.matmul(attended_values_1, self.value_matrix_2)

        scores_2 = torch.matmul(query_2, key_2.transpose(1, 2)).softmax(dim=-1)
        attended_values_2 = torch.matmul(scores_2, value_2)

        ### Third attention layer ###
        query_3 = torch.matmul(attended_values_2, self.query_matrix_3)
        key_3 = torch.matmul(attended_values_2, self.key_matrix_3)
        value_3 = torch.matmul(attended_values_2, self.value_matrix_3)

        scores_3 = torch.matmul(query_3, key_3.transpose(1, 2)).softmax(dim=-1)
        attended_values_3 = torch.matmul(scores_3, value_3)

        # Compute final output using learned coefficients and add bias offset
        output = torch.einsum("bij,ij->b", attended_values_3, self.coeffs_attended) + self.offset

        return output


### Define function for fourier embedding of the temperature

In [5]:
def one_hot_temp_embeddings(temps):
    """
    Generate one-hot embeddings for temperatures with random shuffling and pad them with zeros.

    Args:
    - temps (Tensor): A 1D tensor of temperature values.
    - low_dim (int): The target dimensionality (default is 30).
    - seed (int): The random seed for reproducibility.

    Returns:
    - Tensor: A 2D tensor of shape [len(temps), low_dim] containing padded one-hot embeddings.
    """
    # Define the list of unique temperatures
    unique_temps = [23.0, 25.0, 27.0, 29.0, 31.0, 33.0, 35.0, 37.0]

    # Set the random seed for reproducibility
    random.seed(42)
    shuffled_temps = random.sample(unique_temps, len(unique_temps))  # Shuffle the temperatures

    # Create a mapping from temperature to one-hot index
    temp_to_index = {temp: idx for idx, temp in enumerate(shuffled_temps)}

    # Initialize a tensor to hold the one-hot vectors
    batch_size = len(temps)
    one_hot_vectors = torch.zeros((batch_size, len(unique_temps)))

    # Set the appropriate positions to 1 based on the shuffled temperature values
    for i, temp in enumerate(temps):
        if temp.item() in temp_to_index:
            one_hot_vectors[i, temp_to_index[temp.item()]] = 1.0

    # Pad with zeros to match low_dim (30)
    padding = torch.zeros((batch_size, low_dim - len(unique_temps)))
    one_hot_padded = torch.cat((one_hot_vectors, padding), dim=1)

    return one_hot_padded


### Training on all temperatures combined, using $m$ training data at $T = 23$

In [None]:
# Define directory for saving results
save_dir = "./transfer_learning_synthetic_temps/23C"

# Create the directory if it does not exist
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Define filename for storing test R² scores and remove it if it exists
filename2 = f"{save_dir}/test_r2_score_vs_epsilon.txt"
if os.path.exists(filename2):
    os.remove(filename2)

def write_to_file(filename, *args):
    """
    Writes space-separated values to a specified file.

    Args:
        filename (str): File path to store results.
        *args: Values to be written in the file.
    """
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')

def generate_noise(shape, mean, std_dev):
    """
    Generates Gaussian noise.

    Args:
        shape (tuple): Shape of the noise array.
        mean (float): Mean of the Gaussian distribution.
        std_dev (float): Standard deviation of the Gaussian distribution.

    Returns:
        np.ndarray: Gaussian noise array of the specified shape.
    """
    return np.random.normal(mean, std_dev, shape)

def generate_interactions(geno_data, order, num_combinations):
    """
    Generates `order`-order interaction terms for genotype data with a specified number of random `order` combinations.

    Args:
        geno_data (np.ndarray): The original genotype data, shape = (n_samples, n_loci).
        order (int): Interaction order (e.g., 2 for pairwise, 4 for fourth-order).
        num_combinations (int): Number of unique random 'order' combinations to generate.

    Returns:
        np.ndarray: Interaction terms of shape (n_samples, num_combinations).
    """
    n_samples, n_loci = geno_data.shape  # Get number of samples and loci
    interaction_terms = np.empty((n_samples, num_combinations))  # Initialize array

    # Track selected combinations to ensure uniqueness
    selected_combinations = set()

    i = 0  # Counter for filling interaction terms
    while i < num_combinations:
        # Select `order` loci indices without replacement
        loci_indices = tuple(sorted(np.random.choice(n_loci, size=order, replace=False)))

        # Ensure uniqueness before adding the interaction term
        if loci_indices not in selected_combinations:
            interaction_terms[:, i] = np.prod(geno_data[:, loci_indices], axis=1)  # Compute interaction product
            selected_combinations.add(loci_indices)  # Add combination to the set
            i += 1  # Move to the next combination

    return interaction_terms  # Return generated interaction terms

# Define constants for fitness calculation
mean1 = 0.5  # Mean for coefficient sampling
std1 = 0.5  # Standard deviation for coefficient sampling
scale = 1e-2  # Scaling factor for fitness components
t0 = 30.0  # Reference temperature

def calculate_fitness_with_interactions(geno_data, temps, num_loci, e, order):
    """
    Calculates fitness scores based on genotype data with added noise.
    The function includes:
      - Constant genetic effects
      - Linear effects
      - Higher-order (4th-order) interactions

    Args:
        geno_data (np.ndarray): Genotype data of shape (n_samples, n_loci).
        temps (np.ndarray): Array of environmental temperatures.
        num_loci (int): Number of loci considered.
        e (float): Proportion of linear vs interaction contributions (e=1 → all linear, e=0 → all interactions).
        order (int): Interaction order (e.g., 4 for fourth-order interactions).

    Returns:
        np.ndarray: Fitness values with added noise, shape (n_samples, len(temps)).
    """
    np.random.seed(0)  # Set seed for reproducibility
    num_temps = temps.shape[0]  # Number of environmental conditions (temperatures)

    # Generate random coefficients for constant, linear, and quadratic effects
    coeffs_const = np.random.normal(mean1, std1, num_loci)  
    coeffs_lin = np.random.normal(mean1, std1, num_loci)  
    coeffs_square = np.random.normal(mean1, std1, num_loci)

    # Generate higher-order interaction terms from genotype data
    interaction_data = generate_interactions(geno_data, order, num_loci)

    # Generate random coefficients for interactions
    interaction_coeffs_const = np.random.normal(mean1, std1, interaction_data.shape[1])  
    interaction_coeffs_lin = np.random.normal(mean1, std1, interaction_data.shape[1]) 
    interaction_coeffs_square = np.random.normal(mean1, std1, interaction_data.shape[1]) 

    # Generate a random offset term for baseline fitness adjustment
    offset = np.random.normal(0, 1, 1)  

    # Initialize fitness array with zeros, shape: (n_samples, num_temps)
    fitness = np.zeros((geno_data.shape[0], num_temps))

    # Compute fitness values for each temperature condition
    for i, temp in enumerate(temps):
        # Compute linear genetic effects, with quadratic dependency on temperature
        linear_terms = scale * np.dot(
            geno_data, 
            coeffs_square * (temp - t0) ** 2 + coeffs_lin * (temp - t0) + coeffs_const
        ) / num_loci

        # Compute interaction effects, also temperature-dependent
        interaction_terms = scale * np.dot(
            interaction_data, 
            interaction_coeffs_square * (temp - t0) ** 2 + interaction_coeffs_lin * (temp - t0) + interaction_coeffs_const
        ) / interaction_data.shape[1]

        # Compute weighted sum of linear and interaction effects
        y = e * linear_terms + (1 - e) * interaction_terms + scale * offset

        # Adjust fitness values so that higher fitness is represented by positive values
        fitness[:, i] = -1.0 * y + 1.0  

    # Add Gaussian noise to the fitness data
    fitness_with_noise = np.zeros(fitness.shape)

    for i, temp in enumerate(temps):
        # Compute noise standard deviation as 20% of the standard deviation of fitness
        std = 0.2 * np.mean(np.var(fitness[:, i]) ** 0.5)
        
        # Generate noise for the fitness values
        noise = generate_noise(fitness.shape[0], 0.0, std)
        
        # Add noise to the original fitness values
        fitness_with_noise[:, i] = fitness[:, i] + noise

    return fitness_with_noise  # Return the noisy fitness matrix

# Set epsilon value for fitness computation
e = 0.3

# Generate genotype data by selecting loci
genotype = pick_loci(num_loci)

# Generate synthetic fitness data with noise based on interactions
fitness_with_noise = calculate_fitness_with_interactions(genotype, temps, num_loci, e, 4)

# Define model hyperparameters
low_dim = 30  # Dimensionality of the Fourier embeddings
seq_length = num_loci + 1  # Sequence length
input_dim = low_dim + 1  # Input dimension
query_dim = low_dim + 1  # Query vector dimension
key_dim = low_dim + 1  # Key vector dimension
hidden_dim = 32  # Hidden layer size for the MLPs

# Move the model and tensors to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Filter indices for temperatures 23°C
temp_indices = [i for i, temp in enumerate(temps) if temp in [23.0]]

# First split: Training and Testing datasets
X_train, X_test, y_train, y_test = train_test_split(genotype, fitness_with_noise, test_size=0.15, random_state=42)

# Define a range for the number of training samples m 
m_values = np.logspace(1, np.log10(len(X_train)), num=5, dtype=int)

# Iterate over different values of m 
for m in m_values:
    
    # Define filename for validation R² scores
    filename = f"{save_dir}/validation_r2.txt"

    # Remove the file if it exists to start fresh
    if os.path.exists(filename):
        os.remove(filename)

    # Create a modified copy of y_train to avoid modifying the original dataset
    y_train_modified = y_train.copy()

    # Set all fitness values to NaN for the selected temperatures, except for the first m samples
    for idx in temp_indices:
        y_train_modified[m:, idx] = np.nan  # Retain first m samples, set rest to NaN

    # Second split: Split training data further into training and validation sets
    X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train_modified, test_size=0.15, random_state=42)

    # Compute mean and standard deviation for normalization
    mean_values = np.nanmean(y_train2, axis=0)
    std_values = np.nanstd(y_train2, axis=0)

    # Normalize training, validation, and test sets using computed mean and std
    y_train2 = (y_train2 - mean_values) / std_values
    y_val = (y_val - mean_values) / std_values
    y_test = (y_test - mean_values) / std_values

    # Convert training data to PyTorch tensors
    X_train_tens = torch.tensor(X_train2).float()
    y_train_tens = torch.tensor(np.array(y_train2)).float()

    # Initialize the three-layer attention model and move it to CUDA if available
    attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim, seq_length).to(device)

    # Define loss function and optimizer
    loss_function = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(attention_layer.parameters(), lr=0.001)

    # Define batch parameters
    num_temps = temps.shape[0]  # Number of environmental conditions
    num_elements = int(64 / num_temps)  # Number of elements per batch per environment
    batch_size = num_temps * num_elements  # Total batch size
    chunk_size = 100  # Chunk size for validation inference
    num_epochs = 500  # Number of training epochs
    num_batches = X_train_tens.size(0) // batch_size  # Compute number of batches per epoch

    ### **TRAINING LOOP**
    for epoch in range(num_epochs):

        # Shuffle training data before each epoch
        indices = torch.randperm(X_train_tens.size(0))
        train_input_shuffled = X_train_tens[indices]
        train_target_shuffled = y_train_tens[indices].T  # Transpose for easier environment access

        # Mini-batch training
        for i in range(num_batches):

            start_idx = i * batch_size
            end_idx = start_idx + batch_size

            fitness_env = []

            # Collect num_elements number of fitness data from each environment
            for j, row in enumerate(train_target_shuffled):
                start_col = start_idx + j * num_elements
                end_col = start_col + num_elements
                fitness_env.extend(row[start_col:end_col])

            mini_batch_input = train_input_shuffled[start_idx:end_idx]
            mini_batch_target = torch.tensor(fitness_env)

            # Identify and remove NaN values from mini-batch targets
            nan_mask = np.isnan(mini_batch_target)
            nan_indices = np.where(nan_mask)[0]
            mini_batch_input = np.delete(mini_batch_input, nan_indices, axis=0).to(device)
            mini_batch_target = np.delete(mini_batch_target, nan_indices).to(device)

            # Generate one-hot environment embeddings
            envs = temps.repeat(num_elements)
            envs = np.delete(envs, nan_indices)
            envs = torch.tensor(envs)
            envs_one_hot = one_hot_temp_embeddings(envs).to(device)

            # Create one-hot vector embedding for the loci
            one_hot_mini_batch_input = torch.zeros((mini_batch_input.shape[0], num_loci, num_loci), device=device)
            indices = torch.arange(num_loci, device=device)
            one_hot_mini_batch_input[:, indices, indices] = mini_batch_input.squeeze()

            # Forward pass for the mini-batch
            train_output = attention_layer(one_hot_mini_batch_input, envs_one_hot).to(device)
            train_loss = loss_function(train_output, mini_batch_target)

            # Backpropagation and optimization
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        # Save the model parameters after each epoch
        model_state_path = os.path.join(save_dir, f"epoch_{epoch}.pt")
        torch.save(attention_layer.state_dict(), model_state_path)

        ### **VALIDATION PHASE**
        y_pred = torch.tensor([]).to(device)
        y_val_all = np.array([])

        # Iterate over all environmental conditions
        for env in range(num_temps):

            y_val_env = y_val.T[env]  # Extract validation target values

            # Identify and remove NaN values from validation targets
            nan_mask = np.isnan(y_val_env)
            nan_indices = np.where(nan_mask)[0]
            X_val_env = np.delete(X_val, nan_indices, axis=0)
            y_val_env = np.delete(y_val_env, nan_indices, axis=0)

            # Convert validation genotype data to PyTorch tensor
            X_val_tens = torch.tensor(np.array(X_val_env)).float().to(device)

            # Process validation data in chunks to avoid memory issues
            for i in range(0, len(X_val_tens), chunk_size):
                chunk = X_val_tens[i:i + chunk_size].to(device)
                chunk_size_actual = min(chunk_size, len(X_val_tens) - i)

                # Generate environment embeddings
                envs = [temps[env] for _ in range(chunk_size_actual)]
                envs = torch.tensor(envs)
                envs_one_hot = one_hot_temp_embeddings(envs).to(device)

                # Create one-hot encoded input representation
                one_hot_val_input = torch.zeros((chunk_size_actual, num_loci, num_loci), device=device)
                indices = torch.arange(num_loci, device=device)
                one_hot_val_input[:, indices, indices] = chunk.squeeze(dim=1)

                # Perform model inference without computing gradients
                with torch.no_grad():
                    i_pred = attention_layer(one_hot_val_input, envs_one_hot)

                y_pred = torch.cat((y_pred, i_pred), dim=0)
            y_val_all = np.concatenate((y_val_all, y_val_env))

        # Compute R² score for validation predictions
        val_r_squared = r2_score(y_val_all, y_pred.cpu())

        # Save validation performance after each epoch
        write_to_file(filename, low_dim, epoch, val_r_squared)
        torch.cuda.empty_cache()  # Free up memory

    ### **MODEL SELECTION: CHOOSE BEST EPOCH BASED ON VALIDATION PERFORMANCE**
    data = pd.read_csv(filename, sep='\s+', header=None)
    max_row_index = data[2].idxmax()  # Find the index of the max R² score
    max_row = data.loc[max_row_index]  # Extract the row with the best epoch
    best_epoch = int(max_row[1])  # Extract best epoch number

    # Copy the best-performing model to a dedicated file
    model_path = f"{save_dir}/epoch_{best_epoch}.pt"
    shutil.copyfile(model_path, f"{save_dir}/best_model_num_data_{m}.pt")