# NetMHCII Prediction Results Analysis

This notebook provides comprehensive analysis of NetMHCII prediction results, focusing on binding distributions and peptide characteristics.

## Setup and Imports

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

from src.predictor.utils import SparkManager
from src.analysis.visualization import (
    BindingDistributionPlotter,
    PlotConfig
)
from pyspark.sql.functions import (
    col, count, when, min as spark_min,
    expr, length
)

# Initialize Spark
spark = SparkManager.create_spark_session("Prediction_Analysis")

## Load Results


In [None]:
# Load all prediction results
results_df = SparkManager.read_prediction_results(
    spark,
    "/path/to/results/*_results.csv"
)

# Cache the DataFrame for better performance
results_df.cache()

print("Total predictions:", results_df.count())
print("Unique samples:", results_df.select("sample_id").distinct().count())
print("Memory usage:", results_df.count() * len(results_df.columns) * 8 / 1e9, "GB")

## No-transformation Analysis

In [None]:
# Filter for raw peptides without transformations
raw_results = results_df.filter(
    (col("inverted_manual") == 0) & 
    (col("flipped") == "raw")
).cache()

# Calculate statistics
total_pairs = raw_results.count()
stats = raw_results.agg(
    (count(when(col("%Rank_EL") > 5, True)) / total_pairs * 100)
    .alias("above_5_percent"),
    (count(when(col("%Rank_EL") <= 5, True)) / total_pairs * 100)
    .alias("below_5_percent"),
    percentile_approx("%Rank_EL", 0.5).alias("median_rank")
).collect()[0]

print("\nBinding Statistics:")
print(f"Above 5%: {stats['above_5_percent']:.2f}%")
print(f"Below 5%: {stats['below_5_percent']:.2f}%")
print(f"Median Rank: {stats['median_rank']:.2f}")

# Sample data for plotting
sample_size = 100000
if total_pairs > sample_size:
    plot_data = raw_results.select("%Rank_EL").sample(False, sample_size/total_pairs).toPandas()
else:
    plot_data = raw_results.select("%Rank_EL").toPandas()

# Create plot
plotter = BindingDistributionPlotter(PlotConfig(figsize=(12, 6)))
fig = plotter.plot_rank_distribution(
    plot_data,
    title="Binding Rank Distribution (No Transformations)"
)

## Transformation Analysis

In [None]:
# Define transformation groups
transformations = {
    'No Transform': (col("inverted_manual") == 0) & (col("flipped") == "raw"),
    'Inverted': (col("inverted_manual") == 1) & (col("flipped") == "raw"),
    'Flipped': (col("inverted_manual") == 0) & (col("flipped") != "raw"),
    'All Transform': (col("inverted_manual") == 1) | (col("flipped") != "raw")
}

# Analyze each transformation
transform_stats = {}
for name, condition in transformations.items():
    df = results_df.filter(condition)
    total = df.count()
    
    stats = df.agg(
        (count(when(col("%Rank_EL") > 5, True)) / total * 100)
        .alias("above_5_percent"),
        (count(when(col("%Rank_EL") <= 5, True)) / total * 100)
        .alias("below_5_percent"),
        percentile_approx("%Rank_EL", 0.5).alias("median_rank")
    ).collect()[0]
    
    transform_stats[name] = {
        "total": total,
        "above_5": stats["above_5_percent"],
        "below_5": stats["below_5_percent"],
        "median": stats["median_rank"]
    }

# Convert to pandas for visualization
stats_df = pd.DataFrame(transform_stats).T
display(stats_df)

## Length Analysis

In [None]:
# Add length column and analyze
length_stats = (results_df
    .withColumn("peptide_length", length(col("Peptide")))
    .groupBy("peptide_length")
    .agg(
        avg("%Rank_EL").alias("mean_rank"),
        percentile_approx("%Rank_EL", 0.5).alias("median_rank"),
        count("*").alias("count")
    )
    .orderBy("peptide_length")
).cache()

# Convert to pandas for visualization
length_data = length_stats.toPandas()

# Create plots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Length distribution
sns.barplot(data=length_data, x="peptide_length", y="count", ax=ax1)
ax1.set_title("Peptide Length Distribution")
ax1.set_xlabel("Length")
ax1.set_ylabel("Count")

# Length vs Rank
sns.lineplot(data=length_data, x="peptide_length", y="median_rank", ax=ax2)
ax2.set_title("Peptide Length vs Binding Rank")
ax2.set_xlabel("Length")
ax2.set_ylabel("Median Rank EL")
ax2.set_yscale("log")

plt.tight_layout()

## Sample Comparison

In [None]:
# Calculate per-sample statistics
sample_stats = (results_df
    .groupBy("sample_id")
    .agg(
        count("*").alias("total_peptides"),
        count(when(col("%Rank_EL") < 2, True)).alias("strong_binders"),
        count(when((col("%Rank_EL") >= 2) & (col("%Rank_EL") < 10), True))
        .alias("weak_binders"),
        percentile_approx("%Rank_EL", 0.5).alias("median_rank")
    )
).cache()

# Add percentages
sample_stats = (sample_stats
    .withColumn("strong_binder_percent", 
                col("strong_binders") / col("total_peptides") * 100)
    .withColumn("weak_binder_percent", 
                col("weak_binders") / col("total_peptides") * 100)
)

# Convert to pandas for visualization
stats_pd = sample_stats.toPandas()

# Create visualization
fig, axes = plt.subplots(2, 1, figsize=(12, 10))

# Binder percentages
stats_pd.plot(
    kind='bar',
    x='sample_id',
    y=['strong_binder_percent', 'weak_binder_percent'],
    stacked=True,
    ax=axes[0]
)
axes[0].set_title("Binder Distribution by Sample")
axes[0].set_xlabel("Sample")
axes[0].set_ylabel("Percentage")

# Median ranks
stats_pd.plot(
    kind='bar',
    x='sample_id',
    y='median_rank',
    ax=axes[1]
)
axes[1].set_title("Median Rank by Sample")
axes[1].set_xlabel("Sample")
axes[1].set_ylabel("Median Rank EL")
axes[1].set_yscale("log")

plt.tight_layout()

## Binding Motif Analysis

In [None]:
# Analyze core sequences
core_stats = (results_df
    .filter(col("%Rank_EL") <= 2)  # Focus on strong binders
    .groupBy("Core")
    .agg(
        count("*").alias("frequency"),
        avg("%Rank_EL").alias("mean_rank"),
        collect_set("sample_id").alias("samples")
    )
    .withColumn("num_samples", size(col("samples")))
    .orderBy(col("frequency").desc())
).cache()

# Show top cores
print("Most frequent core sequences in strong binders:")
display(core_stats.limit(20).toPandas())

# Core length distribution
core_length_dist = (core_stats
    .withColumn("core_length", length(col("Core")))
    .groupBy("core_length")
    .agg(count("*").alias("count"))
    .orderBy("core_length")
).toPandas()

plt.figure(figsize=(10, 6))
sns.barplot(data=core_length_dist, x="core_length", y="count")
plt.title("Core Sequence Length Distribution")
plt.xlabel("Core Length")
plt.ylabel("Count")

## Save Analysis Results

In [None]:
# Save statistics to files
SparkManager.save_prediction_results(
    sample_stats,
    "analysis_results/sample_statistics.csv"
)

SparkManager.save_prediction_results(
    core_stats,
    "analysis_results/core_statistics.csv"
)

# Clean up
results_df.unpersist()
raw_results.unpersist()
sample_stats.unpersist()
core_stats.unpersist()