In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

## Configuration & Simulation Constants

In [2]:
# ==========================================
# Configuration & Hyperparameters
# ==========================================
NUM_SITES = 10                  # 
NUM_PATIENTS = 200              # 
TOTAL_GENES = 5000              # 
GENES_PER_PATHWAY = 20          # Derived: 5000 genes / 250 pathways
NUM_PATHWAYS = 250              # 
TRUE_PATHWAYS = 25              # (Implies 500 significant genes)

# Hyperparameters for Stage 3
LAMBDA_VAL = 0.08               # Slightly higher due to high noise (90% noise)
EPOCHS = 100

## Data Generation (Simulation)

In [3]:
# ==========================================
# Helper: Synthetic Data Generation
# ==========================================
def generate_genomic_site_data(n_samples):
    """
    Simulates the output of the 'Frozen Base' (f_base).
    The base has already processed raw genes into 'Pathway Features'.
    """
    # We generate the 250 Pathway Features directly.
    # First 25 are Signal, Next 225 are Noise.
    
    # 1. Generate Latent Pathway Activities
    # Signal Pathways: High variance, correlated with Outcome
    X_signal = torch.randn(n_samples, TRUE_PATHWAYS) * 2.0 
    
    # Noise Pathways: Standard normal, no correlation to Outcome
    X_noise = torch.randn(n_samples, NUM_PATHWAYS - TRUE_PATHWAYS)
    
    # Combine to form the "Learned Features" z_m from the frozen base
    X_pathways = torch.cat([X_signal, X_noise], dim=1)
    
    # 2. Generate Outcome (Survival/Binary)
    # Only dependent on the first 25 pathways 
    # Simple linear relationship for simulation
    true_weights = torch.randn(TRUE_PATHWAYS, 1)
    logits = X_signal @ true_weights
    y_prob = torch.sigmoid(logits)
    y = (y_prob > 0.5).float()
    
    return X_pathways, y

## Trans-VFL Algorithm Implementation

In [4]:
# ==========================================
# 2. Trans-VFL Model (Pathway Level)
# ==========================================
class GeneSelectionModel(nn.Module):
    def __init__(self, num_pathways, d_embed=16):
        super(GeneSelectionModel, self).__init__()
        
        # Selection Layer: Takes 250 Pathways -> Filters them
        # We use a diagonal weight matrix or element-wise scaling to select pathways
        # Shape: [250] learnable weights (one per pathway)
        self.selection_weights = nn.Parameter(torch.ones(num_pathways))
        
        # Embedding Head: Transforms selected pathways to shared embedding
        self.embedding_head = nn.Linear(num_pathways, d_embed)
        
    def forward(self, x):
        # Apply selection weights (element-wise multiplication)
        # equivalent to weighting the entire "group" (pathway)
        selected_features = x * self.selection_weights
        
        # Forward to embedding head
        embedding = self.embedding_head(selected_features)
        return embedding

In [5]:
# ==========================================
# 3. Stage 3 Algorithm: Distillation + Lasso
# ==========================================
def run_stage_3_genomics(site_id, X, initial_model):
    print(f"Site {site_id}: optimizing genomic features...")
    
    # Student Model (Active)
    student_model = GeneSelectionModel(NUM_PATHWAYS)
    student_model.load_state_dict(initial_model.state_dict())
    
    # Teacher Model (Frozen Target from Stage 1)
    initial_model.eval()
    
    optimizer = optim.Adam(student_model.parameters(), lr=0.005)
    
    # Optimization Loop
    for epoch in range(EPOCHS):
        optimizer.zero_grad()
        
        # Forward Passes
        with torch.no_grad():
            target_embed = initial_model(X)
        
        student_embed = student_model(X)
        
        # Loss 1: Distillation (Match the pre-trained embedding) 
        mse_loss = nn.MSELoss()(student_embed, target_embed)
        
        # Loss 2: Group Lasso (L2,1) on Selection Weights
        # Since we used a scaling vector for pathways, L1 norm of this vector 
        # acts as Group Lasso (selecting the whole pathway).
        lasso_penalty = torch.norm(student_model.selection_weights, p=1)
        
        # Total Loss 
        loss = mse_loss + (LAMBDA_VAL * lasso_penalty)
        
        loss.backward()
        optimizer.step()
        
        # Soft Thresholding (Proximal Operator)
        # Forces small weights to exactly zero
        with torch.no_grad():
            mask = torch.abs(student_model.selection_weights) > 0.005
            student_model.selection_weights.data *= mask.float()

    # Evaluation
    final_weights = student_model.selection_weights.data.abs()
    
    # Count Pathways (Indices 0-24 are True, 25-249 are Noise)
    kept_indices = torch.where(final_weights > 0.01)[0]
    
    tp_pathways = sum(1 for idx in kept_indices if idx < TRUE_PATHWAYS)
    fp_pathways = sum(1 for idx in kept_indices if idx >= TRUE_PATHWAYS)
    
    return tp_pathways, fp_pathways

In [6]:
# ==========================================
# Main Simulation
# ==========================================
if __name__ == "__main__":
    print(f"[Simulation 3.2] Genomic Data: {NUM_PATIENTS} Patients, {NUM_SITES} Sites")
    print(f"Total Genes: {TOTAL_GENES} | Mapped to Pathways: {NUM_PATHWAYS}")
    print(f"Ground Truth: {TRUE_PATHWAYS} Significant Pathways ({TRUE_PATHWAYS*GENES_PER_PATHWAY} Genes)")
    print("="*60)
    
    total_tp_genes = 0
    total_fp_genes = 0
    
    samples_per_site = NUM_PATIENTS // NUM_SITES # Very small! (20 per site)
    
    for site in range(NUM_SITES):
        # 1. Generate Data (Output of frozen base)
        X_site, y_site = generate_genomic_site_data(samples_per_site)
        
        # 2. Simulate Pre-trained Model (Teacher)
        # Teacher knows signal is in first 25, but has noise in the rest
        initial_model = GeneSelectionModel(NUM_PATHWAYS)
        with torch.no_grad():
            # Strong signal weights
            initial_model.selection_weights[:TRUE_PATHWAYS].uniform_(0.8, 1.2)
            # Weak noise weights (Stage 1 didn't prune them yet)
            initial_model.selection_weights[TRUE_PATHWAYS:].uniform_(0.0, 0.2)
            
        # 3. Run Trans-VFL Stage 3
        tp_path, fp_path = run_stage_3_genomics(site, X_site, initial_model)
        
        # Convert Pathways to Genes for reporting (as per Table 2)
        # 1 Pathway = 20 Genes
        tp_genes = tp_path * GENES_PER_PATHWAY
        fp_genes = fp_path * GENES_PER_PATHWAY
        
        total_tp_genes += tp_genes
        total_fp_genes += fp_genes
        
        print(f"   > Site {site} Result: Kept {tp_genes + fp_genes} Genes. (TP: {tp_genes}/{TRUE_PATHWAYS*GENES_PER_PATHWAY})")

    avg_tp = total_tp_genes / NUM_SITES
    avg_total_selected = (total_tp_genes + total_fp_genes) / NUM_SITES
    
    print("="*60)
    print(f"FINAL RESULTS (Averaged across {NUM_SITES} sites):")
    print(f"Total Genes Selected: {int(avg_total_selected)} / {TOTAL_GENES}")
    print(f"True Positive Genes:  {int(avg_tp)} / {TRUE_PATHWAYS*GENES_PER_PATHWAY} (Target: 500)")
    print(f"False Positive Genes: {int(avg_total_selected - avg_tp)} / {TOTAL_GENES - (TRUE_PATHWAYS*GENES_PER_PATHWAY)}")
    print("="*60)

[Simulation 3.2] Genomic Data: 200 Patients, 10 Sites
Total Genes: 5000 | Mapped to Pathways: 250
Ground Truth: 25 Significant Pathways (500 Genes)
Site 0: optimizing genomic features...
   > Site 0 Result: Kept 500 Genes. (TP: 500/500)
Site 1: optimizing genomic features...
   > Site 1 Result: Kept 500 Genes. (TP: 500/500)
Site 2: optimizing genomic features...
   > Site 2 Result: Kept 500 Genes. (TP: 500/500)
Site 3: optimizing genomic features...
   > Site 3 Result: Kept 500 Genes. (TP: 500/500)
Site 4: optimizing genomic features...
   > Site 4 Result: Kept 500 Genes. (TP: 500/500)
Site 5: optimizing genomic features...
   > Site 5 Result: Kept 500 Genes. (TP: 500/500)
Site 6: optimizing genomic features...
   > Site 6 Result: Kept 500 Genes. (TP: 500/500)
Site 7: optimizing genomic features...
   > Site 7 Result: Kept 500 Genes. (TP: 500/500)
Site 8: optimizing genomic features...
   > Site 8 Result: Kept 500 Genes. (TP: 500/500)
Site 9: optimizing genomic features...
   > Site 9 