In [1]:
from vit_pytorch import ViT
from skimage import io, img_as_float32, morphology, exposure
from PIL import Image
from torchvision import transforms
from pathlib import Path
from tqdm import tqdm
import timm
import scanpy as sc
import os
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import time
import anndata
import pandas as pd
import scipy.stats as st
import matplotlib.pyplot as plt
import json

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


In [2]:
from preprocess import preprocess,get_feature

In [3]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
data_path = '../data/human_breast_cancer'
count_file = 'V1_Breast_Cancer_Block_A_Section_1_filtered_feature_bc_matrix.h5'
source_image_path = 'V1_Breast_Cancer_Block_A_Section_1_image.tif'

# data_path = '../data/human_ovarian_cancer_target'
# count_file = 'Targeted_Visium_Human_OvarianCancer_Pan_Cancer_filtered_feature_bc_matrix.h5'
# source_image_path = 'Targeted_Visium_Human_OvarianCancer_Pan_Cancer_image.tif'


adata = sc.read_visium(data_path,
           count_file= count_file
           ,source_image_path = source_image_path
          )
adata.var_names_make_unique()
#preprocess data
if 'highly_variable' not in adata.var.keys():
    preprocess(adata)
if 'feat' not in adata.obsm.keys():
    get_feature(adata)

In [4]:
# data_path = '../data/HCC-1L/spatial'
# source_image_path = 'tissue_hires_image.png'

data_path = '../data/ST-colon3/spatial'
source_image_path = 'tissue_hires_image.png'

# data_path = '../data/HT231P1'
# source_image_path = 'A1-HT231P1-S1H3Fc2U1.tif'

# matrix_file = '../data/HT231P1/filtered_feature_bc_matrix/'
# # Read the files and create an AnnData object
# adata = sc.read_10x_mtx(matrix_file, var_names='gene_symbols', cache=True)
# adata.var_names_make_unique()

# # Replace with the actual path to your spatial coordinates file
# spatial_coordinates_file = '../data/HT231P1/spatial/tissue_positions_list.csv'

# # # Read spatial coordinates
# spatial_coordinates = pd.read_csv(spatial_coordinates_file,header=None, index_col=0)
# column_names = ['in_tissue','array_col','array_row','x','y']  # Replace with your actual column names
# spatial_coordinates.columns = column_names
# spatial_coordinates = spatial_coordinates.reindex(adata.obs_names)
# # Add spatial coordinates to adata
# adata.obs['in_tissue'] = spatial_coordinates['in_tissue']
# adata.obs['array_col'] = spatial_coordinates['array_col']
# adata.obs['array_row'] = spatial_coordinates['array_row']
# adata.obs['x'] = spatial_coordinates['x']
# adata.obs['y'] = spatial_coordinates['y']
# spatial_array = spatial_coordinates[['y', 'x']].values

# # Assign the spatial array to adata.obsm['spatial']
# adata.obsm['spatial'] = spatial_array 
# adata = adata[adata.obs['in_tissue']==1,:]
# # 从 TIF 图像加载空间信息
# spatial_image = io.imread('../data/HT231P1/A1-HT231P1-S1H3Fc2U1.tif')
# with open('../data/HT231P1/spatial/scalefactors_json.json', 'r') as f:
#     scale_factors = json.load(f)
# #preprocess data
# if 'highly_variable' not in adata.var.keys():
#     preprocess(adata)
# if 'feat' not in adata.obsm.keys():
#     get_feature(adata)

In [5]:
adata = sc.read_visium('../data/ST-colon3',
           count_file='filtered_feature_bc_matrix.h5'
                      )

# adata = sc.read_visium('../data/HCC-1L',
#            count_file='filtered_feature_bc_matrix.h5'
#                       )

adata.var_names_make_unique()
# #preprocess data
if 'highly_variable' not in adata.var.keys():
    preprocess(adata)
if 'feat' not in adata.obsm.keys():
    get_feature(adata)

  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


In [6]:
source_image_tif = os.path.join(data_path,source_image_path)
save_path = "./models/ST-colon3/get_img_feature"

In [7]:
def image_crop(
        spdata,
        source_image_tif,
        save_path,
        ):
    x, y = spdata.obsm['spatial'][:,0],spdata.obsm['spatial'][:,1]


    ###################################################################
    x = (x * spdata.uns['spatial']['P2_colon']['scalefactors']['tissue_hires_scalef']).astype(int)
    y = (y * spdata.uns['spatial']['P2_colon']['scalefactors']['tissue_hires_scalef']).astype(int)

    # r = int(spdata.uns['spatial']['P2_colon']['scalefactors']['fiducial_diameter_fullres']//2)
    # r = int(scale_factors['fiducial_diameter_fullres']//2)
    r = int(spdata.uns['spatial']['P2_colon']['scalefactors']['fiducial_diameter_fullres']*spdata.uns['spatial']['P2_colon']['scalefactors']['tissue_hires_scalef']//2)

    ########################################################################
    img = io.imread(source_image_tif)
    img = img_as_float32(img)
    img = (255 * img).astype("uint8")
    tile_names = []
    i=0
    with tqdm(total=len(spdata),
              desc="Tiling image",
              bar_format="{l_bar}{bar} [ time left: {remaining} ]") as pbar:

        for imagecol, imagerow in zip(y, x):
            i+=1
            imagecol_left = imagecol - r
            imagecol_right = imagecol + r
            imagerow_down = imagerow - r
            imagerow_up = imagerow + r
            tile = img[imagecol_left : imagecol_right + 1, imagerow_down : imagerow_up + 1]
            spot_mask = morphology.disk(r)
            # only use the spot, not the bbox
            tile = np.einsum("ij,ijk->ijk", spot_mask, tile)
            tile = Image.fromarray(tile[:,:,[0,1,2]])
            out_tile = Path(save_path) / ("spot"+str(i) + ".png")
            tile_names.append(str(out_tile))
            tile.save(out_tile, "PNG")
            pbar.update(1)

    spdata.obs["slices_path"] = tile_names
    return spdata
    # return spdata
adata = image_crop(adata,source_image_tif,save_path)

Tiling image: 100%|██████████ [ time left: 00:00 ]


In [None]:
def extract_image_feat(spdata,save_path,feature_dim=128):

    transform = transforms.Compose([
        transforms.Resize(224, interpolation=3),
        # transforms.CenterCrop(224),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
          std =[0.229, 0.224, 0.225])
    ])
    model = timm.create_model('vit_base_patch32_224_clip_laion2b',pretrained=True).to(device)
    # model = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1).to(device)
    # new_mlp_head = nn.Sequential(*list(model.mlp_head.children())[:-1])
    # model.mlp_head = new_mlp_head
    model.head = nn.Linear(in_features=768, out_features=128, bias=True).to(device)
    for param in model.parameters():
        param.requires_grad = False
    model.eval()
    
    image_feature = torch.zeros(spdata.shape[0],feature_dim)

    with tqdm(total=len(spdata),
          desc="Extract image feature",
          bar_format="{l_bar}{bar} [ time left: {remaining} ]",) as pbar:
        for i in range(spdata.shape[0]):
            spot_slice = Image.open(Path(save_path) / ("spot"+str(i+1) + ".png"))
            spot_slice = transform(spot_slice)[None,].to(device)
            out = model(spot_slice)
            image_feature[i] = out
            pbar.update(1)

    return image_feature 
image_feature = extract_image_feat(adata,save_path,feature_dim=128)


  model = create_fn(
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /laion/CLIP-ViT-B-32-laion2B-s34B-b79K/resolve/main/open_clip_pytorch_model.bin (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fe68f14cc70>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: e38e1d7d-b494-4220-932e-b6daa52d1f46)')' thrown while requesting HEAD https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K/resolve/main/open_clip_pytorch_model.bin
Extract image feature:  21%|██         [ time left: 2:18:12 ]

In [None]:
class model_str(nn.Module):
    def __init__(self,num_features,hid_features,out_features):
        super().__init__()
        self.dw=nn.Sequential(
                nn.Linear(num_features,hid_features),
                nn.ReLU(),
                nn.Linear(hid_features,out_features)
        )
    def forward(self,x):
        x=self.dw(x)
        return x    

In [None]:
num_features = 128
hid_features = 512
out_features = adata.obsm['feat'].shape[1]
lr=0.001 #initial learning rate
weight_decay=0 #regularization term
epochs = 2000
#graph feature
features = torch.tensor(adata.obsm['feat'].copy()).to(device)
image_feature = image_feature.to(device)

model = model_str(num_features,hid_features,out_features).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
def train(epochs):
    train_loss_ep=[None]*epochs
    for epoch in range(epochs):
        # maskedgeres= mask_nodes_edges(features.shape[0],testNodeSize=testNodes,valNodeSize=valNodes,seed=seed)
        # train_nodes_idx,val_nodes_idx,test_nodes_idx = maskedgeres
        
        t = time.time()
        model.train()
        optimizer.zero_grad()
        
        

        features_recon = model(image_feature)
        # loss_x_train=loss_zinb(features_recon, features,XreconWeight,ridgeL)
        
        loss_function = nn.MSELoss()
        loss_r_train=loss_function(features_recon, features)
        # loss_x_train = loss_nb(features_recon, features,XreconWeight)
        # loss_a_train=loss_CE(adj_recon, adj, pos_weight, norm)

        loss=loss_r_train
        # loss_kl_train++0.4*loss_r_train+0.1*loss_a_train
        loss.backward()
        optimizer.step()
        train_loss_ep[epoch]=loss.item()
        if epoch%200 == 0:
            print(' Epoch: {:04d}'.format(epoch),
                  'loss_train: {:.4f}'.format(loss.item()),
                  'time: {:.4f}s'.format(time.time() - t))
            sam = adata.obsm['feat']
            com = features_recon.detach().cpu().numpy()
            sam = anndata.AnnData(sam,var = adata[:,adata.var['highly_variable']].var)
            com = anndata.AnnData(com)
            def cal_Percor(original,res):
                Pearson_CoPearson_Cor = pd.Series(index=original.obs_names)
                for i in range(res.X.shape[0]):
                    Pearson_CoPearson_Cor[i]=st.pearsonr(original.X[i],res.X[i])[0]
                Pearson_Cor_mean = np.mean(Pearson_CoPearson_Cor)
                return Pearson_CoPearson_Cor,Pearson_Cor_mean
            our_Percor,our_Percor_mean = cal_Percor(sam,com)

            print(our_Percor_mean)
        

#         if epoch%saveFreq == 0:
#             torch.save(model.cpu().state_dict(), os.path.join(modelsavepath,str(epoch)+'.pt'))
    # with torch.no_grad():
    # torch.save(model.cpu().state_dict(), os.path.join(modelsavepath,'gae.pt'))
    model.to(device).eval()

    features_recon = model(image_feature)

    return train_loss_ep,features_recon

In [None]:
t_ep=time.time()
train_loss_ep,features_recon=train(epochs)
print(' total time: {:.4f}s'.format(time.time() - t_ep))

In [31]:
# image_feature = image_feature.cpu().numpy()
# image_save_path = Path('./models/HT231P1/get_img_feature') / ("image_feature.npy")
# np.save(image_save_path,image_feature)

In [None]:
#可以做多尺度的生物网络
sam = adata.obsm['feat']
com = features_recon.detach().cpu().numpy() 

In [None]:
plt.style.use('dark_background')
for i in range(20):
    fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(8,4))
    ax1.axis('off')
    cmap = sam[:,i]
    cmap[cmap > np.percentile(cmap,99)] = np.percentile(cmap,99)
    ax1.scatter(adata.obsm['spatial'][:,0],adata.obsm['spatial'][:,1],s=1,c=cmap)
    ax1.set_title('Measured ', fontsize = 12)
    ax1.set_ylabel(i)
    ax1.invert_yaxis()
    # ax1.invert_xaxis()
    ax2.axis('off')
    cmap = com[:,i]
    cmap[cmap > np.percentile(cmap,99)] = np.percentile(cmap,99)
    ax2.scatter(adata.obsm['spatial'][:,0],adata.obsm['spatial'][:,1],s=1,c=cmap)
    ax2.set_title('Predicted ', fontsize = 12)
    ax2.invert_yaxis()
    # ax2.invert_xaxis()

In [None]:
sam = anndata.AnnData(sam,var = adata[:,adata.var['highly_variable']].var)
com = anndata.AnnData(com)

In [None]:
import pandas as pd
import scipy.stats as st

def cal_Percor(original,res):
    Pearson_CoPearson_Cor = pd.Series(index=original.var_names)
    
    for i in range(res.X.shape[1]):
        Pearson_CoPearson_Cor[i]=st.pearsonr(original.X.T[i],res.X.T[i])[0]
    Pearson_Cor_mean = np.mean(Pearson_CoPearson_Cor)
    return Pearson_CoPearson_Cor,Pearson_Cor_mean
our_Percor,our_Percor_mean = cal_Percor(sam,com)
print(our_Percor_mean)
#0.3171313503053768
#0.34746957606266593
#0.408

In [None]:
def cal_Percor(original,res):
    Pearson_CoPearson_Cor = pd.Series(index=original.obs_names)
    for i in range(res.X.shape[0]):
        Pearson_CoPearson_Cor[i]=st.pearsonr(original.X[i],res.X[i])[0]
    Pearson_Cor_mean = np.mean(Pearson_CoPearson_Cor)
    return Pearson_CoPearson_Cor,Pearson_Cor_mean
our_Percor,our_Percor_mean = cal_Percor(sam,com)
print(our_Percor_mean)
#0.664

In [None]:
def cal_Specor(original,res):
    Spearman_CoPearson_Cor = pd.Series(index=original.obs_names)
    for i in range(res.X.shape[0]):
        Spearman_CoPearson_Cor[i]=st.spearmanr(original.X[i],res.X[i])[0]
    Spearman_Cor_mean = np.mean(Spearman_CoPearson_Cor)
    return Spearman_CoPearson_Cor,Spearman_Cor_mean
our_Specor,our_Specor_mean = cal_Specor(sam,com)
print(our_Specor_mean)
#0.598

In [None]:
# 进行主成分分析（PCA）
com.uns['spatial']=adata.uns['spatial']
com.obsm['spatial'] = adata.obsm['spatial']
sc.pp.pca(com, svd_solver="arpack")
# 绘制主成分方差解释曲线
sc.pp.neighbors(com, n_neighbors=10,n_pcs=25)
sc.tl.umap(com)
sc.tl.leiden(com, key_added="leiden_res", resolution=0.3)
sc.pl.umap(
    com,
    # color=["leiden_res0_15", "leiden_res0_5", "leiden_res0_75", "leiden_res1"],
    color=["leiden_res"],
    legend_loc="on data",
)
# sc.pl.spatial(our,img_key='hires',color=["leiden_res0_15", "leiden_res0_5", "leiden_res0_75", "leiden_res1"])
sc.pl.spatial(com,img_key='hires',color=["leiden_res"])

In [None]:
from sklearn import metrics
score1 = metrics.silhouette_score(com.obsm['X_pca'][:,0:25],labels=com.obs['leiden_res'])
print(score1)
score2 = metrics.silhouette_score(com.obsm['spatial'],labels=com.obs['leiden_res'])
print(score2)