In [None]:
import numpy as np
import pandas as pd

from gears import PertData

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join("..", "..", "src", "sena_discrepancy_vae"))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch  # noqa: E402
from utils import MMD_loss  # noqa: E402


In [None]:
# Load "norman" data.
norman = PertData(data_path="../data")
norman.load(data_name="norman")

In [None]:
# Load predicted GEPs.
df = pd.read_csv(filepath_or_buffer="../results/gears_norman_no_test_double.csv")
print(df.head())

In [None]:
results_file_path = os.path.join("..", "results", "gears_norman_no_test_double_mmd.csv")

mmd_avg_values = []

with open(file=results_file_path, mode="w") as f:
    print("double,num_cells,mmd", file=f)

    for i, double in enumerate(df["double"]):
        mmd_values = []

        # Get the predicted GEPs for the current double perturbation.
        pred_gep = df.loc[df["double"] == double]
        pred_gep = df.iloc[0, 1:].tolist()

        # Get all the true GEPs with the current double perturbation.
        double = double.replace("_", "+")
        print(f"Double perturbation {i}/{len(df['double'])}: {double}")
        true_geps = norman.adata[norman.adata.obs["condition"] == double]
        N = true_geps.n_obs

        # BEGIN Average control GEP ####################################################

        # # Compute the estimated GEP using the average control GEP.
        # all_ctrl_geps = norman.adata[norman.adata.obs["condition"] == "ctrl"]
        # avg_ctrl_gep = np.mean(all_ctrl_geps.X.toarray(), axis=0).flatten().tolist()

        # avg_estimated_gep = [
        #     max(0, ctrl + pred) for ctrl, pred in zip(avg_ctrl_gep, pred_gep)
        # ]
        # print(f"Avg. estimated GEP: {avg_estimated_gep}")

        # # Compute the average true GEP.
        # avg_true_gep = np.mean(true_geps.X.toarray(), axis=0).flatten().tolist()
        # print(f"Avg. true GEP: {avg_true_gep}")

        # # Compute the MMD between the average true GEP and the average estimated GEP.
        # MMD_sigma: float = 200.0
        # kernel_num: int = 10
        # mmd_loss = MMD_loss(fix_sigma=MMD_sigma, kernel_num=kernel_num)
        # mmd = mmd_loss.forward(
        #     source=torch.tensor(avg_estimated_gep).unsqueeze(0),
        #     target=torch.tensor(avg_true_gep).unsqueeze(0),
        # )
        # print(f"MMD: {mmd.item()}")

        # mmd_avg_values.append(mmd.item())
        # print(f"{double},{N},{mmd}", file=f)

        # END Average control GEP ######################################################

        # BEGIN Random control GEPs ####################################################

        # Get N random control GEPs.
        all_ctrl_geps = norman.adata[norman.adata.obs["condition"] == "ctrl"]
        random_indices = np.random.choice(all_ctrl_geps.n_obs, size=N, replace=False)
        rand_ctrl_geps = all_ctrl_geps[random_indices, :]

        # Compute the MMD between the true GEPs and the random control GEPs.
        for true_gep, rand_ctrl_gep in zip(true_geps, rand_ctrl_geps):
            true_gep = true_gep.X[0, :].toarray().flatten().tolist()
            rand_ctrl_gep = rand_ctrl_gep.X[0, :].toarray().flatten().tolist()
            # print(f"True GEP: {true_gep}")
            # print(f"Random control GEP: {rand_ctrl_gep}")
            # print(f"Predicted GEP: {pred_gep}")

            # Add the predicted GEP to the random control GEP and clip values < 0.
            estimated_gep = [
                max(0, ctrl + pred) for ctrl, pred in zip(rand_ctrl_gep, pred_gep)
            ]
            # print(f"Estimated GEP: {estimated_gep}")

            # Compute the MMD between the true GEP and the estimated GEP.
            MMD_sigma: float = 200.0
            kernel_num: int = 10
            mmd_loss = MMD_loss(fix_sigma=MMD_sigma, kernel_num=kernel_num)
            mmd = mmd_loss.forward(
                source=torch.tensor(estimated_gep).unsqueeze(0),
                target=torch.tensor(true_gep).unsqueeze(0),
            )
            mmd_values.append(mmd.item())

        mmd_avg = np.mean(mmd_values)
        mmd_avg_values.append(mmd_avg)
        print(f"MMD: {mmd_avg}")
        print(f"{double},{N},{mmd}", file=f)

        # END Random control GEPs ######################################################

    mmd_avg_avg = np.mean(mmd_avg_values)
    print(f"Average MMD: {mmd_avg_avg}")
