In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install scanpy

Collecting scanpy
  Downloading scanpy-1.10.4-py3-none-any.whl.metadata (9.3 kB)
Collecting anndata>=0.8 (from scanpy)
  Downloading anndata-0.11.1-py3-none-any.whl.metadata (8.2 kB)
Collecting legacy-api-wrap>=1.4 (from scanpy)
  Downloading legacy_api_wrap-1.4.1-py3-none-any.whl.metadata (2.1 kB)
Collecting pynndescent>=0.5 (from scanpy)
  Downloading pynndescent-0.5.13-py3-none-any.whl.metadata (6.8 kB)
Collecting session-info (from scanpy)
  Downloading session_info-1.0.0.tar.gz (24 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting umap-learn!=0.5.0,>=0.5 (from scanpy)
  Downloading umap_learn-0.5.7-py3-none-any.whl.metadata (21 kB)
Collecting array-api-compat!=1.5,>1.4 (from anndata>=0.8->scanpy)
  Downloading array_api_compat-1.9.1-py3-none-any.whl.metadata (1.6 kB)
Collecting stdlib_list (from session-info->scanpy)
  Downloading stdlib_list-0.11.0-py3-none-any.whl.metadata (3.3 kB)
Downloading scanpy-1.10.4-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━

In [None]:
! pip install Cell-BLAST

Collecting Cell-BLAST
  Downloading cell_blast-0.5.1-py3-none-any.whl.metadata (4.1 kB)
Collecting igraph>=0.7.1 (from Cell-BLAST)
  Downloading igraph-0.11.8-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting pronto<=0.12.2,>=0.10.2 (from Cell-BLAST)
  Downloading pronto-0.12.2-py2.py3-none-any.whl.metadata (6.2 kB)
Collecting loompy>=2.0.6 (from Cell-BLAST)
  Downloading loompy-3.0.7.tar.gz (4.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting leidenalg>=0.8.10 (from Cell-BLAST)
  Downloading leidenalg-0.10.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting pynvml>=8.0.1 (from Cell-BLAST)
  Downloading pynvml-11.5.3-py3-none-any.whl.metadata (8.8 kB)
Collecting texttable>=1.6.2 (from igraph>=0.7.1->Cell-BLAST)
  Downloading texttable-1.7.0-py2.py3-none-any.whl.metadata (9.

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [None]:
import Cell_BLAST as cb

[INFO] Cell BLAST: Using CPU as computation device.
INFO:Cell BLAST:Using CPU as computation device.


In [None]:
data = '/content/drive/MyDrive/brain_raw.h5ad'
adata = sc.read_h5ad(data)

In [None]:
adata

AnnData object with n_obs × n_vars = 3401 × 23433
    obs: 'cell_ontology_class', 'subtissue', 'mouse.sex', 'mouse.id', 'plate.barcode'
    var: 'ERCC', 'mt', 'ribo', 'hb'

In [None]:
print("Contains NaNs:", np.isnan(adata.X).any())
print("Contains Infinities:", np.isinf(adata.X).any())

Contains NaNs: False
Contains Infinities: False


In [None]:
sc.pp.filter_genes(adata, min_counts=10)
sc.pp.filter_cells(adata, min_genes=200)

In [None]:
adata.var

Unnamed: 0,ERCC,mt,ribo,hb,n_counts
0610005C13Rik,False,False,False,False,402
0610007C21Rik,False,False,False,False,701327
0610007L01Rik,False,False,False,False,122227
0610007N19Rik,False,False,False,False,22101
0610007P08Rik,False,False,False,False,55534
...,...,...,...,...,...
Zyx,False,False,False,False,127115
Zzef1,False,False,False,False,100962
Zzz3,False,False,False,False,92536
a,False,False,False,False,75


In [None]:
if np.isnan(adata.X).any() or np.isinf(adata.X).any():
    adata.X = np.nan_to_num(adata.X, nan=0.0, posinf=0.0, neginf=0.0)


In [None]:
adata.X = adata.X.astype(np.float32)


In [None]:
#@title train test set
import scanpy as sc
from sklearn.model_selection import train_test_split

# label distirbutin
print(adata.obs['cell_ontology_class'].value_counts())

#split data into train and test based on the labels
train_idx, test_idx = train_test_split(
    range(adata.n_obs),
    test_size=0.2,  # 20% of the data for testing
    stratify=adata.obs['cell_ontology_class'],  #stratified split by label
    random_state=42
)




cell_ontology_class
oligodendrocyte                   1574
endothelial cell                   715
astrocyte                          432
neuron                             281
oligodendrocyte precursor cell     203
brain pericyte                     156
Bergmann glial cell                 40
Name: count, dtype: int64


In [None]:
# Create separate AnnData objects for train and test sets
adata_train = adata[train_idx, :].copy()
adata_test = adata[test_idx, :].copy()

# Display shapes to verify the split
print(f"Training set: {adata_train.shape}")
print(f"Test set: {adata_test.shape}")

# Optional: Save the splits for later use
adata_train.write('/content/drive/MyDrive/train_set.h5ad')
adata_test.write('/content/drive/MyDrive/test_set.h5ad')

Training set: (2720, 19001)
Test set: (681, 19001)


In [None]:
adata_train

AnnData object with n_obs × n_vars = 2720 × 19001
    obs: 'cell_ontology_class', 'subtissue', 'mouse.sex', 'mouse.id', 'plate.barcode', 'n_genes'
    var: 'ERCC', 'mt', 'ribo', 'hb', 'n_counts'

In [None]:
#@title Gene selection
axes = cb.data.find_variable_genes(adata)
adata.var['variable_genes'].sum()

  .groupby("log_mean_bin")
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  summary_df["log_vmr_scaled"].fillna(0, inplace=True)
You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step 

1178

In [None]:
#@title Cell_Blast Method
#Step 1: reduce dimension (unsupervised)
axes = cb.data.find_variable_genes(adata_train)

  .groupby("log_mean_bin")
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  summary_df["log_vmr_scaled"].fillna(0, inplace=True)
You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step 

In [None]:
adata_train

AnnData object with n_obs × n_vars = 2720 × 19001
    obs: 'cell_ontology_class', 'subtissue', 'mouse.sex', 'mouse.id', 'plate.barcode', 'n_genes', '__libsize__'
    var: 'ERCC', 'mt', 'ribo', 'hb', 'n_counts', 'variable_genes'

In [None]:
#train 4 models
import time
start_time=time.time()
models = []
for i in range(4):
    models.append(cb.directi.fit_DIRECTi(
        adata_train, genes=adata_train.var.query("variable_genes").index,
        latent_dim=10, cat_dim=20, random_seed=i
    ))

[INFO] Cell BLAST: Using model path: /tmp/tmpm7msa0ku
INFO:Cell BLAST:Using model path: /tmp/tmpm7msa0ku


[DIRECTi epoch 0] train=6.476, val=6.478, time elapsed=1.9s Best save...
[DIRECTi epoch 1] train=6.465, val=6.463, time elapsed=1.2s Best save...
[DIRECTi epoch 2] train=6.450, val=6.445, time elapsed=1.2s Best save...
[DIRECTi epoch 3] train=6.426, val=6.417, time elapsed=1.2s Best save...
[DIRECTi epoch 4] train=6.397, val=6.370, time elapsed=1.2s Best save...
[DIRECTi epoch 5] train=6.323, val=6.256, time elapsed=1.2s Best save...
[DIRECTi epoch 6] train=5.723, val=3.972, time elapsed=1.2s Best save...
[DIRECTi epoch 7] train=2.740, val=2.645, time elapsed=1.2s Best save...
[DIRECTi epoch 8] train=1.890, val=1.844, time elapsed=1.2s Best save...
[DIRECTi epoch 9] train=1.850, val=1.852, time elapsed=1.6s
[DIRECTi epoch 10] train=1.846, val=1.793, time elapsed=1.7s Best save...
[DIRECTi epoch 11] train=1.844, val=1.823, time elapsed=1.7s
[DIRECTi epoch 12] train=1.842, val=1.778, time elapsed=1.4s Best save...
[DIRECTi epoch 13] train=1.839, val=1.773, time elapsed=1.2s Best save...


  self.load_state_dict(torch.load(os.path.join(path, checkpoint), map_location=DEVICE))
[INFO] Cell BLAST: Using model path: /tmp/tmpdkjf9d2c
INFO:Cell BLAST:Using model path: /tmp/tmpdkjf9d2c


[DIRECTi epoch 0] train=6.479, val=6.473, time elapsed=1.2s Best save...
[DIRECTi epoch 1] train=6.464, val=6.458, time elapsed=1.2s Best save...
[DIRECTi epoch 2] train=6.447, val=6.440, time elapsed=1.2s Best save...
[DIRECTi epoch 3] train=6.427, val=6.413, time elapsed=1.7s Best save...
[DIRECTi epoch 4] train=6.392, val=6.366, time elapsed=1.7s Best save...
[DIRECTi epoch 5] train=6.326, val=6.260, time elapsed=1.7s Best save...
[DIRECTi epoch 6] train=5.819, val=4.499, time elapsed=1.3s Best save...
[DIRECTi epoch 7] train=2.965, val=2.967, time elapsed=1.2s Best save...
[DIRECTi epoch 8] train=1.889, val=1.957, time elapsed=1.3s Best save...
[DIRECTi epoch 9] train=1.841, val=1.894, time elapsed=1.2s Best save...
[DIRECTi epoch 10] train=1.833, val=1.868, time elapsed=1.2s Best save...
[DIRECTi epoch 11] train=1.833, val=1.881, time elapsed=1.2s
[DIRECTi epoch 12] train=1.836, val=1.861, time elapsed=1.2s Best save...
[DIRECTi epoch 13] train=1.834, val=1.854, time elapsed=1.2s 

  self.load_state_dict(torch.load(os.path.join(path, checkpoint), map_location=DEVICE))
[INFO] Cell BLAST: Using model path: /tmp/tmp22n3idw_
INFO:Cell BLAST:Using model path: /tmp/tmp22n3idw_


[DIRECTi epoch 0] train=6.461, val=6.607, time elapsed=1.2s Best save...
[DIRECTi epoch 1] train=6.452, val=6.593, time elapsed=1.2s Best save...
[DIRECTi epoch 2] train=6.431, val=6.574, time elapsed=1.2s Best save...
[DIRECTi epoch 3] train=6.413, val=6.547, time elapsed=1.2s Best save...
[DIRECTi epoch 4] train=6.377, val=6.497, time elapsed=1.5s Best save...
[DIRECTi epoch 5] train=6.302, val=6.376, time elapsed=1.7s Best save...
[DIRECTi epoch 6] train=5.714, val=4.301, time elapsed=1.7s Best save...
[DIRECTi epoch 7] train=2.813, val=2.282, time elapsed=1.4s Best save...
[DIRECTi epoch 8] train=1.882, val=1.923, time elapsed=1.2s Best save...
[DIRECTi epoch 9] train=1.842, val=1.876, time elapsed=1.2s Best save...
[DIRECTi epoch 10] train=1.836, val=1.881, time elapsed=1.2s
[DIRECTi epoch 11] train=1.835, val=1.882, time elapsed=1.2s
[DIRECTi epoch 12] train=1.831, val=1.868, time elapsed=1.2s Best save...
[DIRECTi epoch 13] train=1.832, val=1.855, time elapsed=1.2s Best save...


  self.load_state_dict(torch.load(os.path.join(path, checkpoint), map_location=DEVICE))
[INFO] Cell BLAST: Using model path: /tmp/tmp95n0ecwv
INFO:Cell BLAST:Using model path: /tmp/tmp95n0ecwv


[DIRECTi epoch 0] train=6.475, val=6.505, time elapsed=1.2s Best save...
[DIRECTi epoch 1] train=6.461, val=6.491, time elapsed=1.2s Best save...
[DIRECTi epoch 2] train=6.447, val=6.473, time elapsed=1.3s Best save...
[DIRECTi epoch 3] train=6.425, val=6.446, time elapsed=1.7s Best save...
[DIRECTi epoch 4] train=6.387, val=6.399, time elapsed=1.7s Best save...
[DIRECTi epoch 5] train=6.316, val=6.286, time elapsed=1.5s Best save...
[DIRECTi epoch 6] train=5.726, val=4.086, time elapsed=1.2s Best save...
[DIRECTi epoch 7] train=2.767, val=2.823, time elapsed=1.2s Best save...
[DIRECTi epoch 8] train=1.876, val=1.870, time elapsed=1.2s Best save...
[DIRECTi epoch 9] train=1.847, val=1.849, time elapsed=1.2s Best save...
[DIRECTi epoch 10] train=1.838, val=1.825, time elapsed=1.2s Best save...
[DIRECTi epoch 11] train=1.838, val=1.867, time elapsed=1.2s
[DIRECTi epoch 12] train=1.832, val=1.831, time elapsed=1.2s
[DIRECTi epoch 13] train=1.834, val=1.830, time elapsed=1.4s
[DIRECTi epoc

  self.load_state_dict(torch.load(os.path.join(path, checkpoint), map_location=DEVICE))


In [None]:
print("Time elapsed: %.1fs" % (time.time() - start_time))

Time elapsed: 1558.3s


In [None]:
blast = cb.blast.BLAST(models,adata_train)

[INFO] Cell BLAST: Projecting to latent space...
INFO:Cell BLAST:Projecting to latent space...
[INFO] Cell BLAST: Fitting nearest neighbor trees...
INFO:Cell BLAST:Fitting nearest neighbor trees...
[INFO] Cell BLAST: Sampling from posteriors...
INFO:Cell BLAST:Sampling from posteriors...
[INFO] Cell BLAST: Generating empirical null distributions...
INFO:Cell BLAST:Generating empirical null distributions...


In [None]:
blast.save("./adata_train_blast")
del blast
blast = cb.blast.BLAST.load("./adata_train_blast")

  configuration = torch.load(os.path.join(path, config))
  model.load_state_dict(torch.load(os.path.join(path, weights), map_location=DEVICE), strict=False)
[INFO] Cell BLAST: Fitting nearest neighbor trees...
INFO:Cell BLAST:Fitting nearest neighbor trees...


In [None]:
#@title test\
start_time = time.time()
test_hits = blast.query(adata_test)
print("Time per query: %.1fms" % (
    (time.time() - start_time) * 1000 / adata_test.shape[0]
))

[INFO] Cell BLAST: Projecting to latent space...
INFO:Cell BLAST:Projecting to latent space...
[INFO] Cell BLAST: Doing nearest neighbor search...
INFO:Cell BLAST:Doing nearest neighbor search...
[INFO] Cell BLAST: Merging hits across models...
INFO:Cell BLAST:Merging hits across models...
[INFO] Cell BLAST: Computing posterior distribution distances...
INFO:Cell BLAST:Computing posterior distribution distances...
[INFO] Cell BLAST: Computing empirical p-values...
INFO:Cell BLAST:Computing empirical p-values...


Time per query: 78.6ms




In [None]:
test_hits = test_hits.reconcile_models().filter(by="pval", cutoff=0.05)

In [None]:
hits_dict = test_hits[0:5].to_data_frames()
hits_dict.keys()

odict_keys(['B18.MAA000942.3_8_M.1.1', 'C2.MAA000932.3_11_M.1.1', 'F7.MAA000638.3_9_M.1.1', 'A9.MAA001845.3_39_F.1.1', 'C1.MAA000578.3_10_M.1.1'])

In [None]:
hits_dict["B18.MAA000942.3_8_M.1.1"]

Unnamed: 0_level_0,cell_ontology_class,subtissue,mouse.sex,mouse.id,plate.barcode,n_genes,__libsize__,hits,dist,pval
cell,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
L2.MAA000947.3_9_M.1.1,oligodendrocyte,Hippocampus,M,3_9_M,MAA000947,3860,1283700.0,235,10.45378,0.02817
A22.MAA000641.3_9_M.1.1,oligodendrocyte,Cerebellum,M,3_9_M,MAA000641,4958,4781596.0,410,10.186343,0.02544
B5.MAA000923.3_9_M.1.1,oligodendrocyte,Hippocampus,M,3_9_M,MAA000923,3535,562849.0,492,9.728883,0.022042
I3.MAA000941.3_8_M.1.1,oligodendrocyte,Hippocampus,M,3_8_M,MAA000941,3116,271687.0,571,11.393626,0.03949
E5.MAA000935.3_8_M.1.1,oligodendrocyte,Hippocampus,M,3_8_M,MAA000935,3916,1089161.0,631,6.785758,0.004204
M12.MAA000560.3_10_M.1.1,oligodendrocyte,Cortex,M,3_10_M,MAA000560,3713,1052006.0,853,12.367564,0.04954
H4.MAA000560.3_10_M.1.1,oligodendrocyte,Cortex,M,3_10_M,MAA000560,3390,1233048.0,972,7.563083,0.007491
M10.MAA000581.3_10_M.1.1,oligodendrocyte,Cerebellum,M,3_10_M,MAA000581,4413,1169739.0,1147,8.897099,0.014999
J17.MAA000942.3_8_M.1.1,oligodendrocyte,Striatum,M,3_8_M,MAA000942,3795,372419.0,1189,7.721486,0.007723
P14.MAA000942.3_8_M.1.1,oligodendrocyte,Striatum,M,3_8_M,MAA000942,4066,722235.0,1200,11.118523,0.029786


In [None]:
test_predictions = test_hits.annotate("cell_ontology_class")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

true_labels = adata_test.obs["cell_ontology_class"].values
pred_labels = test_predictions.values.ravel()

# Create a DataFrame to count the relationships
df = pd.DataFrame({'True': true_labels, 'Predicted': pred_labels})
label_counts = df.groupby(['True', 'Predicted']).size().reset_index(name='Count')

# Unique labels
unique_labels = list(set(true_labels) | set(pred_labels))
label_indices = {label: i for i, label in enumerate(unique_labels)}




  label_counts = df.groupby(['True', 'Predicted']).size().reset_index(name='Count')


In [None]:
true_labels_list = true_labels.tolist()
pred_labels_list = pred_labels.tolist()

In [None]:
true_labels_processed = [[label] for label in true_labels_list]
pred_labels_processed = [[label] for label in pred_labels_list]

In [None]:
#@title quality metrics
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import (
    classification_report,
    hamming_loss,
    jaccard_score,
    accuracy_score,
    multilabel_confusion_matrix
)




# Initialize the MultiLabelBinarizer
mlb = MultiLabelBinarizer()

# Fit the binarizer on all possible labels
mlb.fit(true_labels_processed + pred_labels_processed)

# Transform the label lists into binary indicator matrices
true_binarized = mlb.transform(true_labels_processed)
pred_binarized = mlb.transform(pred_labels_processed)

# Get the list of all labels
label_classes = mlb.classes_



In [None]:
df.to_csv('cell_blast_results.csv', index=False)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

#multilabel confusion matrix
mlcm = multilabel_confusion_matrix(true_binarized, pred_binarized)

#confusion matrix for each label
for idx, label in enumerate(label_classes):
    cm = mlcm[idx]
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=['Predicted Negative', 'Predicted Positive'],
                yticklabels=['True Negative', 'True Positive'])
    plt.title(f'Confusion Matrix for label: {label}')
    plt.ylabel('True Labels')
    plt.xlabel('Predicted Labels')
    plt.show()






In [None]:
#@title Performance and evualtiaion
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score

#######
# true_labels_list = true_labels.tolist()
# pred_labels_list = pred_labels.tolist()

#f1 and accuracy
accuracy = accuracy_score(true_labels_list, pred_labels_list)
f1 = f1_score(true_labels_list, pred_labels_list, average='weighted')  # weighted for class imbalance
print(f"Accuracy: {accuracy:.2f}")
print(f"F1 Score (Weighted): {f1:.2f}")
print("Classification Report:")
print(classification_report(true_labels_list, pred_labels_list))

#Confusion Matrix and Heatmap
conf_matrix = confusion_matrix(true_labels_list, pred_labels_list)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=sorted(set(true_labels_list)), yticklabels=sorted(set(true_labels_list)))
plt.title("Confusion Matrix Heatmap")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.tight_layout()
plt.show()

results_df = pd.DataFrame({'True': true_labels_list, 'Predicted': pred_labels_list})

# count the number of cells per cell type for true and predicted
true_counts = results_df['True'].value_counts().sort_index()
predicted_counts = results_df['Predicted'].value_counts().sort_index()


all_classes = sorted(set(true_counts.index).union(predicted_counts.index))
true_counts = true_counts.reindex(all_classes, fill_value=0)
predicted_counts = predicted_counts.reindex(all_classes, fill_value=0)



Accuracy: 0.87
F1 Score (Weighted): 0.93
Classification Report:
                                precision    recall  f1-score   support

           Bergmann glial cell       1.00      0.75      0.86         8
                     ambiguous       0.00      0.00      0.00         0
                     astrocyte       1.00      0.83      0.91        87
                brain pericyte       1.00      0.97      0.98        31
              endothelial cell       1.00      0.86      0.92       143
                        neuron       1.00      0.88      0.93        56
               oligodendrocyte       1.00      0.86      0.92       315
oligodendrocyte precursor cell       0.98      0.98      0.98        41
                      rejected       0.00      0.00      0.00         0

                      accuracy                           0.87       681
                     macro avg       0.77      0.68      0.72       681
                  weighted avg       1.00      0.87      0.93       68

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:

#df, plotting
proportion_df = pd.DataFrame({
    'Cell Type': all_classes,
    'Real': true_counts / true_counts.sum(),
    'Predicted': predicted_counts / predicted_counts.sum()
}).melt(id_vars=['Cell Type'], var_name='Type', value_name='Proportion')



In [None]:
proportion_df

Unnamed: 0,Cell Type,Type,Proportion
0,Bergmann glial cell,Real,0.011747
1,ambiguous,Real,0.0
2,astrocyte,Real,0.127753
3,brain pericyte,Real,0.045521
4,endothelial cell,Real,0.209985
5,neuron,Real,0.082232
6,oligodendrocyte,Real,0.462555
7,oligodendrocyte precursor cell,Real,0.060206
8,rejected,Real,0.0
9,Bergmann glial cell,Predicted,0.008811


In [None]:
proportion_df.to_csv('proportion_df.csv', index=False)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


plt.figure(figsize=(12, 8))
sns.barplot(data=proportion_df, x='Cell Type', y='Proportion', hue='Type', errorbar=None)
plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels for better readability
plt.title("Proportion of Real vs Predicted Cells for Each Cell Type")
plt.xlabel("Cell Type")
plt.ylabel("Proportion")
plt.legend(title="Type", loc='upper right')
plt.tight_layout()
plt.show()



In [None]:
len(pred_labels_list)

681

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

##print to see
print("Sample True Labels:", true_labels_list[:5])
print("Sample Predicted Labels:", pred_labels_list[:5])

#cm
conf_matrix = confusion_matrix(true_labels_list, pred_labels_list)
print("Confusion Matrix:\n", conf_matrix)

#sort labels
unique_labels = sorted(set(true_labels_list))
print("Unique Labels:", unique_labels)
###heat map
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=unique_labels,
            yticklabels=unique_labels)


plt.title("Confusion Matrix Heatmap")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.tight_layout()
plt.show()



Sample True Labels: ['oligodendrocyte', 'brain pericyte', 'endothelial cell', 'endothelial cell', 'Bergmann glial cell']
Sample Predicted Labels: ['oligodendrocyte', 'brain pericyte', 'rejected', 'rejected', 'Bergmann glial cell']
Confusion Matrix:
 [[  6   1   0   0   0   0   0   0   1]
 [  0   0   0   0   0   0   0   0   0]
 [  0   0  72   0   0   0   1   0  14]
 [  0   0   0  30   0   0   0   0   1]
 [  0   0   0   0 123   0   0   0  20]
 [  0   0   0   0   0  49   0   1   6]
 [  0   0   0   0   0   0 271   0  44]
 [  0   0   0   0   0   0   0  40   1]
 [  0   0   0   0   0   0   0   0   0]]
Unique Labels: ['Bergmann glial cell', 'astrocyte', 'brain pericyte', 'endothelial cell', 'neuron', 'oligodendrocyte', 'oligodendrocyte precursor cell']


In [None]:
print("Labels in Confusion Matrix Rows:", conf_matrix.shape[0])
print("Labels in Unique Labels:", len(unique_labels))


Labels in Confusion Matrix Rows: 9
Labels in Unique Labels: 7


In [None]:
plt.figure(figsize=(10, 8))

sns.heatmap(
    conf_matrix,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=unique_labels,
    yticklabels=unique_labels
)

plt.title("Confusion Matrix Heatmap")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.tight_layout()
plt.show()


In [None]:
import matplotlib
matplotlib.use('Agg')  #
plt.show()