# scFoundation: Preprocessing Code

In [3]:
def get_human_readable_size(bytes, decimal_places=2):
    """Convert bytes to a human-readable format (e.g., MB, GB)."""
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if bytes < 1024:
            return f"{bytes:.{decimal_places}f} {unit}"
        bytes /= 1024

In [4]:
# cuda appears particularlly finicky with this model. If cuda.synchrnize() doesn't run, then model won't run
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(torch.cuda.current_device()))
print(torch.cuda.synchronize())

device = torch.device('cuda:0')
print(f'Allocated Memory: {get_human_readable_size(torch.cuda.memory_allocated(device))}')
print(f'Reserved Memory: {get_human_readable_size(torch.cuda.memory_reserved(device))}')



True
3
0
Tesla V100-PCIE-16GB
None
Allocated Memory: 0.00 B
Reserved Memory: 0.00 B


In [5]:
import torch
import pandas as pd
import scanpy as sc
import sys 
import os
import numpy as np
import random
sys.path.append('/home/jpic/scFoundationProject/scFoundation/scFoundation/model')
from pretrainmodels import select_model
import math
from tqdm import tqdm

In [8]:
device = torch.device('cuda:0')
print(f'Allocated Memory: {get_human_readable_size(torch.cuda.memory_allocated(device))}')
print(f'Reserved Memory: {get_human_readable_size(torch.cuda.memory_reserved(device))}')

Allocated Memory: 0.00 B
Reserved Memory: 0.00 B


## Preprocess Data

In [9]:
def main_gene_selection(X_df, gene_list):
    """
    Describe:
        rebuild the input adata to select target genes encode protein 
    Parameters:
        adata->`~anndata.AnnData` object: adata with var index_name by gene symbol
        gene_list->list: wanted target gene 
    Returns:
        adata_new->`~anndata.AnnData` object
        to_fill_columns->list: zero padding gene
    """
    to_fill_columns = list(set(gene_list) - set(X_df.columns))
    padding_df = pd.DataFrame(np.zeros((X_df.shape[0], len(to_fill_columns))), 
                              columns=to_fill_columns, 
                              index=X_df.index)
    X_df = pd.DataFrame(np.concatenate([df.values for df in [X_df, padding_df]], axis=1), 
                        index=X_df.index, 
                        columns=list(X_df.columns) + list(padding_df.columns))
    X_df = X_df[gene_list]
    
    var = pd.DataFrame(index=X_df.columns)
    var['mask'] = [1 if i in to_fill_columns else 0 for i in list(var.index)]
    return X_df, to_fill_columns,var

In [10]:
input_file = '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad'
adata         = sc.read_h5ad(input_file)


df = pd.DataFrame(adata.X.T, index=adata.var.index)
gene_list_df = pd.read_csv('/nfs/turbo/umms-indikar/shared/projects/foundation_models/scFoundation/scFoundation/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
gene_list = list(gene_list_df['gene_name'])
X_df, to_fill_columns, var = main_gene_selection(df.T, gene_list)

In [11]:
device = torch.device('cuda:0')
print(f'Allocated Memory: {get_human_readable_size(torch.cuda.memory_allocated(device))}')
print(f'Reserved Memory: {get_human_readable_size(torch.cuda.memory_reserved(device))}')

Allocated Memory: 0.00 B
Reserved Memory: 0.00 B


## Load model

In [12]:
def setup():
    random.seed(0)
    np.random.seed(0)  # numpy random generator

    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

def convertconfig(ckpt):
    newconfig = {}
    newconfig['config']={}
    model_type = ckpt['config']['model']
    
    for key, val in ckpt['config']['model_config'][model_type].items():
        newconfig['config'][key]=val
        
    for key, val in ckpt['config']['dataset_config']['rnaseq'].items():
        newconfig['config'][key]=val
        
    if model_type == 'performergau_resolution':
        model_type = 'performer_gau'
    
    import collections
    d = collections.OrderedDict()
    for key, val in ckpt['state_dict'].items():
        d[str(key).split('model.')[1]]=val
        
    newconfig['config']['model_type']=model_type
    newconfig['model_state_dict']=d
    newconfig['config']['pos_embed']=False
    newconfig['config']['device']='cuda'
    return newconfig

def gatherData(data, labels, pad_token_id):
    """
    Gathers data and prepares it for model input by handling padding and sorting based on labels.

    Parameters:
    data (torch.Tensor): The input data tensor.
    labels (torch.Tensor): The labels tensor indicating the presence of values in the data tensor.
    pad_token_id (int): The token ID used for padding.

    Returns:
    new_data (torch.Tensor): The gathered data tensor with padding handled.
    padding_labels (torch.Tensor): The tensor indicating which positions are padding.
    """
    # Calculate the number of values per row
    value_nums = labels.sum(1)
    max_num = max(value_nums)

    # Create fake data for padding
    fake_data = torch.full((data.shape[0], max_num), pad_token_id, device=data.device)
    data = torch.hstack([data, fake_data])

    # Create fake labels for padding
    fake_label = torch.full((labels.shape[0], max_num), 1, device=labels.device)
    none_labels = ~labels
    labels = labels.float()
    labels[none_labels] = torch.tensor(-float('Inf'), device=labels.device)

    # Create a tensor to adjust labels for sorting
    tmp_data = torch.tensor([(i + 1) * 20000 for i in range(labels.shape[1], 0, -1)], device=labels.device)
    labels += tmp_data

    # Concatenate the original labels with fake labels
    labels = torch.hstack([labels, fake_label])

    # Sort and gather data based on the top-k labels
    fake_label_gene_idx = labels.topk(max_num).indices
    new_data = torch.gather(data, 1, fake_label_gene_idx)

    # Determine which positions are padding
    padding_labels = (new_data == pad_token_id)

    return new_data, padding_labels

In [18]:
best_ckpt_path = '/nfs/turbo/umms-indikar/shared/projects/foundation_models/scFoundation/scFoundation/model/models/models.ckpt'
key = 'cell'

model_data = torch.load(best_ckpt_path,map_location='cpu')
model_data = model_data[key]
model_data = convertconfig(model_data)
if not model_data.__contains__('config'):
    print('***** No config *****')
    config={}
    config['model_type']='flash_all'
else:
    config=model_data['config']
    print(config)
if not config.__contains__('qv_dim'):
    if config['model'] != 'mae_autobin':
        if config.__contains__('dim_head'):
            config['qv_dim']=config['dim_head']
        else:
            print('***** No qv_dim ***** set 64')
            config['qv_dim']= 64
if not config.__contains__('ppi_edge'):
    config['ppi_edge']=None
model = select_model(config)
model_state_dict = model_data['model_state_dict']    
model.load_state_dict(model_state_dict)
model.cuda()


{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma

MaeAutobin(
  (token_emb): AutoDiscretizationEmbedding2(
    (mlp): Linear(in_features=1, out_features=100, bias=True)
    (mlp2): Linear(in_features=100, out_features=100, bias=True)
    (LeakyReLU): LeakyReLU(negative_slope=0.1)
    (Softmax): Softmax(dim=-1)
    (emb): Embedding(100, 768)
    (emb_mask): Embedding(1, 768)
    (emb_pad): Embedding(1, 768)
  )
  (pos_emb): Embedding(19267, 768)
  (decoder_embed): Linear(in_features=768, out_features=512, bias=True)
  (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (to_final): Linear(in_features=512, out_features=1, bias=True)
  (encoder): pytorchTransformerModule(
    (transformer_encoder): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=3072, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
     

In [19]:
device = torch.device('cuda:0')
print(f'Allocated Memory: {get_human_readable_size(torch.cuda.memory_allocated(device))}')
print(f'Reserved Memory: {get_human_readable_size(torch.cuda.memory_reserved(device))}')

Allocated Memory: 455.36 MB
Reserved Memory: 504.00 MB


In [14]:
import numpy as np
import torch
from tqdm import tqdm

def embed(gexpr_feature, input_type='singlecell', pre_normalized='T', tgthighres='f0.5', output_type='cell', pool_type='all', 
          pretrainmodel=None, pretrainconfig=None, gatherData=None, getEncoerDecoderData=None, strname='output.npy'):
    """
    Embeds gene expression data using a pre-trained model.

    Parameters:
    gexpr_feature (DataFrame): The gene expression feature data.
    input_type (str): Type of input data ('bulk' or 'singlecell'). Default is 'singlecell'.
    pre_normalized (str): Indicates if the data is pre-normalized ('T', 'F', 'A'). Default is 'T'.
    tgthighres (str): Target high resolution ('f', 'a', 't' followed by a number). Default is 'f0.5'.
    output_type (str): Type of output embedding ('cell', 'gene', 'gene_batch', 'gene_expression'). Default is 'cell'.
    pool_type (str): Pooling type for embeddings ('all' or 'max'). Default is 'all'.
    pretrainmodel (torch.nn.Module): The pre-trained model used for embedding.
    pretrainconfig (dict): Configuration dictionary for the pre-trained model.
    gatherData (function): Function to gather data for the model.
    getEncoerDecoderData (function): Function to get encoder-decoder data.
    strname (str): The name of the output file to save embeddings. Default is 'output.npy'.

    Returns:
    None
    """
    geneexpemb = []
    batchcontainer = []

    for i in tqdm(range(gexpr_feature.shape[0])):
        with torch.no_grad():
            if input_type == 'bulk':
                if pre_normalized == 'T':
                    totalcount = gexpr_feature.iloc[i, :].sum()
                elif pre_normalized == 'F':
                    totalcount = np.log10(gexpr_feature.iloc[i, :].sum())
                else:
                    raise ValueError('pre_normalized must be T or F')
                tmpdata         = gexpr_feature.iloc[i, :].tolist()
                pretrain_gene_x = torch.tensor(tmpdata + [totalcount, totalcount]).unsqueeze(0).cuda()
                data_gene_ids   = torch.arange(19266, device=pretrain_gene_x.device).repeat(pretrain_gene_x.shape[0], 1)
            
            elif input_type == 'singlecell':
                if pre_normalized == 'F':
                    tmpdata = np.log1p(gexpr_feature.iloc[i, :] / gexpr_feature.iloc[i, :].sum() * 1e4).tolist()
                elif pre_normalized == 'T':
                    tmpdata = gexpr_feature.iloc[i, :].tolist()
                elif pre_normalized == 'A':
                    tmpdata = gexpr_feature.iloc[i, :-1].tolist()
                else:
                    raise ValueError('pre_normalized must be T, F, or A')

                if pre_normalized == 'A':
                    totalcount = gexpr_feature.iloc[i, -1]
                else:
                    totalcount = gexpr_feature.iloc[i, :].sum()

                if tgthighres[0] == 'f':
                    pretrain_gene_x = torch.tensor(tmpdata + [np.log10(totalcount * float(tgthighres[1:])), np.log10(totalcount)]).unsqueeze(0).cuda()
                elif tgthighres[0] == 'a':
                    pretrain_gene_x = torch.tensor(tmpdata + [np.log10(totalcount) + float(tgthighres[1:]), np.log10(totalcount)]).unsqueeze(0).cuda()
                elif tgthighres[0] == 't':
                    pretrain_gene_x = torch.tensor(tmpdata + [float(tgthighres[1:]), np.log10(totalcount)]).unsqueeze(0).cuda()
                else:
                    raise ValueError('tgthighres must start with f, a, or t')
                data_gene_ids = torch.arange(19266, device=pretrain_gene_x.device).repeat(pretrain_gene_x.shape[0], 1)

            value_labels = pretrain_gene_x > 0
            x, x_padding = gatherData(pretrain_gene_x, value_labels, pretrainconfig['pad_token_id'])

            if output_type == 'cell':
                position_gene_ids, _ = gatherData(data_gene_ids, value_labels, pretrainconfig['pad_token_id'])
                x = pretrainmodel.token_emb(torch.unsqueeze(x, 2).float(), output_weight=0)
                position_emb = pretrainmodel.pos_emb(position_gene_ids)
                x += position_emb
                geneemb = pretrainmodel.encoder(x, x_padding)

                geneemb1 = geneemb[:, -1, :]
                geneemb2 = geneemb[:, -2, :]
                geneemb3, _ = torch.max(geneemb[:, :-2, :], dim=1)
                geneemb4 = torch.mean(geneemb[:, :-2, :], dim=1)
                if pool_type == 'all':
                    geneembmerge = torch.concat([geneemb1, geneemb2, geneemb3, geneemb4], axis=1)
                elif pool_type == 'max':
                    geneembmerge, _ = torch.max(geneemb, dim=1)
                else:
                    raise ValueError('pool_type must be all or max')
                geneexpemb.append(geneembmerge.detach().cpu().numpy())

            elif output_type == 'gene':
                pretrainmodel.to_final = None
                encoder_data, encoder_position_gene_ids, encoder_data_padding, encoder_labels, decoder_data, decoder_data_padding, new_data_raw, data_mask_labels, decoder_position_gene_ids = getEncoerDecoderData(pretrain_gene_x.float(), pretrain_gene_x.float(), pretrainconfig)
                out = pretrainmodel.forward(x=encoder_data, padding_label=encoder_data_padding,
                                            encoder_position_gene_ids=encoder_position_gene_ids,
                                            encoder_labels=encoder_labels,
                                            decoder_data=decoder_data,
                                            mask_gene_name=False,
                                            mask_labels=None,
                                            decoder_position_gene_ids=decoder_position_gene_ids,
                                            decoder_data_padding_labels=decoder_data_padding)
                out = out[:, :19264, :].contiguous()
                geneexpemb.append(out.detach().cpu().numpy())

            elif output_type == 'gene_batch':
                batchcontainer.append(pretrain_gene_x.float())
                if len(batchcontainer) == gexpr_feature.shape[0]:
                    batchcontainer = torch.concat(batchcontainer, axis=0)
                else:
                    continue
                pretrainmodel.to_final = None
                encoder_data, encoder_position_gene_ids, encoder_data_padding, encoder_labels, decoder_data, decoder_data_padding, new_data_raw, data_mask_labels, decoder_position_gene_ids = getEncoerDecoderData(batchcontainer, batchcontainer, pretrainconfig)
                out = pretrainmodel.forward(x=encoder_data, padding_label=encoder_data_padding,
                                            encoder_position_gene_ids=encoder_position_gene_ids,
                                            encoder_labels=encoder_labels,
                                            decoder_data=decoder_data,
                                            mask_gene_name=False,
                                            mask_labels=None,
                                            decoder_position_gene_ids=decoder_position_gene_ids,
                                            decoder_data_padding_labels=decoder_data_padding)
                geneexpemb = out[:, :19264, :].contiguous().detach().cpu().numpy()

            elif output_type == 'gene_expression':
                encoder_data, encoder_position_gene_ids, encoder_data_padding, encoder_labels, decoder_data, decoder_data_padding, new_data_raw, data_mask_labels, decoder_position_gene_ids = getEncoerDecoderData(pretrain_gene_x.float(), pretrain_gene_x.float(), pretrainconfig)
                out = pretrainmodel.forward(x=encoder_data, padding_label=encoder_data_padding,
                                            encoder_position_gene_ids=encoder_position_gene_ids,
                                            encoder_labels=encoder_labels,
                                            decoder_data=decoder_data,
                                            mask_gene_name=False,
                                            mask_labels=None,
                                            decoder_position_gene_ids=decoder_position_gene_ids,
                                            decoder_data_padding_labels=decoder_data_padding)
                out = out[:, :19264].contiguous()
                geneexpemb.append(out.detach().cpu().numpy())                
            else:
                raise ValueError('output_type must be cell, gene, gene_batch, or gene_expression')

    geneexpemb = np.squeeze(np.array(geneexpemb))
    return geneexpemb
    # print(geneexpemb.shape)
    # np.save(strname, geneexpemb)


In [20]:
device = torch.device('cuda:0')
print(f'Allocated Memory: {get_human_readable_size(torch.cuda.memory_allocated(device))}')
print(f'Reserved Memory: {get_human_readable_size(torch.cuda.memory_reserved(device))}')

Allocated Memory: 455.36 MB
Reserved Memory: 504.00 MB


In [21]:
setup()

cuda


In [10]:
torch.cuda.synchronize()
torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()
# torch.cuda.synchronize()

In [22]:
device = torch.device('cuda:0')
print(f'Allocated Memory: {get_human_readable_size(torch.cuda.memory_allocated(device))}')
print(f'Reserved Memory: {get_human_readable_size(torch.cuda.memory_reserved(device))}')

Allocated Memory: 455.36 MB
Reserved Memory: 504.00 MB


In [23]:
# Run the cell below before this one
embeddings = embed(X_df, input_type='singlecell', pre_normalized='T', tgthighres='f0.5', output_type='cell', pool_type='all', 
          pretrainmodel=model, pretrainconfig=config, gatherData=gatherData, getEncoerDecoderData=None, strname='output.npy')

  0%|          | 0/66 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 8.58 GiB. GPU 0 has a total capacty of 15.77 GiB of which 6.10 GiB is free. Including non-PyTorch memory, this process has 9.66 GiB memory in use. Of the allocated memory 707.76 MiB is allocated by PyTorch, and 8.61 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [24]:
device = torch.device('cuda:0')
print(f'Allocated Memory: {get_human_readable_size(torch.cuda.memory_allocated(device))}')
print(f'Reserved Memory: {get_human_readable_size(torch.cuda.memory_reserved(device))}')

Allocated Memory: 707.76 MB
Reserved Memory: 9.30 GB


In [14]:
import numpy as np
import torch
from tqdm import tqdm

def embed(gexpr_feature, input_type='singlecell', pre_normalized='T', tgthighres='f0.5', output_type='cell', pool_type='all', 
          pretrainmodel=None, pretrainconfig=None, gatherData=None, getEncoerDecoderData=None, strname='output.npy'):
    """
    Embeds gene expression data using a pre-trained model.

    Parameters:
    gexpr_feature (DataFrame): The gene expression feature data.
    input_type (str): Type of input data ('bulk' or 'singlecell'). Default is 'singlecell'.
    pre_normalized (str): Indicates if the data is pre-normalized ('T', 'F', 'A'). Default is 'T'.
    tgthighres (str): Target high resolution ('f', 'a', 't' followed by a number). Default is 'f0.5'.
    output_type (str): Type of output embedding ('cell', 'gene', 'gene_batch', 'gene_expression'). Default is 'cell'.
    pool_type (str): Pooling type for embeddings ('all' or 'max'). Default is 'all'.
    pretrainmodel (torch.nn.Module): The pre-trained model used for embedding.
    pretrainconfig (dict): Configuration dictionary for the pre-trained model.
    gatherData (function): Function to gather data for the model.
    getEncoerDecoderData (function): Function to get encoder-decoder data.
    strname (str): The name of the output file to save embeddings. Default is 'output.npy'.

    Returns:
    None
    """
    geneexpemb = []
    batchcontainer = []

    for i in tqdm(range(gexpr_feature.shape[0])):
        with torch.no_grad():
            if input_type == 'bulk':
                if pre_normalized == 'T':
                    totalcount = gexpr_feature.iloc[i, :].sum()
                elif pre_normalized == 'F':
                    totalcount = np.log10(gexpr_feature.iloc[i, :].sum())
                else:
                    raise ValueError('pre_normalized must be T or F')
                tmpdata         = gexpr_feature.iloc[i, :].tolist()
                pretrain_gene_x = torch.tensor(tmpdata + [totalcount, totalcount]).unsqueeze(0).cuda()
                data_gene_ids   = torch.arange(19266, device=pretrain_gene_x.device).repeat(pretrain_gene_x.shape[0], 1)
            
            elif input_type == 'singlecell':
                if pre_normalized == 'F':
                    tmpdata = np.log1p(gexpr_feature.iloc[i, :] / gexpr_feature.iloc[i, :].sum() * 1e4).tolist()
                elif pre_normalized == 'T':
                    tmpdata = gexpr_feature.iloc[i, :].tolist()
                elif pre_normalized == 'A':
                    tmpdata = gexpr_feature.iloc[i, :-1].tolist()
                else:
                    raise ValueError('pre_normalized must be T, F, or A')

                if pre_normalized == 'A':
                    totalcount = gexpr_feature.iloc[i, -1]
                else:
                    totalcount = gexpr_feature.iloc[i, :].sum()

                if tgthighres[0] == 'f':
                    pretrain_gene_x = torch.tensor(tmpdata + [np.log10(totalcount * float(tgthighres[1:])), np.log10(totalcount)]).unsqueeze(0).cuda()
                elif tgthighres[0] == 'a':
                    pretrain_gene_x = torch.tensor(tmpdata + [np.log10(totalcount) + float(tgthighres[1:]), np.log10(totalcount)]).unsqueeze(0).cuda()
                elif tgthighres[0] == 't':
                    pretrain_gene_x = torch.tensor(tmpdata + [float(tgthighres[1:]), np.log10(totalcount)]).unsqueeze(0).cuda()
                else:
                    raise ValueError('tgthighres must start with f, a, or t')
                data_gene_ids = torch.arange(19266, device=pretrain_gene_x.device).repeat(pretrain_gene_x.shape[0], 1)

            value_labels = pretrain_gene_x > 0
            x, x_padding = gatherData(pretrain_gene_x, value_labels, pretrainconfig['pad_token_id'])

            if output_type == 'cell':
                position_gene_ids, _ = gatherData(data_gene_ids, value_labels, pretrainconfig['pad_token_id'])
                x = pretrainmodel.token_emb(torch.unsqueeze(x, 2).float(), output_weight=0)
                position_emb = pretrainmodel.pos_emb(position_gene_ids)
                x += position_emb
                geneemb = pretrainmodel.encoder(x, x_padding)

                geneemb1 = geneemb[:, -1, :]
                geneemb2 = geneemb[:, -2, :]
                geneemb3, _ = torch.max(geneemb[:, :-2, :], dim=1)
                geneemb4 = torch.mean(geneemb[:, :-2, :], dim=1)
                if pool_type == 'all':
                    geneembmerge = torch.concat([geneemb1, geneemb2, geneemb3, geneemb4], axis=1)
                elif pool_type == 'max':
                    geneembmerge, _ = torch.max(geneemb, dim=1)
                else:
                    raise ValueError('pool_type must be all or max')
                geneexpemb.append(geneembmerge.detach().cpu().numpy())

            elif output_type == 'gene':
                pretrainmodel.to_final = None
                encoder_data, encoder_position_gene_ids, encoder_data_padding, encoder_labels, decoder_data, decoder_data_padding, new_data_raw, data_mask_labels, decoder_position_gene_ids = getEncoerDecoderData(pretrain_gene_x.float(), pretrain_gene_x.float(), pretrainconfig)
                out = pretrainmodel.forward(x=encoder_data, padding_label=encoder_data_padding,
                                            encoder_position_gene_ids=encoder_position_gene_ids,
                                            encoder_labels=encoder_labels,
                                            decoder_data=decoder_data,
                                            mask_gene_name=False,
                                            mask_labels=None,
                                            decoder_position_gene_ids=decoder_position_gene_ids,
                                            decoder_data_padding_labels=decoder_data_padding)
                out = out[:, :19264, :].contiguous()
                geneexpemb.append(out.detach().cpu().numpy())

            elif output_type == 'gene_batch':
                batchcontainer.append(pretrain_gene_x.float())
                if len(batchcontainer) == gexpr_feature.shape[0]:
                    batchcontainer = torch.concat(batchcontainer, axis=0)
                else:
                    continue
                pretrainmodel.to_final = None
                encoder_data, encoder_position_gene_ids, encoder_data_padding, encoder_labels, decoder_data, decoder_data_padding, new_data_raw, data_mask_labels, decoder_position_gene_ids = getEncoerDecoderData(batchcontainer, batchcontainer, pretrainconfig)
                out = pretrainmodel.forward(x=encoder_data, padding_label=encoder_data_padding,
                                            encoder_position_gene_ids=encoder_position_gene_ids,
                                            encoder_labels=encoder_labels,
                                            decoder_data=decoder_data,
                                            mask_gene_name=False,
                                            mask_labels=None,
                                            decoder_position_gene_ids=decoder_position_gene_ids,
                                            decoder_data_padding_labels=decoder_data_padding)
                geneexpemb = out[:, :19264, :].contiguous().detach().cpu().numpy()

            elif output_type == 'gene_expression':
                encoder_data, encoder_position_gene_ids, encoder_data_padding, encoder_labels, decoder_data, decoder_data_padding, new_data_raw, data_mask_labels, decoder_position_gene_ids = getEncoerDecoderData(pretrain_gene_x.float(), pretrain_gene_x.float(), pretrainconfig)
                out = pretrainmodel.forward(x=encoder_data, padding_label=encoder_data_padding,
                                            encoder_position_gene_ids=encoder_position_gene_ids,
                                            encoder_labels=encoder_labels,
                                            decoder_data=decoder_data,
                                            mask_gene_name=False,
                                            mask_labels=None,
                                            decoder_position_gene_ids=decoder_position_gene_ids,
                                            decoder_data_padding_labels=decoder_data_padding)
                out = out[:, :19264].contiguous()
                geneexpemb.append(out.detach().cpu().numpy())                
            else:
                raise ValueError('output_type must be cell, gene, gene_batch, or gene_expression')

    geneexpemb = np.squeeze(np.array(geneexpemb))
    return geneexpemb
    # print(geneexpemb.shape)
    # np.save(strname, geneexpemb)


## Helper Code

In [3]:
def setup():
    random.seed(0)
    np.random.seed(0)  # numpy random generator

    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

def convertconfig(ckpt):
    newconfig = {}
    newconfig['config']={}
    model_type = ckpt['config']['model']
    
    for key, val in ckpt['config']['model_config'][model_type].items():
        newconfig['config'][key]=val
        
    for key, val in ckpt['config']['dataset_config']['rnaseq'].items():
        newconfig['config'][key]=val
        
    if model_type == 'performergau_resolution':
        model_type = 'performer_gau'
    
    import collections
    d = collections.OrderedDict()
    for key, val in ckpt['state_dict'].items():
        d[str(key).split('model.')[1]]=val
        
    newconfig['config']['model_type']=model_type
    newconfig['model_state_dict']=d
    newconfig['config']['pos_embed']=False
    newconfig['config']['device']='cuda'
    return newconfig

def loaddata(data_path, verbose=True, pre_normalized='T', input_type='singlecell', demo=False):
    #Load data
    if data_path[-3:]=='npz':
        gexpr_feature = scipy.sparse.load_npz(data_path)
        gexpr_feature = pd.DataFrame(gexpr_feature.toarray())
    elif data_path[-4:]=='h5ad':
        gexpr_feature = sc.read_h5ad(data_path)
        idx = gexpr_feature.obs_names.tolist()
        col = gexpr_feature.var.gene_name.tolist()
        if issparse(gexpr_feature.X):
            gexpr_feature = gexpr_feature.X.toarray()
        else:
            gexpr_feature = gexpr_feature
        gexpr_feature = pd.DataFrame(gexpr_feature,index=idx,columns=col)
    elif data_path[-3:]=='npy':
        gexpr_feature = np.load(data_path)
        gexpr_feature = pd.DataFrame(gexpr_feature)
    else:
        gexpr_feature=pd.read_csv(data_path,index_col=0)
    
    if gexpr_feature.shape[1]<19264:
        print('covert gene feature into 19264')
        gexpr_feature, to_fill_columns,var = main_gene_selection(gexpr_feature,gene_list)
        assert gexpr_feature.shape[1]>=19264
    
    if (pre_normalized == 'F') and (input_type == 'bulk'):
        adata = sc.AnnData(gexpr_feature)
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
        gexpr_feature = pd.DataFrame(adata.X,index=adata.obs_names,columns=adata.var_names)

    if demo:
        gexpr_feature = gexpr_feature.iloc[:10,:]
    if verbose:
        print(f'data.shape={gexpr_feature.shape}')
    return gexpr_feature

def gatherData(data, labels, pad_token_id):
    """
    Gathers data and prepares it for model input by handling padding and sorting based on labels.

    Parameters:
    data (torch.Tensor): The input data tensor.
    labels (torch.Tensor): The labels tensor indicating the presence of values in the data tensor.
    pad_token_id (int): The token ID used for padding.

    Returns:
    new_data (torch.Tensor): The gathered data tensor with padding handled.
    padding_labels (torch.Tensor): The tensor indicating which positions are padding.
    """
    # Calculate the number of values per row
    value_nums = labels.sum(1)
    max_num = max(value_nums)

    # Create fake data for padding
    fake_data = torch.full((data.shape[0], max_num), pad_token_id, device=data.device)
    data = torch.hstack([data, fake_data])

    # Create fake labels for padding
    fake_label = torch.full((labels.shape[0], max_num), 1, device=labels.device)
    none_labels = ~labels
    labels = labels.float()
    labels[none_labels] = torch.tensor(-float('Inf'), device=labels.device)

    # Create a tensor to adjust labels for sorting
    tmp_data = torch.tensor([(i + 1) * 20000 for i in range(labels.shape[1], 0, -1)], device=labels.device)
    labels += tmp_data

    # Concatenate the original labels with fake labels
    labels = torch.hstack([labels, fake_label])

    # Sort and gather data based on the top-k labels
    fake_label_gene_idx = labels.topk(max_num).indices
    new_data = torch.gather(data, 1, fake_label_gene_idx)

    # Determine which positions are padding
    padding_labels = (new_data == pad_token_id)

    return new_data, padding_labels


In [4]:
DATAPATH = '/nfs/turbo/umms-indikar/shared/projects/foundation_models/example_inputs/scFoundation/cell_type_rawdata/zheng/data_test_count.npy'
data = loaddata(DATAPATH, demo=True)

data.shape=(10, 19264)


In [8]:
gexpr_feature = np.load(DATAPATH)
gexpr_feature

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 2., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

## Set Data & Parameters

## Preprocess Data