In [1]:
import scanpy as sc
import pandas as pd
import lightning

In [2]:
import torch

In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning as L
import sklearn.model_selection

In [4]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight)

class Cell_Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), 
                                nn.ReLU(), 
                                nn.Linear(hidden_dim, hidden_dim),
                                nn.ReLU(), 
                                nn.Linear(hidden_dim, hidden_dim),
                               )

    def forward(self, x):
        return self.l1(x)
    
class Gene_Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), 
                                nn.ReLU(), 
                                nn.Linear(hidden_dim, hidden_dim),
                                nn.ReLU(), 
                                nn.Linear(hidden_dim, hidden_dim),
                               )

    def forward(self, x):
        return self.l1(x)

In [5]:
def filter_nonzero(x):
    return F.softplus(x)

In [6]:
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder1, encoder2, gene_emb, train_index, eta=1e-4, nonneg=False):
        super().__init__()
        self.cellencoder = encoder1
#         self.cellencoder.apply(init_weights)
        
        self.geneencoder = encoder2 
#         self.geneencoder.apply(init_weights)
        self.gene_emb = gene_emb
        self.train_index = train_index
        self.eta = eta
        
        if nonneg:
            self.id = nn.Softplus()
        else:
            self.id = nn.Identity()
            
        self.droprate = 0.55

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch

        batch = x.shape[0]
        ngenes = x.shape[1]
        droprate = self.droprate * 1.1
        # we model the sampling zeros (dropping 30% of the reads)
        res = torch.poisson((x * (self.droprate / 2))).int()
        # we model the technical zeros (dropping 50% of the genes)
        notdrop = (
            torch.rand((batch, ngenes), device=x.device) >= (self.droprate / 2)
        ).int()
        mat = (x - res) * notdrop
        out_data = torch.maximum(
            mat, torch.zeros((1, 1), device=x.device, dtype=torch.int)
        )
    
        z_cell = self.cellencoder(out_data)
        z_gene = self.geneencoder(self.gene_emb)
        
        out_final = torch.matmul(z_cell, z_gene.T)
        out_final = self.id(out_final)
        
        loss = F.mse_loss(out_final[:,self.train_index], y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch

        batch = x.shape[0]
        ngenes = x.shape[1]
        droprate = self.droprate * 1.1
        # we model the sampling zeros (dropping 30% of the reads)
        res = torch.poisson((x * (self.droprate / 2))).int()
        # we model the technical zeros (dropping 50% of the genes)
        notdrop = (
            torch.rand((batch, ngenes), device=x.device) >= (self.droprate / 2)
        ).int()
        mat = (x - res) * notdrop
        out_data = torch.maximum(
            mat, torch.zeros((1, 1), device=x.device, dtype=torch.int)
        )
        
        z_cell = self.cellencoder(out_data)
        z_gene = self.geneencoder(self.gene_emb)
        
        out_final = torch.matmul(z_cell, z_gene.T)
        out_final = self.id(out_final)
        
        val_loss = F.mse_loss(out_final[:,self.train_index], y)
        self.log('val_loss', val_loss, on_epoch=True, on_step=False)
        return val_loss

        
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch

        out_data = x
        z_cell = self.cellencoder(out_data)
        z_gene = self.geneencoder(self.gene_emb)
        
        out_final = torch.matmul(z_cell, z_gene.T)
        out_final = self.id(out_final)
        
        test_loss = F.mse_loss(out_final[:,self.train_index], y)
        self.log('test_loss', test_loss)
        return test_loss
        
    def forward(self, x):
        z_cell = self.cellencoder(x)
        z_gene = self.geneencoder(self.gene_emb)
        
        out_final = torch.matmul(z_cell, z_gene.T)
        out_final = self.id(out_final)
        
        return out_final

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.eta)
        return optimizer
    

In [None]:
adata = sc.read_h5ad("/home/tl688/pitl688/ref_free_imp/adata_stimage_hest.h5ad")

In [None]:
adata.obs['batch'].value_counts()

In [None]:
adata

In [None]:
# Convert the expression matrix to a pandas DataFrame
# If X is a sparse matrix, convert to a dense array first
import scipy.sparse as sp
if sp.issparse(adata.X):
    expr_df = pd.DataFrame(adata.X.toarray(), index=adata.obs_names, columns=adata.var_names)
else:
    expr_df = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)

# Identify duplicated rows
# keep='first' ensures we mark duplicates after the first occurrence
duplicated_mask = expr_df.duplicated(keep='first')

# Subset the AnnData to keep only the non-duplicated rows
adata_no_duplicates = adata[~duplicated_mask, :].copy()

In [None]:
adata_no_duplicates.obs['batch']

In [8]:
import numpy as np

In [None]:
adata.X

In [10]:
# adata_dense = adata.X.toarray()

In [None]:
sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=3)


In [12]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [13]:
a = 1

In [None]:
adata.X

In [20]:
adata_emb = sc.read_h5ad("/home/tl688/project/seq2cells_data/spatial_data/adata_spatial_logcounts_waitimpute.h5ad")

In [22]:
df_id = adata_emb.obs 

In [23]:
df_id = df_id[df_id['add_id'].isin(adata.var_names)]

In [20]:
adata_train =  adata

In [21]:
train_index = [int(i) for i in df_id.index]

In [None]:
adata_train

In [23]:
adata_train = adata_train[:,df_id['add_id'].values]

In [24]:
# adata_train.X = adata_train.layers['logcounts']

In [25]:
# adata_train.X = adata_train.X.toarray()

In [None]:
adata_train.shape

In [27]:
import sklearn.model_selection
import numpy as np
import random
np.random.seed(2024)
random.seed(2024)

In [None]:
sorted(set(adata_train.obs['batch']))

In [29]:
train_name, test_name = sklearn.model_selection.train_test_split(sorted(set(adata_train.obs['batch'])))

In [30]:
valid_name = test_name

In [27]:
cell_enc = Cell_Encoder(input_dim=21648, hidden_dim=64)

In [28]:
gene_enc = Gene_Encoder(input_dim=3072,hidden_dim=64)

In [29]:
model = LitAutoEncoder(encoder1=cell_enc, encoder2=gene_enc, eta=1e-4, gene_emb=torch.FloatTensor(adata_emb.obsm['seq_embedding']).cuda(), train_index=train_index, nonneg=True)

In [None]:
model

In [35]:
X_tr, X_val, X_test, y_tr, y_val, y_test = torch.FloatTensor(adata_train[adata_train.obs['batch'].isin(train_name)].X.toarray()),torch.FloatTensor(adata_train[adata_train.obs['batch'].isin(valid_name)].X.toarray()), torch.FloatTensor(adata_train[adata_train.obs['batch'].isin(test_name)].X.toarray()),torch.FloatTensor(adata_train[adata_train.obs['batch'].isin(train_name)].X.toarray()),torch.FloatTensor(adata_train[adata_train.obs['batch'].isin(valid_name)].X.toarray()),torch.FloatTensor(adata_train[adata_train.obs['batch'].isin(test_name)].X.toarray())

In [36]:
train_dataset = torch.utils.data.TensorDataset(X_tr, y_tr)
valid_dataset = torch.utils.data.TensorDataset(X_val, y_val)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)

In [37]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=512, num_workers=1)
valid_loader = DataLoader(valid_dataset, batch_size=512, num_workers=1)


In [None]:
# train with both splits
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
trainer = L.Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=100)], max_epochs=1000)
trainer.fit(model, train_loader, valid_loader, )

In [None]:
best_checkpoint = trainer.checkpoint_callback.best_model_path
print(best_checkpoint)
model = LitAutoEncoder.load_from_checkpoint(best_checkpoint, encoder1=cell_enc, encoder2=gene_enc, eta=1e-4, gene_emb=torch.FloatTensor(adata_emb.obsm['seq_embedding']).cuda(), train_index=train_index)

In [None]:
model

In [47]:
adata_train.write_h5ad("/home/tl688/pitl688/ref_free_imp/adata_fulllist_imputation.h5ad")

In [10]:
#visium train, rest techniques for testing

In [None]:
best_checkpoint

In [91]:
test_loader = DataLoader(test_dataset, batch_size=512, num_workers=1)

In [None]:
trainer.test(model,test_loader)

In [81]:
import numpy as np
import scanpy as sc

In [82]:
# import pickle
# with open(f"/home/tl688/pitl688/ref_free_imp/adata_list_stimage1k.pickle", 'rb') as fp:
#     seq_list = pickle.load(fp)

# len(seq_list)

# seq_list_new = []

# for i in seq_list:
#     i.var_names = [item.upper() for item in i.var_names]
#     seq_list_new.append(i)

# adata = sc.concat(seq_list_new, join='outer')

adata = sc.read_h5ad("./cancer_type_allresult.h5ad")

In [83]:
# np.max(adata.X)

In [84]:
# for i in adata.obs['batch'].unique():
#     adata_new = adata[adata.obs['batch'] == i]
#     print(adata_new.obs['annotation'].unique())

In [85]:
# adata_new.obs

In [None]:
for i in adata.obs['annotation'].unique():
    print(i)

In [87]:
correct_map = {"invasive cancer":"invasive cancer",
"breast glands":"breast glands",
"connective tissue":"connective tissue",
"undetermined":"undetermined",
"adipose tissue":"Adipose tissue",
"immune infiltrate":"immune infiltrate",
"cancer in situ":"cancer in situ",
"NO_TLS":"NO_TLS",
"TLS":"TLS",
"nan":"undetermined",
"T_agg":"T_agg",
"Adipose tissue":"Adipose tissue",
"Tumor stroma with inflammation":"Tumor stroma with inflammation",
"Tumor cells":"Tumor cells",
"Tumor stroma fibrous":"Tumor stroma fibrous",
"Artifacts":"Artifacts",
"Tumor Cells":"Tumor cells",
"Tumor Stroma":"Tumor Stroma",
"Fibrosis":"Fibrosis",
"High TILs Stroma":"High TILs Stroma",
"Tumour Cells":"Tumor cells",
"Normal Epithelium":"Normal Epithelium",
"Lymphocytes":"Lymphocytes",
"Vascular":"Vascular",
"Peripheral Nerve":"Peripheral Nerve",
"Lymphoid stroma":"Lymphoid stroma",
"Fibrous stroma":"Fibrous stroma",
"In situ Carcinoma*":"In situ Carcinoma",
"Tumor cells ?":"Tumor cells ?",
"In Situ Carcinoma":"In situ Carcinoma",
"Hyperplasia":"Hyperplasia",
"Tumor cells - Spindle Cells":"Tumor cells - Spindle Cells",
"Tumour cells":"Tumor cells",
"Fibrosis (peritumoral)":"Fibrosis (peritumoral)",
"Artefacts":"Artifacts",
"Tumour Stroma":"Tumor Stroma",
"Endothelial":"Endothelial",
"Necrosis":"Necrosis",
"Tumor":"Tumor cells",
"Non Tumor":"Non Tumor",
"Tumor_edge_5":"Tumor_edge_5",
"IDC_4":"IDC_4",
"Healthy_1":"Healthy_1",
"IDC_3":"IDC_3",
"IDC_2":"IDC_2",
"Tumor_edge_3":"Tumor_edge_3",
"DCIS/LCIS_3":"DCIS/LCIS_3",
"Tumor_edge_2":"Tumor_edge_2",
"DCIS/LCIS_5":"DCIS/LCIS_5",
"IDC_6":"IDC_6",
"Tumor_edge_6":"Tumor_edge_6",
"Healthy_2":"Healthy_2",
"IDC_5":"IDC_5",
"DCIS/LCIS_4":"DCIS/LCIS_4",
"IDC_7":"IDC_7",
"Tumor_edge_1":"Tumor_edge_1",
"DCIS/LCIS_1":"DCIS/LCIS_1",
"DCIS/LCIS_2":"DCIS/LCIS_2",
"Tumor_edge_4":"Tumor_edge_4",
"IDC_1":"IDC_1",
"Stroma":"Stroma",
"GG1":"GG1",
"Benign":"Benign",
"Exclude":"Exclude",
"Vessel":"Vessel",
"GG4 Cribriform":"GG4 Cribriform",
"GG2":"GG2",
"Chronic inflammation":"Chronic inflammation",
"Fat":"Fat",
"Nerve":"Nerve",
"GG4":"GG4",
"PIN":"PIN",
"Inflammation":"Inflammation",
"Transition_State":"Transition_State",
"Benign*":"Benign",
    'Pal::GPi': 'Pal::GPi',
    'MOB::Gl_1': 'MOB::Gl_1',
    'SS::L2/3': 'SS::L2/3',
    'MOB::Gr': 'MOB::Gr',
    'Pal::MA': 'Pal::MA',
    'ORB::L2/3': 'ORB::L2/3',
    'AOE': 'AOE',
    'OT::PoL': 'OT::PoL',
    'CPu': 'CPu',
    'AON::L1_1': 'AON::L1_1',
    'CC': 'CC',
    'AOB::Gr': 'AOB::Gr',
    'Pal::NDB': 'Pal::NDB',
    'AON::L2': 'AON::L2',
    'PIR': 'PIR',
    'MO::L5': 'MO::L5',
    'St': 'St',
    'SS::L1': 'SS::L1',
    'SLu': 'SLu',
    'Fim': 'Fim',
    'AcbC': 'AcbC',
    'MOB::Opl': 'MOB::Opl',
    'Cl': 'Cl',
    'Pal::Sl': 'Pal::Sl',
    'SS::L6': 'SS::L6',
    'Ft': 'Ft',
    'MO::L2/3': 'MO::L2/3',
    'AOB::Gl': 'AOB::Gl',
    'MOB::lpl': 'MOB::lpl',
    'MO::L1': 'MO::L1',
    'TH::RT': 'TH::RT',
    'SS::L5': 'SS::L5',
    'Io': 'Io',
    'FRP::L1': 'FRP::L1',
    'HY::LPO': 'HY::LPO',
    'AcbSh': 'AcbSh',
    'ORB::L5': 'ORB::L5',
    'OT::Ml': 'OT::Ml',
    'FRP::L2/3': 'FRP::L2/3',
    'MO::L6': 'MO::L6',
    'MOB::Gl_2': 'MOB::Gl_2',
    'AOB::Ml': 'AOB::Ml',
    'ORB::L1': 'ORB::L1',
    'En': 'En',
    'Or': 'Or',
    'ORB::L6': 'ORB::L6',
    'AON::L1_2': 'AON::L1_2',
    'OT::Pl': 'OT::Pl',
    'Py': 'Py',
    'LV': 'LV',
    'MOB::MI': 'MOB::MI',
    'Not_annotated': "undetermined"
}

In [88]:
adata.obs['annotation'] = [correct_map[i] for i in adata.obs['annotation']]

In [93]:
filled_arr = np.where(np.isnan(adata.X), 0, adata.X)
adata.X = filled_arr

In [None]:
np.max(adata.X)

In [95]:
# sc.pp.filter_cells(adata, min_genes=100)
# sc.pp.filter_genes(adata, min_cells=3)

In [96]:
# adata_train.X

In [97]:
adata_train_wait = adata_train[0,:]

In [98]:
adata.obsm['spatial'] = adata.obsm['spatial'].values[:,0:2]

In [99]:
# adata = adata[:,adata_train.var_names]

In [101]:
adata_merge = sc.concat([adata_train_wait, adata], join='outer', fill_value=0)

In [None]:
np.max(adata_merge.X)

In [103]:
a =1

In [104]:
adata_merge = adata_merge[1:,:]

In [105]:
# sc.pp.subsample(adata_merge, fraction=0.1)

In [106]:
# del adata_train

In [None]:
np.max(adata_merge.X)

In [110]:
# adata_merge.X = adata_merge.X.toarray()

In [111]:
# filled_arr = np.where(np.isnan(adata_merge.X), 0, adata_merge.X)

In [112]:
# adata_merge.X = filled_arr.copy()

In [113]:
adata_merge = adata_merge[:,adata_train_wait.var_names]

In [None]:
np.max(adata_merge.X)

In [115]:
# adata_merge.X = adata_merge.X.toarray()
# filled_arr = np.where(np.isnan(adata_merge.X), 0, adata_merge.X)
# adata_merge.X = filled_arr.copy()

In [116]:
a = 1

In [120]:
adata_merge.write_h5ad("/home/tl688/scratch/cancer_type_training.h5ad")

In [8]:
adata_merge = sc.read_h5ad("/home/tl688/scratch/cancer_type_training.h5ad")

In [None]:
adata = sc.read_h5ad("/home/tl688/pitl688/test_image/visium_all.h5ad")
sc.pp.subsample(adata, fraction=0.1, random_state=2024)

In [None]:
adata_merge

In [14]:
adata_merge = adata[:,adata_merge.var_names]

In [15]:
best_checkpoint = "/gpfs/radev/project/ying_rex/tl688/large_scale_imputation/lightning_logs/version_117404/checkpoints/epoch=814-step=364305.ckpt"

In [None]:
# best_checkpoint = "/gpfs/radev/project/ying_rex/tl688/large_scale_imputation/lightning_logs/version_117404/checkpoints/epoch=999-step=447000.ckpt"
best_checkpoint = "/gpfs/radev/project/ying_rex/tl688/large_scale_imputation/lightning_logs/version_pretrain_almostlargedata/checkpoints/epoch=127-step=57216.ckpt"
print(best_checkpoint)
train_index = [0]
model = LitAutoEncoder.load_from_checkpoint(best_checkpoint, encoder1=cell_enc, encoder2=gene_enc, eta=1e-4, gene_emb=torch.FloatTensor(adata_emb.obsm['seq_embedding']).cuda(), train_index=train_index)

model = model.cuda()
with torch.no_grad():
    imputed1 = model.forward(torch.FloatTensor(adata_merge.X.toarray()).cuda()).cpu()

In [32]:
adata_imp = sc.AnnData(imputed1.numpy())
adata_imp.obs = adata_merge.obs.copy()
adata_imp.var = adata_emb.obs.copy()
adata_imp.var_names = adata_imp.var['add_id'].values

In [None]:
adata_imp.obs.batch

In [35]:
adata_imp.write_h5ad("/home/tl688/scratch/adata_imp_largestdata_diseasepred.h5ad")

In [None]:
for batch in adata_imp.obs['batch'].unique():
    adata_new = adata_imp[adata_imp.obs.batch == batch]

    sc.tl.pca(adata_new)
    sc.pp.neighbors(adata_new)

    sc.tl.umap(adata_new)

    sc.pl.umap(adata_new, color='annotation')

In [None]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [155]:
# adata_merge.obs['batch']

In [None]:
nmi_s = []
ari_s = []
for batch in adata_merge.obs['batch'].unique():
    adata_new = adata[adata.obs.batch == batch]
    
    sc.pp.highly_variable_genes(adata_new, n_top_genes=2000)
    adata_new = adata_new[:,adata_new.var['highly_variable']]

    sc.tl.pca(adata_new)
    sc.pp.neighbors(adata_new)
    
    nmi = []
    ari = []
    for i in np.linspace(0,2,21)[1:]:
        sc.tl.leiden(adata_new, resolution = i)
        nmi.append(sklearn.metrics.normalized_mutual_info_score(adata_new.obs.leiden, adata_new.obs.annotation))
        ari.append(sklearn.metrics.adjusted_rand_score(adata_new.obs.leiden, adata_new.obs.annotation))
    nmi_s.append(max(nmi))
    ari_s.append(max(ari))

    sc.tl.umap(adata_new)

    sc.pl.umap(adata_new, color='annotation')

In [None]:
nmi_s

In [158]:
# adata_imp.X

In [None]:
nmi_l = []
ari_l = []
for batch in adata_imp.obs['batch'].unique():
    adata_new = adata_imp[adata_imp.obs.batch == batch]
    
    sc.pp.highly_variable_genes(adata_new, n_top_genes=2000)
    adata_new = adata_new[:,adata_new.var['highly_variable']]

    sc.tl.pca(adata_new)
    sc.pp.neighbors(adata_new)
    
    nmi = []
    ari = []
    for i in np.linspace(0,2,21)[1:]:
        sc.tl.leiden(adata_new, resolution = i)
        nmi.append(sklearn.metrics.normalized_mutual_info_score(adata_new.obs.leiden, adata_new.obs.annotation))
        ari.append(sklearn.metrics.adjusted_rand_score(adata_new.obs.leiden, adata_new.obs.annotation))
    nmi_l.append(max(nmi))
    ari_l.append(max(ari))

    sc.tl.umap(adata_new)

    sc.pl.umap(adata_new, color='annotation')

In [None]:
nmi_l

In [None]:
ari_l

In [159]:
import matplotlib.pyplot as plt

In [None]:
plt.scatter(nmi_l, nmi_s)

In [None]:
import scipy.stats
scipy.stats.pearsonr(nmi_l, nmi_s)

In [None]:
plt.scatter(ari_l, ari_s)

In [None]:
import scipy.stats
scipy.stats.pearsonr(ari_l, ari_s)

In [165]:
label_l = []

for batch in adata_imp.obs['batch'].unique():
    adata_new = adata_imp[adata_imp.obs.batch == batch]
    label_l.append(len(adata_new.obs['annotation'].unique()))

In [166]:
label_l = np.array(label_l)

In [167]:
import scipy.stats

In [170]:
diff = np.array(ari_l) - np.array(ari_s)

In [None]:
scipy.stats.pearsonr(diff,label_l)

In [None]:
plt.scatter(diff,label_l)
plt.xlabel('win score')
plt.ylabel('cluster number')

In [173]:
diff_ari = np.array(ari_l) - np.array(ari_s)

In [None]:
plt.scatter(diff_ari, label_l)
plt.xlabel('win score')
plt.ylabel('cluster number')

In [None]:
scipy.stats.pearsonr(diff_ari,label_l)

In [None]:
best_checkpoint = "/gpfs/radev/project/ying_rex/tl688/large_scale_imputation/lightning_logs/version_pretrain_almostlargedata/checkpoints/epoch=127-step=57216.ckpt"
print(best_checkpoint)

cell_enc = Cell_Encoder(input_dim=21648, hidden_dim=64)

gene_enc = Gene_Encoder(input_dim=3072,hidden_dim=64)

model = LitAutoEncoder.load_from_checkpoint(best_checkpoint, encoder1=cell_enc, encoder2=gene_enc, eta=1e-4, gene_emb=torch.FloatTensor(adata_emb.obsm['seq_embedding']).cuda(), train_index=train_index)

model = model.cuda()
with torch.no_grad():
    imputed1 = model.forward(torch.FloatTensor(adata_merge.X).cuda()).cpu()

In [95]:
adata_imp = sc.AnnData(imputed1.numpy())
adata_imp.obs = adata_merge.obs.copy()
adata_imp.var = adata_emb.obs.copy()
adata_imp.var_names = adata_imp.var['add_id'].values

In [None]:
adata_imp.obs.batch

In [None]:
for batch in adata_imp.obs['batch'].unique():
    adata_new = adata_imp[adata_imp.obs.batch == batch]

    sc.tl.pca(adata_new)
    sc.pp.neighbors(adata_new)

    sc.tl.umap(adata_new)

    sc.pl.umap(adata_new, color='annotation')

In [1]:
nmi_s = [0.08585809108189037,
 0.06232096556125869,
 0.07456601307055569,
 0.02915012991486388,
 0.08362670552242965,
 0.10026078191689969,
 0.05457748498183513,
 0.028880175610608367,
 0.0,
 0.05844357280406492,
 0.20331220767770605,
 0.01449385174593068,
 0.008799301976226356,
 0.11757295070024754,
 0.014088200088330533,
 0.017581566651454484,
 0.13288964017796612,
 0.030494452504848068,
 0.0,
 0.08207977234123992,
 0.17684938216420826,
 0.04556948630210969,
 0.039477668434467296,
 0.18573516503770066,
 0.17044956952679668,
 0.05513297855513959,
 0.15119272139316414,
 0.2443987243283582,
 0.2644354184971032,
 0.23442516454342563,
 0.1977834719544222,
 0.10535063580187067,
 0.15532358533906598,
 0.2756224410057031,
 0.13475631841313784,
 0.24051930520897719,
 0.06595811118620239,
 0.116261902894733,
 0.4305981482651703,
 0.6093296377939322,
 0.2602406655667357,
 0.3583389459627304,
 0.34617777667672656,
 0.3165318080132729,
 0.2726459787450713,
 0.3355548291640934,
 0.2694374545233615,
 0.6825186984230193]

nmi_l = [0.10309309158558037,
 0.03167295144777077,
 0.08896677465164816,
 0.022372995571474204,
 0.13537242812319109,
 0.0994968115369422,
 0.04401958199389832,
 0.01986336851467381,
 0.0,
 0.052528683384559065,
 0.14293297280280834,
 0.011399072277495688,
 0.00849366374288696,
 0.06666712982057355,
 0.015582133000180251,
 0.01919443347775918,
 0.13925113791673766,
 0.020795413825667863,
 0.0,
 0.07044239374635096,
 0.18260340151269136,
 0.025088083554627594,
 0.04343836816974321,
 0.09357711783666878,
 0.22790831340731546,
 0.13483908784698068,
 0.20532075178353543,
 0.3683327430128713,
 0.19405304748986668,
 0.24796775924315062,
 0.2006933538120348,
 0.12361098932983694,
 0.2004860040525278,
 0.3501681197740687,
 0.07862564384210861,
 0.2865826587850417,
 0.06039518892488261,
 0.09813439598129726,
 0.5830041837066104,
 0.3843539747251409,
 0.272936738583149,
 0.38449108877637106,
 0.28908657514896946,
 0.3464975885078202,
 0.3282284455925776,
 0.35214034921286663,
 0.30693244217028914,
 0.4697875785541327]

In [9]:
# adata_imp = sc.read_h5ad("/home/tl688/scratch/adata_imp_largestdata_diseasepred.h5ad")

In [12]:
# adata_imp
import numpy as np

In [13]:
label_l = []

for batch in adata_merge.obs['batch'].unique():
    adata_new = adata_merge[adata_merge.obs.batch == batch]
    label_l.append(len(adata_new.obs['annotation'].unique()))

label_l = np.array(label_l)

In [16]:
import matplotlib.pyplot as plt

In [17]:
diff_nmi = np.array(nmi_l) - np.array(nmi_s)

In [None]:
plt.scatter(diff_nmi, label_l)
plt.xlabel('win score')
plt.ylabel('cluster number')

In [None]:
import scipy.stats
scipy.stats.pearsonr(diff_nmi,label_l)