# **SIMS: jupyter notebook for simple training and inference**

This notebook is designed to assist in running SIMS for basic model training or performing cell type inference using our pre-trained model checkpoints. We recommend using this notebook within GitHub Codespaces. If you are looking for a step-by-step SIMS tutorial with example data, please use [this link](https://colab.research.google.com/drive/1UrsNTrd-JYRpg1MMQSLQhT6OHcSkloIX). For more advanced training customization, refer to the SIMS GitHub repository for API usage instructions: https://github.com/braingeneers/SIMS

## Getting started

To get started, drag and drop your `.h5ad` files into the `codespaces` folder for easy access.

Locate and run the `setup_sims_env.sh` file within this directory to set up your virtual environment.

- In your terminal, execute the following commands:

    `chmod +x setup_sims_env.sh`

    `./setup_sims_env.sh` 

- This ensures that SIMS runs with the correct software dependencies. 


After executing the script, make sure to select the newly created `sims_env` environment as the kernel for this notebook. It is essential to use Python3.9.

Next, download the following libraries:

In [None]:
from scsims import SIMS
import pandas as pd
import anndata as an
import torch
from scsims import SIMS
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

-----------------

## Training

**Note:** We recommend using pretrained models with this notebook, but we have also included instructions for training your own model. If you want to train a model using a labeled dataset, follow this section. If you only need to make predictions using a pretrained model checkpoint from our checkpoint folder, **please skip to the Inference section.**

#### Load training data and initialize SIMS model

- Don't forget to replace the file paths and class label!

- Your training data should be located in this folder (`codespaces`), assuming you uploaded your `.h5ad` files here for easy access. 

In [None]:
labeled_data_path = "/my_labeled_data.h5ad"  # Replace 'my_labeled_data.h5ad' with the labeled
                                             # data you would like to use for training.
                                            
train_data = an.read_h5ad(labeled_data_path)

sims = SIMS(data=train_data, class_label='class_label')  # Change 'class_label' to be the variable you are predicting for

**Note:** For more customized training configurations, view the [Google Colab tutorial](https://colab.research.google.com/drive/1UrsNTrd-JYRpg1MMQSLQhT6OHcSkloIX#scrollTo=kIYU37mymllr) that walks through configurating the training process with sample data. 

#### Begin training

- Checkpoints will save automatically to a `lightning_logs` directory in the current working directory

In [None]:
sims.train()  # Train the model

--------

## Inference

#### Load data and initialize SIMS model

- Don't forget to change the file paths!

- Your test data should be located in this folder (`codespaces`), assuming you uploaded your `.h5ad` files here for easy access. 

In [None]:
checkpoint_path = '../checkpoint/myawesomemodel.ckpt'   # Replace '../checkpoint/myawesomemodel.ckpt' with the path to a
                                                        # pretrained model from the sims_app/checkpoint folder. Alternatively, 
                                                        # you can load your own .ckpt file developed during training from the 
                                                        # 'lightning_logs' directory  
                                                        
test_data_path = '/my_test_data.h5ad'  # Replace 'my_test_data.h5ad' with the data you would like
                                       # to perform predictions on using your chosen model checkpoint. 

sims = SIMS(weights_path=checkpoint_path)  # Initialize the SIMS model with the checkpoint

#### Run and save predictions

- Uses `sims.predict` to run predictions on test data with your chosen checkpoint.

- Saves predictions to a `.csv` file.

In [None]:
cell_predictions = sims.predict(test_data_path)  # Predict
cell_predictions.to_csv('predictions.csv')  # Save

#### View the explainability of the model



- Assesses the importance of different genes in making predictions.

- Uses `sims.explain`(`test_data_path`) to create an explainability matrix for the test data.

- Obtains the list of gene names from the SIMS model using `sims.model.genes`.

- Constructs a Pandas DataFrame, `explain`, using the explainability matrix and gene names as column headers.

- Computes the mean explanation score for each gene across all samples with `explain.mean(axis=0`)`.

- Determines the variance of the explanation scores for each gene across all samples using `explain.var(axis=0)`.

- Combines the mean and variance data into a new Pandas DataFrame, `mean_and_var_data`, and labels the columns as `mean_explain` and `var_explain`.

- Prints out the top 10 most influential genes.

In [None]:
explainability_matrix = sims.explain(test_data_path)  # Generate explainability matrix
gene_names = sims.model.genes  # Retrieve the gene names from the SIMS model

explain = pd.DataFrame(explainability_matrix, columns=gene_names)   # Create a Pandas DataFrame from the explainability
                                                                    # matrix, using gene names as columns

mean_explain = explain.mean(axis=0)     # Mean explanation score for each gene across all samples
var_explain = explain.var(axis=0)   # Variance of explanation score for each gene across all samples

mean_and_var_data = pd.concat([mean_explain, var_explain], axis=1)  # Pandas dataframe containing information on gene importance
mean_and_var_data.columns = ["mean_explain", "var_explain"]


top10_genes = mean_explain.nlargest(10) # Can increase to top 20, 30, etc.
print("Top 10 most important genes:")
print(top10_genes)