# Fine-tuning to predict the driver gene on example dataset

## Download pre-trained weights and example datasets

In [1]:
import os

dir_path_pretrain = "../data/pretrain"
dir_path_model = "../data/finetune/model"
dir_path_log = "../data/log"

os.makedirs(dir_path_pretrain, exist_ok=True)
os.makedirs(dir_path_model, exist_ok=True)
os.makedirs(dir_path_log, exist_ok=True)

#### For example dataset fine-tuning and inference, please refer to README.md and DOWNLOAD all the files below
- `pretrain_weights.pth` under `../data/pretrain`
- `Re-stimulated_t_example_train.h5ad` under `../data`
- `Resting_t_example_test.h5ad` under `../data`
- `dist_t_matrix.csv` under `../data`
- `adj_t_matrix.csv` under `../data`
- `checkpoint-step-1000.pth` under `../data/finetune/model`

## Parameter tuning

In [2]:
import json

params = {
    "global_batch_size": 128,
    "local_batch_size": 1,
    "mixed_precision": "true",
    "nr_step": 1050,
    "warmup_step": 500,
    "lr": 0.001,
    "chk_time_interval": 3600,
    "chk_step_interval": 100,
    
    "dataset_dir": "../data/", ## directory of the training and testing dataset
    "log_dir": "../data/log", ## directory of the training log
    "model_dir": "../data/finetune/model", ## directory of the fine-tuing checkpoints
    "pretrain_model_dir": "../data/pretrain", ## directory of pretrain_weights.pth

    "train_data": "Re-stimulated_t_example_train.h5ad",
    "test_data": "Resting_t_example_test.h5ad",
    "dist_graph" : "dist_t_matrix.csv",
    "adj_graph" : "adj_t_matrix.csv"

}


with open('../config.json', 'w') as f:
    json.dump(params, f, indent=4)

## Fine-tuning

### Check training dataset

In [3]:
import scanpy as sc

adata_train = sc.read_h5ad(params['dataset_dir']+params['train_data'])

## The cell by gene matrix should be raw counts.
adata_train

AnnData object with n_obs × n_vars = 5690 × 19240
    obs: 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'condition', 'guide_id', 'gene', 'gene_category', 'crispr', 'donor', 'percent.mt', 'percent.ribo', 'nCount_SCT', 'nFeature_SCT', 'S.Score', 'G2M.Score', 'Phase', 'old.ident', 'CD4.CD8.Score', 'CD4.or.CD8', 'SCT_snn_res.0.4', 'seurat_clusters', 'cluster_name', 'activation.score', 'perturbation'
    var: 'gene_ids', 'n_cells'

In [4]:
## The perturbation label is stored in adata.obs.perturbation
adata_train.obs.perturbation

CGGAATTAGACTTCAC-6         IFNG
ATTACCTAGGAGATAG-5       INPPL1
TGGGTTAGTTGTATGC-5         GRAP
CTCCCTCTCGGTAGGA-8       P2RY14
GACCTTCTCATCTATC-6        IKZF3
                        ...    
TTCAGGACAGCAATTC-6      ARHGDIB
AGACAAACAAGAGCTG-5        IL2RB
AACCAACCAGACTGCC-7     APOBEC3C
CATGCAAGTACATACC-5         CD28
ACGGTTAGTACGATCT-8    NO-TARGET
Name: perturbation, Length: 5690, dtype: category
Categories (70, object): ['ABCB10', 'AKAP12', 'ALX4', 'APOBEC3C', ..., 'TRAF3IP2', 'TRIM21', 'VAV1', 'WT1']

### CellNavi supports multi-node distributed computing on GPUs. 

### If you want to use the NCCL-based distributed architecture, please run `bash launch_train.sh` directly from the command line.

In [5]:
## Single GPU fine-tuning
%run start_train.py

CUDA is available. Number of GPUs: 1
dataset info: ../data/Re-stimulated_t_example_train.h5ad 5690
dataset info: ../data/Resting_t_example_test.h5ad 3158
Train Step: [1010/1050], ce_loss: 2.152972, update_loss: 2.152972, top1_acc: 0.467969, top5_acc: 0.681641, top10_acc: 0.779818, top50_acc: 0.992188, top100_acc: 1.000000, Speed: 0.056 m/s, Passed: 0.050 h, Estimate: 0.199 h
Train Step: [1020/1050], ce_loss: 2.077976, update_loss: 2.077976, top1_acc: 0.472656, top5_acc: 0.705729, top10_acc: 0.805339, top50_acc: 1.000000, top100_acc: 1.000000, Speed: 0.054 m/s, Passed: 0.100 h, Estimate: 0.149 h
Train Step: [1030/1050], ce_loss: 2.015243, update_loss: 2.015243, top1_acc: 0.486328, top5_acc: 0.768229, top10_acc: 0.841797, top50_acc: 0.992188, top100_acc: 1.000000, Speed: 0.055 m/s, Passed: 0.149 h, Estimate: 0.099 h
Train Step: [1040/1050], ce_loss: 2.168831, update_loss: 2.168831, top1_acc: 0.459896, top5_acc: 0.739193, top10_acc: 0.813932, top50_acc: 0.986328, top100_acc: 1.000000, Spe

## Testing 

### Load results and evaluate. 

In [6]:
## Here we load results on step 1000 as an example. 

%run load_results.py -c 1000

dataset info: ../data/Resting_t_example_test.h5ad 3158
dataset info: ../data/Re-stimulated_t_example_train.h5ad 5690


100%|██████████| 3158/3158 [15:56<00:00,  3.30it/s]


In [7]:
import pandas as pd
import scanpy as sc
from sklearn.metrics import accuracy_score, f1_score
import numpy as np


df = pd.read_csv('set3_test_example_results.csv', index_col=0)

## The rows represent cell names and columns represent perturbed genes, with each value indicating the logits.
print(df.head())


                      ABCB10    AKAP12      ALX4  APOBEC3C  APOBEC3D  \
TGCATCCTCGATCCAA-4 -7.821463 -1.156066  0.904565 -5.127326 -7.951521   
TTGGATGGTATCCTCC-2 -7.477587 -1.947112 -1.975892 -1.696544 -7.003691   
AACCACAGTCTCCCTA-1 -9.641579 -3.978698 -5.091012 -5.724802 -5.395859   
TAGAGTCTCATGGATC-4 -1.639565 -2.339934  0.846601 -3.660203 -7.996193   
TTTCAGTTCCATTCGC-2 -3.726980 -3.369183 -4.184911  0.105774 -1.985862   

                       APOL2   ARHGDIB    BICDL2      CBY1       CD2  ...  \
TGCATCCTCGATCCAA-4 -1.957271  2.360333 -5.758391 -6.885813 -9.281010  ...   
TTGGATGGTATCCTCC-2 -2.546962  1.258157 -2.950146 -0.769823 -4.229147  ...   
AACCACAGTCTCCCTA-1 -5.997128 -5.460759 -0.255089 -5.033724 -3.466392  ...   
TAGAGTCTCATGGATC-4 -2.124491  1.050665 -0.941778  1.620585 -2.813236  ...   
TTTCAGTTCCATTCGC-2 -2.452409  2.454847 -1.646882 -0.831029 -1.429325  ...   

                       TAGAP     TBX21  TNFRSF1A   TNFRSF1B   TNFRSF9  \
TGCATCCTCGATCCAA-4 -9.457092 -5

In [8]:
adata_test = sc.read_h5ad(params['dataset_dir']+params['test_data'])
perturb_gene = adata_test.obs['perturbation'].values

df['pred_gene'] = df.idxmax(axis=1)
pred_gene = df['pred_gene'].values


accuracy = accuracy_score(perturb_gene, pred_gene)
f1 = f1_score(perturb_gene, pred_gene, average='weighted')


print(f"Testing accuracy (for toy dataset): {accuracy:.3f}")
print(f"Testing F1 score (for toy dataset): {f1:.3f}")

Testing accuracy (for toy dataset): 0.510
Testing F1 score (for toy dataset): 0.498
