# See if searching over different perturbation algorithm parameters does anything

In [13]:
import pandas as pd
import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

In [14]:
def load_comparison_matrix(run_name: str) -> pd.DataFrame:
    return pd.read_csv(
        f"../../results/{run_name}/evaluation/weight_matrices_comparison.tsv", sep="\t"
    )

## Does the number of steps matter?

In [15]:
run_name = "steps_1_to_100_1"
# run_name = "grid_search_real_1"
steps_1_to_100 = load_comparison_matrix(run_name)
steps_1_to_100.head()

Unnamed: 0,matrix1,matrix2,euclidean_distance,wasserstein_distance,energy_distance,combination_id,steps,score_transform,threshold,steepness,midpoint
0,real_perturbed,real_unperturbed,4.56149,1.463655e-06,8.8e-05,0,1,threshold,0.6,,
1,real_perturbed,predicted_perturbed,4.561418,1.458827e-06,8.9e-05,0,1,threshold,0.6,,
2,real_unperturbed,predicted_perturbed,0.057107,4.04839e-08,2e-06,0,1,threshold,0.6,,
3,real_perturbed,real_unperturbed,4.56149,1.463655e-06,8.8e-05,1,10,threshold,0.6,,
4,real_perturbed,predicted_perturbed,4.558264,1.399971e-06,9.1e-05,1,10,threshold,0.6,,


In [16]:
steps_1_to_100.loc[
    (steps_1_to_100["matrix1"] == "real_perturbed")
    & (steps_1_to_100["matrix2"] == "predicted_perturbed")
]

Unnamed: 0,matrix1,matrix2,euclidean_distance,wasserstein_distance,energy_distance,combination_id,steps,score_transform,threshold,steepness,midpoint
1,real_perturbed,predicted_perturbed,4.561418,1.458827e-06,8.89965e-05,0,1,threshold,0.6,,
4,real_perturbed,predicted_perturbed,4.558264,1.399971e-06,9.086129e-05,1,10,threshold,0.6,,
7,real_perturbed,predicted_perturbed,4.054982,0.0002189355,0.009291815,2,20,threshold,0.6,,
10,real_perturbed,predicted_perturbed,5114.844,0.2235703,0.2952516,3,30,threshold,0.6,,
13,real_perturbed,predicted_perturbed,5234488.0,228.6146,9.440862,4,40,threshold,0.6,,
16,real_perturbed,predicted_perturbed,5353336000.0,233795.0,301.9069,5,50,threshold,0.6,,
19,real_perturbed,predicted_perturbed,5474779000000.0,239096800.0,9654.762,6,60,threshold,0.6,,
22,real_perturbed,predicted_perturbed,5598958000000000.0,244519700000.0,308753.2,7,70,threshold,0.6,,
25,real_perturbed,predicted_perturbed,5.72595e+18,250065700000000.0,9873736.0,8,80,threshold,0.6,,
28,real_perturbed,predicted_perturbed,5.855822e+21,2.557375e+17,315756000.0,9,90,threshold,0.6,,


In [17]:
px.line(
    steps_1_to_100.loc[
        (steps_1_to_100["matrix1"] == "real_perturbed")
        & (steps_1_to_100["matrix2"] == "predicted_perturbed")
    ],
    x="steps",
    y="euclidean_distance",
    title="Score vs Step for Different Perturbation Steps (Euclidean Distance)",
    color="score_transform",
    log_y=True,
).show()

In [18]:
px.line(
    steps_1_to_100.loc[
        (steps_1_to_100["matrix1"] == "real_perturbed")
        & (steps_1_to_100["matrix2"] == "predicted_perturbed")
    ],
    x="steps",
    y="wasserstein_distance",
    title="Score vs Step for Different Perturbation Steps (Wasserstein Distance)",
    color="score_transform",
    log_y=True,
).show()

In [76]:
fig = px.line(
    steps_1_to_100.loc[
        (steps_1_to_100["matrix1"] == "real_perturbed")
        & (steps_1_to_100["matrix2"] == "predicted_perturbed")
    ],
    x="steps",
    y="energy_distance",
    title="Energy Distance as a Function of Perturbation Steps",
    color="score_transform",
    log_y=True,
    width=1000,
    labels={
        "steps": "Steps",
        "energy_distance": "Energy Distance",
        "score_transform": "Score Transformation",
    },
)

fig.update_yaxes(
    showgrid=True,
    gridwidth=1.5,
    gridcolor="lightgray",
    showline=True,
    linewidth=2,
    linecolor="black",
    ticks="outside",
    tickwidth=2,
    tickcolor="black",
    title_font=dict(size=18, color="black"),
    tickfont=dict(size=14, color="black"),
)

# Style the axes with scientific look - larger for PDF
fig.update_xaxes(
    showgrid=True,
    gridwidth=1.5,  # Thicker grid
    gridcolor="lightgray",
    showline=True,
    linewidth=2,  # Thicker axis lines
    linecolor="black",
    ticks="outside",
    tickwidth=2,
    tickcolor="black",
    title_font=dict(size=18, color="black"),  # Larger axis title
    tickfont=dict(size=14, color="black"),  # Larger tick labels
    range=[0, None],
)

# Apply scientific styling with dots and different line styles for visibility
fig.update_traces(line=dict(width=4), mode="lines+markers")  # Thicker lines

# Make the two lines visually distinct with different sizes and styles
line_styles = ["dash", "solid"]
marker_symbols = ["circle", "square"]
marker_sizes = [12, 18]  # Larger markers

for i, trace in enumerate(fig.data):
    if i == 0:  # First trace (circle)
        trace.update(
            line=dict(width=4, dash=line_styles[i]),
            marker=dict(
                size=marker_sizes[i],
                symbol=marker_symbols[i],
                line=dict(width=3, color="white"),  # Thicker borders
                opacity=1.0,
            ),
        )
    else:  # Second trace (square)
        trace.update(
            line=dict(width=4, dash=line_styles[i]),
            marker=dict(
                size=marker_sizes[i],
                symbol=marker_symbols[i],
                line=dict(width=3, color="white"),
                opacity=1,
            ),
        )

# Reorder traces so circle (smaller) appears on top of square (larger)
fig.data = fig.data[::-1]

fig.update_layout(
    # Use white background with grid
    plot_bgcolor="white",
    paper_bgcolor="white",
    # Scientific font styling - larger for PDF
    font=dict(family="Arial, sans-serif", size=16, color="black"),  # Larger base font
    # Title styling
    title=dict(
        font=dict(size=22, color="black"), x=0.5, xanchor="center"
    ),  # Larger title
    # Legend styling
    legend=dict(
        bgcolor="rgba(255,255,255,0.8)",
        font=dict(size=14),  # Larger legend font
    ),
    showlegend=True,
)

fig.update_xaxes(range=[0, 101], autorange=False)  # Set explicit range

fig.show()

## Bar charts

In [20]:
comp = load_comparison_matrix(
    "GSE158067_perturbation_prediction_with_step_by_step_eval_screen-2"
)
comp

Unnamed: 0,matrix1,matrix2,wasserstein_distance,euclidean_distance,energy_distance,combination_id,steps,score_transform,threshold,steepness,midpoint
0,real_unperturbed_weights,real_perturbed_weights,4.807301e-06,3.270144,0.000324,,,,,,
1,real_unperturbed_weights,predicted_perturbed_weights,1.230569e-07,0.024028,6e-06,0.0,3.0,threshold,0.6,,
2,real_perturbed_weights,predicted_perturbed_weights,4.790837e-06,3.268107,0.000327,0.0,3.0,threshold,0.6,,


In [21]:
def make_bar_chart(df, distance_type: str):
    df = df.copy()
    df[["matrix1", "matrix2"]] = df[["matrix1", "matrix2"]].map(
        lambda x: " ".join(x.split("_")[:-1]).title()
    )
    colnames = df["matrix1"] + " vs " + df["matrix2"]
    distances = df[distance_type]
    px.bar(
        x=colnames,
        y=distances,
        title=f'Perturbation Prediction Evaluation ({distance_type.replace("_", " ").title()})',
        labels={"x": "Matrix Comparison", "y": distance_type.replace("_", " ").title()},
        log_y=True,
    ).show()


make_bar_chart(comp, "wasserstein_distance")
make_bar_chart(comp, "euclidean_distance")
make_bar_chart(comp, "energy_distance")