Skip to content

PhGK/scGeneRAI

Repository files navigation

scGeneRAI

Please cite our paper

  from scGeneRAI import scGeneRAI
  

initialize model and fit data

  model = scGeneRAI()
  model.fit(data, model_depth, nepochs, lr=2e-2, batch_size=5, lr_decay = 0.99, descriptors = None, early_stopping = True, device_name = 'cpu')
  
  • data: A pandas dataframe with shape m x n, containing RNA samples of m cells and n genes.

  • nepochs: Number of training epochs

  • model_depth: (default=2)

  • lr: learning rate of stochastic gradient descent optimizer (default=2e-2)

  • batch_size (default=5)

  • lr_decay: Learning rate decay using pytorch exponential learning rate scheduler (lr_decay corresponds to pytorch's gamma, default: 0.99)

  • descriptors: Pandas frame of additional categorical cell descriptors, e.g. batch, cell type. Need to have the same sample size as data (default=None)

  • early_stopping: If True, scGeneRAI chooses the model with the smallest test loss during training (default=True).

  • device_name: can be used to run computation on GPU (e.g. with 'cuda:0', default='cpu'). device_name is give to torch.device().

    predict networks

  model.predict_networks(data, descriptors = None, LRPau = True, remove_descriptors = True, device_name = 'cpu', PATH = '.')
  • data: A pandas dataframe with shape k x n, containing RNA samples of k cells and n genes. While this dataframe may contain an arbitrary number of samples, gene should exactly match the genes on which the model was fitted.

  • descriptors: Pandas frame of additional categorical cell descriptors, e.g. batch, cell type. Need to have the same sample size as data (default=None). Also, describing features must have been seen during model.fit(...).

  • LRPau: If True, predict_networks() returns the absolute undirected LRP scores between every pair of genes (default=True)

  • remove_descriptors: If True, interactions between data features and descriptor features are omitted (default=True)

  • device_name: Can be used to run computation on GPU (e.g. with 'cuda:0', default='cpu'). device_name is handed to torch.device().

  • PATH: PATH defines the location where locations will be saved. (default = '.')