In [5]:
# %%
import copy
import gc
import json
import os
from pathlib import Path
import shutil
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings
import pandas as pd
# from . import asyn
import pickle
import torch
from anndata import AnnData
import scanpy as sc
import scvi
import seaborn as sns
import numpy as np
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from sklearn.metrics import confusion_matrix

sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics

sc.set_figure_params(figsize=(6, 6))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

In [6]:
pad_token = "<pad>" # 用于在处理文本数据时填充（pad）短于最大长度的序列。这样可以保持序列的统一长度

'''
special_tokens = [pad_token, "<cls>", "<eoc>"]：
    special_tokens 是一个包含特殊标记的列表。除了上面提到的 pad_token，列表中还包括：
    "<cls>"：通常用于表示一个序列的开始。
    "<eoc>"：代表一个序列的结束（End Of Content）
'''
# n_hvg 代表“高变异基因”的数量。在生物统计分析中，选择高变异基因（Highly Variable Genes）是为了捕捉细胞间的重要变异特征
special_tokens = [pad_token, "<cls>", "<eoc>"]
n_hvg = 1200
n_bins = 51
mask_value = -1
pad_value = -2
n_input_bins = n_bins

## Load pre-trained model¶

In [22]:
# here we load the pre-trained scGPT blood model

model_dir = "../pre_trained_model/scGPT_blood/"
print(model_dir)

model_config_file = model_dir +"args.json"
model_file = model_dir + "best_model.pt"
vocab_file = model_dir + "vocab.json"

# 加载和更新词汇表
vocab = GeneVocab.from_file(vocab_file)

# 检查特殊标记 special_tokens 是否存在于词汇表中，如果不存在，则添加到词汇表中
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

# Retrieve model parameters from config files
with open(model_config_file, "r") as f:
    model_configs = json.load(f)
print(
    f"Resume model from {model_file}, the model args will override the "
    f"config {model_config_file}."
)

'''
提取模型配置参数：
    embsize：模型中嵌入向量的大小。
    nhead：在模型中，多头注意力机制中的头数。
    d_hid：隐藏层的维度。
    nlayers：模型中的层数。
    n_layers_cls：分类层的数量，特定于模型的某些设计。
'''

embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]

gene2idx = vocab.get_stoi()

../pre_trained_model/scGPT_blood/
Resume model from ../pre_trained_model/scGPT_blood/best_model.pt, the model args will override the config ../pre_trained_model/scGPT_blood/args.json.


In [23]:
print(len(gene2idx))

36574


### Check the model genes embeddings

In [24]:
keys = list(gene2idx.keys())
keys

['hsa-mir-423',
 'ZZEF1',
 'ZYX',
 'ZYG11A',
 'ZXDB',
 'ZXDA',
 'ZW10',
 'ZUP1',
 'ZSWIM9',
 'ZSWIM3',
 'ZSCAN5C',
 'ZSCAN5B',
 'ZSCAN26',
 'ZSCAN23',
 'ZSCAN22',
 'ZSCAN21',
 'ZSCAN2',
 'ZSCAN16-AS1',
 'ZSCAN16',
 'ZSCAN12',
 'ZRANB3',
 'ZRANB1',
 'ZPLD1',
 'ZP4',
 'ZP3',
 'ZP1',
 'ZNRF3',
 'ZNRF1',
 'ZNRD2-AS1',
 'ZNNT1',
 'ZNHIT3',
 'ZNHIT2',
 'ZNF99',
 'ZNF880',
 'ZNF879',
 'ZNF875',
 'ZNF865',
 'ZNF860',
 'ZNF853',
 'ZNF85',
 'ZNF846',
 'ZNF841',
 'ZNF84',
 'ZNF837',
 'ZNF835',
 'ZNF831',
 'ZNF830',
 'ZNF813',
 'ZNF808',
 'ZNF804B',
 'ZNF80',
 'ZNF8-DT',
 'ZNF8',
 'ZNF799',
 'ZNF793-AS1',
 'ZNF791',
 'ZNF787',
 'ZNF786',
 'ZNF785',
 'ZNF780B',
 'ZNF778',
 'ZNF771',
 'ZNF770',
 'ZNF766',
 'ZNF75D',
 'ZNF75A',
 'ZNF746',
 'ZNF737',
 'ZNF736',
 'ZNF732',
 'ZNF729',
 'ZNF727',
 'ZNF724',
 'ZNF713',
 'ZNF709',
 'ZNF708',
 'ZNF707',
 'ZNF705G',
 'ZNF705A',
 'ZNF703',
 'ZNF701',
 'ZNF70',
 'ZNF699',
 'ZNF696',
 'ZNF691',
 'ZNF687-AS1',
 'ZNF684',
 'ZNF682',
 'ZNF680',
 'ZNF679',
 'ZNF676

In [25]:
# 将键写入文本文件
with open('keys_blood.txt', 'w') as file:
    for key in keys:
        file.write(str(key) + '\n')

In [28]:
import pandas as pd

# 导入 CSV 文件
df = pd.read_csv('scaled_OLINK_data.csv')
header = df.columns # 取出表头（列名）
header_list = header.tolist() # 将表头转换为列表
header_list.pop(0) # 去除 header_list 中的第一个元素
processed_header_list = [element.split('_')[1] for element in header_list] # 处理每个元素，删除 'P_' 和 '_0'
protein_name_list = list(set(processed_header_list)) # 去重（不保留顺序）

print(type(protein_name_list))

<class 'list'>


In [29]:
protein_name_list

['TBL1X',
 'GASK1A',
 'MYH9',
 'ZP4',
 'COL3A1',
 'COL4A4',
 'ADA',
 'ITM2A',
 'AIF1L',
 'TGFBR2',
 'BEX3',
 'RBM17',
 'MAGEA3',
 'LAMP2',
 'TOP1MT',
 'APOE',
 'GPI',
 'FABP2',
 'LILRA2',
 'DNAJC21',
 'CFI',
 'IL10RB',
 'STC1',
 'COL5A1',
 'OGFR',
 'LAT2',
 'SH2D1A',
 'MORF4L1',
 'CD36',
 'ERBIN',
 'CD58',
 'SERPING1',
 'TP53',
 'AFAP1',
 'CUZD1',
 'GPR158',
 'GPKOW',
 'PFDN4',
 'TNFRSF11B',
 'SCARF2',
 'PXDNL',
 'ADGRB3',
 'MRC1',
 'TRIAP1',
 'PPP1R9B',
 'BNIP3L',
 'PCDHB15',
 'CD99',
 'NOTCH3',
 'JAM2',
 'EXOSC10',
 'PNLIPRP2',
 'TMPRSS11D',
 'FUT3',
 'SLC12A2',
 'PKLR',
 'EGF',
 'HEPH',
 'ACTN2',
 'RNF43',
 'TMED8',
 'SIGLEC10',
 'CASP8',
 'SLITRK6',
 'RAC3',
 'UBXN1',
 'CCL16',
 'DDHD2',
 'SPAG1',
 'TCL1B',
 'MCTS1',
 'PAIP2B',
 'NFKB1',
 'ARHGAP5',
 'ESPL1',
 'RNF41',
 'SETMAR',
 'CEP43',
 'PSMD1',
 'AK1',
 'TNFSF12',
 'SKAP2',
 'RILP',
 'OGA',
 'MGLL',
 'SKAP1',
 'KITLG',
 'SRPK2',
 'PPBP',
 'NAA10',
 'AMY2B',
 'KAZN',
 'PTS',
 'VPS28',
 'CD70',
 'TFF2',
 'AK2',
 'TERF1',
 'MAP1L

In [16]:
# 检查每个元素是否在 gene2idx 中找到对应的键
i = 0
for gene in unique_processed_header_list:
    if gene in gene2idx:
        i += 1
        # print(f"'{gene}' found in gene2idx with value {gene2idx[gene]}")
    else:
        print(f"'{gene}' not found in gene2idx")
print(len(unique_processed_header_list))
print(i)
print(len(unique_processed_header_list) - i)

'CD99' not found in gene2idx
'PNLIPRP2' not found in gene2idx
'LEG1' not found in gene2idx
'IL3RA' not found in gene2idx
'AKR7L' not found in gene2idx
'KIR2DS4' not found in gene2idx
'KIR2DL2' not found in gene2idx
'FHIP2A' not found in gene2idx
'BAP18' not found in gene2idx
'BTNL10' not found in gene2idx
'ANP32C' not found in gene2idx
'MENT' not found in gene2idx
'SARG' not found in gene2idx
'CERT' not found in gene2idx
'GATD3' not found in gene2idx
'HCG22' not found in gene2idx
'SIGLEC5' not found in gene2idx
'CSF2RA' not found in gene2idx
'GPR15L' not found in gene2idx
'NTproBNP' not found in gene2idx
'PALM2' not found in gene2idx
'LILRA3' not found in gene2idx
'WARS' not found in gene2idx
2920
2897
23


In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 计算词汇表的长度，即模型需要处理的不同词汇的数量
ntokens = len(vocab)  # size of vocabulary

model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    vocab=vocab,
    pad_value=pad_value,
    n_input_bins=n_input_bins,
)

try:
    model.load_state_dict(torch.load(model_file))
    print(f"Loading all model params from {model_file}")
except:
    # only load params that are in the model and match the size
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_file)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and v.shape == model_dict[k].shape
    }
    for k, v in pretrained_dict.items():
        print(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

model.to(device)

Loading params encoder.embedding.weight with shape torch.Size([36574, 512])
Loading params encoder.enc_norm.weight with shape torch.Size([512])
Loading params encoder.enc_norm.bias with shape torch.Size([512])
Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
Loading params value_encoder.linear1.bias with shape torch.Size([512])
Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
Loading params value_encoder.linear2.bias with shape torch.Size([512])
Loading params value_encoder.norm.weight with shape torch.Size([512])
Loading params value_encoder.norm.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.linear1.bias with 

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(36574, 512, padding_idx=36571)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.5, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

### Retrieve scGPT's gene embeddings

In [32]:
# Retrieve the data-independent gene embeddings from scGPT

gene_ids = np.array([id for id in gene2idx.values()]) # 提取基因 ID 和相应的索引
gene_embeddings = model.encoder(torch.tensor(gene_ids, dtype=torch.long).to(device)) # 使用模型编码器获取基因嵌入向量
gene_embeddings = gene_embeddings.detach().cpu().numpy()

In [33]:
len(gene_embeddings)

36574

In [34]:
gene_embeddings = {gene: gene_embeddings[i] for i, gene in enumerate(gene2idx.keys()) if gene in protein_name_list}
print('Retrieved gene embeddings for {} genes.'.format(len(gene_embeddings)))

Retrieved gene embeddings for 2897 genes.


In [35]:
gene_embeddings['ZP4']

array([ 9.0908483e-02,  8.7722105e-01,  1.0274558e+00, -1.3882810e+00,
       -9.9509966e-01,  8.7437022e-01,  1.8333793e-01, -4.3422338e-01,
       -1.4117726e+00,  1.8266861e+00, -6.0937452e-01, -4.9604096e-02,
       -7.5148779e-01, -4.0608305e-01,  1.0094994e+00, -1.1313052e+00,
       -8.2560384e-01,  6.7931354e-02, -6.9902837e-01,  1.7424139e+00,
       -1.1972936e+00, -1.0002319e+00,  7.9404163e-01, -1.8435235e+00,
        9.3538827e-01, -3.9908871e-01,  4.8055208e-01, -7.1397781e-02,
        1.9529712e+00,  1.5586306e+00, -8.1507307e-01,  8.0080867e-01,
       -1.0662802e+00,  5.5932343e-02, -1.0604399e+00,  2.5609517e-01,
        2.8598475e-01, -1.2425861e+00,  6.0982054e-01, -1.9765290e+00,
        6.9338226e-01,  1.3598503e+00,  1.9679266e+00, -3.8258371e-01,
       -4.9964395e-01, -1.3429350e+00, -1.9517930e-02, -2.3001935e+00,
        1.2533377e+00, -1.6938232e-01,  1.0453185e+00,  1.1792245e+00,
        2.1922734e-01, -2.1123407e+00, -2.3792788e-01, -1.2083787e+00,
      

### Combine gene embeddings for every individual

In [38]:
# 检查模型是否有layers属性
if hasattr(model, 'layers'):
    layers = model.layers
elif hasattr(model, 'transformer_encoder'):
    layers = model.transformer_encoder.layers
else:
    raise AttributeError("The model does not have 'layers' or 'transformer_encoder.layers' attributes.")

# 遍历并打印层的名称
for i, layer in enumerate(layers):
    # 如果层有name属性，直接使用
    if hasattr(layer, 'name'):
        print(i, layer.name)
    else:
        # 否则，使用类名
        print(i, layer.__class__.__name__)


0 TransformerEncoderLayer
1 TransformerEncoderLayer
2 TransformerEncoderLayer
3 TransformerEncoderLayer
4 TransformerEncoderLayer
5 TransformerEncoderLayer
6 TransformerEncoderLayer
7 TransformerEncoderLayer
8 TransformerEncoderLayer
9 TransformerEncoderLayer
10 TransformerEncoderLayer
11 TransformerEncoderLayer


### Model