To download the cell type annotations:
`aws s3 cp s3://openproblems-bio/public/post_competition/openproblems_bmmc_cite_complete.h5ad ./datasets/post_competition/ --no-sign-request`
`aws s3 cp s3://openproblems-bio/public/post_competition/openproblems_bmmc_multiome_complete.h5ad ./datasets/post_competition/ --no-sign-request`


In [None]:
import os
import anndata as ad
import torch
import numpy
import argparse
import pandas as pd
from tqdm.auto import tqdm

os.getcwd()

In [None]:
# TASK = 'GEX2ADT'
TASK = 'GEX2ATAC'
DATASET_PATH = "datasets"
PREDICTION_PATH = "pretrain/defaultGEX2ATAC.h5ad"
# PREDICTION_PATH = "pretrainNovel/NovelGEX2ATAC.h5ad"


In [None]:
if TASK == 'GEX2ADT':
    test_path = os.path.join(DATASET_PATH, "openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna"
                                                ".censor_dataset.output_")
    completedata_path = os.path.join(DATASET_PATH, "post_competition/openproblems_bmmc_cite_complete.h5ad")
elif TASK == 'GEX2ATAC':
    test_path = os.path.join(DATASET_PATH, "openproblems_bmmc_multiome_phase2_rna"
                                                "/openproblems_bmmc_multiome_phase2_rna.censor_dataset.output_")
    completedata_path = os.path.join(DATASET_PATH, "post_competition/openproblems_bmmc_multiome_complete.h5ad")
else:
    raise ValueError('Unknown task: ' + TASK)

In [None]:
par = {
        "input_mod1": f"{test_path}test_mod1.h5ad",
        "input_mod2": f"{test_path}test_mod2.h5ad",
        "input_complete": completedata_path,
        "input_test_sol": f"{test_path}test_sol.h5ad",
        "input_test_prediction": PREDICTION_PATH,
}

In [None]:
input_mod1 = ad.read_h5ad(par["input_mod1"])
input_mod2 = ad.read_h5ad(par["input_mod2"])
complete = ad.read_h5ad(par["input_complete"])
prediction_test = ad.read_h5ad(par["input_test_prediction"])
sol_test = ad.read_h5ad(par["input_test_sol"])

In [None]:
# for PBMC
input_mod1 = ad.read_h5ad("../datasets/PBMC/glue_processed/test_mod1.h5ad")

In [None]:
PREDICTION_PATH = "pretrain/pbmc1GEX2ATAC.h5ad"
prediction_test = ad.read_h5ad("../" + PREDICTION_PATH)

In [None]:
mod1_withcelltype = input_mod1
mod1_withcelltype.obs["cell_type"]

In [None]:
def obs_fea(adata):
    print(f"The data has {adata.n_obs} observations and {adata.n_vars} features.")

obs_fea(prediction_test)
obs_fea(sol_test)
obs_fea(complete)
obs_fea(input_mod1)
obs_fea(input_mod2)


In [None]:
input_mod2.obs_names

In [None]:
complete

In [None]:
mod1_withcelltype = complete[input_mod1.obs_names]
mod1_withcelltype.obs["cell_type"]

In [None]:
celltypes = mod1_withcelltype.obs["cell_type"].cat.categories.tolist()

In [None]:
celltypes

In [None]:
celltype2idx = dict([(celltype, idx) for idx, celltype in enumerate(celltypes)])
celltype2idx

In [None]:
if type(prediction_test.X) != numpy.ndarray:
    X = prediction_test.X.toarray()
else:
    X = prediction_test.X
X = torch.tensor(X)

In [None]:
Xsol = torch.tensor(sol_test.X.toarray())
Xsol.argmax(1)
# Order the columns of the prediction matrix so that the perfect prediction is the identity matrix
X = X[:, Xsol.argmax(1)]

In [None]:
perm = mod1_withcelltype.obs["cell_type"].to_frame()
print(perm.value_counts())
perm.insert(1, "idx", list(range(X.shape[0])))
perm = perm.sort_values("cell_type")
perm

In [None]:
len(perm)

In [None]:
block_idxs = [0]
for i in tqdm(range(1, len(perm))):
    if perm.iloc[i].cell_type != perm.iloc[i-1].cell_type:
        block_idxs.append(i)
block_idxs.append(len(perm))

In [None]:
print(block_idxs)
len(block_idxs)

In [None]:
# Permute X such that it is a block diagonal matrix with 1 block per cell type
X = X[perm["idx"]][:, perm["idx"]]

In [None]:
print(X.shape)

In [None]:
mask = torch.zeros_like(X)
for i in tqdm(range(1, len(block_idxs))):
    idx_start = block_idxs[i-1]
    idx_end = block_idxs[i]
    mask[idx_start:idx_end, idx_start:idx_end] = 1

In [None]:
# X = Xsoft

In [None]:
Xsoft = X
print(X.sum())
X = X.clip(min=0)
print(X.sum())
mx = torch.max(X, dim=1, keepdim=True).values
X = (mx == X).float()   # convert to a hard matching
logits_row_sums = X.sum(dim=1)
print(logits_row_sums)
X = torch.div(X, logits_row_sums)
print(X.shape, X.sum())

scoreX = X.mul(mask)
print(scoreX.sum())

cell_type_score = scoreX.sum() / scoreX.shape[0]

print("Cell type matching competition score", cell_type_score.item())


In [None]:
len(block_idxs)

In [None]:
block_idxs

In [None]:
per_celltype_scores = []
for i in range(1, len(block_idxs)):
    idx_start = block_idxs[i-1]
    idx_end = block_idxs[i]
    n_cells = idx_end - idx_start
    celltype = perm.iloc[idx_start].cell_type
    print(celltype, "n_cells:", n_cells)
    acc_celltype = (scoreX[idx_start:idx_end].sum()/n_cells).item()
    print(acc_celltype)
    per_celltype_scores.append([celltype, n_cells, acc_celltype])


In [None]:
per_celltype_scores_df = pd.DataFrame(columns=["celltype", "n_cells", "acc_celltype"], data=per_celltype_scores)
print(per_celltype_scores_df.sort_values("acc_celltype"))
print("non balanced acc", numpy.mean(per_celltype_scores_df.acc_celltype.values))
print("tot cell", numpy.sum(per_celltype_scores_df.n_cells.values))
per_celltype_scores_df.to_csv("per_celltype_scores_novel.csv")

In [None]:
def idx2celltypeidx(idx):
    return celltype2idx[perm.iloc[idx].cell_type]

In [None]:
y_celltype_true = []
y_idx_pred = torch.argmax(X, dim=1).numpy()
y_celltype_pred = [idx2celltypeidx(idx) for idx in y_idx_pred]
sum=0
for i in range(1, len(block_idxs)):
    idx_start = block_idxs[i-1]
    idx_end = block_idxs[i]
    n_cells = idx_end - idx_start
    print(n_cells)
    sum += n_cells
    celltype = perm.iloc[idx_start].cell_type
    y_celltype_true += [i-1 for c in range(n_cells)]
print(sum)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sns
cm = confusion_matrix(y_celltype_true, y_celltype_pred, normalize='true')
numpy.sum(cm, axis=1)

In [None]:
# disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=celltypes)
# plt.figure(figsize=(10,10))
# disp.plot(xticks_rotation='vertical', values_format='.1f', include_values=False, cmap='viridis')
# plt.tight_layout()
# plt.show()
fig, ax = plt.subplots(figsize=(13,10))
sns.heatmap(cm, annot=True, fmt='.2f', xticklabels=celltypes, yticklabels=celltypes, square=True)
plt.ylabel('True')
plt.xlabel('Predicted')
plt.tight_layout()
plt.show(block=False)
fig.savefig('confusion_percelltype_novel.pdf')

In [None]:
scoreX.sum().item()

In [None]:
sns.heatmap(mask[::100, ::100].numpy())

In [None]:
sns.heatmap(scoreX[::100, ::100].numpy())

In [None]:
sns.heatmap(torch.pow(Xsoft[::100, ::100], 0.15).numpy())

In [None]:
best = 0
for p in torch.arange(0.05, 2, 0.05):
    print(p)
    X=torch.pow(Xsoft, p)
    # Xsoft = X
    print(X.sum())
    X = X.clip(min=0)
    print(X.sum())
    mx = torch.max(X, dim=0, keepdim=True).values
    # X = (mx == X).float()   # convert to a hard matching
    logits_row_sums = X.sum(dim=0)
    print(logits_row_sums)
    X = torch.div(X, logits_row_sums)
    print(X.shape, X.sum().item())

    scoreX = X.mul(mask)
    print(scoreX.sum())

    cell_type_score = scoreX.sum() / scoreX.shape[0]

    print("Cell type matching competition score", cell_type_score.item())
    if cell_type_score.item() > best:
        best = cell_type_score.item()
        best_p = p
        print("best p:", best_p)

In [None]:
import seaborn as sns
sns.heatmap(scoreX.numpy())