## Import packages

In [1]:
import pandas as pd
import re
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
import pickle
import torch.nn.functional as F
from itertools import combinations

In [2]:
# Load the genotype data
geno_data = np.load("merged_geno_data.npy")

### Genotype data processing

In [3]:
# Temperature conditions
temps = np.array([23.0, 25.0, 27.0, 29.0, 31.0, 33.0, 35.0, 37.0])

def pick_loci(num_loci):
    """
    Selects a subset of loci that meet a threshold based on mean allele frequency.

    Args:
        num_loci (int): The number of loci to select.

    Returns:
        ndarray: A genotype matrix of shape (num_segregants, num_loci) 
                 with values mapped from [0,1] to [-1,+1].
    """
    # Get the number of segregants (samples)
    num_segregants = geno_data.shape[0]

    # Shuffle segregant indices for randomization
    shuffled_indices = list(range(num_segregants))
    random.seed(0)  # Ensure reproducibility
    random.shuffle(shuffled_indices)

    # Apply shuffling to genotype data
    geno_data2 = geno_data[shuffled_indices]

    # Load independent loci list
    loci_list = np.load("ind_loci_list_3.npy")
    loci_list = np.sort(loci_list)  # Sort loci indices in ascending order
    loci_list_reduced = []  # List to store selected loci

    # Iterate through loci and filter based on mean allele frequency
    for i in loci_list:
        avg_loci = (2.0 * geno_data2[:, i] - 1.0).mean()  # Convert allele states to [-1, 1] and compute mean
        
        if abs(avg_loci) < 0.05:  # Select loci with mean close to zero (balanced representation)
            loci_list_reduced.append(i)
        
        if len(loci_list_reduced) > num_loci - 1:  # Stop when required number of loci is reached
            break

    # Extract the genotype data for selected loci and map values from [0,1] to [-1,+1]
    genotype = 2.0 * geno_data2[:, np.array(loci_list_reduced)] - 1.0
    
    return genotype

# Define number of loci to be selected
num_loci = 100

### Attention layer class in PyTorch

In [4]:
class ThreeLayerAttention(nn.Module):
    def __init__(self, input_dim, query_dim, key_dim, seq_length, low_dim, hidden_dim):
        """
        Three-layer attention model with MLP processing for environmental Fourier features.

        Args:
            input_dim (int): Dimension of input features.
            query_dim (int): Dimension of query vectors.
            key_dim (int): Dimension of key vectors.
            seq_length (int): Length of the sequence.
            low_dim (int): Dimensionality of the processed environmental embeddings.
            hidden_dim (int): Hidden layer size for the MLPs.
        """
        super(ThreeLayerAttention, self).__init__()
        
        self.input_dim = input_dim  
        self.query_dim = query_dim  
        self.key_dim = key_dim  
        self.seq_length = seq_length  
        self.low_dim = low_dim  
        self.hidden_dim = hidden_dim  

        # Learnable parameters for attention layers
        self.query_matrix_1 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_1 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_1 = nn.Parameter(torch.empty(input_dim, input_dim))

        self.query_matrix_2 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_2 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_2 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        self.query_matrix_3 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_3 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_3 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        # Learnable random projection matrix
        self.random_matrix = nn.Parameter(torch.empty(seq_length-1, low_dim))

        # Initialize MLPs for processing elements of `envs_fourier`
        self.mlps = nn.ModuleList([self._make_mlp(hidden_dim) for _ in range(low_dim)])

        # Additional learnable coefficients for output computation
        self.coeffs_attended = nn.Parameter(torch.empty(seq_length, input_dim))
        self.offset = nn.Parameter(torch.randn(1))

        # Initialize model parameters
        self.init_parameters()

    def _make_mlp(self, hidden_dim):
        """Utility function to create a simple MLP with one hidden layer."""
        return nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def init_parameters(self):
        """Initialize model parameters with appropriate distributions."""
        init_scale = 0.04

        # Initialize attention matrices and additional parameters
        for param in [self.query_matrix_1, self.key_matrix_1, self.value_matrix_1,
                      self.query_matrix_2, self.key_matrix_2, self.value_matrix_2,
                      self.query_matrix_3, self.key_matrix_3, self.value_matrix_3,
                      self.random_matrix, self.coeffs_attended, self.offset]:
            init.normal_(param, std=init_scale)

        # Initialize weights and biases in MLPs
        for mlp in self.mlps:
            for layer in mlp:
                if isinstance(layer, nn.Linear):
                    init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                    layer.bias.data.fill_(0)

    def process_envs_fourier(self, envs_fourier):
        """
        Process each element of envs_fourier using its respective MLP.

        Args:
            envs_fourier (Tensor): Input environmental Fourier embeddings of shape (batch_size, low_dim).

        Returns:
            Tensor: Processed environmental features of shape (batch_size, low_dim).
        """
        batch_size, _ = envs_fourier.shape
        outputs = []

        for i in range(self.low_dim):  
            input_column = envs_fourier[:, i].unsqueeze(1)  # Extract single column
            output_column = self.mlps[i](input_column)  # Pass through corresponding MLP
            outputs.append(output_column)
        
        processed_envs = torch.cat(outputs, dim=1)  
        return processed_envs

    def forward(self, x, envs_fourier):
        """
        Forward pass through the three-layer attention model.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_length-1, input_dim).
            envs_fourier (Tensor): Environmental Fourier embeddings of shape (batch_size, low_dim).

        Returns:
            Tensor: Final output prediction of shape (batch_size,).
        """

        # Process `envs_fourier` using MLPs
        mlp_output = self.process_envs_fourier(envs_fourier)

        # Apply transformation using the learnable `random_matrix`
        y = torch.matmul(x, self.random_matrix)

        # Concatenate transformed input with MLP outputs
        z = torch.cat((y, mlp_output.unsqueeze(1)), dim=1)

        # Append ones to `z` for bias handling
        z = torch.cat((z, torch.ones(z.shape[0], z.shape[1], 1).to(device)), dim=2)

        # First attention layer
        query_1 = torch.matmul(z, self.query_matrix_1)
        key_1 = torch.matmul(z, self.key_matrix_1)
        value_1 = torch.matmul(z, self.value_matrix_1)
        scores_1 = torch.matmul(query_1, key_1.transpose(1, 2)).softmax(dim=-1)
        attended_values_1 = torch.matmul(scores_1, value_1)

        # Second attention layer
        query_2 = torch.matmul(attended_values_1, self.query_matrix_2)
        key_2 = torch.matmul(attended_values_1, self.key_matrix_2)
        value_2 = torch.matmul(attended_values_1, self.value_matrix_2)
        scores_2 = torch.matmul(query_2, key_2.transpose(1, 2)).softmax(dim=-1)
        attended_values_2 = torch.matmul(scores_2, value_2)

        # Third attention layer
        query_3 = torch.matmul(attended_values_2, self.query_matrix_3)
        key_3 = torch.matmul(attended_values_2, self.key_matrix_3)
        value_3 = torch.matmul(attended_values_2, self.value_matrix_3)
        scores_3 = torch.matmul(query_3, key_3.transpose(1, 2)).softmax(dim=-1)
        attended_values_3 = torch.matmul(scores_3, value_3)

        # Compute final output using learned coefficients and offset
        output = torch.einsum("bij,ij->b", attended_values_3, self.coeffs_attended) + self.offset
        return output

### Define function for fourier embedding of the temperature

In [5]:
def fourier_positional_embeddings(envs, d_model):
    """
    Generate Fourier positional embeddings for a 1D tensor of positions.

    Args:
    - envs (Tensor): A 1D tensor of positions.
    - d_model (int): The dimensionality of the embeddings.

    Returns:
    - Tensor: A 2D tensor of shape [len(envs), d_model] containing Fourier positional embeddings.
    """
    position = envs.unsqueeze(1)
    # Compute the div term
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(100.0)) / d_model))
    # Apply sine to even indices in the tensor; 2i
    pe_sin = torch.sin(position * div_term)
    # Apply cosine to odd indices; 2i+1
    pe_cos = torch.cos(position * div_term)

    # interleave pe_sin and pe_cos
    pe = torch.empty((len(envs), d_model), device=envs.device)
    pe[:, 0::2] = pe_sin
    pe[:, 1::2] = pe_cos if d_model % 2 == 0 else pe_cos[:, :-1] 

    return pe

### Training on all temperatures combined except 29 and 35

In [None]:
# Define the directory for saving results
save_dir = "./synthetic_data_multi_env_attention_interpol"

# Create the directory if it doesn't exist
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Define the output file for storing test R2 scores
filename2 = f"{save_dir}/test_r2_score_vs_epsilon.txt"
if os.path.exists(filename2):
    os.remove(filename2)  # Remove existing file if present

def write_to_file(filename, *args):
    """Writes space-separated values to a file."""
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')

def generate_noise(shape, mean, std_dev):
    """Generates Gaussian noise with given mean and standard deviation."""
    return np.random.normal(mean, std_dev, shape)

def generate_interactions(geno_data, order, num_combinations):
    """
    Generates `order`-order interaction terms for genotype data with a specified number of unique `order`-way combinations.

    Args:
        geno_data (np.array): The original genotype data, shape = (n_samples, n_loci).
        order (int): The interaction order (e.g., 2 for pairwise, 4 for fourth-order).
        num_combinations (int): The number of unique random `order` combinations to generate.

    Returns:
        np.array: Interaction terms, shape = (n_samples, num_combinations).
    """
    n_samples, n_loci = geno_data.shape
    interaction_terms = np.empty((n_samples, num_combinations))

    # Track selected combinations to ensure uniqueness
    selected_combinations = set()
    
    i = 0
    while i < num_combinations:
        # Sample `order` random loci without replacement
        loci_indices = tuple(sorted(np.random.choice(n_loci, size=order, replace=False)))
        
        # Ensure uniqueness of combinations
        if loci_indices not in selected_combinations:
            interaction_terms[:, i] = np.prod(geno_data[:, loci_indices], axis=1)  # Compute interaction product
            selected_combinations.add(loci_indices)
            i += 1

    return interaction_terms

# Constants for generating synthetic fitness data
mean1 = 0.5
std1 = 0.5
scale = 1e-2
t0 = 30.0  # Reference temperature

def calculate_fitness_with_interactions(geno_data, temps, num_loci, e, order):
    """
    Calculates fitness scores based on genotype data with noise, including:
      - Constant genetic effects
      - Linear effects
      - Higher-order (4th-order) interaction effects

    Args:
        geno_data (np.array): Genotype data of shape (n_samples, num_loci).
        temps (np.array): Array of environmental temperatures.
        num_loci (int): Number of loci considered.
        e (float): Proportion of linear vs interaction contributions (e=1 → all linear, e=0 → all interactions).
        order (int): Interaction order (e.g., 4 for fourth-order interactions).

    Returns:
        np.array: Fitness values with added noise, shape = (n_samples, len(temps)).
    """
    np.random.seed(0)  # Set seed for reproducibility
    num_temps = temps.shape[0]

    # Generate random coefficients for constant, linear, and square terms for each locus
    coeffs_const = np.random.normal(mean1, std1, num_loci)  
    coeffs_lin = np.random.normal(mean1, std1, num_loci)  
    coeffs_square = np.random.normal(mean1, std1, num_loci)

    # Generate higher-order interaction data
    interaction_data = generate_interactions(geno_data, order, num_loci)

    # Generate random coefficients for interactions
    interaction_coeffs_const = np.random.normal(mean1, std1, interaction_data.shape[1])  
    interaction_coeffs_lin = np.random.normal(mean1, std1, interaction_data.shape[1]) 
    interaction_coeffs_square = np.random.normal(mean1, std1, interaction_data.shape[1]) 

    # Generate random offset for the fitness function
    offset = np.random.normal(0, 1, 1)  

    # Initialize fitness matrix
    fitness = np.zeros((geno_data.shape[0], num_temps))

    # Compute fitness across all temperature conditions
    for i, temp in enumerate(temps):
        # Compute linear and interaction-based terms using quadratic temperature dependency
        linear_terms = scale * np.dot(
            geno_data, 
            coeffs_square * (temp - t0) ** 2 + coeffs_lin * (temp - t0) + coeffs_const
        ) / num_loci

        interaction_terms = scale * np.dot(
            interaction_data, 
            interaction_coeffs_square * (temp - t0) ** 2 + interaction_coeffs_lin * (temp - t0) + interaction_coeffs_const
        ) / interaction_data.shape[1]

        # Compute fitness: weighted sum of linear and interaction effects, plus offset
        y = e * linear_terms + (1 - e) * interaction_terms + scale * offset

        # Apply transformation to ensure fitness values remain positive
        fitness[:, i] = -1.0 * y + 1.0  

    # Add Gaussian noise to fitness data
    fitness_with_noise = np.zeros(fitness.shape)
    for i, temp in enumerate(temps):
        # Compute noise standard deviation as 20% of the standard deviation of fitness
        std = 0.2 * np.mean(np.var(fitness[:, i]) ** 0.5)
        noise = generate_noise(fitness.shape[0], 0.0, std)
        fitness_with_noise[:, i] = fitness[:, i] + noise

    return fitness_with_noise

# Define list of e values controlling interaction contribution
e_list = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

# Generate synthetic genotype data
genotype = pick_loci(num_loci)  # Function to select loci (not defined in script)

for e in e_list:
    # Generate synthetic fitness data with noise
    fitness_with_noise = calculate_fitness_with_interactions(genotype, temps, num_loci, e, 4)

    # Reduce the fitness dataset to selected temperature indices
    fitness_reduced = fitness_with_noise[:, [0, 1, 2, 4, 5, 7]]

    # Define model hyperparameters
    low_dim = 30
    seq_length = num_loci + 1
    input_dim = low_dim + 1
    query_dim = low_dim + 1
    key_dim = low_dim + 1
    hidden_dim = 32

    # Define filename for storing validation R² scores
    filename = f"{save_dir}/validation_r2.txt"
    if os.path.exists(filename):
        os.remove(filename)

    # Determine computing device (CUDA if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Split data into training, validation, and test sets
    X_train, X_test, y_train, y_test = train_test_split(genotype, fitness_reduced, test_size=0.15, random_state=42)
    X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train, test_size=0.15, random_state=42)

    # Normalize fitness values based on training mean and standard deviation
    mean_values = np.nanmean(y_train2, axis=0)
    std_values = np.nanstd(y_train2, axis=0)
    y_train2 = (y_train2 - mean_values) / std_values
    y_val = (y_val - mean_values) / std_values
    y_test = (y_test - mean_values) / std_values

    # Convert data to PyTorch tensors
    X_train_tens = torch.tensor(X_train2).float()
    y_train_tens = torch.tensor(np.array(y_train2)).float()

    # Initialize the attention model and move it to the appropriate device
    attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim, seq_length).to(device)

    # Define loss function and optimizer
    loss_function = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(attention_layer.parameters(), lr=0.001)

    # Define environment temperature settings
    temps2 = np.array([23.0, 25.0, 27.0, 31.0, 33.0, 37.0])
    num_temps2 = temps2.shape[0]
    num_elements = int(64 / num_temps2)  # Number of elements per batch per environment
    batch_size = num_temps2 * num_elements  # Define batch size
    chunk_size = 100  # Chunk size for validation processing
    num_epochs = 500  # Total number of training epochs
    num_batches = X_train_tens.size(0) // batch_size  # Number of batches per epoch

    ### **TRAINING LOOP**
    for epoch in range(num_epochs):
        # Shuffle training data before each epoch
        indices = torch.randperm(X_train_tens.size(0))
        train_input_shuffled = X_train_tens[indices]
        train_target_shuffled = y_train_tens[indices].T  # Transpose for easier environment access

        # Mini-batch training loop
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size

            # Collect corresponding fitness data for each environment
            fitness_env = []
            for j, row in enumerate(train_target_shuffled):
                start_col = start_idx + j * num_elements
                end_col = start_col + num_elements
                fitness_env.extend(row[start_col:end_col])

            mini_batch_input = train_input_shuffled[start_idx:end_idx].to(device)
            mini_batch_target = torch.tensor(fitness_env).to(device)

            # Generate environment embeddings
            envs = temps2.repeat(num_elements)
            envs = torch.tensor(envs)
            envs_fourier = fourier_positional_embeddings(envs, low_dim).to(device)

            # Create one-hot encoded input representation
            one_hot_mini_batch_input = torch.zeros((mini_batch_input.shape[0], num_loci, num_loci), device=device)
            indices = torch.arange(num_loci, device=device)
            one_hot_mini_batch_input[:, indices, indices] = mini_batch_input.squeeze()

            # Forward pass and loss computation
            train_output = attention_layer(one_hot_mini_batch_input, envs_fourier)
            train_loss = loss_function(train_output, mini_batch_target)

            # Backpropagation and optimization
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        # Save model state after each epoch
        model_state_path = os.path.join(save_dir, f"epoch_{epoch}.pt")
        torch.save(attention_layer.state_dict(), model_state_path)

        ### **VALIDATION PHASE**
        y_pred = torch.tensor([]).to(device)
        y_val_all = np.array([])

        for env in range(num_temps2):
            y_val_env = y_val.T[env]  # Extract validation target values

            # Convert validation genotype data to PyTorch tensors
            X_val_tens = torch.tensor(np.array(X_val)).float().to(device)

            # Process validation data in chunks to avoid memory overflow
            for i in range(0, len(X_val_tens), chunk_size):
                chunk = X_val_tens[i:i + chunk_size].to(device)
                chunk_size_actual = min(chunk_size, len(X_val_tens) - i)

                envs = [temps2[env] for _ in range(chunk_size_actual)]
                envs = torch.tensor(envs)
                envs_fourier = fourier_positional_embeddings(envs, low_dim).to(device)

                one_hot_val_input = torch.zeros((chunk_size_actual, num_loci, num_loci), device=device)
                indices = torch.arange(num_loci, device=device)
                one_hot_val_input[:, indices, indices] = chunk.squeeze(dim=1)

                with torch.no_grad():
                    i_pred = attention_layer(one_hot_val_input, envs_fourier)

                y_pred = torch.cat((y_pred, i_pred), dim=0)
            y_val_all = np.concatenate((y_val_all, y_val_env))

        # Compute validation R² score
        val_r_squared = r2_score(y_val_all, y_pred.cpu())

        # Save validation results
        write_to_file(filename, low_dim, epoch, val_r_squared)
        torch.cuda.empty_cache()  # Free up memory

    ### **MODEL SELECTION: CHOOSE BEST EPOCH BASED ON VALIDATION PERFORMANCE**
    data = pd.read_csv(filename, sep='\s+', header=None)
    max_row_index = data[2].idxmax()  # Find the index of the max R² score
    max_row = data.loc[max_row_index]  # Extract the row with the best epoch
    best_epoch = int(max_row[1])  # Extract best epoch number

    # Copy the best-performing model to a dedicated file
    model_path = f"{save_dir}/epoch_{best_epoch}.pt"
    shutil.copyfile(model_path, f"{save_dir}/best_model_e_{e}.pt")

### Test evaluation and prediction

In [None]:
# Define the directory for saving results
save_dir = "./synthetic_data_multi_env_attention_interpol"

# Create the directory if it doesn't exist
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Define the output filename for storing R² scores vs epsilon
filename2 = f"{save_dir}/test_r2_score_vs_epsilon.txt"

# Remove the file if it already exists to start fresh
if os.path.exists(filename2):
    os.remove(filename2)

def write_to_file(filename, *args):
    """
    Append values to a file, separated by spaces.

    Args:
        filename (str): Path to the file where data should be written.
        *args: The values to be written in the file.
    """
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')  # Convert arguments to strings and join with spaces

def generate_noise(shape, mean, std_dev):
    """
    Generates Gaussian noise for data augmentation or randomness.

    Args:
        shape (tuple): The shape of the noise array.
        mean (float): The mean of the Gaussian distribution.
        std_dev (float): The standard deviation of the Gaussian distribution.

    Returns:
        np.ndarray: A NumPy array of Gaussian noise with the specified shape.
    """
    return np.random.normal(mean, std_dev, shape)  # Generate normally distributed random values

def generate_interactions(geno_data, order, num_combinations):
    """
    Generates interaction terms of the specified order for genotype data.

    Args:
        geno_data (np.ndarray): Genotype data of shape (n_samples, n_loci).
        order (int): The interaction order (e.g., 2 for pairwise, 4 for fourth-order).
        num_combinations (int): Number of unique random 'order' interactions to generate.

    Returns:
        np.ndarray: Interaction terms, shape (n_samples, num_combinations).
    """
    n_samples, n_loci = geno_data.shape  # Get number of samples and loci
    interaction_terms = np.empty((n_samples, num_combinations))  # Initialize empty array

    # Keep track of selected combinations to ensure uniqueness
    selected_combinations = set()

    i = 0  # Counter for filling interaction terms
    while i < num_combinations:
        # Randomly select `order` loci indices without replacement
        loci_indices = tuple(sorted(np.random.choice(n_loci, size=order, replace=False)))

        # Ensure uniqueness: only add if this combination hasn't been used before
        if loci_indices not in selected_combinations:
            # Compute the interaction term as the product of the selected loci's values
            interaction_terms[:, i] = np.prod(geno_data[:, loci_indices], axis=1)
            selected_combinations.add(loci_indices)  # Add combination to the set
            i += 1  # Move to the next combination

    return interaction_terms  # Return the generated interaction terms

# Define parameters for fitness calculation
mean1 = 0.5  # Mean value for randomly generated coefficients
std1 = 0.5   # Standard deviation for randomly generated coefficients
scale = 1e-2  # Scaling factor for fitness contributions
t0 = 30.0  # Reference temperature for temperature-dependent terms

def calculate_fitness_with_interactions(geno_data, temps, num_combinations, e, order):
    """
    Calculates fitness scores based on genotype data with added noise.
    The function includes:
      - Constant genetic effects
      - Linear effects
      - Higher-order (e.g., 4th-order) interactions

    Args:
        geno_data (np.ndarray): Genotype data of shape (n_samples, n_loci).
        temps (np.ndarray): Array of environmental temperatures.
        num_combinations (int): Number of unique random interaction combinations.
        e (float): Proportion of linear vs interaction contributions (e=1 → all linear, e=0 → all interactions).
        order (int): Interaction order (e.g., 4 for fourth-order interactions).

    Returns:
        np.ndarray: Fitness values with added noise, shape (n_samples, len(temps)).
    """
    
    np.random.seed(0)  # Set seed for reproducibility
    
    num_temps = temps.shape[0]  # Number of environmental conditions (temperatures)

    # Generate random coefficients for constant, linear, and quadratic genetic effects
    coeffs_const = np.random.normal(mean1, std1, num_loci)  # Constant effects
    coeffs_lin = np.random.normal(mean1, std1, num_loci)  # Linear effects
    coeffs_square = np.random.normal(mean1, std1, num_loci)  # Quadratic effects

    # Generate higher-order interaction data from genotype data
    interaction_data = generate_interactions(geno_data, order, num_combinations)

    # Generate random coefficients for interaction effects
    interaction_coeffs_const = np.random.normal(mean1, std1, interaction_data.shape[1])  # Constant interaction effects
    interaction_coeffs_lin = np.random.normal(mean1, std1, interaction_data.shape[1])  # Linear interaction effects
    interaction_coeffs_square = np.random.normal(mean1, std1, interaction_data.shape[1])  # Quadratic interaction effects

    # Generate a random offset term for baseline fitness adjustment
    offset = np.random.normal(0, 1, 1)

    # Initialize fitness array with zeros, shape: (n_samples, num_temps)
    fitness = np.zeros((geno_data.shape[0], num_temps))

    # Compute fitness values for each temperature condition
    for i, temp in enumerate(temps):
        # Compute linear genetic effects, with quadratic dependency on temperature
        linear_terms = scale * np.dot(
            geno_data, 
            coeffs_square * (temp - t0)**2 + coeffs_lin * (temp - t0) + coeffs_const
        ) / num_loci  # Normalize by num_loci

        # Compute interaction effects, also temperature-dependent
        interaction_terms = scale * np.dot(
            interaction_data, 
            interaction_coeffs_square * (temp - t0)**2 + interaction_coeffs_lin * (temp - t0) + interaction_coeffs_const
        ) / interaction_data.shape[1]  # Normalize by number of interaction terms

        # Compute weighted sum of linear and interaction effects
        y = e * linear_terms + (1 - e) * interaction_terms + scale * offset

        # Adjust fitness values so that higher fitness is represented by positive values
        fitness[:, i] = -1.0 * y + 1.0  

    # Add Gaussian noise to the fitness data
    fitness_with_noise = np.zeros(fitness.shape)

    for i, temp in enumerate(temps):
        # Compute noise standard deviation as 20% of the standard deviation of fitness
        std = 0.2 * np.mean(np.var(fitness[:, i])**0.5)
        
        # Generate noise for the fitness values
        noise = generate_noise(fitness.shape[0], 0.0, std)
        
        # Add noise to the original fitness values
        fitness_with_noise[:, i] = fitness[:, i] + noise

    return fitness_with_noise  # Return the noisy fitness matrix

# Define list of epsilon values controlling interaction contribution
e_list = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

# Generate genotype data by selecting loci (function pick_loci not defined in this snippet)
genotype = pick_loci(num_loci)

# Define environmental temperature conditions
temps = np.array([23.0, 25.0, 27.0, 29.0, 31.0, 33.0, 35.0, 37.0])

# Iterate over different values of e (proportion of linear vs. interaction effects)
for row, e in enumerate(e_list):

    # Generate synthetic fitness data with noise based on interactions
    fitness_with_noise = calculate_fitness_with_interactions(genotype, temps, num_loci, e, 4)

    # Define model hyperparameters
    low_dim = 30  # Dimensionality of the Fourier embeddings
    seq_length = num_loci + 1  # Sequence length
    input_dim = low_dim + 1  # Input dimension
    query_dim = low_dim + 1  # Query vector dimension
    key_dim = low_dim + 1  # Key vector dimension
    hidden_dim = 32  # Hidden layer size for the MLPs
    chunk_size = 100  # Number of samples processed per batch in validation/testing

    # Split the dataset into training, validation, and testing sets
    X_train, X_test, y_train, y_test = train_test_split(genotype, fitness_with_noise, test_size=0.15, random_state=42)
    X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train, test_size=0.15, random_state=42)

    # Compute the mean and standard deviation for normalization
    mean_values = np.nanmean(y_train2, axis=0)
    std_values = np.nanstd(y_train2, axis=0)

    # Select all columns except the 5th (index 3) and 7th (index 6) for normalization
    cols_to_normalize = [i for i in range(y_train2.shape[1]) if i not in [3, 6]]

    # Normalize y_train2, y_val, y_test excluding the 5th and 7th columns
    y_train2[:, cols_to_normalize] = (y_train2[:, cols_to_normalize] - mean_values[cols_to_normalize]) / std_values[cols_to_normalize]
    y_val[:, cols_to_normalize] = (y_val[:, cols_to_normalize] - mean_values[cols_to_normalize]) / std_values[cols_to_normalize]
    y_test[:, cols_to_normalize] = (y_test[:, cols_to_normalize] - mean_values[cols_to_normalize]) / std_values[cols_to_normalize]

    # Set computing device (CUDA if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the pre-trained attention model
    attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim, seq_length).to(device)
    model_path = f"{save_dir}/best_model_e_{e}.pt"
    state_dict = torch.load(model_path, map_location=device)
    attention_layer.load_state_dict(state_dict)
    attention_layer.to(device)  # Move model to appropriate device
    attention_layer.eval()  # Set model to evaluation mode

    # Load the mean and standard deviation values used for fitness scaling
    mean_data = np.load('./synthetic_data_multi_env_attention_interpol/fitting_mean_fitness.npy')
    std_data = np.load('./synthetic_data_multi_env_attention_interpol/fitting_std_fitness.npy')

    # Iterate over all environmental temperature conditions
    for env in range(temps.shape[0]):

        y_pred = torch.tensor([]).to(device)  # Initialize tensor for storing predictions
        y_test_env = y_test.T[env]  # Extract the test set fitness values for the current environment

        # Convert test genotype data to PyTorch tensor
        X_test_tens = torch.tensor(np.array(X_test)).float().to(device)

        # Process test samples in chunks to avoid memory overflow
        for i in range(0, len(X_test_tens), chunk_size):
            chunk = X_test_tens[i:i + chunk_size].to(device)

            # Determine the actual chunk size for the last batch
            chunk_size_actual = min(chunk_size, len(X_test_tens) - i)

            # Generate environment embeddings for the current chunk
            envs = [temps[env] for _ in range(chunk_size_actual)]
            envs = torch.tensor(envs)
            envs_fourier = fourier_positional_embeddings(envs, low_dim).to(device)

            # Create one-hot encoded input representation for loci
            one_hot_test_input = torch.zeros((chunk_size_actual, num_loci, num_loci), device=device)
            indices = torch.arange(num_loci, device=device)
            one_hot_test_input[:, indices, indices] = chunk.squeeze(dim=1)

            # Perform model inference without computing gradients
            with torch.no_grad():
                i_pred = attention_layer(one_hot_test_input, envs_fourier)

            # Rescale predictions for specific environments (5th and 7th temperatures)
            if env == 3:
                y_pred = torch.cat((y_pred, i_pred * std_data[row, 0] + mean_data[row, 0]), dim=0)
            elif env == 6:
                y_pred = torch.cat((y_pred, i_pred * std_data[row, 1] + mean_data[row, 1]), dim=0)
            else:
                y_pred = torch.cat((y_pred, i_pred), dim=0)

        # Compute R² score for test predictions
        test_r_squared = r2_score(y_test_env, y_pred.cpu())

        # Save test R² score along with epsilon and temperature
        write_to_file(filename2, e, temps[env], test_r_squared)

        # Print results for monitoring
        print(temps[env], e, test_r_squared)