In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.metrics import r2_score
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import scipy.stats as stats
import joblib
from mambapy.mamba import Mamba, MambaConfig
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 250
train_features = pd.read_csv('train_features.csv').iloc[:, 0].tolist() 

input_dim = len(train_features)  

class MambaRegressor(nn.Module):
    def __init__(self, input_dim, L, D):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.4),  
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Dropout(0.3),  
            
            nn.Linear(256, D * L),
            nn.Unflatten(1, (D, L))
        )
        
        self.mamba = Mamba(MambaConfig(
            d_model=D, 
            n_layers=2,  
            d_state=12,   
            d_conv=4,    
            pscan=True,
        ))
        
        self.fc = nn.Sequential(
            nn.Linear(D, 32),  
            nn.ReLU(),
            #nn.Dropout(0.2),    
            nn.Linear(32, 1)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.transpose(1, 2)
        x = self.mamba(x)
        x = x.max(dim=1).values
        return self.fc(x).squeeze(-1)

# --------------------------- Dataset Class ---------------------------
class MyDataset(Dataset):
    def __init__(self, feature_array, label_array, dtype=np.float32):
        self.features = feature_array.astype(np.float32)
        self.labels = label_array

    def __getitem__(self, index):
        return self.features[index], self.labels[index]
    
    def __len__(self):
        return self.labels.shape[0]

# --------------------------- Evaluation Function ---------------------------
def plot_regression_results(y_true, y_pred, disease_labels, title, geo_name):
    plt.figure(figsize=(10, 8))

    mae = np.mean(np.abs(y_pred - y_true))
    mse = np.mean((y_pred - y_true)**2)
    r2 = r2_score(y_true, y_pred)
    slope, intercept, _, _, _ = stats.linregress(y_true, y_pred)
    
    healthy_mask = np.array([label == "Healthy" for label in disease_labels])
    
    # Get all disease types (excluding "Healthy")
    disease_types = sorted(set(disease_labels) - {"Healthy"})
    
    colors = plt.cm.tab10.colors
    color_map = {}
    for i, disease in enumerate(disease_types):
        color_map[disease] = colors[i % len(colors)]
    
    # Scatter plot for healthy and diseased samples
    plt.scatter(y_true[healthy_mask], y_pred[healthy_mask], 
                c='dodgerblue', alpha=0.7, edgecolors='w', s=80,
                label='Healthy')
    
    for disease in disease_types:
        disease_mask = np.array([label == disease for label in disease_labels])
        plt.scatter(y_true[disease_mask], y_pred[disease_mask], 
                    c=color_map[disease], alpha=0.7, edgecolors='w', s=80,
                    label=disease)
    
    line_x = np.linspace(min(y_true)-5, max(y_true)+5, 100)
    line_y = slope * line_x + intercept
    plt.plot(line_x, line_y, 'r--', lw=2, label=f'Fit: y={slope:.2f}x+{intercept:.2f}')
    
    max_val = max(np.max(y_true), np.max(y_pred)) + 5
    plt.plot([0, max_val], [0, max_val], 'k--', lw=1, alpha=0.5, label='Ideal Prediction')
    
    plt.title(f'{title}\nR²={r2:.3f}, MAE={mae:.2f}, MSE={mse:.2f}', fontsize=14)
    plt.xlabel('True Age', fontsize=12)
    plt.ylabel('Predicted Age', fontsize=12)
    plt.legend(loc='best', frameon=True, shadow=True)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Histogram of prediction errors
    plt.figure(figsize=(10, 4))
    errors = y_pred - y_true
    plt.hist(errors[healthy_mask], bins=30, alpha=0.7, color='dodgerblue', label='Healthy')
    
    for disease in disease_types:
        disease_mask = np.array([label == disease for label in disease_labels])
        plt.hist(errors[disease_mask], bins=30, alpha=0.7, color=color_map[disease], label=disease)
    
    plt.axvline(x=0, color='k', linestyle='--', alpha=0.5)
    plt.title('Prediction Error Distribution', fontsize=14)
    plt.xlabel('Prediction Error (Predicted - True Age)', fontsize=12)
    plt.ylabel('Frequency', fontsize=12)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    plt.close('all')

def evaluate_dataset(model, X, y, disease_labels, geo_name):
    scaler = joblib.load('scaler_topk.joblib')
    X_std = scaler.transform(X)
    dataset = MyDataset(X_std, y)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for features, labels in loader:
            features = features.to(device)
            labels = labels.float().to(device)
            output = model(features)
            preds += output.tolist()
            trues += labels.tolist()
    
    y_true = np.array(trues)
    y_pred = np.array(preds)
    
    plot_regression_results(y_true, y_pred, disease_labels, f"Independent Test on {geo_name}", geo_name)

    return {
        "MAE": np.mean(np.abs(y_true - y_pred)),
        "MSE": np.mean((y_true - y_pred)**2),
        "R2": r2_score(y_true, y_pred),
        "Pearson": stats.pearsonr(y_true, y_pred)[0]
    }

# --------------------------- Evaluation Pipeline ---------------------------
def evaluate_independent_datasets(model, geo_list, data_path="./dataset/"):
    results = []
    for GEO in geo_list:
        try:
            print(f"\nEvaluating on {GEO}...")
            pheno_path = os.path.join(data_path, f"{GEO}_pheno.csv")
            beta_path = os.path.join(data_path, f"{GEO}_beta.csv")

            pheno_df = pd.read_csv(pheno_path)
            age_unit = pheno_df['Age_unit'][1]
            age = pheno_df['Age']
            if age_unit == "Month":
                age = age / 12
            elif age_unit == "Week":
                age = age / 48
            elif age_unit == "Day":
                age = age / 365
            
            disease_labels = pheno_df['Disease'].values
            
            beta_df = pd.read_csv(beta_path, index_col=0).T
            missing_features = set(train_features) - set(beta_df.columns)
            if missing_features:
                raise ValueError(f"Missing features in test set {GEO}: {missing_features}")
            beta_df = beta_df[train_features] 

            metrics = evaluate_dataset(model, beta_df.values, age.values, disease_labels, GEO)
            results.append((GEO, metrics))
            print(f"{GEO} Age Stats: Min={age.min():.1f}, Max={age.max():.1f}, Mean={age.mean():.1f}")
        except Exception as e:
            print(f"Error processing {GEO}: {e}")
    
    return results

# --------------------------- Main ---------------------------
if __name__ == "__main__":
    new_geo_list = ['GSEXXX'
        
    ]

    model = MambaRegressor(input_dim, L=64, D=256).to(device)
    model.load_state_dict(torch.load('best_model.pt', map_location=device))

    results = evaluate_independent_datasets(model, new_geo_list)
    
    print("\n\n" + "=" * 80)
    print("Summary of Independent Dataset Evaluation Results")
    print("=" * 80)
    
    print(f"{'Dataset':<12} | {'MAE':<8} | {'MSE':<8} | {'R²':<8} | {'Pearson':<8} | {'Sample Type':<10}")
    print("-" * 80)
    
    healthy_datasets = ['GSExxxx']
    
    for geo_id, metrics in results:
        if geo_id in healthy_datasets:
            sample_type = "Healthy"
            print(f"{geo_id:<12} | {metrics['MAE']:<8.2f} | {metrics['MSE']:<8.2f} | "
                  f"{metrics['R2']:<8.4f} | {metrics['Pearson']:<8.4f} | {sample_type:<10}")
    
    print("-" * 80)
    
    for geo_id, metrics in results:
        if geo_id not in healthy_datasets:
            sample_type = "Diseased"
            print(f"{geo_id:<12} | {metrics['MAE']:<8.2f} | {metrics['MSE']:<8.2f} | "
                  f"{metrics['R2']:<8.4f} | {metrics['Pearson']:<8.4f} | {sample_type:<10}")
    
    print("-" * 80)
    
    all_mae = [metrics['MAE'] for _, metrics in results]
    all_mse = [metrics['MSE'] for _, metrics in results]
    all_r2 = [metrics['R2'] for _, metrics in results]
    all_pearson = [metrics['Pearson'] for _, metrics in results]
    
    healthy_metrics = [metrics for geo_id, metrics in results if geo_id in healthy_datasets]
    healthy_mae = np.mean([m['MAE'] for m in healthy_metrics])
    healthy_mse = np.mean([m['MSE'] for m in healthy_metrics])
    healthy_r2 = np.mean([m['R2'] for m in healthy_metrics])
    healthy_pearson = np.mean([m['Pearson'] for m in healthy_metrics])
    
    disease_metrics = [metrics for geo_id, metrics in results if geo_id not in healthy_datasets]
    disease_mae = np.mean([m['MAE'] for m in disease_metrics])
    disease_mse = np.mean([m['MSE'] for m in disease_metrics])
    disease_r2 = np.mean([m['R2'] for m in disease_metrics])
    disease_pearson = np.mean([m['Pearson'] for m in disease_metrics])
    
    avg_mae = np.mean(all_mae)
    avg_mse = np.mean(all_mse)
    avg_r2 = np.mean(all_r2)
    avg_pearson = np.mean(all_pearson)
    
    print(f"{'Avg (Healthy)':<12} | {healthy_mae:<8.2f} | {healthy_mse:<8.2f} | "
          f"{healthy_r2:<8.4f} | {healthy_pearson:<8.4f} | {'':<10}")
    
    print(f"{'Avg (Diseased)':<12} | {disease_mae:<8.2f} | {disease_mse:<8.2f} | "
          f"{disease_r2:<8.4f} | {disease_pearson:<8.4f} | {'':<10}")
    
    print(f"{'Avg (All)':<12} | {avg_mae:<8.2f} | {avg_mse:<8.2f} | "
          f"{avg_r2:<8.4f} | {avg_pearson:<8.4f} | {'':<10}")
    
    print("=" * 80)
    
    result_df = pd.DataFrame({
        'Dataset': [geo for geo, _ in results],
        'MAE': [metrics['MAE'] for _, metrics in results],
        'MSE': [metrics['MSE'] for _, metrics in results],
        'R2': [metrics['R2'] for _, metrics in results],
        'Pearson': [metrics['Pearson'] for _, metrics in results],
        'Type': ['Healthy' if geo in healthy_datasets else 'Diseased' for geo, _ in results]
    })
    
    avg_row = pd.DataFrame({
        'Dataset': ['Average'],
        'MAE': [avg_mae],
        'MSE': [avg_mse],
        'R2': [avg_r2],
        'Pearson': [avg_pearson],
        'Type': ['Overall']
    })
    
    result_df = pd.concat([result_df, avg_row], ignore_index=True)
    result_df.to_csv("independent_test_results.csv", index=False)
    print("\nResults saved to independent_test_results.csv")