In [None]:
%autoreload 2

In [None]:
import os
import anndata as ad
import torch
import numpy as np

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

os.getcwd()

In [None]:
USE_ADT2GEX = False
if USE_ADT2GEX:
    #adt2gex
    test_path = "../datasets/openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna.censor_dataset.output_"
else:
    #atac2gex
    test_path = "../datasets/openproblems_bmmc_multiome_phase2_rna/openproblems_bmmc_multiome_phase2_rna.censor_dataset.output_"

# change this to the matching prediction file path
prediction_path = "../pretrain/pbmc1NoEGEX2ATAC.h5ad"

par = {
    "input_test_prediction": prediction_path,
    "input_test_sol": f"{test_path}test_sol.h5ad",
}

prediction_test = ad.read_h5ad(par["input_test_prediction"])

In [None]:
if type(prediction_test.X) != np.ndarray:
    X = prediction_test.X.toarray()
else:
    X = prediction_test.X
X = torch.tensor(X)

In [None]:
# run this if the two datasets are not aligned and there is a ground truth matching matrix
sol_test = ad.read_h5ad(par["input_test_sol"])
Xsol = torch.tensor(sol_test.X.toarray())
Xsol.argmax(1)
# Order the columns of the prediction matrix so that the perfect prediction is the identity matrix
X = X[:, Xsol.argmax(1)]

In [None]:
labels = torch.arange(X.shape[0])
forward_accuracy = (torch.argmax(X, dim=1) == labels).float().mean().item()
backward_accuracy = (
    (torch.argmax(X, dim=0) == labels).float().mean().item()
)
avg_accuracy = 0.5 * (forward_accuracy + backward_accuracy)
print(forward_accuracy, backward_accuracy, "top1-acc:", avg_accuracy)

In [None]:
_, top_indexes_forward = X.topk(5, dim=1)
_, top_indexes_backward = X.topk(5, dim=0)
l_forward = labels.expand(5, X.shape[0]).T
l_backward = l_forward.T
top5_forward_accuracy = (
    torch.any(top_indexes_forward == l_forward, 1).float().mean().item()
)
top5_backward_accuracy = (
    torch.any(top_indexes_backward == l_backward, 0).float().mean().item()
)
top5_avg_accuracy = 0.5 * (top5_forward_accuracy + top5_backward_accuracy)

print(top5_forward_accuracy, top5_backward_accuracy, "top5-acc:", top5_avg_accuracy)

In [None]:
top_indexes_forward

In [None]:
(top_indexes_forward[:, 1:] != torch.Tensor([1,3,4,0])).float().sum()

In [None]:
torch.any(top_indexes_forward == l_forward, 1).float().mean()

### FOSCTTM

In [None]:
print("FOSCTTM:", (X > torch.diag(X)).float().mean().item())


In [None]:
foscttm_x = (X >= torch.diag(X)).float().mean(axis=1)
foscttm_y = (X >= torch.diag(X)).float().mean(axis=0)
# foscttm_y = (d < np.expand_dims(np.diag(d), axis=0)).mean(axis=0)
print("foscttm_x", foscttm_x, "foscttm_y", foscttm_y)

In [None]:
foscttm_y.mean()

### For soft predictions, the competition score can be made equal to the forward accuracy (or backward accuracy) by putting 1 at the max of each row (or each column) and 0 elsewhere

In [None]:
logits_row_sums = X.clip(min=0).sum(dim=1)
top1_competition_metric = X.clip(min=0).diag().div(logits_row_sums).mean().item()
print("Top-1 competition metric for hard matching predictions:", top1_competition_metric)

In [None]:
mx = torch.max(X, dim=1, keepdim=True).values
hard_X = (mx == X).float()
logits_row_sums = hard_X.clip(min=0).sum(dim=1)
top1_competition_metric = hard_X.clip(min=0).diagonal().div(logits_row_sums).mean().item()
print("Top-1 competition metric for soft matching predictions: ", top1_competition_metric)

In [None]:
X.sum()

In [None]:
# save hard_X
with open("hard_X.npy", "wb") as f:
    np.save(f, hard_X.numpy())
