In [None]:
import os
import random
import numpy as np
import scanpy as sc
import torch
from torch.utils.data import DataLoader
import argparse
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_rand_score

from dataset import Dataset
from model import SpaCLR, TrainerSpaCLR
from utils import get_predicted_results, load_ST_file
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
parser = argparse.ArgumentParser()

# preprocess
parser.add_argument('--dataset', type=str, default="SpatialLIBD")   #BreastCancer  MouseBrain MouseOlfactoryBulb
parser.add_argument('--path', type=str, default="/root/autodl-tmp/data/DLPFC") #/DLPFC
parser.add_argument("--gene_preprocess", choices=("pca", "hvg"), default="hvg")
parser.add_argument("--n_gene", choices=(3000, 1000), default=3000)
parser.add_argument('--img_size', type=int, default=112)
parser.add_argument('--num_workers', type=int, default=15)

# model
parser.add_argument('--last_dim', type=int, default=64)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--p_drop', type=float, default=0)

parser.add_argument('--w_g2i', type=float, default=1)
parser.add_argument('--w_s2i', type=float, default=1)
parser.add_argument('--w_g2g', type=float, default=0.1)
parser.add_argument('--w_s2s', type=float, default=0.1)
parser.add_argument('--w_s2g', type=float, default=0.1)
parser.add_argument('--w_i2i', type=float, default=0.1)
parser.add_argument('--w_recon', type=float, default=0.3)
parser.add_argument('--w_graph_loss', type=float, default=0.5)

# data augmentation
parser.add_argument('--prob_mask', type=float, default=0.5)
parser.add_argument('--pct_mask', type=float, default=0.2)
parser.add_argument('--prob_noise', type=float, default=0.5)
parser.add_argument('--pct_noise', type=float, default=0.8)
parser.add_argument('--sigma_noise', type=float, default=0.5)
parser.add_argument('--prob_swap', type=float, default=0.5)
parser.add_argument('--pct_swap', type=float, default=0.1)
parser.add_argument('--backbone', type=str, default='swin_s')
# train
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--log_name', type=str, default="log_name")
parser.add_argument('--name', type=str, default="151672")

parser.add_argument('--is_train', type=bool, default=True)
parser.add_argument('--is_load', type=bool, default=False)
parser.add_argument('--ckpt_path', type=str, default="last.pth")
args = parser.parse_args(args=['--epochs', '30', '--name', '151672'])
#args = parser.parse_args()
print(args)

Namespace(backbone='swin_s', batch_size=64, ckpt_path='last.pth', dataset='SpatialLIBD', device='cuda', epochs=30, gene_preprocess='hvg', img_size=112, is_load=False, is_train=True, last_dim=64, log_name='log_name', lr=0.001, n_gene=3000, name='151672', num_workers=15, p_drop=0, path='/root/autodl-tmp/ConGI/data/DLPFC', pct_mask=0.2, pct_noise=0.8, pct_swap=0.1, prob_mask=0.5, prob_noise=0.5, prob_swap=0.5, sigma_noise=0.5, w_g2g=0.1, w_g2i=1, w_graph_loss=0.5, w_i2i=0.1, w_recon=0.3, w_s2g=0.1, w_s2i=1, w_s2s=0.1)


In [None]:
xg = np.load(f'embeddings/{args.name}_xg.npy')
xg1 = np.load(f'embeddings/{args.name}_xg1.npy')
xi = np.load(f'embeddings/{args.name}_xi.npy')
a1=0.5
b1=0.1
z = xg + xg1*a1 + b1*xi    # GAT+MLP+Image
ari, pred_label = get_predicted_results(args.dataset, args.name, args.path, z)
if not os.path.exists("output"):
    os.mkdir("output")
pd.DataFrame({"cluster_labels": pred_label.tolist()}).to_csv(
    "output/" + f"{args.name}_pred.csv")

图注意力xg： (4015, 30)
MLP编码器xg1： (4015, 30)
形态学图像xi： (4015, 30)


R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.0.1
Type 'citation("mclust")' for citing this R package in publications.



fitting ...
Adjusted rand index = 0.593


In [3]:
ari

0.5927183911724826

In [4]:
adata = load_ST_file(os.path.join(args.path, args.name))
adata

AnnData object with n_obs × n_vars = 4015 × 33538
    obs: 'in_tissue', 'array_row', 'array_col'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'spatial'

In [5]:
adata.var_names_make_unique()
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

In [6]:
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_total(adata, inplace=True)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=3000)

In [7]:
adata

AnnData object with n_obs × n_vars = 4015 × 18730
    obs: 'in_tissue', 'array_row', 'array_col', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt'
    var: 'gene_ids', 'feature_types', 'genome', 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'spatial', 'log1p', 'hvg'
    obsm: 'spatial'

In [8]:
adata = load_ST_file(os.path.join(args.path, args.name))
df_meta = pd.read_csv(os.path.join(args.path, args.name, 'metadata.tsv'), sep='\t')
label = pd.Categorical(df_meta['layer_guess']).codes
df_meta = df_meta[~pd.isnull(df_meta['layer_guess'])]
n_clusters = label.max()+1
pred = pd.read_csv(f'output/{args.name}_pred.csv')['cluster_labels']
pred = pred[pred != -1]
adata.obs['ground_truth'] = df_meta['layer_guess']
adata.obs['DESTCLU'] = pred.array.astype(str)
adata.obsm['DESTCLU'] = z
sc.pp.neighbors(adata, use_rep='DESTCLU')
sc.tl.umap(adata)

In [None]:
plt.rcParams["figure.figsize"] = (5, 5)
sc.pl.spatial(adata, img_key="hires", color=["DESTCLU"], title='CLFF(ARI=%.3f)'%ari, legend_loc=None,frameon=False,size=1.8,show=False)
plt.savefig("./output/CLFF.pdf")

In [None]:
sc.pl.umap(adata, color=["DESTCLU"], title=['CLFF' + '\n' + '(ARI=%.3f)'%ari])

In [None]:
plt.rcParams["figure.figsize"] = (5,5)
sc.tl.paga(adata, groups="ground_truth")
sc.pl.paga_compare(adata, color="ground_truth", title='CLFF(ARI=%.3f)'%ari, size=40,show=False)
plt.savefig("./output/CLFF_151672_umap_PAGA.pdf")