In [1]:
import pandas as pd
import re
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
import pickle
import torch.nn.functional as F
from itertools import combinations

In [2]:
# Load the genotype data
geno_data = np.load("merged_geno_data.npy")

### Pick loci from the list of independent loci

In [3]:
# Define temperature conditions
temps = np.array([30.0])

def pick_loci(num_loci):
    """
    Selects a subset of loci that meet a threshold based on mean allele frequency.

    Args:
        num_loci (int): The number of loci to select.

    Returns:
        ndarray: A genotype matrix of shape (num_segregants, num_loci) 
                 with values mapped from [0,1] to [-1,+1].
    """
    # Get the number of segregants (samples)
    num_segregants = geno_data.shape[0]

    # Shuffle segregant indices for randomization
    shuffled_indices = list(range(num_segregants))
    random.seed(0)  # Ensure reproducibility
    random.shuffle(shuffled_indices)

    # Apply shuffling to genotype data
    geno_data2 = geno_data[shuffled_indices]

    # Load independent loci list
    loci_list = np.load("ind_loci_list_3.npy")
    loci_list = np.sort(loci_list)  # Sort loci indices in ascending order
    loci_list_reduced = []  # List to store selected loci

    # Iterate through loci and filter based on mean allele frequency
    for i in loci_list:
        avg_loci = (2.0 * geno_data2[:, i] - 1.0).mean()  # Convert allele states to [-1, 1] and compute mean
        
        if abs(avg_loci) < 0.05:  # Select loci with mean close to zero (balanced representation)
            loci_list_reduced.append(i)
        
        if len(loci_list_reduced) > num_loci - 1:  # Stop when required number of loci is reached
            break

    # Extract the genotype data for selected loci and map values from [0,1] to [-1,+1]
    genotype = 2.0 * geno_data2[:, np.array(loci_list_reduced)] - 1.0
    
    return genotype

# Define number of loci to be selected
num_loci = 100


### Attention layer class in PyTorch

In [4]:
class ThreeLayerAttention(nn.Module):
    def __init__(self, input_dim, query_dim, key_dim, seq_length):
        """
        Implements a three-layer self-attention mechanism.

        Args:
            input_dim (int): Dimension of the input features.
            query_dim (int): Dimension of the query matrix.
            key_dim (int): Dimension of the key matrix.
            seq_length (int): Length of the input sequence.
        """
        super(ThreeLayerAttention, self).__init__()
        self.input_dim = input_dim
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.seq_length = seq_length

        # Learnable weight matrices for the first attention layer
        self.query_matrix_1 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_1 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_1 = nn.Parameter(torch.empty(input_dim, input_dim))

        # Learnable weight matrices for the second attention layer
        self.query_matrix_2 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_2 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_2 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        # Learnable weight matrices for the third attention layer
        self.query_matrix_3 = nn.Parameter(torch.empty(input_dim, query_dim))
        self.key_matrix_3 = nn.Parameter(torch.empty(input_dim, key_dim))
        self.value_matrix_3 = nn.Parameter(torch.empty(input_dim, input_dim))
        
        # Learnable random projection matrix (reduces input dimensionality)
        self.random_matrix = nn.Parameter(torch.empty(seq_length, low_dim))

        # Learnable coefficients for attended values
        self.coeffs_attended = nn.Parameter(torch.empty(seq_length, input_dim))

        # Learnable scalar offset for output adjustment
        self.offset = nn.Parameter(torch.randn(1))
        
        # Initialize model parameters
        self.init_parameters()

    def init_parameters(self):
        """
        Initializes model parameters using a normal distribution with a small standard deviation.
        """
        init_scale = 0.04  # Standard deviation for parameter initialization

        # Initialize all weight matrices and learnable parameters
        for param in [self.query_matrix_1, self.key_matrix_1, self.value_matrix_1,
                      self.query_matrix_2, self.key_matrix_2, self.value_matrix_2,
                      self.query_matrix_3, self.key_matrix_3, self.value_matrix_3,
                      self.random_matrix, self.coeffs_attended, self.offset]:
            init.normal_(param, std=init_scale)

    def forward(self, x):
        """
        Forward pass through the three-layer self-attention mechanism.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_length, input_dim).

        Returns:
            Tensor: Final attended output of shape (batch_size,).
        """
        # Apply random projection to the sequence data
        y = torch.matmul(x, self.random_matrix)       

        # Concatenate an additional constant feature
        z = torch.cat((y, torch.ones(y.shape[0], y.shape[1], 1).to(device)), dim=2)

        # First self-attention layer
        query_1 = torch.matmul(z, self.query_matrix_1)
        key_1 = torch.matmul(z, self.key_matrix_1)
        value_1 = torch.matmul(z, self.value_matrix_1)
        scores_1 = torch.matmul(query_1, key_1.transpose(1, 2)).softmax(dim=-1)  # Compute attention scores
        attended_values_1 = torch.matmul(scores_1, value_1)  # Apply attention weights

        # Second self-attention layer
        query_2 = torch.matmul(attended_values_1, self.query_matrix_2)
        key_2 = torch.matmul(attended_values_1, self.key_matrix_2)
        value_2 = torch.matmul(attended_values_1, self.value_matrix_2)
        scores_2 = torch.matmul(query_2, key_2.transpose(1, 2)).softmax(dim=-1)
        attended_values_2 = torch.matmul(scores_2, value_2)

        # Third self-attention layer
        query_3 = torch.matmul(attended_values_2, self.query_matrix_3)
        key_3 = torch.matmul(attended_values_2, self.key_matrix_3)
        value_3 = torch.matmul(attended_values_2, self.value_matrix_3)
        scores_3 = torch.matmul(query_3, key_3.transpose(1, 2)).softmax(dim=-1)
        attended_values_3 = torch.matmul(scores_3, value_3)
        
        # Compute final weighted sum using learned coefficients
        output = torch.einsum("bij,ij->b", attended_values_3, self.coeffs_attended) + self.offset

        return output


### Fitness generation, training, and testing


In [None]:
def write_to_file(filename, *args):
    """
    Appends values to a file, separating them by spaces.

    Args:
        filename (str): Path to the output file.
        *args: Values to write (e.g., low_dim, epoch, R² score).
    """
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')

def generate_noise(shape, mean, std_dev):
    """Generates Gaussian noise with specified mean and standard deviation."""
    return np.random.normal(mean, std_dev, shape)

def generate_4th_order_interactions(geno_data, order, num_combinations):
    """
    Generates 4th order interaction terms from genotype data.

    Args:
        geno_data (ndarray): Genotype data of shape (n_samples, n_loci).
        order (int): Order of interactions (e.g., 4 for 4th order).
        num_combinations (int): Number of interaction terms to generate.

    Returns:
        ndarray: Generated interaction terms of shape (n_samples, num_combinations).
    """
    n_samples, n_loci = geno_data.shape
    interaction_terms = np.empty((n_samples, num_combinations))

    selected_combinations = set()
    i = 0
    while i < num_combinations:
        # Randomly select `order` loci without replacement
        loci_indices = tuple(sorted(np.random.choice(n_loci, size=order, replace=False)))

        # Ensure unique combinations
        if loci_indices not in selected_combinations:
            interaction_terms[:, i] = np.prod(geno_data[:, loci_indices], axis=1)
            selected_combinations.add(loci_indices)
            i += 1

    return interaction_terms

# Define scaling factor and reference temperature
scale_all = 1e-2
t0 = 30.0

def calculate_fitness_with_interactions(geno_data, temps, num_combinations, e, order):
    """
    Computes fitness with interactions, linear terms, and added noise.

    Args:
        geno_data (ndarray): Genotype data.
        temps (ndarray): Temperature values.
        num_combinations (int): Number of interaction terms.
        e (float): Weight between linear and interaction terms.
        order (int): Order of interactions.

    Returns:
        ndarray: Fitness values with noise.
    """
    np.random.seed(0)
    num_temps = temps.shape[0]

    # Generate coefficients from exponential distributions
    coeffs_const = np.random.exponential(1, num_loci)
    coeffs_lin = np.random.exponential(1, num_loci)
    coeffs_square = np.random.exponential(1, num_loci)

    # Generate interaction coefficients
    interaction_data = generate_4th_order_interactions(geno_data, order, num_combinations)
    interaction_coeffs_const = np.random.exponential(1, interaction_data.shape[1])
    interaction_coeffs_lin = np.random.exponential(1, interaction_data.shape[1])
    interaction_coeffs_square = np.random.exponential(1, interaction_data.shape[1])

    # Generate offset
    offset = np.random.exponential(1, 1)

    fitness = np.zeros((geno_data.shape[0], num_temps))

    for i, temp in enumerate(temps):
        linear_terms = scale_all * np.dot(
            geno_data, coeffs_square * (temp - t0)**2 + coeffs_lin * (temp - t0) + coeffs_const
        ) / num_loci

        interaction_terms = scale_all * np.dot(
            interaction_data, interaction_coeffs_square * (temp - t0)**2 +
            interaction_coeffs_lin * (temp - t0) + interaction_coeffs_const
        ) / interaction_data.shape[1]

        y = e * linear_terms + (1 - e) * interaction_terms + scale_all * offset
        fitness[:, i] = -1.0 * y + 1.0  # Adjust fitness scale

    # Add noise to fitness data
    fitness_with_noise = np.zeros(fitness.shape)
    for i, temp in enumerate(temps):
        std = 0.2 * np.mean(np.var(fitness[:, i])**0.5)
        noise = generate_noise(fitness.shape[0], 0.0, std)
        fitness_with_noise[:, i] = fitness[:, i] + noise

    return fitness_with_noise

# List of e values (weights for linear vs. interaction terms)
e_list = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

# Pick loci from BBQ genotype data
genotype = pick_loci(num_loci)

# List of dimensionality values to evaluate
low_dims = [30, 50]

# Iterate over different low-dimensional embeddings
for low_dim in low_dims:
    
    save_dir = f"./synthetic_data_attention_exponential/exp/num_loci_{num_loci}/d_{low_dim}"

    # Create the directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # File to store test R² scores
    filename2 = f"{save_dir}/test_r2_scores.txt"
    if os.path.exists(filename2):
        os.remove(filename2)

    # Iterate over different values of e (linear vs. interaction term weight)
    for e in e_list:

        # Generate synthetic fitness data with noise
        fitness_with_noise = calculate_fitness_with_interactions(genotype, temps, num_loci, e, 4)

        # Define model input dimensions
        seq_length = num_loci
        input_dim = low_dim + 1
        query_dim = low_dim + 1
        key_dim = low_dim + 1

        # File to store validation R² scores
        filename = f"{save_dir}/validation_r2.txt"
        if os.path.exists(filename):
            os.remove(filename)

        # Set computation device (GPU if available)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Split dataset into training (85%) and testing (15%)
        X_train, X_test, y_train, y_test = train_test_split(genotype, fitness_with_noise, test_size=0.15, random_state=42)
        X_train2, X_val, y_train2, y_val = train_test_split(X_train, y_train, test_size=0.15, random_state=42)

        # Compute mean and standard deviation for normalization
        mean_values = np.nanmean(y_train2, axis=0)
        std_values = np.nanstd(y_train2, axis=0)

        # Normalize target fitness data
        y_train2 = (y_train2 - mean_values) / std_values
        y_val = (y_val - mean_values) / std_values
        y_test = (y_test - mean_values) / std_values

        # Convert dataset to PyTorch tensors
        X_train_tens = torch.tensor(X_train2).float()
        y_train_tens = torch.tensor(np.array(y_train2)).float()

        # Initialize the attention model and move it to CUDA
        attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim, seq_length).to(device)

        # Define loss function and optimizer
        loss_function = nn.MSELoss()
        optimizer = torch.optim.Adam(attention_layer.parameters(), lr=0.001)

        # Define training parameters
        batch_size = 64
        chunk_size = 100
        num_epochs = 1000 if low_dim == 2 else 500
        num_batches = X_train_tens.size(0) // batch_size

        # TRAINING LOOP
        for epoch in range(num_epochs):

            # Shuffle training data at the beginning of each epoch
            indices = torch.randperm(X_train_tens.size(0))
            train_input_shuffled = X_train_tens[indices]
            train_target_shuffled = y_train_tens[indices]

            # Mini-batch training
            for i in range(num_batches):

                start_idx = i * batch_size
                end_idx = start_idx + batch_size

                mini_batch_input = train_input_shuffled[start_idx:end_idx].to(device)
                mini_batch_target = train_target_shuffled[start_idx:end_idx].squeeze(1).to(device)

                # Create one-hot encoding for genotype data
                one_hot_mini_batch_input = torch.zeros((mini_batch_input.shape[0], num_loci, num_loci), device=device)
                indices = torch.arange(num_loci, device=device)
                one_hot_mini_batch_input[:, indices, indices] = mini_batch_input.squeeze()

                # Forward pass
                train_output = attention_layer(one_hot_mini_batch_input).to(device)
                train_loss = loss_function(train_output, mini_batch_target)

                # Backpropagation
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

            # Save model parameters after each epoch
            model_state_path = os.path.join(save_dir, f"epoch_{epoch}.pt")
            torch.save(attention_layer.state_dict(), model_state_path)

            # VALIDATION EVALUATION
            y_pred = torch.tensor([]).to(device)
            y_val_all = np.array([])

            for env in range(temps.shape[0]):
                y_val_env = y_val.T[env]
                X_val_tens = torch.tensor(np.array(X_val)).float().to(device)

                # Process validation data in chunks to prevent memory issues
                for i in range(0, len(X_val_tens), chunk_size):
                    chunk = X_val_tens[i:i + chunk_size].to(device)
                    chunk_size_actual = min(chunk_size, len(X_val_tens) - i)

                    one_hot_val_input = torch.zeros((chunk_size_actual, num_loci, num_loci), device=device)
                    indices = torch.arange(num_loci, device=device)
                    one_hot_val_input[:, indices, indices] = chunk.squeeze(dim=1)

                    with torch.no_grad():
                        i_pred = attention_layer(one_hot_val_input)

                    y_pred = torch.cat((y_pred, i_pred), dim=0)

                y_val_all = np.concatenate((y_val_all, y_val_env))

            val_r_squared = r2_score(y_val_all, y_pred.cpu())

            # Save validation R² score
            write_to_file(filename, low_dim, epoch, val_r_squared)
            torch.cuda.empty_cache()

        # LOAD THE BEST MODEL CHECKPOINT
        data = pd.read_csv(filename, sep='\s+', header=None)
        max_row_index = data[2].idxmax()
        max_row = data.loc[max_row_index]
        max_second_column_value = max_row[1]

        # Initialize a new attention model
        attention_layer = ThreeLayerAttention(input_dim, query_dim, key_dim, seq_length).to(device)

        # Load best-performing model
        epoch = int(max_second_column_value)
        model_path = f"{save_dir}/epoch_{epoch}.pt"
        state_dict = torch.load(model_path, map_location=device)
        attention_layer.load_state_dict(state_dict)
        attention_layer.to(device)
        attention_layer.eval()

        # TEST SET EVALUATION
        for env in range(temps.shape[0]):

            y_pred = torch.tensor([]).to(device)
            y_test_env = y_test.T[env]
            X_test_tens = torch.tensor(np.array(X_test)).float().to(device)

            # Process test data in chunks
            for i in range(0, len(X_test_tens), chunk_size):
                chunk = X_test_tens[i:i + chunk_size].to(device)
                chunk_size_actual = min(chunk_size, len(X_test_tens) - i)

                one_hot_test_input = torch.zeros((chunk_size_actual, num_loci, num_loci), device=device)
                indices = torch.arange(num_loci, device=device)
                one_hot_test_input[:, indices, indices] = chunk.squeeze(dim=1)

                with torch.no_grad():
                    i_pred = attention_layer(one_hot_test_input)

                y_pred = torch.cat((y_pred, i_pred), dim=0)

            # Compute test R² score
            test_r_squared = r2_score(y_test_env, y_pred.cpu())
            write_to_file(filename2, num_loci, e, temps[env], test_r_squared)