# See if searching over different perturbation algorithm parameters does anything

In [None]:
import pandas as pd
import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def load_comparison_matrix(run_name: str) -> pd.DataFrame:
    return pd.read_csv(
        f"../../results/{run_name}/evaluation/weight_matrices_comparison.tsv", sep="\t"
    )

## Does the number of steps matter?

In [None]:
run_name = "steps_1_to_100_1"
# run_name = "grid_search_real_1"
steps_1_to_100 = load_comparison_matrix(run_name)
steps_1_to_100.head()

In [None]:
steps_1_to_100.loc[
    (steps_1_to_100["matrix1"] == "real_perturbed")
    & (steps_1_to_100["matrix2"] == "predicted_perturbed")
]

In [None]:
px.line(
    steps_1_to_100.loc[
        (steps_1_to_100["matrix1"] == "real_perturbed")
        & (steps_1_to_100["matrix2"] == "predicted_perturbed")
    ],
    x="steps",
    y="euclidean_distance",
    title="Score vs Step for Different Perturbation Steps (Euclidean Distance)",
    color="score_transform",
    log_y=True,
).show()

In [None]:
px.line(
    steps_1_to_100.loc[
        (steps_1_to_100["matrix1"] == "real_perturbed")
        & (steps_1_to_100["matrix2"] == "predicted_perturbed")
    ],
    x="steps",
    y="wasserstein_distance",
    title="Score vs Step for Different Perturbation Steps (Wasserstein Distance)",
    color="score_transform",
    log_y=True,
).show()

In [None]:
px.line(
    steps_1_to_100.loc[
        (steps_1_to_100["matrix1"] == "real_perturbed")
        & (steps_1_to_100["matrix2"] == "predicted_perturbed")
    ],
    x="steps",
    y="energy_distance",
    title="Score vs Step for Different Perturbation Steps (Energy Distance)",
    color="score_transform",
    log_y=True,
).show()

## Bar charts

In [None]:
comp = load_comparison_matrix(
    "GSE158067_perturbation_prediction_with_step_by_step_eval_screen-2"
)
comp

In [None]:
def make_bar_chart(df, distance_type: str):
    df = df.copy()
    df[["matrix1", "matrix2"]] = df[["matrix1", "matrix2"]].map(
        lambda x: " ".join(x.split("_")[:-1]).title()
    )
    colnames = df["matrix1"] + " vs " + df["matrix2"]
    distances = df[distance_type]
    px.bar(
        x=colnames,
        y=distances,
        title=f'Perturbation Prediction Evaluation ({distance_type.replace("_", " ").title()})',
        labels={"x": "Matrix Comparison", "y": distance_type.replace("_", " ").title()},
        log_y=True,
    ).show()


make_bar_chart(comp, "wasserstein_distance")
make_bar_chart(comp, "euclidean_distance")
make_bar_chart(comp, "energy_distance")