In [1]:
# =============================================================================
# Import necessary libraries and modules
# =============================================================================
import torch
import numpy as np
import pandas as pd
import time
import os
import torch.nn.functional as F
import random
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
import sys

sys.path.append('..')

# Import custom modules
from config import Config
from data.data_utils import (
    load_genotype_data, apply_missing_mask, encode_genotype_to_categorical,
    load_phenotype_data, preprocess_phenotype_data, GenotypeDataset, PhenotypeDataset,
    prepare_pretraining_data
)

from model.model_utils import (
    CMDAutoEncoder, CMDPhenotypePredictor, train_autoencoder, evaluate_autoencoder,
    train_phenotype_predictor, evaluate_phenotype_predictor, set_random_seed,
    load_pretrained_weights, pretrain_autoencoder
)

from train.cross_train import cross_validation_phenotype_prediction
from train.evaluation_utils import (
    calculate_runtime_summary, print_final_results
)


# Initialize configuration
config = Config()
start_time = time.time()

# CMD Genotype-Phenotype Association Analysis System

## System Overview
This notebook implements a genotype-phenotype association analysis system based on CMD architecture, supporting pre-training and 10-fold cross-validation.

## Main Features
- **Pre-training Model**: Uses autoencoder for genotype data pre-training
- **Phenotype Prediction**: Predicts phenotype values based on pre-trained features
- **Cross-validation**: Supports 10-fold cross-validation for model performance evaluation
- **Early Stopping**: Early stopping strategy to prevent overfitting

## Configuration Description
All configuration parameters are centralized in the `Config` class, including:
- `USE_PRETRAINED`: Whether to use pre-trained model
- `MISSING_RATIO`: Missing data ratio
- `EPOCHS`: Number of training epochs
- `EARLY_STOPPING_PATIENCE`: Early stopping patience value

## Usage Workflow
1. **Data Preprocessing**: Load and preprocess genotype data
2. **Pre-training Phase**: Train autoencoder model (optional)
3. **Phenotype Prediction**: Train phenotype prediction model
4. **Cross-validation**: 10-fold cross-validation for performance evaluation


In [None]:
# =============================================================================
# Data loading and preprocessing
# =============================================================================

# Load data
genotype_file = './dataset/test_geno.csv'
phenotype_file= "./dataset/test_pheno.csv"


df_ori = load_genotype_data(genotype_file, max_rows=config.MAX_ROWS)
print(f"Genotype data shape: {df_ori.shape}")

phenotype_data = load_phenotype_data(phenotype_file)
print(f"Phenotype data shape: {phenotype_data.shape}")

Genotype data shape: (1000, 20000)
Phenotype data shape: (1000, 8)


In [3]:
# =============================================================================
# Data preprocessing
# =============================================================================


# Apply data preprocessing
mask_data = apply_missing_mask(df_ori, config.MISSING_RATIO)
print(f"Missing ratio: {config.MISSING_RATIO}")
# Create DataFrame for masked data
mask_data_copy = pd.DataFrame(mask_data)
mask_data_copy.index = df_ori.index
print(f"Masked data shape: {mask_data_copy.shape}")
# Encode genotype data to categorical format
df_onehot = encode_genotype_to_categorical(mask_data)
df_onehot_no_miss = encode_genotype_to_categorical(df_ori.to_numpy())
print(f"Encoded data shape: {df_onehot.shape}")

Missing ratio: 0.0
Masked data shape: (1000, 20000)
Encoded data shape: (1000, 20000, 3)


In [4]:
# Verify data shape
print(f"Encoded data shape: {df_onehot.shape}")

Encoded data shape: (1000, 20000, 3)


In [5]:
# =============================================================================
# Data splitting and model initialization
# =============================================================================

# Prepare pre-training data
train_loader, valid_loader = prepare_pretraining_data(
    df_onehot, df_onehot_no_miss, 
    test_size=0.1, random_seed=config.RANDOM_SEED
)

In [6]:
# =============================================================================
# Pre-training phase
# =============================================================================

# Execute pre-training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_model_path = pretrain_autoencoder(train_loader, valid_loader, device, config.EPOCHS)

Starting autoencoder pre-training...
Epoch 1/100 - Training accuracy: 0.77812, Validation accuracy: 0.87981, Loss: 0.56752993
Epoch 2/100 - Training accuracy: 0.92155, Validation accuracy: 0.94908, Loss: 0.20536899
Epoch 3/100 - Training accuracy: 0.96079, Validation accuracy: 0.96993, Loss: 0.09992990
Epoch 4/100 - Training accuracy: 0.97685, Validation accuracy: 0.98156, Loss: 0.06405237
Epoch 5/100 - Training accuracy: 0.98363, Validation accuracy: 0.98845, Loss: 0.04666885
Epoch 6/100 - Training accuracy: 0.98933, Validation accuracy: 0.99087, Loss: 0.03129701
Epoch 7/100 - Training accuracy: 0.99144, Validation accuracy: 0.99194, Loss: 0.02607300
Epoch 8/100 - Training accuracy: 0.99231, Validation accuracy: 0.99256, Loss: 0.02315861
Epoch 9/100 - Training accuracy: 0.99281, Validation accuracy: 0.99290, Loss: 0.02109643
Epoch 10/100 - Training accuracy: 0.99333, Validation accuracy: 0.99369, Loss: 0.01958372
Epoch 11/100 - Training accuracy: 0.99331, Validation accuracy: 0.99388,

In [7]:
# =============================================================================
# Phenotype data preprocessing
# =============================================================================

# Preprocess phenotype data
genotype_encoded, phenotype_normalized, phenotype_scaler = preprocess_phenotype_data(
    phenotype_data, mask_data_copy, phenotype_column=2
)

Phenotype column name: AL
After normalization - Mean: -0.0000, Std: 1.0000


In [8]:
# =============================================================================
# 10-fold cross-validation
# =============================================================================

# Execute 10-fold cross-validation
all_best_correlations = cross_validation_phenotype_prediction(
    genotype_encoded, 
    phenotype_normalized, 
    phenotype_scaler,
    device,config=config
)

print("\n===== Cross-validation completed =====")
print("Best correlations for each fold:", np.round(all_best_correlations, 4))
print(f"Mean correlation: {np.mean(all_best_correlations):.4f} ± {np.std(all_best_correlations):.4f}")


🔄 Fold 1: Using random initialization (pre-training disabled)
Correlation coefficient: 0.5562
✅ Epoch 1: New best correlation = 0.5562
Correlation coefficient: 0.4859
Correlation coefficient: 0.5132
Correlation coefficient: 0.5083
Correlation coefficient: 0.5173
Correlation coefficient: 0.5175
Correlation coefficient: 0.5332
Correlation coefficient: 0.5210
Correlation coefficient: 0.5279
Correlation coefficient: 0.5206
Correlation coefficient: 0.5261
Correlation coefficient: 0.5251
Correlation coefficient: 0.5209
Correlation coefficient: 0.5338
Correlation coefficient: 0.5248
Correlation coefficient: 0.5171
Correlation coefficient: 0.5268
Correlation coefficient: 0.5194
Correlation coefficient: 0.5279
Correlation coefficient: 0.5304
Correlation coefficient: 0.5246
⏹️  Epoch 21 early stopping (no improvement for 20 epochs)
🏁 Fold 1 best correlation: 0.5562

🔄 Fold 2: Using random initialization (pre-training disabled)
Correlation coefficient: 0.7601
✅ Epoch 1: New best correlation = 0.

In [9]:
# =============================================================================
# Results statistics and summary
# =============================================================================

def calculate_runtime_summary(start_time):
    """Calculate runtime statistics"""
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Total runtime: {elapsed_time:.2f} seconds")
    return elapsed_time

# Calculate runtime
total_runtime = calculate_runtime_summary(start_time)

Total runtime: 720.95 seconds


In [10]:
# =============================================================================
# Final results output
# =============================================================================

def print_final_results(missing_ratio, correlations):
    """Print final results"""
    print(f"Missing ratio: {missing_ratio}")
    mean_correlation = np.mean(correlations)
    std_correlation = np.std(correlations)
    
    print(f"Mean correlation: {mean_correlation:.4f}")
    print(f"Standard deviation: {std_correlation:.4f}")
    print(f"Best correlation: {np.max(correlations):.4f}")
    print(f"Worst correlation: {np.min(correlations):.4f}")
    
    return mean_correlation

# Output final results
final_mean_correlation = print_final_results(config.MISSING_RATIO, all_best_correlations)

Missing ratio: 0.0
Mean correlation: 0.6874
Standard deviation: 0.0705
Best correlation: 0.7659
Worst correlation: 0.5562
