# Gut Microbiome Disease Risk Prediction Model
Implementation of encoder-decoder VAE for predicting disease risk from microbiome data

In [None]:
# Install required packages
!pip install pandas numpy scikit-learn torch torchvision tqdm

import pandas as pd
import numpy as np
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
import torch
torch.manual_seed(42)

print("✓ Dependencies loaded successfully")

In [None]:
# Data Loading Class
class DataLoader:
    def __init__(self, data_dir="./"):
        self.data_dir = Path(data_dir)
        self.disease_files = {
            'adenoma': 'Adenoma.csv',
            'alzheimer': 'Alzheimer_Disease.csv',
            'anemia_sickle': 'Anemia_Sickle_Cell.csv',
            'anorexia': 'Anorexia.csv',
            'arthritis_juvenile': 'Arthritis_Juvenile.csv',
            'arthritis_reactive': 'Arthritis_Reactive.csv',
            'arthritis_rheumatoid': 'Arthritis_Rheumatoid.csv',
            'asthma': 'Asthma.csv',
            'atherosclerosis': 'Atherosclerosis.csv',
            'adhd': 'Attention_Deficit_Disorder_with_Hyperactivity.csv'
        }
        self.data = {}
        self.combined_data = None
        
    def load_single_file(self, disease_name):
        file_path = self.data_dir / self.disease_files[disease_name]
        if file_path.exists():
            df = pd.read_csv(file_path)
            df['disease_label'] = disease_name
            df['is_healthy'] = 0
            print(f"✓ Loaded {disease_name}: {len(df)} samples")
            return df
        else:
            print(f"✗ File not found: {file_path}")
            return None
    
    def load_all_files(self):
        for disease_name in self.disease_files.keys():
            self.data[disease_name] = self.load_single_file(disease_name)
        
        valid_data = [df for df in self.data.values() if df is not None]
        self.combined_data = pd.concat(valid_data, ignore_index=True)
        print(f"\n✓ Total samples loaded: {len(self.combined_data)}")
        return self.combined_data

# Load data
loader = DataLoader()
combined_data = loader.load_all_files()

In [None]:
# Create abundance matrix
print("Creating abundance matrix...")
abundance_matrix = combined_data.pivot_table(
    index='run_id',
    columns='scientific_name',
    values='relative_abundance',
    aggfunc='mean',
    fill_value=0
)

# Filter low-prevalence species (>5% prevalence)
prevalence = (abundance_matrix > 0).mean()
keep_species = prevalence[prevalence >= 0.05].index
abundance_matrix = abundance_matrix[keep_species]

print(f"✓ Abundance matrix shape: {abundance_matrix.shape}")

# Create sample metadata
sample_metadata = combined_data.groupby('run_id').agg({
    'disease_label': 'first',
    'is_healthy': 'first',
    'host_age': 'mean',
    'sex': 'first',
    'country': 'first'
}).loc[abundance_matrix.index]

print(f"✓ Sample metadata shape: {sample_metadata.shape}")

In [None]:
# Apply CLR transformation
print("Applying CLR transformation...")
pseudocount = 1e-6
data_pseudo = abundance_matrix + pseudocount
geometric_mean = np.exp(np.log(data_pseudo).mean(axis=1))
abundance_matrix_clr = np.log(data_pseudo.div(geometric_mean, axis=0))

# Handle missing values
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.preprocessing import StandardScaler

imputer = IterativeImputer(random_state=42, max_iter=5)
abundance_imputed = pd.DataFrame(
    imputer.fit_transform(abundance_matrix_clr),
    index=abundance_matrix_clr.index,
    columns=abundance_matrix_clr.columns
)

# Scale features
scaler = StandardScaler()
abundance_scaled = pd.DataFrame(
    scaler.fit_transform(abundance_imputed),
    index=abundance_imputed.index,
    columns=abundance_imputed.columns
)

print(f"✓ Processed data shape: {abundance_scaled.shape}")

In [None]:
# Feature selection
from sklearn.feature_selection import SelectKBest, f_classif

# Create binary disease labels
disease_labels = pd.get_dummies(sample_metadata['disease_label'])

# Select top features based on average scores across diseases
scores_per_label = []
for disease_col in disease_labels.columns:
    selector = SelectKBest(f_classif, k=min(100, abundance_scaled.shape[1]))
    selector.fit(abundance_scaled, disease_labels[disease_col])
    scores_per_label.append(selector.scores_)

avg_scores = np.mean(scores_per_label, axis=0)
top_k_indices = np.argsort(avg_scores)[-100:]
selected_features = abundance_scaled.columns[top_k_indices]
X_selected = abundance_scaled[selected_features]

print(f"✓ Selected {len(selected_features)} features")
print(f"✓ Final feature matrix shape: {X_selected.shape}")

In [None]:
# Create risk scores (continuous targets)
risk_scores = disease_labels.copy().astype(float)

# For disease samples: high risk (0.8-1.0) for their disease, low for others
# For simplicity, we'll use binary labels with some noise
for idx in risk_scores.index:
    current_scores = risk_scores.loc[idx].values
    disease_indices = np.where(current_scores == 1)[0]
    
    # Primary disease gets high score
    for disease_idx in disease_indices:
        current_scores[disease_idx] = np.random.uniform(0.8, 1.0)
    
    # Other diseases get low scores
    other_indices = np.where(current_scores == 0)[0]
    for other_idx in other_indices:
        current_scores[other_idx] = np.random.uniform(0, 0.2)
    
    risk_scores.loc[idx] = current_scores

print(f"✓ Risk scores created: {risk_scores.shape}")
print(f"✓ Risk score range: [{risk_scores.min().min():.3f}, {risk_scores.max().max():.3f}]")

In [None]:
# Model Architecture
import torch.nn as nn
import torch.nn.functional as F

class MicrobiomeVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=32, n_diseases=10):
        super(MicrobiomeVAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, n_diseases),
            nn.Sigmoid()
        )
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        z = self.reparameterize(mu, logvar)
        risk_scores = self.decoder(z)
        return risk_scores, mu, logvar, z

# Initialize model
input_dim = X_selected.shape[1]
n_diseases = risk_scores.shape[1]
model = MicrobiomeVAE(input_dim=input_dim, latent_dim=32, n_diseases=n_diseases)

print(f"✓ Model initialized:")
print(f"  Input dimension: {input_dim}")
print(f"  Output diseases: {n_diseases}")

In [None]:
# Data preparation for training
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader as TorchDataLoader

# Split data
stratify_labels = sample_metadata['disease_label'].values

X_temp, X_test, y_temp, y_test, meta_temp, meta_test = train_test_split(
    X_selected, risk_scores, sample_metadata,
    test_size=0.2, random_state=42, stratify=stratify_labels
)

stratify_temp = meta_temp['disease_label'].values
X_train, X_val, y_train, y_val, meta_train, meta_val = train_test_split(
    X_temp, y_temp, meta_temp,
    test_size=0.25, random_state=42, stratify=stratify_temp
)

print(f"✓ Train set: {X_train.shape[0]} samples")
print(f"✓ Val set: {X_val.shape[0]} samples")
print(f"✓ Test set: {X_test.shape[0]} samples")

# Create data loaders
train_dataset = TensorDataset(
    torch.FloatTensor(X_train.values),
    torch.FloatTensor(y_train.values)
)
val_dataset = TensorDataset(
    torch.FloatTensor(X_val.values),
    torch.FloatTensor(y_val.values)
)
test_dataset = TensorDataset(
    torch.FloatTensor(X_test.values),
    torch.FloatTensor(y_test.values)
)

train_loader = TorchDataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = TorchDataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = TorchDataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# Loss function
class VAELoss(nn.Module):
    def __init__(self, beta=0.5):
        super(VAELoss, self).__init__()
        self.beta = beta
        
    def forward(self, risk_pred, risk_true, mu, logvar):
        # Reconstruction loss
        recon_loss = F.mse_loss(risk_pred, risk_true, reduction='mean')
        
        # KL divergence
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_loss = kl_loss / risk_pred.shape[0]
        
        total_loss = recon_loss + self.beta * kl_loss
        return total_loss, recon_loss, kl_loss

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = VAELoss(beta=0.5)

print(f"✓ Training setup complete on device: {device}")

In [None]:
# Training loop
from tqdm import tqdm

def train_epoch(model, train_loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    total_recon = 0
    total_kl = 0
    
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        risk_pred, mu, logvar, _ = model(batch_x)
        loss, recon_loss, kl_loss = loss_fn(risk_pred, batch_y, mu, logvar)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl_loss.item()
    
    n_batches = len(train_loader)
    return total_loss/n_batches, total_recon/n_batches, total_kl/n_batches

def validate(model, val_loader, loss_fn, device):
    model.eval()
    total_loss = 0
    total_recon = 0
    total_kl = 0
    
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            risk_pred, mu, logvar, _ = model(batch_x)
            loss, recon_loss, kl_loss = loss_fn(risk_pred, batch_y, mu, logvar)
            
            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_kl += kl_loss.item()
    
    n_batches = len(val_loader)
    return total_loss/n_batches, total_recon/n_batches, total_kl/n_batches

# Training loop
epochs = 30
best_val_loss = float('inf')
patience = 8
patience_counter = 0

print("Starting training...")
print("Epoch | Train Loss | Val Loss   | Train Recon| Val Recon  | Train KL   | Val KL")
print("-" * 80)

for epoch in range(epochs):
    train_loss, train_recon, train_kl = train_epoch(model, train_loader, optimizer, loss_fn, device)
    val_loss, val_recon, val_kl = validate(model, val_loader, loss_fn, device)
    
    print(f"{epoch+1:5d} | {train_loss:10.4f} | {val_loss:10.4f} | {train_recon:10.4f} | {val_recon:10.4f} | {train_kl:10.4f} | {val_kl:10.4f}")
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

# Load best model
model.load_state_dict(torch.load('best_model.pth'))
print("\n✓ Training completed and best model loaded")

In [None]:
# Model evaluation
from sklearn.metrics import roc_auc_score, average_precision_score, mean_squared_error, classification_report

def evaluate_model(model, test_loader, device):
    model.eval()
    predictions = []
    actuals = []
    
    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x = batch_x.to(device)
            risk_pred, _, _, _ = model(batch_x)
            predictions.append(risk_pred.cpu().numpy())
            actuals.append(batch_y.numpy())
    
    predictions = np.vstack(predictions)
    actuals = np.vstack(actuals)
    
    return predictions, actuals

# Get predictions
test_pred, test_actual = evaluate_model(model, test_loader, device)

# Calculate metrics
print("MODEL EVALUATION RESULTS")
print("=" * 50)

# Overall MSE
overall_mse = mean_squared_error(test_actual, test_pred)
print(f"\nOverall MSE: {overall_mse:.4f}")

# Per-disease metrics
disease_names = risk_scores.columns.tolist()
print("\nPer-Disease Performance:")
print("-" * 60)
print(f"{'Disease':<25} {'AUC':>8} {'AP':>8} {'MSE':>8} {'F1':>8}")
print("-" * 60)

for i, disease in enumerate(disease_names):
    # Convert to binary for classification metrics
    y_true_binary = (test_actual[:, i] > 0.5).astype(int)
    y_pred_binary = (test_pred[:, i] > 0.5).astype(int)
    
    try:
        auc = roc_auc_score(y_true_binary, test_pred[:, i])
        ap = average_precision_score(y_true_binary, test_pred[:, i])
        # F1 score
        from sklearn.metrics import f1_score
        f1 = f1_score(y_true_binary, y_pred_binary)
    except:
        auc = np.nan
        ap = np.nan
        f1 = np.nan
    
    mse_disease = mean_squared_error(test_actual[:, i], test_pred[:, i])
    
    print(f"{disease:<25} {auc:>8.3f} {ap:>8.3f} {mse_disease:>8.4f} {f1:>8.3f}")

# Summary statistics
print("\nSUMMARY STATISTICS")
print("=" * 30)
print(f"Average prediction: {test_pred.mean():.3f}")
print(f"Prediction std: {test_pred.std():.3f}")
print(f"Prediction range: [{test_pred.min():.3f}, {test_pred.max():.3f}]")

# Calculate accuracy for binary classification (threshold = 0.5)
binary_pred = (test_pred > 0.5).astype(int)
binary_actual = (test_actual > 0.5).astype(int)
accuracy = (binary_pred == binary_actual).mean()
print(f"Overall binary accuracy: {accuracy:.3f}")

print("\n✓ Evaluation complete")

In [None]:
# Save model for deployment
import pickle

# Save model checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'input_dim': input_dim,
        'latent_dim': 32,
        'n_diseases': n_diseases
    }
}, 'microbiome_model.pth')

# Save preprocessing artifacts
with open('preprocessing_artifacts.pkl', 'wb') as f:
    pickle.dump({
        'scaler': scaler,
        'selected_features': selected_features.tolist(),
        'disease_names': disease_names
    }, f)

print("✓ Model and preprocessing artifacts saved")
print("  - microbiome_model.pth")
print("  - preprocessing_artifacts.pkl")

# Display final model architecture summary
print("\nFINAL MODEL SUMMARY")
print("=" * 30)
print(f"Input features: {input_dim}")
print(f"Latent dimension: 32")
print(f"Output diseases: {n_diseases}")
print(f"Total parameters: {sum(p.numel() for p.numel() in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")