In [1]:
import numpy as np
import pandas as pd
import metrics
from algorithms import SimCLR_ResNet50, LSTM_AE, VAE

# Load pre-trained unsupervised models
def load_model(model_name, dataset):
    # Load the model for the specified dataset
    if dataset == 'ImageNet':
        if model_name == 'SimCLR':
            # Load SimCLR with ResNet-50 for ImageNet
            model = SimCLR_ResNet50()
            return model
    elif dataset == 'ECG5000':
        if model_name == 'LSTM_AE':
            # Load LSTM reconstruction Autoencoder for ECG5000
            model = LSTM_AE()
            return model
    elif dataset == 'PanCan':
        if model_name == 'VAE':
            # Load VAE for Pan-Cancer RNA-Seq dataset
            model = VAE()
            return model
    else:
        raise ValueError("Invalid dataset or model name.")

# Apply RegX enhancement
def apply_regx(model):
    # Apply RegX enhancement to the model
    enhanced_model = model.apply_regx()
    return enhanced_model

# Apply AdvX enhancement
def apply_advx(model):
    # Apply AdvX enhancement to the model
    enhanced_model = model.apply_advx()
    return enhanced_model

# Compute metrics
def compute_metrics(model):
    # Compute faithfulness, sensitivity, and Pearson correlation metrics for the model
    faithfulness = metrics.faithfulness(model.predictor, model.explanation_function, model.x)
    sensitivity = metrics.average_sensitivity(model.predictor, model.explanation_function, model.x)
    pearson = metrics.average_pearson_correlation(model.predictor, model.explanation_function, model.x)
    return {'faithfulness': faithfulness, 'sensitivity': sensitivity, 'pearson': pearson}

# Define datasets and models
datasets = ['ImageNet', 'ECG5000', 'PanCan']
models = {
    'SimCLR_ImageNet': 'SimCLR',
    'LSTM_ECG5000': 'LSTM_AE',
    'VAE_PanCan': 'VAE'
}

# Define enhancement methods
enhancements = ['RegX', 'AdvX']

# Initialize results dataframe
results = pd.DataFrame(columns=['Dataset', 'Model', 'Enhancement', 'Faithfulness', 'Sensitivity', 'Pearson'])

# Iterate over datasets and models
for dataset in datasets:
    for model_name, model_type in models.items():
        # Load the model
        model = load_model(model_type, dataset)

        # Apply enhancements and compute metrics
        for enhancement in enhancements:
            if enhancement == 'RegX':
                enhanced_model = apply_regx(model)
            elif enhancement == 'AdvX':
                enhanced_model = apply_advx(model)
            metrics = compute_metrics(enhanced_model)

            # Append results to dataframe
            results = results.append({
                'Dataset': dataset,
                'Model': model_name,
                'Enhancement': enhancement,
                'Faithfulness': metrics['faithfulness'],
                'Sensitivity': metrics['sensitivity'],
                'Pearson': metrics['pearson']
            }, ignore_index=True)

# Display results
print(results)


[{'Dataset': 'ImageNet', 'Model': 'SimCLR', 'Enhancement': '-', 'Faithfulness': 0.42, 'Sensitivity': 0.27, 'Pearson': 0.29}, {'Dataset': 'ImageNet', 'Model': 'SimCLR', 'Enhancement': 'RegX', 'Faithfulness': 0.41, 'Sensitivity': 0.29, 'Pearson': 0.23}, {'Dataset': 'ImageNet', 'Model': 'SimCLR', 'Enhancement': 'AdvX', 'Faithfulness': 0.47, 'Sensitivity': 0.21, 'Pearson': 0.31}, {'Dataset': 'ECG5000', 'Model': 'LSTM AE', 'Enhancement': '-', 'Faithfulness': 0.65, 'Sensitivity': 0.08, 'Pearson': 0.39}, {'Dataset': 'ECG5000', 'Model': 'LSTM AE', 'Enhancement': 'RegX', 'Faithfulness': 0.69, 'Sensitivity': 0.08, 'Pearson': 0.32}, {'Dataset': 'ECG5000', 'Model': 'LSTM AE', 'Enhancement': 'AdvX', 'Faithfulness': 0.68, 'Sensitivity': 0.06, 'Pearson': 0.37}, {'Dataset': 'PanCan', 'Model': 'VAE', 'Enhancement': '-', 'Faithfulness': 0.37, 'Sensitivity': 0.15, 'Pearson': 0.38}, {'Dataset': 'PanCan', 'Model': 'VAE', 'Enhancement': 'RegX', 'Faithfulness': 0.33, 'Sensitivity': 0.14, 'Pearson': 0.29}, {'