# UniProt Mapping and Gene Expression Analysis

This notebook analyzes the biological context of predicted peptides by mapping to UniProt and integrating gene expression data.

## Setup and Imports

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import requests
import xml.etree.ElementTree as ET

from src.predictor.utils import SparkManager
from src.analysis.visualization import TargetAnalysisPlotter, PlotConfig
from pyspark.sql.functions import (
    col, collect_set, explode, array, collect_list, udf, broadcast,
    count, when, avg, min as spark_min, max as spark_max, lit,
    percentile_approx, size
)
from pyspark.sql.types import *
import seaborn as sns
import matplotlib.pyplot as plt

# Initialize Spark with increased memory
spark = SparkManager.create_spark_session("UniProt_Analysis", memory_gb=32)

## Load and Process MS data

In [None]:
# Load MS data
ms_schema = StructType([
    StructField("SampleID", StringType(), True),
    StructField("Protein Accession", StringType(), True),
    StructField("Peptide", StringType(), True),
    StructField("IonScore", FloatType(), True)
])

ms_data = (spark.read
    .option("header", True)
    .schema(ms_schema)
    .csv("/path/to/ms_data.csv")
).cache()

# Load prediction results
predictions = SparkManager.read_prediction_results(
    spark,
    "/path/to/results/*_predictions.csv"
).cache()

print("MS Data Samples:", ms_data.select("SampleID").distinct().count())
print("Prediction Samples:", predictions.select("sample_id").distinct().count())

## UniProt Mapping

In [None]:
# Define UDF for UniProt fetching
@udf(returnType=StructType([
    StructField("gene_name", StringType(), True),
    StructField("protein_name", StringType(), True),
    StructField("organism", StringType(), True)
]))
def fetch_uniprot_info(accession):
    """Fetch protein information from UniProt."""
    try:
        base_url = "https://rest.uniprot.org/uniprotkb/"
        response = requests.get(f"{base_url}{accession}.xml")
        if response.status_code != 200:
            return (None, None, None)
        
        root = ET.fromstring(response.content)
        
        # Get gene name
        gene_element = root.find(
            ".//{http://uniprot.org/uniprot}gene/"
            "{http://uniprot.org/uniprot}name[@type='primary']"
        )
        gene_name = gene_element.text if gene_element is not None else None
        
        # Get protein name
        protein_element = root.find(
            ".//{http://uniprot.org/uniprot}protein/"
            "{http://uniprot.org/uniprot}recommendedName/"
            "{http://uniprot.org/uniprot}fullName"
        )
        protein_name = protein_element.text if protein_element is not None else None
        
        # Get organism
        organism_element = root.find(
            ".//{http://uniprot.org/uniprot}organism/"
            "{http://uniprot.org/uniprot}name[@type='scientific']"
        )
        organism = organism_element.text if organism_element is not None else None
        
        return (gene_name, protein_name, organism)
    except:
        return (None, None, None)

# Get unique proteins and map
unique_proteins = (predictions
    .select(collect_set("Protein.Accession"))
    .first()[0]
)

# Create accession DataFrame
accession_df = spark.createDataFrame(
    [(acc,) for acc in unique_proteins],
    ["accession"]
)

# Apply UDF and cache results
uniprot_info = (accession_df
    .withColumn("info", fetch_uniprot_info(col("accession")))
    .select(
        "accession",
        col("info.gene_name").alias("gene_name"),
        col("info.protein_name").alias("protein_name"),
        col("info.organism").alias("organism")
    )
).cache()

# Join with predictions
predictions_with_genes = predictions.join(
    broadcast(uniprot_info),
    predictions["Protein.Accession"] == uniprot_info["accession"],
    "left"
).cache()

## Load Gene Expression Data

In [None]:
# Load TPM data
expression_schema = StructType([
    StructField("gene_id", StringType(), True),
    StructField("sample_id", StringType(), True),
    StructField("TPM", FloatType(), True)
])

expression_data = (spark.read
    .option("header", True)
    .schema(expression_schema)
    .csv("/path/to/expression_data.csv")
).cache()

# Load metadata
metadata_schema = StructType([
    StructField("sample_id", StringType(), True),
    StructField("tissue_type", StringType(), True),
    StructField("disease_state", StringType(), True),
    StructField("source_type", StringType(), True)
])

metadata = (spark.read
    .option("header", True)
    .schema(metadata_schema)
    .csv("/path/to/metadata.csv")
).cache()

# Join expression data with metadata
expression_with_meta = expression_data.join(
    metadata,
    on="sample_id",
    how="inner"
).cache()

## Tissue-specific Analysis

In [None]:
# Calculate tissue-specific expression statistics
tissue_expression = (expression_with_meta
    .groupBy("tissue_type")
    .agg(
        count("sample_id").alias("sample_count"),
        count(when(col("TPM") > 1, True)).alias("expressed_genes"),
        percentile_approx("TPM", 0.5).alias("median_expression"),
        percentile_approx("TPM", [0.25, 0.75]).alias("quartiles")
    )
).cache()

# Convert to pandas for visualization
tissue_stats = tissue_expression.toPandas()

# Plot tissue-specific expression
plt.figure(figsize=(12, 6))
sns.boxplot(
    data=expression_with_meta.toPandas(),
    x="tissue_type",
    y="TPM"
)
plt.xticks(rotation=45)
plt.yscale("log")
plt.title("Gene Expression Distribution by Tissue")

## Strong Binder Analysis

In [None]:
# Analyze strong binders
strong_binders = (predictions_with_genes
    .filter(col("%Rank_EL") < 2)
    .groupBy("gene_name", "protein_name")
    .agg(
        count("*").alias("total_binders"),
        count(distinct("Peptide")).alias("unique_peptides"),
        avg("%Rank_EL").alias("mean_rank"),
        collect_set("sample_id").alias("samples")
    )
    .withColumn("num_samples", size(col("samples")))
    .orderBy(col("total_binders").desc())
).cache()

# Show top binders
print("\nTop binding proteins:")
display(strong_binders.limit(20).toPandas())

# Plot distribution
plt.figure(figsize=(10, 6))
sns.histplot(
    data=strong_binders.toPandas(),
    x="unique_peptides",
    bins=50
)
plt.title("Distribution of Unique Binding Peptides per Protein")
plt.xlabel("Number of Unique Peptides")
plt.xscale("log")

## Cancer/Normal Comparison

In [None]:
# Separate cancer and normal samples
cancer_samples = metadata.filter(col("disease_state") == "Tumor")
normal_samples = metadata.filter(col("disease_state") == "Normal")

# Get predictions for each group
cancer_predictions = predictions_with_genes.join(
    broadcast(cancer_samples.select("sample_id")),
    predictions_with_genes["sample_id"] == cancer_samples["sample_id"],
    "inner"
)

normal_predictions = predictions_with_genes.join(
    broadcast(normal_samples.select("sample_id")),
    predictions_with_genes["sample_id"] == normal_samples["sample_id"],
    "inner"
)

# Calculate differential presentation
def analyze_differential_presentation(cancer_df, normal_df):
    cancer_counts = (cancer_df
        .groupBy("gene_name")
        .agg(
            count(distinct("Peptide")).alias("cancer_peptides"),
            avg("%Rank_EL").alias("cancer_mean_rank")
        )
    )
    
    normal_counts = (normal_df
        .groupBy("gene_name")
        .agg(
            count(distinct("Peptide")).alias("normal_peptides"),
            avg("%Rank_EL").alias("normal_mean_rank")
        )
    )
    
    return (cancer_counts
        .join(normal_counts, "gene_name", "outer")
        .na.fill(0)
        .withColumn(
            "fold_change",
            log2((col("cancer_peptides") + 1) / (col("normal_peptides") + 1))
        )
    )

differential_presentation = analyze_differential_presentation(
    cancer_predictions,
    normal_predictions
).cache()

# Plot differential presentation
plot_data = differential_presentation.toPandas()
plt.figure(figsize=(10, 10))
plt.scatter(
    plot_data["normal_peptides"],
    plot_data["cancer_peptides"],
    alpha=0.5
)
plt.plot([0, plot_data["normal_peptides"].max()],
         [0, plot_data["normal_peptides"].max()],
         'r--', alpha=0.5)
plt.xlabel("Normal Peptides")
plt.ylabel("Cancer Peptides")
plt.title("Cancer vs Normal Peptide Presentation")
plt.xscale("log")
plt.yscale("log")

## Save Analysis Results

In [None]:
# Save results
SparkManager.save_prediction_results(
    strong_binders,
    "analysis_results/strong_binders.csv"
)

SparkManager.save_prediction_results(
    differential_presentation,
    "analysis_results/differential_presentation.csv"
)

# Clean up
[df.unpersist() for df in [
    ms_data, predictions, predictions_with_genes,
    expression_data, metadata, expression_with_meta,
    strong_binders, differential_presentation
]]