In [47]:
import scanpy as sc
import anndata
from anndata import AnnData
import pandas as pd
import numpy as np
import random
import scipy as sci
import scipy.sparse as sps
import matplotlib as mpl
import matplotlib.pyplot as plt
#mpl.rcParams['pdf.fonttype'] = 42
#mpl.rcParams["font.sans-serif"] = "Arial"
#%config InlineBackend.figure_format = 'retina'
#sc.settings.set_figure_params(dpi=50, dpi_save=300, figsize=(4, 4))
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import warnings
warnings.filterwarnings("ignore")

In [48]:
from functions import calc

In [52]:
import shutil
import pickle
import os

def save(var, name):
    if os.path.exists(name+'.pkl'):
        os.remove(name+'.pkl')
    with open(name+'.pkl', 'wb') as f:
        pickle.dump(var, f)
        
def load(name):
    if not os.path.exists(name+'.pkl'):
        raise ValueError(name)
    with open(name+'.pkl', 'rb') as f:
        return pickle.load(f)
    
def flo2str(n): #删除小数点后多余的0并转换为字符
    n = str(n)
    if '.' in n:
        n = n.rstrip('0')  # 删除小数点后多余的0
        if n.endswith('.'):
            n = n.rstrip('.')
    return n

# generate simulations

In [None]:
from copy import deepcopy
from janitor import expand_grid
from simulations import sim

In [None]:
#sp
pth = "simulations/bm_sp"
dpth = os.path.join(pth,"data")
rpth = os.path.join(pth,"results")
mpth = os.path.join(pth,"models")
# from utils import misc
# misc.mkdir_p(dpth) # or use symlink to dropbox
# misc.mkdir_p(rpth)
# misc.mkdir_p(mpth)

# %% define scenarios
cfg = {"sim":["quilt","ggblocks","both"], "nside":36, "bkg_mean":0.2,
       "nb_shape":10.0, "Jsp":200, "Jmix":0, "Jns":0, "expr_mean":20.0,
       "seed":[1,2,3,4,5], "V":5}
a = expand_grid(others=cfg)
col = a.columns
a.columns = list(map(lambda x:x[0], col))
b = pd.DataFrame({"sim":["quilt","ggblocks","both"], "L":[4,4,8]})
a = a.merge(b,how="left",on="sim")
a["scenario"] = list(range(1,a.shape[0]+1))
a.to_csv(os.path.join(pth,"scenarios.csv"),index=False) #store separately for data generation

# %% generate the simulated datasets and store to disk
a = pd.read_csv(os.path.join(pth,"scenarios.csv")).convert_dtypes()
def sim2disk(p):
    p = deepcopy(p)
    scen = p.pop("scenario")
    ad = sim.sim(p["sim"], Lns=0, **p)
    ad.write_h5ad(os.path.join(dpth,"S{}.h5ad".format(scen)),compression="gzip")
a.apply(sim2disk,axis=1)

In [None]:
#mixed
pth = "simulations/bm_mixed"
dpth = os.path.join(pth,"data")
rpth = os.path.join(pth,"results")
mpth = os.path.join(pth,"models")
# from utils import misc
# misc.mkdir_p(dpth) # or use symlink to dropbox
# misc.mkdir_p(rpth)
# misc.mkdir_p(mpth)

# %%
cfg = {"sim":["quilt","ggblocks","both"], "nside":36, "nzprob_nsp":0.2,
       "bkg_mean":0.2, "nb_shape":10.0,
       "J":[(250,0,250),(0,500,0)], #"Jsp":0, "Jmix":500, "Jns":0,
       "expr_mean":20.0, "mix_frac_spat":0.6,
       "seed":[1,2,3,4,5], "V":5}
a = expand_grid(others=cfg)
col = a.columns
a.columns = list(map(lambda x:x[0] if x[1]==0 else x[0]+'_'+str(x[1]), col))
a.rename(columns={"J":"Jsp", "J_1":"Jmix", "J_2":"Jns"}, inplace=True)
b = pd.DataFrame({"sim":["quilt","ggblocks","both"],
                  "Lsp":[4,4,8],"Lns":[3,3,6]})
a = a.merge(b,how="left",on="sim")
a["scenario"] = list(range(1,a.shape[0]+1))
a.to_csv(os.path.join(pth,"scenarios.csv"),index=False)

# %% generate the simulated datasets and store to disk
a = pd.read_csv(os.path.join(pth,"scenarios.csv")).convert_dtypes()

def sim2disk(p):
    p = deepcopy(p)
    scen = p.pop("scenario")
    Lsp = p.pop("Lsp")
    ad = sim.sim(p["sim"], **p)
    ad.write_h5ad(os.path.join(dpth,"S{}.h5ad".format(scen)),compression="gzip")
    
a.apply(sim2disk,axis=1)

In [None]:
from utils import visualize
ad = sc.read_h5ad(path.join(dpth,"S1.h5ad"))
X = ad.obsm["spatial"]
Y = ad.layers["counts"]
Yn = ad.X
visualize.heatmap(X,Y[:,0],cmap="Blues")
visualize.heatmap(X,Yn[:,0],cmap="Blues")
#check distribution of validation data points
N = Y.shape[0]
z = np.zeros(N)
Ntr = round(0.95*N)
z[Ntr:] = 1
visualize.heatmap(X,z,cmap="Blues")

In [98]:
a=pd.DataFrame([['a','b','a'],['d','c','a']],columns=['m','n','o'])
a

Unnamed: 0,m,n,o
0,a,b,a
1,d,c,a


In [99]:
a.loc[a['o']=='a','o']='k'
a

Unnamed: 0,m,n,o
0,a,b,k
1,d,c,k


In [6]:
# choose simulations that have seed 5
pattern = ['quilt','ggblock','both']
pth = "simulations/bm_sp/data"
i = 0
for scen in [5,10,15]: #序号与pattern一一对应
    dpth = os.path.join(pth,"S{}.h5ad".format(scen))
    npth = os.path.join('./simdata/sp',pattern[i]) #新路径
    if not os.path.exists(npth):
        os.makedirs(npth)
    else:
        shutil.rmtree(npth)
        os.makedirs(npth)
    adata = sc.read_h5ad(dpth)
    spatial = adata.obsm['spatial'] #坐标
    save(spatial, npth+'/spa')
    X = adata.layers['counts'] #原读数
    X = np.array(X)
    save(X,npth+'/x')
    X_norm = median_normalize(X)
    X_norm = np.array(X_norm)
    save(X_norm,npth+'/xnorm')
    
    spfac = adata.obsm['spfac'] #F
    save(spfac, npth+'/spfac')
    spload = adata.varm['spload'] #W
    save(spload, npth+'/spload')
    i = i + 1
    
pth = "simulations/bm_mixed/data"
i = 0
for scen in [10,20,30]: #序号与pattern一一对应
    dpth = os.path.join(pth,"S{}.h5ad".format(scen))
    npth = os.path.join('./simdata/mix_1',pattern[i])
    if not os.path.exists(npth):
        os.makedirs(npth)
    else:
        shutil.rmtree(npth)
        os.makedirs(npth)
    adata = sc.read_h5ad(dpth)
    spatial = adata.obsm['spatial']#坐标
    save(spatial, npth+'/spa')
    X = adata.layers['counts'] #原读数
    X = np.array(X)
    save(X,npth+'/x')
    X_norm = median_normalize(X)
    X_norm = np.array(X_norm)
    save(X_norm,npth+'/xnorm')
    
    spfac = adata.obsm['spfac'] #F
    save(spfac, npth+'/spfac')
    spload = adata.varm['spload'] #W
    save(spload, npth+'/spload')
    nsfac = adata.obsm['nsfac'] #H
    save(nsfac, npth+'/nsfac')
    nsload = adata.varm['nsload'] #V
    save(nsload, npth+'/nsload')
    i = i + 1
    
i = 0
for scen in [5,15,25]: #序号与pattern一一对应
    dpth = os.path.join(pth,"S{}.h5ad".format(scen))
    npth = os.path.join('./simdata/mix_2',pattern[i])
    if not os.path.exists(npth):
        os.makedirs(npth)
    else:
        shutil.rmtree(npth)
        os.makedirs(npth)
    adata = sc.read_h5ad(dpth)
    spatial = adata.obsm['spatial']#坐标
    save(spatial, npth+'/spa')
    X = adata.layers['counts'] #原读数
    X = np.array(X)
    save(X, npth+'/x')
    X_norm = median_normalize(X)
    X_norm = np.array(X_norm)
    save(X_norm, npth+'/xnorm')
    
    spfac = adata.obsm['spfac'] #F
    save(spfac, npth+'/spfac')
    spload = adata.varm['spload'] #W
    save(spload, npth+'/spload')
    nsfac = adata.obsm['nsfac'] #H
    save(nsfac, npth+'/nsfac')
    nsload = adata.varm['nsload'] #V
    save(nsload, npth+'/nsload')
    i = i + 1

# path set

In [60]:
sd_pth = os.path.join(os.getcwd(), 'simdata')
scenario = ['sp','mix_1','mix_2']
pattern = ['quilt','ggblock','both']

#### 选取seed=5的模拟数据进行分析
# sparsify

In [7]:
# each subfolder contains sparse 'x','xnorm'
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        X = load(pth+'/x') #原始表达
        X = sps.csr_matrix(X) #稀疏表示
        num = X.shape[0] * X.shape[1]
        nonnull = len(X.data) #非0个数
        spar = 1 - nonnull / num
        print(scene+'/'+p+'稀疏度为', spar)
        for spr in [75,85,95]:
            spr_pth = os.path.join(pth,str(spr)) #稀疏度
            if not os.path.exists(spr_pth):
                os.makedirs(spr_pth)
            else:
                shutil.rmtree(spr_pth)
                os.makedirs(spr_pth)
            sparse = X.copy()
            random.seed(111)
            setnull = random.sample(range(nonnull), round(num*(0.01*spr-spar))) #赋0值下标
            sparse.data[setnull] = 0.
            sparse = np.array(sparse.todense())
            save(sparse, spr_pth+'/x')
            spr_norm = median_normalize(sparse)
            spr_norm = np.array(spr_norm)
            save(spr_norm, spr_pth+'/xnorm')

sp/quilt稀疏度为 0.48137345679012344
sp/ggblock稀疏度为 0.6830864197530864
sp/both稀疏度为 0.5844714506172839
mix_1/quilt稀疏度为 0.3846759259259259
mix_1/ggblock稀疏度为 0.5481435185185185
mix_1/both稀疏度为 0.4737530864197531
mix_2/quilt稀疏度为 0.5687006172839506
mix_2/ggblock稀疏度为 0.6709737654320987
mix_2/both稀疏度为 0.6224027777777779


In [None]:
# each subfolder contains 'spa'
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        for spr in [75,85,95]:#稀疏度
            spr_pth = os.path.join(pth,str(spr))
            _= shutil.copy(pth+'/spa.pkl', spr_pth+'/spa.pkl')

# sprod

In [12]:
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        spatial = load(pth+'/spa') #原始表达
        meta = pd.DataFrame(spatial, index=range(spatial.shape[0]), columns=['X','Y'])
        for spr in [75,85,95]:
            spr_pth = os.path.join(pth,str(spr)) #稀疏度
            xnorm = load(spr_pth+'/xnorm')
            cts = pd.DataFrame(xnorm, index=range(xnorm.shape[0]), columns=range(xnorm.shape[1]))
            inpth = os.path.join(spr_pth,'input')
            outpth = os.path.join(spr_pth,'output')
            if not os.path.exists(inpth):
                os.makedirs(inpth)
            else:
                shutil.rmtree(inpth)
                os.makedirs(inpth)
            if not os.path.exists(outpth):
                os.makedirs(outpth)
            else:
                shutil.rmtree(outpth)
                os.makedirs(outpth)
            meta.to_csv(inpth+'/Spot_metadata.csv')
            cts.to_csv(inpth+'/Counts.txt',sep='\t')

In [None]:
# python sprod.py inpth outpth
# python sprod.py inpth outpth -ws -ag

In [5]:
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        for spr in [75,85,95]:#稀疏度
            spr_pth = os.path.join(pth,str(spr))
            sprod_pth = os.path.join(spr_pth, 'output')
            xsprod = pd.read_csv(sprod_pth+'/sprod_Denoised_matrix.txt', sep='\t', index_col=0)
            xsprod = xsprod.to_numpy()
            save(xsprod, spr_pth+'/xsprod')

# wedge

In [38]:
#csv
wdg_pth = '/home/qukun/muziyu/wdg'
i = 0
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        for spr in [75,85,95]:
            spr_pth = os.path.join(pth,str(spr)) #稀疏度
            xnorm = load(spr_pth+'/xnorm')
            index = list(map(lambda x:'gene'+str(x),range(xnorm.shape[1])))
            col = list(map(lambda x:'cell'+str(x),range(xnorm.shape[0])))
            cts = pd.DataFrame(xnorm.T, index=index, columns=col)
            cts.to_csv(wdg_pth+'/'+str(spr)+scene[0]+scene[-1]+p[0]+'.csv',sep=',')

In [46]:
#result
wdg_pth = '/home/qukun/muziyu/sprod_data/sim/wedge/'
for scene in scenario:
    for p in pattern:
        for spr in [75,85,95]:
            d_pth = os.path.join(wdg_pth,scene[0]+scene[-1]+p[0],str(spr))
            spr_pth = os.path.join(sd_pth,scene,p,str(spr))
            W = pd.read_csv(d_pth+'W.csv', header=None)
            W = W.to_numpy()
            H = pd.read_csv(d_pth+'H.csv', header=None)
            H = H.to_numpy()
            rk_wdg = W.shape[1]
            save(rk_wdg, spr_pth+'/rk_wdg')
            xwedge = np.dot(W, H)
            save(xwedge.T, spr_pth+'/xwedge')

# similarity matrix

In [None]:
#SNN
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        for spr in [75,85,95]:#稀疏度
            spr_pth = os.path.join(pth,str(spr))
            xnorm = load(spr_pth+'/xnorm')
            sim = calc.SNN(xnorm)
            save(sim, spr_pth+'/sSNN') #'./simdata/mix_1/quilt/75/sim.pkl'

In [9]:
#SPR
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        for spr in [75,85,95]:#稀疏度
            spr_pth = os.path.join(pth,str(spr))
            sim_sr = pd.read_csv(spr_pth+'/output/sprod_Detected_graph.txt', sep='\t')
            sim_sr = sim_sr.to_numpy()
            save(sim_sr, spr_pth+'/sSPR')

# dimension check

In [9]:
scene = 'sp'

print('        wedge','alra')
for p in pattern:
    print(p)
    pth = os.path.join(sd_pth,scene,p)
    for spr in [75,85,95]:
        spr_pth = os.path.join(pth,str(spr)) #稀疏度
        xnorm = load(spr_pth+'/xnorm')
        dim_w, dim_a = calc.rank_w_a(xnorm)
        save(dim_w, spr_pth+'/dim_wedge')
        save(dim_a, spr_pth+'/dim_alra')
        print('    ',spr,':', dim_w, dim_a)

quilt
     75 : 4 4
     85 : 4 4
     95 : 3 3
ggblock
     75 : 4 4
     85 : 4 4
     95 : 4 4
both
     75 : 8 8
     85 : 8 7
     95 : 3 2


In [10]:
scene = 'mix_1'

print('        wedge','alra')
for p in pattern:
    print(p)
    pth = os.path.join(sd_pth,scene,p)
    for spr in [75,85,95]:
        spr_pth = os.path.join(pth,str(spr)) #稀疏度
        xnorm = load(spr_pth+'/xnorm')
        dim_w, dim_a = calc.rank_w_a(xnorm)
        save(dim_w, spr_pth+'/dim_wedge')
        save(dim_a, spr_pth+'/dim_alra')
        print('    ',spr,':', dim_w, dim_a)

quilt
     75 : 6 6
     85 : 6 6
     95 : 3 3
ggblock
     75 : 6 6
     85 : 6 6
     95 : 6 6
both
     75 : 4 13
     85 : 4 13
     95 : 3 1


In [11]:
scene = 'mix_2'

print('        wedge','alra')
for p in pattern:
    print(p)
    pth = os.path.join(sd_pth,scene,p)
    for spr in [75,85,95]:
        spr_pth = os.path.join(pth,str(spr)) #稀疏度
        xnorm = load(spr_pth+'/xnorm')
        dim_w, dim_a = calc.rank_w_a(xnorm)
        save(dim_w, spr_pth+'/dim_wedge')
        save(dim_a, spr_pth+'/dim_alra')
        print('    ',spr,':', dim_w, dim_a)

quilt
     75 : 7 7
     85 : 7 7
     95 : 6 2
ggblock
     75 : 7 7
     85 : 7 7
     95 : 7 7
both
     75 : 10 14
     85 : 10 14
     95 : 3 2


In [20]:
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        for spr in [75,85,95]:
            spr_pth = os.path.join(pth,str(spr))
            calc.svdSIGdif_plot(spr_pth, dname='-'+str(spr)+'%-'+scene+'-'+p)

# set parameters

In [17]:
lamb = 10
similarity = ['SNN','SPR']
multiple = [1, 2, 3, 5]
initial = ['grdtr','rand','nmf','nndsvd','wdgsvd','tcdsvd']

## statistics

In [18]:
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        xnorm = load(pth+'/xnorm')
        ori_cell_cor = np.corrcoef(xnorm)
        ori_gene_cor = np.corrcoef(xnorm.T)
        for spr in [75,85,95]:#稀疏度
            spr_pth = os.path.join(pth,str(spr))
            dim_w = load(spr_pth+'/dim_wedge')
            if spr == 95:
                weight = [0.1, 0.2, 0.5]
            else:
                weight = [0.2, 0.5, 0.7]
            for w in weight:
                file = 'w'+flo2str(w)+'-l'+flo2str(lamb)
                for m in multiple:
                    dim = min(dim_w * m, 30)
                    for init in initial:
                        for sim in similarity:
                            res_pth = os.path.join(spr_pth,'result',init,file+'-d'+flo2str(dim),sim)
                            X = load(res_pth+'/x_'+sim)
                            stat_pth = os.path.join(res_pth, 'statistics')
                            if not os.path.exists(stat_pth):
                                os.makedirs(stat_pth)
                            
                            if os.path.exists(stat_pth+'/pcc_cell_'+sim+'.pkl'):
                                continue
                            pcc_cell, pcc_p_cell = calc.PCC(xnorm, X)
                            pcc_gene, pcc_p_gene = calc.PCC(xnorm.T, X.T)
                            save(pcc_cell, stat_pth+'/pcc_cell_'+sim)
                            save(pcc_p_cell, stat_pth+'/pcc_p_cell_'+sim)
                            save(pcc_gene, stat_pth+'/pcc_gene_'+sim)
                            save(pcc_p_gene, stat_pth+'/pcc_p_gene_'+sim)
                        
                            wdg_cell_cor = np.corrcoef(X)
                            wdg_gene_cor = np.corrcoef(X.T)
                            CMD_cell = calc.CMD(ori_cell_cor, wdg_cell_cor)
                            save(CMD_cell, stat_pth+'/cmd_cell_'+sim)
                            CMD_gene = calc.CMD(ori_gene_cor, wdg_gene_cor)
                            save(CMD_gene, stat_pth+'/cmd_gene_'+sim)

In [54]:
#wedge&sprod&sparse
for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        xnorm = load(pth+'/xnorm')
        ori_cell_cor = np.corrcoef(xnorm)
        ori_gene_cor = np.corrcoef(xnorm.T)
        for spr in [75,85,95]:#稀疏度
            spr_pth = os.path.join(pth,str(spr))
            dim_w = load(spr_pth+'/dim_wedge')
            if os.path.exists(spr_pth+'/xnorm.pkl'):
                os.rename(spr_pth+'/xnorm.pkl',spr_pth+'/xsparse.pkl')
            for method in ['sprod','wedge','sparse']:
                X = load(spr_pth+'/x'+method)
                stat_pth = os.path.join(spr_pth, 'statistics', method)
                if not os.path.exists(stat_pth):
                    os.makedirs(stat_pth)
                pcc_cell, pcc_p_cell = calc.PCC(xnorm, X)
                pcc_gene, pcc_p_gene = calc.PCC(xnorm.T, X.T)
                save(pcc_cell, stat_pth+'/pcc_cell_'+method)
                save(pcc_p_cell, stat_pth+'/pcc_p_cell_'+method)
                save(pcc_gene, stat_pth+'/pcc_gene_'+method)
                save(pcc_p_gene, stat_pth+'/pcc_p_gene_'+method)
                wdg_cell_cor = np.corrcoef(X)
                wdg_gene_cor = np.corrcoef(X.T)
                CMD_cell = calc.CMD(ori_cell_cor, wdg_cell_cor)
                save(CMD_cell, stat_pth+'/cmd_cell_'+method)
                CMD_gene = calc.CMD(ori_gene_cor, wdg_gene_cor)
                save(CMD_gene, stat_pth+'/cmd_gene_'+method)

In [61]:
best_dict = {'sp/quilt':{'75':'w'+flo2str(0.5)+'-d'+flo2str(12),
                         '85':'w'+flo2str(0.5)+'-d'+flo2str(8),
                         '95':'w'+flo2str(0.1)+'-d'+flo2str(6)},
             'mix_1/ggblock':{'75':'w'+flo2str(0.5)+'-d'+flo2str(12),
                         '85':'w'+flo2str(0.5)+'-d'+flo2str(12),
                         '95':'w'+flo2str(0.2)+'-d'+flo2str(6)},
             'mix_2/both':{'75':'w'+flo2str(0.5)+'-d'+flo2str(30),
                         '85':'w'+flo2str(0.5)+'-d'+flo2str(20),
                         '95':'w'+flo2str(0.2)+'-d'+flo2str(15)},}

In [64]:
lambdas = [5,18,25,30]
sim = 'SPR'

for lamb in lambdas:
    for key, value in best_dict.items():
        pth = os.path.join(sd_pth,key)
        xnorm = load(pth+'/xnorm')
        ori_cell_cor = np.corrcoef(xnorm)
        ori_gene_cor = np.corrcoef(xnorm.T)
        for k, v in value.items():
            spr_pth = os.path.join(pth,k)
            res_pth = os.path.join(spr_pth, 'result', v, 'l'+flo2str(lamb))
            X = load(res_pth+'/x_'+sim)
            stat_pth = os.path.join(res_pth, 'statistics')
            if not os.path.exists(stat_pth):
                os.makedirs(stat_pth)
            pcc_cell, pcc_p_cell = calc.PCC(xnorm, X)
            pcc_gene, pcc_p_gene = calc.PCC(xnorm.T, X.T)
            save(pcc_cell, stat_pth+'/pcc_cell_'+sim)
            save(pcc_p_cell, stat_pth+'/pcc_p_cell_'+sim)
            save(pcc_gene, stat_pth+'/pcc_gene_'+sim)
            save(pcc_p_gene, stat_pth+'/pcc_p_gene_'+sim)
            wdg_cell_cor = np.corrcoef(X)
            wdg_gene_cor = np.corrcoef(X.T)
            CMD_cell = calc.CMD(ori_cell_cor, wdg_cell_cor)
            save(CMD_cell, stat_pth+'/cmd_cell_'+sim)
            CMD_gene = calc.CMD(ori_gene_cor, wdg_gene_cor)
            save(CMD_gene, stat_pth+'/cmd_gene_'+sim)

# svd

In [None]:
sd_pth = './simdata'
scenario = ['sp','mix_1','mix_2']
pattern = ['quilt','ggblock','both']

for scene in scenario:
    for p in pattern:
        pth = os.path.join(sd_pth,scene,p)
        for spr in [75,85,95]:
            spr_pth = os.path.join(pth,str(spr))
            xnorm = load(spr_pth+'/xnorm')
            u,s,v = np.linalg.svd(xnorm)
            save(u, spr_pth+'/u')
            save(s, spr_pth+'/sig')
            save(v, spr_pth+'/vt')

In [78]:
def svd_heatmap(klist, scenario='sp', pattern='quilt', gene=0):
    '''
    klist: dimension for reduction
    scenario: sp, mix_1, or mix_2
    pattern: quilt, ggblock, or both
    '''
    sd_pth = './simdata'
    pth = os.path.join(sd_pth,scenario,pattern)
    spatial = load(pth+'/spa')
    file = 'svd'
    
    ppth = os.path.join(pth,'oripic')
    if not os.path.exists(ppth):
        os.makedirs(ppth)
    pic_pth = ppth+'/g'+str(gene)+'.png'
    if os.path.exists(pic_pth):
        import matplotlib.image as mpimg # mpimg 用于读取图片
        img = mpimg.imread(pic_pth)
        t=plt.figure(figsize=(3,2))
        t=plt.imshow(img)
        t=plt.axis('off')
    else:
        fig,ax=plt.subplots(figsize=(3,2))
        t=ax.set_facecolor('gray')
        t=ax.scatter(spatial[:,0],spatial[:,1],c=xori[:,gene],cmap='Blues',s=8)
        if not os.path.exists(pic_pth):
            plt.savefig(pic_pth.rstrip('.png') ) #'./simdata/sp/block/g0.png'
        fig.show()
    
    row = len(klist) + 1
    fig,ax=plt.subplots(row, 3, figsize=(10, int(round(row*2.5))), sharex=True, sharey=True)
    fig.suptitle(scenario+'-'+pattern+'-'+file+'-g'+str(gene))
    j = 0
    for spr in [75,85,95]:#稀疏度
        spr_pth = os.path.join(pth,str(spr))
        res_pth = os.path.join(spr_pth, file)
        xnorm = load(spr_pth+'/xnorm')
        u = load(spr_pth+'/u')
        sig = load(spr_pth+'/sig')
        vt = load(spr_pth+'/vt')
        sort = np.argsort(-sig)
        sig = sig[sort]
        if j == 0:
            t=ax[0][j].set_ylabel('sparse data')
        t=ax[0][j].set_title(str(spr)+'%')
        t=ax[0][j].set_facecolor('gray')
        t=ax[0][j].scatter(spatial[:,0],spatial[:,1],c=xnorm[:,gene],cmap='Blues',s=16)
        i = 1
        for k in klist:
            sortk = sort[:k]
            uk = u[:,sortk]
            sigk = np.diag(sig[:k])
            vtk = vt[sortk,:]
            xsvd = np.dot(uk, sigk.dot(vtk))
            if j == 0:
                t=ax[i][j].set_ylabel('svd '+ str(k))
            t=ax[i][j].set_facecolor('gray')
            t=ax[i][j].scatter(spatial[:,0],spatial[:,1],c=xsvd[:,gene],cmap='Blues',s=16)
            i = i + 1
        j = j + 1
    plt.tight_layout(h_pad=0.1, w_pad=0.1)
    
    ppth = os.path.join(pth, file)
    if not os.path.exists(ppth):
        os.makedirs(ppth)
    pic_pth = ppth+'/g'+str(gene)+'.png'
    if os.path.exists(pic_pth):
        os.remove(pic_pth)
    plt.savefig(pic_pth.rstrip('.png')) #'./simdata/sp/block/85/w_0-lam_10/g0'
    fig.show()

In [None]:
klist = [2, 4, 8, 20, 50]
pattern = ['quilt','ggblock','both']
for p in pattern:
    for gene in [0,96,129]:
        svd_heatmap(klist=klist, scenario='sp', pattern=p, gene=gene)

In [None]:
klist = [2, 4, 6, 13, 50]
pattern = ['quilt','ggblock','both']
for p in pattern:
    for gene in [0,96,129]:
        svd_heatmap(klist=klist, scenario='mix_1', pattern=p, gene=gene)

In [None]:
klist = [2, 4, 7, 14, 50]
pattern = ['quilt','ggblock','both']
for p in pattern:
    for gene in [0,96,129]:
        svd_heatmap(klist=klist, scenario='mix_2', pattern=p, gene=gene)