# Viral Polyprotein Cleavage Site Prediction - Data Download & Training

This notebook demonstrates the complete workflow for:
1. Downloading viral polyprotein data from RefSeq
2. Processing and validating the data
3. Training a machine learning model for cleavage site prediction

## Overview

Viral polyproteins are large precursor proteins that are cleaved into functional mature proteins. Accurate prediction of cleavage sites is crucial for understanding viral protein processing and can aid in antiviral drug development.

## 1. Import Required Libraries

First, let's import all the necessary libraries for data download, processing, and machine learning.

In [None]:
# Core libraries
import pandas as pd
import numpy as np
import json
import os
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Machine learning libraries
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.feature_extraction.text import CountVectorizer

# Our custom data preparation module
from data_prep import (
    download_specific_viral_families,
    create_train_val_test_splits,
    validate_data_format
)

print("‚úì All libraries imported successfully!")
print(f"Working directory: {os.getcwd()}")

## 2. Download and Load Data

Let's download viral polyprotein data from RefSeq using our enhanced search terms. We'll target multiple viral families to get a diverse dataset.

In [None]:
# Configuration for data download
VIRAL_FAMILIES = [
    'Coronaviridae',    # SARS-CoV-2, MERS, etc.
    'Picornaviridae',   # Poliovirus, rhinovirus
    'Flaviviridae',     # Dengue, Zika, HCV
    'Caliciviridae',    # Norovirus
    'Arteriviridae',    # PRRSV
]

OUTPUT_FILE = "viral_polyproteins_dataset.json"
MAX_PER_FAMILY = 10  # Limit per family for this demo
EMAIL = "demo@example.com"  # Replace with your email

print("ü¶† Starting data download from RefSeq...")
print(f"Target viral families: {VIRAL_FAMILIES}")
print(f"Max entries per family: {MAX_PER_FAMILY}")

# Download the data
try:
    download_specific_viral_families(
        viral_families=VIRAL_FAMILIES,
        output_file=OUTPUT_FILE,
        max_per_family=MAX_PER_FAMILY,
        email=EMAIL
    )
    print(f"‚úì Data downloaded successfully to {OUTPUT_FILE}")
except Exception as e:
    print(f"‚ùå Error downloading data: {e}")
    print("Note: This requires internet connection to NCBI RefSeq")

In [None]:
# Load and examine the downloaded data
if os.path.exists(OUTPUT_FILE):
    with open(OUTPUT_FILE, 'r') as f:
        data = json.load(f)
    
    print(f"üìä Dataset Summary:")
    print(f"Total polyproteins: {len(data)}")
    
    if len(data) > 0:
        # Analyze by viral family
        families = [entry['viral_family'] for entry in data]
        family_counts = Counter(families)
        
        print(f"\nBy viral family:")
        for family, count in family_counts.items():
            print(f"  ‚Ä¢ {family}: {count} polyproteins")
        
        # Analyze cleavage sites
        cleavage_counts = [len(entry['cleavage_sites']) for entry in data]
        total_sites = sum(cleavage_counts)
        
        print(f"\nCleavage site statistics:")
        print(f"  ‚Ä¢ Total cleavage sites: {total_sites}")
        print(f"  ‚Ä¢ Average sites per protein: {np.mean(cleavage_counts):.1f}")
        print(f"  ‚Ä¢ Range: {min(cleavage_counts)} - {max(cleavage_counts)} sites")
        
        # Sequence length analysis
        seq_lengths = [len(entry['sequence']) for entry in data]
        print(f"\nSequence length statistics:")
        print(f"  ‚Ä¢ Average length: {np.mean(seq_lengths):.0f} amino acids")
        print(f"  ‚Ä¢ Range: {min(seq_lengths)} - {max(seq_lengths)} amino acids")
        
        # Show a sample entry
        sample = data[0]
        print(f"\nüìã Sample entry:")
        print(f"  ‚Ä¢ ID: {sample['protein_id']}")
        print(f"  ‚Ä¢ Organism: {sample['organism']}")
        print(f"  ‚Ä¢ Family: {sample['viral_family']}")
        print(f"  ‚Ä¢ Sequence length: {len(sample['sequence'])} aa")
        print(f"  ‚Ä¢ Cleavage sites: {sample['cleavage_sites']}")
        
else:
    print("‚ùå No data file found. Please run the download cell first.")
    data = []

## 3. Data Preprocessing

Now let's process the raw polyprotein data to create features suitable for machine learning. We'll create a dataset where each amino acid position is labeled as either a cleavage site (1) or non-cleavage site (0).

In [None]:
def create_sequence_features(sequence, window_size=5):
    """
    Create features for each position in a protein sequence using a sliding window.
    
    Args:
        sequence: Protein sequence string
        window_size: Size of the window around each position
    
    Returns:
        List of feature dictionaries, one per position
    """
    features = []
    half_window = window_size // 2
    
    # Pad sequence with special characters
    padded_seq = 'X' * half_window + sequence + 'X' * half_window
    
    for i in range(half_window, len(padded_seq) - half_window):
        # Extract window around position
        window = padded_seq[i - half_window:i + half_window + 1]
        
        # Create features
        feature_dict = {
            'center_aa': padded_seq[i],
            'window': window,
            'position': i - half_window,  # 0-indexed position in original sequence
            'rel_position': (i - half_window) / len(sequence),  # Relative position
        }
        
        # Add amino acid composition features
        for aa in 'ACDEFGHIKLMNPQRSTVWY':
            feature_dict[f'aa_{aa}'] = window.count(aa) / len(window)
        
        features.append(feature_dict)
    
    return features

def process_polyprotein_data(data):
    """
    Convert polyprotein data into a machine learning dataset.
    
    Args:
        data: List of polyprotein dictionaries
    
    Returns:
        pandas DataFrame with features and labels
    """
    all_features = []
    
    for entry in data:
        sequence = entry['sequence']
        cleavage_sites = set(entry['cleavage_sites'])
        viral_family = entry['viral_family']
        organism = entry['organism']
        
        # Get features for each position
        seq_features = create_sequence_features(sequence)
        
        # Add labels and metadata
        for feature in seq_features:
            position = feature['position']
            feature['is_cleavage'] = 1 if position in cleavage_sites else 0
            feature['viral_family'] = viral_family
            feature['organism'] = organism
            feature['protein_id'] = entry['protein_id']
            feature['sequence_length'] = len(sequence)
            
            all_features.append(feature)
    
    return pd.DataFrame(all_features)

# Process the data
if len(data) > 0:
    print("üîÑ Processing polyprotein data for machine learning...")
    df = process_polyprotein_data(data)
    
    print(f"‚úì Created dataset with {len(df)} amino acid positions")
    print(f"‚úì Features per position: {len(df.columns)}")
    print(f"‚úì Cleavage sites: {df['is_cleavage'].sum()}")
    print(f"‚úì Non-cleavage sites: {(df['is_cleavage'] == 0).sum()}")
    print(f"‚úì Class balance: {df['is_cleavage'].mean():.3f} cleavage rate")
    
    # Display the first few rows
    print(f"\nüìä Sample of processed data:")
    display_cols = ['protein_id', 'position', 'center_aa', 'window', 'is_cleavage', 'viral_family']
    print(df[display_cols].head(10))
    
else:
    print("‚ùå No data to process")
    df = pd.DataFrame()

## 4. Model Training & Evaluation

Now let's split the data and train machine learning models to predict cleavage sites.

In [None]:
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve
from sklearn.preprocessing import StandardScaler, LabelEncoder

def prepare_features(df):
    """
    Prepare features for machine learning models.
    
    Args:
        df: DataFrame from process_polyprotein_data
    
    Returns:
        X: Feature matrix
        y: Target labels
        feature_names: List of feature names
    """
    if len(df) == 0:
        return np.array([]), np.array([]), []
    
    # Select amino acid composition features
    aa_features = [col for col in df.columns if col.startswith('aa_')]
    
    # Additional numerical features
    numerical_features = ['rel_position', 'sequence_length']
    
    # Categorical features (encoded)
    categorical_features = []
    df_encoded = df.copy()
    
    # Encode viral family
    if 'viral_family' in df.columns:
        le_family = LabelEncoder()
        df_encoded['viral_family_encoded'] = le_family.fit_transform(df['viral_family'])
        categorical_features.append('viral_family_encoded')
    
    # Combine all features
    feature_cols = aa_features + numerical_features + categorical_features
    X = df_encoded[feature_cols].values
    y = df['is_cleavage'].values
    
    return X, y, feature_cols

def train_and_evaluate_models(X, y, feature_names):
    """
    Train and evaluate multiple models using cross-validation.
    """
    if len(X) == 0:
        print("‚ùå No data available for training")
        return
    
    print(f"üèÉ Training models on {len(X)} samples with {X.shape[1]} features")
    print(f"üìä Class distribution: {np.bincount(y)}")
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # Define models
    models = {
        'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced'),
        'Logistic Regression': LogisticRegression(random_state=42, class_weight='balanced', max_iter=1000)
    }
    
    results = {}
    
    for name, model in models.items():
        print(f"\nüî¨ Training {name}...")
        
        # Use scaled features for Logistic Regression, original for Random Forest
        X_train_model = X_train_scaled if 'Logistic' in name else X_train
        X_test_model = X_test_scaled if 'Logistic' in name else X_test
        
        # Train model
        model.fit(X_train_model, y_train)
        
        # Cross-validation
        cv_scores = cross_val_score(
            model, X_train_model, y_train, 
            cv=StratifiedKFold(n_splits=5, shuffle=True, random_state=42),
            scoring='roc_auc'
        )
        
        # Test predictions
        y_pred = model.predict(X_test_model)
        y_pred_proba = model.predict_proba(X_test_model)[:, 1]
        
        # Metrics
        test_auc = roc_auc_score(y_test, y_pred_proba)
        
        results[name] = {
            'model': model,
            'cv_auc_mean': cv_scores.mean(),
            'cv_auc_std': cv_scores.std(),
            'test_auc': test_auc,
            'y_test': y_test,
            'y_pred': y_pred,
            'y_pred_proba': y_pred_proba
        }
        
        print(f"‚úì CV AUC: {cv_scores.mean():.3f} ¬± {cv_scores.std():.3f}")
        print(f"‚úì Test AUC: {test_auc:.3f}")
        print(f"‚úì Classification Report:")
        print(classification_report(y_test, y_pred, digits=3))
    
    return results, scaler, feature_names

# Train models if data is available
if len(df) > 0:
    print("üöÄ Preparing features and training models...")
    X, y, feature_names = prepare_features(df)
    
    if len(X) > 0:
        model_results, scaler, feature_names = train_and_evaluate_models(X, y, feature_names)
        print("‚úÖ Model training completed!")
    else:
        print("‚ùå No features could be prepared")
        model_results = {}
else:
    print("‚ùå No data available for training")
    model_results = {}

## 5. Results Visualization

Let's visualize the model performance and feature importance.

In [None]:
def plot_model_results(model_results):
    """
    Create visualizations for model performance.
    """
    if not model_results:
        print("‚ùå No model results to plot")
        return
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. ROC Curves
    ax1 = axes[0, 0]
    for name, results in model_results.items():
        y_test = results['y_test']
        y_pred_proba = results['y_pred_proba']
        
        fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
        auc = results['test_auc']
        
        ax1.plot(fpr, tpr, label=f'{name} (AUC = {auc:.3f})', linewidth=2)
    
    ax1.plot([0, 1], [0, 1], 'k--', alpha=0.5)
    ax1.set_xlabel('False Positive Rate')
    ax1.set_ylabel('True Positive Rate')
    ax1.set_title('ROC Curves')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Cross-validation AUC comparison
    ax2 = axes[0, 1]
    names = list(model_results.keys())
    cv_means = [model_results[name]['cv_auc_mean'] for name in names]
    cv_stds = [model_results[name]['cv_auc_std'] for name in names]
    
    bars = ax2.bar(names, cv_means, yerr=cv_stds, capsize=5, alpha=0.7)
    ax2.set_ylabel('Cross-validation AUC')
    ax2.set_title('Model Comparison (5-fold CV)')
    ax2.set_ylim(0, 1)
    ax2.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar, mean in zip(bars, cv_means):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{mean:.3f}', ha='center', va='bottom')
    
    # 3. Confusion Matrix (for Random Forest)
    if 'Random Forest' in model_results:
        ax3 = axes[1, 0]
        rf_results = model_results['Random Forest']
        cm = confusion_matrix(rf_results['y_test'], rf_results['y_pred'])
        
        im = ax3.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        ax3.figure.colorbar(im, ax=ax3)
        
        # Add text annotations
        thresh = cm.max() / 2.
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax3.text(j, i, format(cm[i, j], 'd'),
                        ha="center", va="center",
                        color="white" if cm[i, j] > thresh else "black")
        
        ax3.set_xlabel('Predicted Label')
        ax3.set_ylabel('True Label')
        ax3.set_title('Confusion Matrix (Random Forest)')
        ax3.set_xticks([0, 1])
        ax3.set_yticks([0, 1])
        ax3.set_xticklabels(['Not Cleavage', 'Cleavage'])
        ax3.set_yticklabels(['Not Cleavage', 'Cleavage'])
    
    # 4. Feature Importance (Random Forest)
    if 'Random Forest' in model_results and 'feature_names' in globals():
        ax4 = axes[1, 1]
        rf_model = model_results['Random Forest']['model']
        importances = rf_model.feature_importances_
        
        # Get top 15 most important features
        indices = np.argsort(importances)[::-1][:15]
        top_features = [feature_names[i] for i in indices]
        top_importances = importances[indices]
        
        y_pos = np.arange(len(top_features))
        ax4.barh(y_pos, top_importances, alpha=0.7)
        ax4.set_yticks(y_pos)
        ax4.set_yticklabels(top_features)
        ax4.set_xlabel('Feature Importance')
        ax4.set_title('Top 15 Feature Importances (Random Forest)')
        ax4.grid(True, alpha=0.3)
        
        # Reverse y-axis to show most important at top
        ax4.invert_yaxis()
    
    plt.tight_layout()
    plt.show()

def plot_cleavage_site_analysis(df):
    """
    Analyze cleavage site patterns in the data.
    """
    if len(df) == 0:
        print("‚ùå No data to analyze")
        return
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. Cleavage sites by viral family
    ax1 = axes[0, 0]
    cleavage_by_family = df.groupby('viral_family')['is_cleavage'].agg(['sum', 'count', 'mean'])
    cleavage_by_family['rate'] = cleavage_by_family['mean']
    
    bars = ax1.bar(cleavage_by_family.index, cleavage_by_family['rate'], alpha=0.7)
    ax1.set_ylabel('Cleavage Site Rate')
    ax1.set_title('Cleavage Site Rate by Viral Family')
    ax1.tick_params(axis='x', rotation=45)
    ax1.grid(True, alpha=0.3)
    
    # Add value labels
    for bar, rate in zip(bars, cleavage_by_family['rate']):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.001,
                f'{rate:.3f}', ha='center', va='bottom')
    
    # 2. Amino acid frequency at cleavage sites
    ax2 = axes[0, 1]
    cleavage_aa = df[df['is_cleavage'] == 1]['center_aa'].value_counts()
    non_cleavage_aa = df[df['is_cleavage'] == 0]['center_aa'].value_counts()
    
    # Normalize by total counts
    cleavage_freq = cleavage_aa / cleavage_aa.sum()
    non_cleavage_freq = non_cleavage_aa / non_cleavage_aa.sum()
    
    all_aa = sorted(set(cleavage_freq.index) | set(non_cleavage_freq.index))
    cleavage_vals = [cleavage_freq.get(aa, 0) for aa in all_aa]
    non_cleavage_vals = [non_cleavage_freq.get(aa, 0) for aa in all_aa]
    
    x = np.arange(len(all_aa))
    width = 0.35
    
    ax2.bar(x - width/2, cleavage_vals, width, label='Cleavage Sites', alpha=0.7)
    ax2.bar(x + width/2, non_cleavage_vals, width, label='Non-cleavage Sites', alpha=0.7)
    
    ax2.set_xlabel('Amino Acid')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Amino Acid Frequency at Cleavage vs Non-cleavage Sites')
    ax2.set_xticks(x)
    ax2.set_xticklabels(all_aa)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. Cleavage site position distribution
    ax3 = axes[1, 0]
    cleavage_positions = df[df['is_cleavage'] == 1]['rel_position']
    
    ax3.hist(cleavage_positions, bins=20, alpha=0.7, edgecolor='black')
    ax3.set_xlabel('Relative Position in Sequence')
    ax3.set_ylabel('Number of Cleavage Sites')
    ax3.set_title('Distribution of Cleavage Sites by Position')
    ax3.grid(True, alpha=0.3)
    
    # 4. Sequence length distribution
    ax4 = axes[1, 1]
    unique_sequences = df.drop_duplicates('protein_id')
    
    ax4.hist(unique_sequences['sequence_length'], bins=15, alpha=0.7, edgecolor='black')
    ax4.set_xlabel('Sequence Length')
    ax4.set_ylabel('Number of Sequences')
    ax4.set_title('Distribution of Sequence Lengths')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Generate plots if data and results are available
if len(df) > 0:
    print("üìä Creating visualizations...")
    
    # Plot model results
    if model_results:
        plot_model_results(model_results)
    
    # Plot data analysis
    plot_cleavage_site_analysis(df)
    
    print("‚úÖ Visualization completed!")
else:
    print("‚ùå No data available for visualization")

## 6. Summary & Next Steps

Let's summarize our findings and suggest improvements for the model.

In [None]:
# Summary and analysis
if len(df) > 0 and model_results:
    print("üéØ VIRAL POLYPROTEIN CLEAVAGE PREDICTION - SUMMARY")
    print("=" * 60)
    
    # Dataset summary
    print(f"\nüìä DATASET SUMMARY:")
    print(f"   ‚Ä¢ Total amino acid positions: {len(df):,}")
    print(f"   ‚Ä¢ Unique proteins: {df['protein_id'].nunique()}")
    print(f"   ‚Ä¢ Viral families: {df['viral_family'].nunique()}")
    print(f"   ‚Ä¢ Cleavage sites: {df['is_cleavage'].sum():,}")
    print(f"   ‚Ä¢ Cleavage rate: {df['is_cleavage'].mean():.1%}")
    
    # Viral family breakdown
    print(f"\nü¶† VIRAL FAMILIES:")
    family_stats = df.groupby('viral_family').agg({
        'protein_id': 'nunique',
        'is_cleavage': ['sum', 'mean']
    }).round(3)
    
    for family in family_stats.index:
        proteins = family_stats.loc[family, ('protein_id', 'nunique')]
        cleavage_sites = family_stats.loc[family, ('is_cleavage', 'sum')]
        cleavage_rate = family_stats.loc[family, ('is_cleavage', 'mean')]
        print(f"   ‚Ä¢ {family}: {proteins} proteins, {cleavage_sites} sites ({cleavage_rate:.1%})")
    
    # Model performance
    print(f"\nü§ñ MODEL PERFORMANCE:")
    for name, results in model_results.items():
        cv_auc = results['cv_auc_mean']
        test_auc = results['test_auc']
        print(f"   ‚Ä¢ {name}:")
        print(f"     - Cross-validation AUC: {cv_auc:.3f} ¬± {results['cv_auc_std']:.3f}")
        print(f"     - Test AUC: {test_auc:.3f}")
    
    # Feature insights (Random Forest)
    if 'Random Forest' in model_results and 'feature_names' in globals():
        rf_model = model_results['Random Forest']['model']
        importances = rf_model.feature_importances_
        
        print(f"\nüîç TOP PREDICTIVE FEATURES:")
        indices = np.argsort(importances)[::-1][:5]
        for i, idx in enumerate(indices):
            feature = feature_names[idx]
            importance = importances[idx]
            print(f"   {i+1}. {feature}: {importance:.3f}")
    
    print(f"\nüí° INSIGHTS:")
    
    # Class imbalance insight
    cleavage_rate = df['is_cleavage'].mean()
    if cleavage_rate < 0.1:
        print(f"   ‚Ä¢ High class imbalance ({cleavage_rate:.1%} cleavage sites)")
        print(f"   ‚Ä¢ Used balanced class weights to address this")
    
    # Performance insight
    best_model = max(model_results.items(), key=lambda x: x[1]['test_auc'])
    best_name, best_results = best_model
    if best_results['test_auc'] > 0.8:
        print(f"   ‚Ä¢ {best_name} shows good performance (AUC = {best_results['test_auc']:.3f})")
    elif best_results['test_auc'] > 0.7:
        print(f"   ‚Ä¢ {best_name} shows moderate performance (AUC = {best_results['test_auc']:.3f})")
    else:
        print(f"   ‚Ä¢ Models show limited performance (best AUC = {best_results['test_auc']:.3f})")
    
    print(f"\nüöÄ NEXT STEPS:")
    print(f"   1. Collect more data (especially from underrepresented families)")
    print(f"   2. Try ensemble methods or deep learning approaches")
    print(f"   3. Include structural features (secondary structure, surface accessibility)")
    print(f"   4. Experiment with different window sizes and feature engineering")
    print(f"   5. Use domain-specific features (protease specificity motifs)")
    print(f"   6. Cross-validate across viral families to test generalization")
    
    print(f"\nüìö RESOURCES:")
    print(f"   ‚Ä¢ MEROPS database: https://www.ebi.ac.uk/merops/")
    print(f"   ‚Ä¢ UniProt viral proteomes: https://www.uniprot.org/proteomes/")
    print(f"   ‚Ä¢ PDB structures for 3D features: https://www.rcsb.org/")
    
else:
    print("‚ùå No data or model results available for summary")

print(f"\n‚úÖ Notebook execution completed!")
print(f"üìù Check the plots above for detailed visualizations")