In [1]:
# Protein Fitness Prediction Evaluation Script
# This notebook demonstrates model evaluation on test datasets

import matplotlib.pyplot as plt
plt.rcParams["figure.dpi"] = 300
plt.rcParams["savefig.dpi"] = 300

import os
import argparse
import numpy as np
import pandas as pd
from loguru import logger
import sys
sys.path.insert(0, '/home/xux/Desktop/ProteinMCP/ProteinMCP/mcp-servers/ev_onehot_mcp/repo/ev_onehot')
from predictor import JointPredictor, EVPredictor, OnehotRidgePredictor
from util import spearman, ndcg, auroc

In [2]:
def load_predictor(data_path, predictor_name='ev+onehot'):
    logger.info(f'Predictor {predictor_name} -----')
    predictor_cls = [EVPredictor, OnehotRidgePredictor]
    predictor = JointPredictor(data_path, predictor_cls, predictor_name)

    predictor.load_model()
    return predictor

In [3]:
def main(args):
    logger.info(f'Data path {args.data_path} -----')

    # Load dataset
    test = pd.read_csv(f'{args.data_path}/data_test.csv')
    logger.info(f'Number of samples: {len(test)}')

    predictor = load_predictor(args.data_path)
    test_pred = predictor.predict(test.seq.values)
    metric_fns = {'spearman': spearman, 'ndcg': ndcg, 'auroc': auroc}
    np.set_printoptions(precision=3)
    #logger.info(f"pred: {test_pred}")
    test_ret = [f"{m}: {mfn(test_pred, test['log_fitness'].to_numpy())}" for m, mfn in metric_fns.items()]
    test['pred_fitness'] = test_pred
    test.to_csv(f'{args.data_path}/data_test_pred.csv')
    logger.info(", ".join(test_ret))

In [4]:
def parse_args():
    parser = argparse.ArgumentParser(description='Train ev+onehot model:'
                                                 'python ev_onehot_train.py <dataset_name>')
    parser.add_argument('data_path', type=str,
                        help='Dataset path, including following files: \n'
                             'data_test.csv,    with  `seq` and `log_fitness` columns; \n'
                             'ridge_model.pkl,  pretrained linear regression model; \n'
                             'wt.fasta,         containing wild-type sequences;\n'
                             'plmc/,            folder contain EVmutation model parameters.\n')
    parser.add_argument('--ignore_gaps', dest='ignore_gaps', action='store_true')
    args = parser.parse_args()
    return args

In [5]:
# Demonstration: Evaluation Pipeline
print("Prediction Evaluation Pipeline Demonstration")
print("=" * 60)
print("\nThis notebook defines evaluation functions:")
print("1. load_predictor() - Loads pretrained model from disk")
print("2. main() - Evaluates on test set and computes metrics")
print("   - Spearman correlation")
print("   - NDCG (Normalized Discounted Cumulative Gain)")
print("   - AUROC (Area Under ROC Curve)")
print("\nFor actual usage, provide:")
print("  - data_test.csv with 'seq' and 'log_fitness' columns")
print("  - ridge_model.joblib (trained model)")
print("  - wt.fasta and plmc/ directory")
print("\n" + "=" * 60)

Prediction Evaluation Pipeline Demonstration

This notebook defines evaluation functions:
1. load_predictor() - Loads pretrained model from disk
2. main() - Evaluates on test set and computes metrics
   - Spearman correlation
   - NDCG (Normalized Discounted Cumulative Gain)
   - AUROC (Area Under ROC Curve)

For actual usage, provide:
  - data_test.csv with 'seq' and 'log_fitness' columns
  - ridge_model.joblib (trained model)
  - wt.fasta and plmc/ directory

