In [2]:
import os
os.environ['CUDA_HOME'] = '/usr/lib/cuda'
os.environ['CUDA_PATH'] = '/usr/lib/cuda/bin'

In [3]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())

True
1


In [4]:
import copy
import json
import os
from pathlib import Path
import sys
import warnings
import torch
import pandas as pd
import numpy as np

from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.model import TransformerModel
from scgpt.utils import set_seed 
from scgpt.tasks import GeneEmbedding

warnings.filterwarnings('ignore')
set_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# Specify paths
file_path = '../BEELINE-data/inputs/Curated/HSC/HSC-2000-1/ExpressionData.csv'
model_dir = Path("../models/scGPT_Blood_model/scGPT_bc")
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

In [8]:
# Load vocabulary
vocab = GeneVocab.from_file(vocab_file)
special_tokens = ["<pad>", "<cls>", "<eoc>"]
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

In [9]:
with open(model_config_file, "r") as f:
    model_configs = json.load(f)

In [10]:
# Initialize model
ntokens = len(vocab)
model = TransformerModel(
    ntokens,
    model_configs["embsize"],
    model_configs["nheads"],
    model_configs["d_hid"],
    model_configs["nlayers"],
    vocab=vocab,
    pad_value=model_configs.get("pad_value", -2),
    n_input_bins=model_configs.get("n_bins", 51),
)


In [11]:
# Load model weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(model_file, map_location=device)

# Filter out unexpected keys
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(pretrained_dict)

# Load the filtered state dict
model.load_state_dict(model_dict, strict=False)
model.to(device)
model.eval()

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(36574, 512, padding_idx=36571)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.5, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

In [12]:
# Load and preprocess data
expression_data = pd.read_csv(file_path, index_col=0)
gene_names = expression_data.index.tolist()

# Tokenize genes
tokenized_genes = [vocab[gene] for gene in gene_names]

# Convert expression data to tensor
expression_tensor = torch.tensor(expression_data.values, dtype=torch.float32).to(device)

In [13]:
# Generate embeddings
gene_ids = np.array([vocab[gene] for gene in gene_names])
with torch.no_grad():
    gene_embeddings = model.encoder(torch.tensor(gene_ids, dtype=torch.long).to(device))
gene_embeddings = gene_embeddings.detach().cpu().numpy()

# Create a dictionary of gene embeddings
gene_embeddings_dict = {gene: gene_embeddings[i] for i, gene in enumerate(gene_names)}

# Construct gene embedding network
embed = GeneEmbedding(gene_embeddings_dict)

print(f'Retrieved gene embeddings for {len(gene_embeddings_dict)} genes.')

100%|██████████| 11/11 [00:00<00:00, 279620.27it/s]

Retrieved gene embeddings for 11 genes.





In [14]:
# Save embeddings to a file
embedding_df = pd.DataFrame(gene_embeddings, index=gene_names)
embedding_df.to_csv('gene_embeddings_HSC-2000-1.csv')

print("Gene embeddings have been generated and saved to 'gene_embeddings_HSC-2000-1.csv'")

Gene embeddings have been generated and saved to 'gene_embeddings_HSC-2000-1.csv'
