In [4]:
# =============================================================================
# Import necessary libraries and modules
# =============================================================================
import torch
import numpy as np
import pandas as pd
import time
import os
import torch.nn.functional as F
import random
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from mamba_ssm import Mamba2
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# =============================================================================
# Configuration parameters
# =============================================================================
class Config:
    """Experimental configuration class"""
    # Pre-training model configuration
    USE_PRETRAINED = False  
    
    # Data configuration
    MAX_ROWS = 1000
    MISSING_RATIO = 0.0
    
    # Model configuration
    IN_CHANNELS = 3
    OUT_CHANNELS = 256
    KERNEL_SIZE = 5
    STRIDE = 5
    D_STATE = 128
    
    # Training configuration
    BATCH_SIZE = 32
    LEARNING_RATE = 1e-3
    EPOCHS = 100
    EARLY_STOPPING_PATIENCE = 20
    
    # Cross-validation configuration
    N_SPLITS = 10
    RANDOM_SEED = 42

# Initialize configuration
config = Config()
start_time = time.time()

# CMD Genotype-Phenotype Association Analysis System

## System Overview
This notebook implements a genotype-phenotype association analysis system based on Mamba2 architecture, supporting pre-training and 10-fold cross-validation.

## Main Features
- **Pre-training Model**: Uses autoencoder for genotype data pre-training
- **Phenotype Prediction**: Predicts phenotype values based on pre-trained features
- **Cross-validation**: Supports 10-fold cross-validation for model performance evaluation
- **Early Stopping**: Early stopping strategy to prevent overfitting

## Configuration Description
All configuration parameters are centralized in the `Config` class, including:
- `USE_PRETRAINED`: Whether to use pre-trained model
- `MISSING_RATIO`: Missing data ratio
- `EPOCHS`: Number of training epochs
- `EARLY_STOPPING_PATIENCE`: Early stopping patience value

## Usage Workflow
1. **Data Preprocessing**: Load and preprocess genotype data
2. **Pre-training Phase**: Train autoencoder model (optional)
3. **Phenotype Prediction**: Train phenotype prediction model
4. **Cross-validation**: 10-fold cross-validation for performance evaluation


In [5]:
# !pip install causal-conv1d
# !pip install mamba-ssm
# !pip install "triton == 2.1.0"

In [6]:
# =============================================================================
# Data loading and preprocessing
# =============================================================================
def load_genotype_data(file_path, max_rows=None):
    """
    Load genotype data
    
    Args:
        file_path (str): Data file path
        max_rows (int): Maximum number of rows limit
    
    Returns:
        pd.DataFrame: Genotype data frame
    """
    data = []
    index = []
    
    with open(file_path, mode="r") as file:
        header = None
        for i, line in enumerate(file):
            if max_rows and i >= max_rows + 1:
                break
            row = line.strip().split(",")
            if i == 0:
                header = row[7:20007]  # Skip first 7 columns, take 20000 columns
            else:
                index.append(row[0])
                data.append(row[7:20007])
    
    df = pd.DataFrame(data, columns=header, index=index)
    return df

# Load data
input_name = 'test_geno.csv'
df_ori = load_genotype_data(input_name, max_rows=config.MAX_ROWS)
print(f"Genotype data shape: {df_ori.shape}")

Genotype data shape: (1000, 20000)


In [7]:
# =============================================================================
# Data preprocessing functions
# =============================================================================
def apply_missing_mask(data, missing_ratio):
    """
    Apply missing mask to genotype data
    
    Args:
        data (pd.DataFrame): Original genotype data
        missing_ratio (float): Missing ratio
    
    Returns:
        np.ndarray: Data after applying mask
    """
    data_array = data.to_numpy().copy()
    if missing_ratio > 0:
        mask = np.random.uniform(0, 1, size=data_array.shape)
        data_array[mask < missing_ratio] = -1
    return data_array

def encode_genotype_to_categorical(genotype_data):
    """
    Encode genotype data to categorical format
    
    Args:
        genotype_data (np.ndarray): Genotype data array
    
    Returns:
        np.ndarray: Encoded categorical data
    """
    codebook = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0]])
    return codebook[genotype_data.astype(int)]

# Apply data preprocessing
mask_data = apply_missing_mask(df_ori, config.MISSING_RATIO)
print(f"Missing ratio: {config.MISSING_RATIO}")

Missing ratio: 0.0


In [8]:
# Create DataFrame for masked data
mask_data_copy = pd.DataFrame(mask_data)
mask_data_copy.index = df_ori.index
print(f"Masked data shape: {mask_data_copy.shape}")

Masked data shape: (1000, 20000)


In [9]:
# Encode genotype data to categorical format
df_onehot = encode_genotype_to_categorical(mask_data)
df_onehot_no_miss = encode_genotype_to_categorical(df_ori.to_numpy())
print(f"Encoded data shape: {df_onehot.shape}")

Encoded data shape: (1000, 20000, 3)


In [10]:
# Verify data shape
print(f"Encoded data shape: {df_onehot.shape}")

Encoded data shape: (1000, 20000, 3)


In [11]:

# =============================================================================
# Dataset class definitions
# =============================================================================
class GenotypeDataset(Dataset):
    """
    Genotype dataset class - for pre-training
    """
    def __init__(self, masked_data, original_data):
        """
        Initialize dataset
        
        Args:
            masked_data (np.ndarray): Masked data
            original_data (np.ndarray): Original data
        """
        self.masked_data = masked_data
        self.original_data = original_data

    def __len__(self):
        return len(self.masked_data)

    def __getitem__(self, idx):
        x = self.masked_data[idx]
        y = self.original_data[idx]

        # Adjust data dimensions
        if len(x.shape) == 2:
            x = x.transpose(1, 0)
            y = y.transpose(1, 0)
        elif len(x.shape) == 3:
            x = x.transpose(0, 2, 1)
            y = y.transpose(0, 2, 1)
            
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


class PhenotypeDataset(Dataset):
    """
    Phenotype dataset class - for phenotype prediction
    """
    def __init__(self, genotype_data, phenotype_data):
        """
        Initialize phenotype dataset
        
        Args:
            genotype_data (np.ndarray): Genotype data
            phenotype_data (np.ndarray): Phenotype data
        """
        self.genotype_data = np.asarray(genotype_data, dtype=np.float32)
        self.phenotype_data = np.asarray(phenotype_data, dtype=np.float32)

    def __len__(self):
        return len(self.genotype_data)

    def __getitem__(self, idx):
        x = self.genotype_data[idx]
        y = self.phenotype_data[idx]

        # Adjust data dimensions
        if len(x.shape) == 2:
            x = x.transpose(1, 0)
        elif len(x.shape) == 3:
            x = x.transpose(0, 2, 1)

        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)



In [12]:

# =============================================================================
# Model architecture definitions
# =============================================================================
class Mamba2AutoEncoder(nn.Module):
    """
    Mamba2-based autoencoder model - for pre-training
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, d_state=128):
        super(Mamba2AutoEncoder, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride)
        self.relu = nn.ReLU()
        
        # Mamba2 layer
        self.mamba = Mamba2(
            d_model=out_channels,
            d_state=d_state,
            d_conv=4,
            expand=2,
        ).to("cuda")
        
        # Decoder layers
        self.conv1 = nn.Conv1d(out_channels, stride, kernel_size, stride=1, padding='same')
        self.conv2 = nn.ConvTranspose1d(stride, in_channels, kernel_size, stride=stride)

    def forward(self, x):
        # Encoder
        x = self.conv(x)
        x = self.relu(x)
        x = x.permute(0, 2, 1)  # (B, C, L) -> (B, L, C)
        x = self.mamba(x)
        x = x.permute(0, 2, 1)  # (B, L, C) -> (B, C, L)
        
        # Decoder
        x = self.conv1(x)
        x = self.conv2(x)
        return x

    def get_encoder_features(self, x):
        """Get encoder features"""
        x = self.conv(x)
        x = self.relu(x)
        x = x.permute(0, 2, 1)
        x = self.mamba(x)
        x = x.permute(0, 2, 1)
        x = self.conv1(x)
        return x


class Mamba2PhenotypePredictor(nn.Module):
    """
    Mamba2-based phenotype prediction model
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, d_state=128, input_length=None):
        super(Mamba2PhenotypePredictor, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride)
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

        # Mamba2 layer
        self.mamba = Mamba2(
            d_model=out_channels,
            d_state=d_state,
            d_conv=4,
            expand=2,
        ).to("cuda")

        self.conv1 = nn.Conv1d(out_channels, stride, kernel_size, stride=1, padding="same")
        
        # Calculate linear layer input dimension
        if input_length is None:
            input_length = self._calculate_input_length(20000, kernel_size, stride)
        
        self.predictor = nn.Sequential(
            nn.Linear(input_length, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def _calculate_input_length(self, input_length, kernel_size, stride):
        """Calculate linear layer input dimension"""
        conv1_output = (input_length - kernel_size) // stride + 1
        conv2_output = conv1_output  # padding='same'
        return stride * conv2_output

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        x = x.permute(0, 2, 1)
        x = self.mamba(x)
        x = x.permute(0, 2, 1)

        x = self.conv1(x)
        x = x.flatten(-2, -1)
        x = self.predictor(x)
        return x



In [13]:


# =============================================================================
# Training and evaluation functions
# =============================================================================
def train_autoencoder(model, train_loader, device, criterion, optimizer):
    """
    Train autoencoder model
    
    Args:
        model: Autoencoder model
        train_loader: Training data loader
        device: Device
        criterion: Loss function
        optimizer: Optimizer
    
    Returns:
        tuple: (average loss, training accuracy)
    """
    model.train()
    running_loss = 0
    correct = 0
    total = 0
    
    for inputs, labels in train_loader:
        # Create mask
        bools = ((labels.sum(axis=1) != 0).unsqueeze(1).expand(labels.shape)).to(device)
        inputs, labels = inputs.to(device), labels.to(device)
        true_labels = torch.argmax(labels, dim=1)

        outputs = model(inputs)
        
        # Adjust output dimensions
        if outputs.shape[2] < bools.shape[2]:
            outputs = F.pad(outputs, (0, bools.shape[2] - outputs.shape[2]))
        elif outputs.shape[2] > bools.shape[2]:
            outputs = outputs[:, :, :bools.shape[2]]

        outputs_ = outputs * bools
        loss = criterion(outputs_, true_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        pred = torch.argmax(outputs, dim=1)
        correct += (pred == true_labels).sum().item()
        total += true_labels.numel()

    avg_loss = running_loss / len(train_loader)
    train_accuracy = correct / total
    
    return avg_loss, train_accuracy


def evaluate_autoencoder(model, valid_loader, device):
    """
    Evaluate autoencoder model
    
    Args:
        model: Autoencoder model
        valid_loader: Validation data loader
        device: Device
    
    Returns:
        float: Validation accuracy
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in valid_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            
            pred = torch.argmax(outputs, dim=1)
            true_labels = torch.argmax(labels, dim=1)
            
            # Adjust prediction dimensions
            if pred.shape[1] < true_labels.shape[1]:
                pred = F.pad(pred, (0, true_labels.shape[1] - pred.shape[1]))
            elif pred.shape[1] > true_labels.shape[1]:
                pred = pred[:, :true_labels.shape[1]]
            
            correct += (pred == true_labels).sum().item()
            total += true_labels.numel()

    valid_accuracy = correct / total
    return valid_accuracy


In [14]:
# =============================================================================
# Data splitting and model initialization
# =============================================================================
def prepare_pretraining_data(genotype_data, original_data, test_size=0.1, random_seed=42):
    """
    Prepare pre-training data
    
    Args:
        genotype_data: Genotype data
        original_data: Original data
        test_size: Test set ratio
        random_seed: Random seed
    
    Returns:
        tuple: (training data loader, validation data loader)
    """
    train_X, valid_X = train_test_split(genotype_data, test_size=test_size, random_state=random_seed)
    train_X_original, valid_X_original = train_test_split(original_data, test_size=test_size, random_state=random_seed)
    
    train_dataset = GenotypeDataset(train_X, train_X_original)
    valid_dataset = GenotypeDataset(valid_X, valid_X_original)
    
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=len(valid_dataset), shuffle=False)
    
    return train_loader, valid_loader

# Prepare pre-training data
train_loader, valid_loader = prepare_pretraining_data(
    df_onehot, df_onehot_no_miss, 
    test_size=0.1, random_seed=config.RANDOM_SEED
)

In [15]:
# =============================================================================
# Pre-training phase
# =============================================================================
def pretrain_autoencoder(train_loader, valid_loader, device, epochs=100):
    """
    Pre-train autoencoder model
    
    Args:
        train_loader: Training data loader
        valid_loader: Validation data loader
        device: Device
        epochs: Number of training epochs
    
    Returns:
        str: Best model save path
    """
    # Initialize model
    model = Mamba2AutoEncoder(
        in_channels=config.IN_CHANNELS,
        out_channels=config.OUT_CHANNELS,
        kernel_size=config.KERNEL_SIZE,
        stride=config.STRIDE,
        d_state=config.D_STATE
    ).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    
    # Create model save directory
    model_dir = "./model"
    os.makedirs(model_dir, exist_ok=True)
    
    # Training records
    train_accuracies = []
    valid_accuracies = []
    losses = []
    best_valid_acc = 0
    model_file_path = None
    
    print("Starting autoencoder pre-training...")
    for epoch in range(epochs):
        avg_loss, acc_train = train_autoencoder(model, train_loader, device, criterion, optimizer)
        acc_valid = evaluate_autoencoder(model, valid_loader, device)

        # Convert tensor to numpy
        if isinstance(acc_train, torch.Tensor):
            acc_train = acc_train.cpu().numpy()
        if isinstance(acc_valid, torch.Tensor):
            acc_valid = acc_valid.cpu().numpy()

        # Save best model
        if acc_valid > best_valid_acc:
            best_valid_acc = acc_valid
            model_file_path = os.path.join(model_dir, "model_state.pth")
            torch.save(model.state_dict(), model_file_path)

        train_accuracies.append(acc_train)
        valid_accuracies.append(acc_valid)
        losses.append(avg_loss)
        
        print(f'Epoch {epoch+1}/{epochs} - Training accuracy: {acc_train:.5f}, Validation accuracy: {acc_valid:.5f}, Loss: {avg_loss:.8f}')

    print(f"Pre-training completed! Best model saved at: {model_file_path}")
    return model_file_path

# Execute pre-training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_model_path = pretrain_autoencoder(train_loader, valid_loader, device, config.EPOCHS)

Starting autoencoder pre-training...
Epoch 1/100 - Training accuracy: 0.69692, Validation accuracy: 0.83724, Loss: 0.66664172
Epoch 2/100 - Training accuracy: 0.88513, Validation accuracy: 0.90978, Loss: 0.28811342
Epoch 3/100 - Training accuracy: 0.92174, Validation accuracy: 0.93513, Loss: 0.17894833
Epoch 4/100 - Training accuracy: 0.94296, Validation accuracy: 0.94984, Loss: 0.13288135
Epoch 5/100 - Training accuracy: 0.95362, Validation accuracy: 0.95497, Loss: 0.10647228
Epoch 6/100 - Training accuracy: 0.95746, Validation accuracy: 0.95934, Loss: 0.09223406
Epoch 7/100 - Training accuracy: 0.96045, Validation accuracy: 0.96219, Loss: 0.08366253
Epoch 8/100 - Training accuracy: 0.96361, Validation accuracy: 0.96379, Loss: 0.07765317
Epoch 9/100 - Training accuracy: 0.96498, Validation accuracy: 0.96499, Loss: 0.07355069
Epoch 10/100 - Training accuracy: 0.96586, Validation accuracy: 0.96500, Loss: 0.07145612
Epoch 11/100 - Training accuracy: 0.96672, Validation accuracy: 0.96726,

In [16]:
# =============================================================================
# Phenotype data loading and preprocessing
# =============================================================================
def load_phenotype_data(file_path):
    """
    Load phenotype data
    
    Args:
        file_path (str): Phenotype data file path
    
    Returns:
        pd.DataFrame: Phenotype data frame
    """
    return pd.read_csv(file_path, sep=',', index_col=0)

# Load phenotype data
phenotype_data = load_phenotype_data("1000pheno.txt")
print(f"Phenotype data shape: {phenotype_data.shape}")

Phenotype data shape: (1000, 7)


In [17]:
# =============================================================================
# Phenotype data preprocessing
# =============================================================================
def preprocess_phenotype_data(phenotype_df, genotype_df, phenotype_column=1):
    """
    Preprocess phenotype data and align with genotype data
    
    Args:
        phenotype_df: Phenotype data frame
        genotype_df: Genotype data frame
        phenotype_column: Phenotype column index
    
    Returns:
        tuple: (genotype data, phenotype data, scaler)
    """
    # Find common samples
    common_samples = set(phenotype_df.index) & set(genotype_df.index)
    common_samples = list(common_samples)
    
    # Extract corresponding data
    phenotype_values = phenotype_df.loc[common_samples].iloc[:, phenotype_column].values
    genotype_values = genotype_df.loc[common_samples].values
    
    # Encode genotype data to categorical format
    genotype_encoded = encode_genotype_to_categorical(genotype_values)
    
    # Normalize phenotype data
    scaler = StandardScaler()
    phenotype_normalized = scaler.fit_transform(phenotype_values.reshape(-1, 1)).flatten()
    
    print(f"Phenotype column name: {phenotype_df.columns[phenotype_column]}")
    print(f"After normalization - Mean: {np.mean(phenotype_normalized):.4f}, Std: {np.std(phenotype_normalized):.4f}")
    
    return genotype_encoded, phenotype_normalized, scaler

# Preprocess phenotype data
genotype_encoded, phenotype_normalized, phenotype_scaler = preprocess_phenotype_data(
    phenotype_data, mask_data_copy, phenotype_column=1
)

Phenotype column name: AL
After normalization - Mean: 0.0000, Std: 1.0000


In [18]:
# =============================================================================
# Phenotype prediction training and evaluation functions
# =============================================================================
def train_phenotype_predictor(epoch, model, device, optimizer, criterion, train_loader, test_loader, scaler):
    """
    Train phenotype prediction model
    
    Args:
        epoch: Current epoch
        model: Prediction model
        device: Device
        optimizer: Optimizer
        criterion: Loss function
        train_loader: Training data loader
        test_loader: Test data loader
        scaler: Scaler
    
    Returns:
        tuple: (correlation coefficient, prediction output, loss values list)
    """
    model.train()
    
    # Set all parameters trainable
    for param in model.parameters():
        param.requires_grad = True
    
    all_loss = 0
    fold_loss_values = []
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.float(), target.float()
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.view(-1), target.view(-1))
        loss.backward()
        optimizer.step()
        all_loss += loss.item()

        if batch_idx % 10 == 0:
            print(f'Training epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
    avg_loss = all_loss / len(train_loader)
    fold_loss_values.append(avg_loss)
    print(f'=========> Epoch: {epoch} Average loss: {avg_loss:.4f}')

    # Evaluate model
    corr, predictions = evaluate_phenotype_predictor(model, device, test_loader, scaler)
    return corr, predictions, fold_loss_values


def evaluate_phenotype_predictor(model, device, test_loader, scaler):
    """
    Evaluate phenotype prediction model
    
    Args:
        model: Prediction model
        device: Device
        test_loader: Test data loader
        scaler: Scaler
    
    Returns:
        tuple: (correlation coefficient, prediction output)
    """
    model.eval()
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.float(), target.float()
            data, target = data.to(device), target
            
            output = model(data)
            
            # Inverse normalize prediction results and true values
            output_original = scaler.inverse_transform(
                output.view(-1).detach().cpu().numpy().reshape(-1, 1)
            ).flatten()
            target_original = scaler.inverse_transform(
                target.reshape(-1).reshape(-1, 1)
            ).flatten()
            
            # Calculate correlation coefficient
            corr = np.corrcoef(output_original, target_original)[0, 1]
            print(f"Correlation coefficient: {corr:.4f}")
            return corr, output_original

In [19]:
# all_params = torch.load(model_file_path)
# model_state = model2.state_dict()

# # 过滤掉形状不匹配的参数
# filtered_params = {k: v for k, v in all_params.items() if k in model_state and v.size() == model_state[k].size()}

# # 更新模型参数
# model_state.update(filtered_params)
# model2.load_state_dict(model_state, strict=False)

# print("部分参数已加载，以下参数未匹配到:")
# for name, param in model2.state_dict().items():
#     if name not in filtered_params:
#         print(f"{name}: expected {param.size()}, got {all_params[name].size() if name in all_params else 'missing'}")

In [20]:


# =============================================================================
# 10-fold cross-validation
# =============================================================================
def set_random_seed(seed=42):
    """Set random seed to ensure reproducible results"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False     

def load_pretrained_weights(model, pretrained_path, device):
    """
    Load pre-trained weights
    
    Args:
        model: Target model
        pretrained_path: Pre-trained model path
        device: Device
    
    Returns:
        bool: Whether loading was successful
    """
    if not os.path.exists(pretrained_path):
        return False
    
    try:
        pretrained_state = torch.load(pretrained_path, map_location=device)
        model_state = model.state_dict()
        
        # Filter out shape-mismatched parameters
        filtered_params = {k: v for k, v in pretrained_state.items() 
                          if k in model_state and v.size() == model_state[k].size()}
        
        # Update model parameters
        model_state.update(filtered_params)
        model.load_state_dict(model_state, strict=False)
        
        # Print unmatched parameters
        unmatched_params = [name for name in model_state.keys() if name not in filtered_params]
        if unmatched_params:
            print(f"⚠️  The following parameters did not match pre-trained weights: {unmatched_params}")
        
        return True
    except Exception as e:
        print(f"❌ Pre-trained model loading failed: {e}")
        return False

def cross_validation_phenotype_prediction(genotype_data, phenotype_data, scaler, device):
    """
    Execute 10-fold cross-validation for phenotype prediction
    
    Args:
        genotype_data: Genotype data
        phenotype_data: Phenotype data
        scaler: Scaler
        device: Device
    
    Returns:
        list: Best correlation coefficients for each fold
    """
    set_random_seed(config.RANDOM_SEED)
    
    kf = KFold(n_splits=config.N_SPLITS, shuffle=True, random_state=config.RANDOM_SEED)
    all_best_correlations = []
    fold_losses = [[] for _ in range(config.N_SPLITS)]
    
    for fold, (train_index, test_index) in enumerate(kf.split(genotype_data), start=1):
        print(f"\n========== Cross-validation Fold {fold}/{config.N_SPLITS} ==========")
        
        # Set random seed for each fold
        set_random_seed(config.RANDOM_SEED + fold)
        
        # Split data
        X_train, X_test = genotype_data[train_index], genotype_data[test_index]
        y_train, y_test = phenotype_data[train_index], phenotype_data[test_index]
        
        # Initialize model
        model = Mamba2PhenotypePredictor(
            in_channels=X_train.shape[2],
            out_channels=config.OUT_CHANNELS,
            kernel_size=config.KERNEL_SIZE,
            stride=config.STRIDE,
            d_state=config.D_STATE
        ).to(device)
        
        # Load pre-trained weights
        if config.USE_PRETRAINED and os.path.exists("model/model_state.pth"):
            print(f"🔄 Fold {fold}: Loading pre-trained model...")
            if load_pretrained_weights(model, "model/model_state.pth", device):
                print(f"✅ Fold {fold}: Pre-trained model loaded successfully")
            else:
                print(f"⚠️  Fold {fold}: Pre-trained model loading failed, using random initialization")
        else:
            if config.USE_PRETRAINED:
                print(f"⚠️  Fold {fold}: Pre-trained model file does not exist, using random initialization")
            else:
                print(f"🔄 Fold {fold}: Using random initialization")
        
        # Set optimizer and loss function
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.00040854050495879857)
        
        # Create data loaders
        train_dataset = PhenotypeDataset(X_train, y_train)
        test_dataset = PhenotypeDataset(X_test, y_test)
        train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
        
        # Train model
        best_corr = -1.0
        early_stopping_counter = 0
        
        for epoch in range(1, 101):  
            corr, predictions, epoch_loss_values = train_phenotype_predictor(
                epoch, model, device, optimizer, criterion, train_loader, test_loader, scaler
            )
            
            fold_losses[fold - 1].extend(epoch_loss_values)
            
            # Early stopping mechanism
            if corr > best_corr:
                best_corr = corr
                print(f"✅ Epoch {epoch}: New best correlation = {best_corr:.4f}")
                torch.save(model.state_dict(), f"model/fold_{fold}_best_model.pth")
                early_stopping_counter = 0
            else:
                early_stopping_counter += 1
            
            if early_stopping_counter >= config.EARLY_STOPPING_PATIENCE:
                print(f"⏹️  Epoch {epoch} early stopping (no improvement for {config.EARLY_STOPPING_PATIENCE} epochs)")
                break
        
        print(f"🏁 Fold {fold} best correlation: {best_corr:.4f}")
        all_best_correlations.append(best_corr)
    
    return all_best_correlations

# Execute 10-fold cross-validation
all_best_correlations = cross_validation_phenotype_prediction(
    genotype_encoded, phenotype_normalized, phenotype_scaler, device
)

print("\n===== Cross-validation completed =====")
print("Best correlations for each fold:", np.round(all_best_correlations, 4))
print(f"Mean correlation: {np.mean(all_best_correlations):.4f} ± {np.std(all_best_correlations):.4f}")


🔄 Fold 1: Using random initialization
Correlation coefficient: 0.7091
✅ Epoch 1: New best correlation = 0.7091
Correlation coefficient: 0.7105
✅ Epoch 2: New best correlation = 0.7105
Correlation coefficient: 0.7045
Correlation coefficient: 0.6948
Correlation coefficient: 0.7099
Correlation coefficient: 0.7021
Correlation coefficient: 0.7047
Correlation coefficient: 0.7045
Correlation coefficient: 0.7057
Correlation coefficient: 0.7064
Correlation coefficient: 0.7081
Correlation coefficient: 0.7053
Correlation coefficient: 0.7093
Correlation coefficient: 0.7018
Correlation coefficient: 0.7100
Correlation coefficient: 0.7097
Correlation coefficient: 0.7013
Correlation coefficient: 0.7156
✅ Epoch 18: New best correlation = 0.7156
Correlation coefficient: 0.7045
Correlation coefficient: 0.6980
Correlation coefficient: 0.7128
Correlation coefficient: 0.6911
Correlation coefficient: 0.7037
Correlation coefficient: 0.7097
Correlation coefficient: 0.7093
Correlation coefficient: 0.7092
Corre

In [21]:
# =============================================================================
# Results statistics and summary
# =============================================================================
def calculate_runtime_summary(start_time):
    """Calculate runtime statistics"""
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Total runtime: {elapsed_time:.2f} seconds")
    return elapsed_time

# Calculate runtime
total_runtime = calculate_runtime_summary(start_time)

Total runtime: 1396.18 seconds


In [22]:
# =============================================================================
# Final results output
# =============================================================================
def print_final_results(missing_ratio, correlations):
    """Print final results"""
    print(f"Missing ratio: {missing_ratio}")
    mean_correlation = np.mean(correlations)
    std_correlation = np.std(correlations)
    
    print(f"Mean correlation: {mean_correlation:.4f}")
    print(f"Standard deviation: {std_correlation:.4f}")
    print(f"Best correlation: {np.max(correlations):.4f}")
    print(f"Worst correlation: {np.min(correlations):.4f}")
    
    return mean_correlation

# Output final results
final_mean_correlation = print_final_results(config.MISSING_RATIO, all_best_correlations)

Missing ratio: 0.0
Mean correlation: 0.6833
Standard deviation: 0.0446
Best correlation: 0.7844
Worst correlation: 0.6213
