# Fine-tuning to predict the driver gene on example dataset

## Important
Before running this script, please make sure:
1. Download all files in the dataset and checkpoint links mentioned in CellNavi/README.md.
2. Completed the preparation steps 0-3 in CellNavi/tutorials/README.md.


## Download example dataset

In [None]:
import requests


url = "https://ucc0e5b8ed82a72fc5a2f99aa081.dl.dropboxusercontent.com/cd/0/get/ClGGk1CQmoTYoHlYHd43QnZ399puxUPAdDosmZD5FFvg90_LzsayqVn76-XCfJLam2eVEv2s6oW8UGlIjlNSgtRzWzHSyigYeCgKyiNKSdlRgl3BUSiKz27ceaFUIOYwSliaKcEPQlLCpyYN5B82AM7zN5MxP_WEB_Ce0OO4hiX10g/file?_download_id=49996818229892190311348360354759554017146239774901417905492274051615&_log_download_success=1#"


save_path = "../data/"  # 替换为实际文件名和扩展名


response = requests.get(url, stream=True)
if response.status_code == 200:
    with open(save_path, "wb") as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)
    print(f"Successfully downloaded to {save_path}")
else:
    print(f"Error: {response.status_code}")

## Training

### Follow the README.md for training preparations, including datasets and pretrained files. The links for examples can be found inCellNavi/README.md.

In [None]:
import json


params = {
    "global_batch_size": 128,
    "local_batch_size": 1,
    "n_cls": 2058,
    "mixed_precision": true,
    "nr_step": 3000,
    "warmup_step": 500,
    "lr": 0.001,
    "chk_time_interval": 3600,
    "chk_step_interval": 100,

    "saved_dir": "../data/",
    "pretrained_dir": "../data/",
    "dataset_dir": "../data/",
    "log_dir": "../data/log",
    "model_dir": "../data/finetune/model",
    "pretrain_model_dir": "../data/pretrain",

    "train_data": "set3_example_train.h5ad",
    "test_data": "set3_example_test.h5ad",
    "dist_graph" : "shortest_path_integrated_network_setting3_all_genes.csv",
    "adj_graph" : "integrated_network_setting3_all_genes.csv"

}


with open('../config.json', 'w') as f:
    json.dump(params, f, indent=4)

### Check training dataset

In [1]:
## Here we use the toy dataset 'set3_example_train.h5ad' and 'set3_example_test.h5ad' as an example. 
import scanpy as sc

## The file_path should be the path where the training data is stored.
train_file_path = '../data/set3_example_train.h5ad'
adata_train = sc.read_h5ad(train_file_path)

## The cell by gene matrix should be raw counts.
adata_train

AnnData object with n_obs × n_vars = 5690 × 19240
    obs: 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'condition', 'guide_id', 'gene', 'gene_category', 'crispr', 'donor', 'percent.mt', 'percent.ribo', 'nCount_SCT', 'nFeature_SCT', 'S.Score', 'G2M.Score', 'Phase', 'old.ident', 'CD4.CD8.Score', 'CD4.or.CD8', 'SCT_snn_res.0.4', 'seurat_clusters', 'cluster_name', 'activation.score', 'perturbation'
    var: 'gene_ids', 'n_cells'

In [2]:
## The perturbation label is stored in adata.obs.perturbation
adata_train.obs.perturbation

CGGAATTAGACTTCAC-6         IFNG
ATTACCTAGGAGATAG-5       INPPL1
TGGGTTAGTTGTATGC-5         GRAP
CTCCCTCTCGGTAGGA-8       P2RY14
GACCTTCTCATCTATC-6        IKZF3
                        ...    
TTCAGGACAGCAATTC-6      ARHGDIB
AGACAAACAAGAGCTG-5        IL2RB
AACCAACCAGACTGCC-7     APOBEC3C
CATGCAAGTACATACC-5         CD28
ACGGTTAGTACGATCT-8    NO-TARGET
Name: perturbation, Length: 5690, dtype: category
Categories (70, object): ['ABCB10', 'AKAP12', 'ALX4', 'APOBEC3C', ..., 'TRAF3IP2', 'TRIM21', 'VAV1', 'WT1']

### Tune parameters

#### please adjust the parameters in common/config.py. The current parameters shown are the default paramters.

### Run training scripts

In [3]:
## The dataset link has been given in CellNavi/README.md.

!chmod u+x launch_train.sh
!./launch_train.sh

## Testing 

### Load results and evaluate. 

In [4]:
## Here we load results on step 1000 as an example. 
## The link for 'checkpoint-step-1000.pth' has been given in CellNavi/README.md.

%run load_results.py -c 1000

  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))


dataset info: /home/pany3/pany3/CellNavi/dataset_full/set3_example_test.h5ad 3158


100%|██████████| 3158/3158 [16:09<00:00,  3.26it/s]


In [5]:
import pandas as pd
import scanpy as sc
from sklearn.metrics import accuracy_score, f1_score
import numpy as np


df = pd.read_csv('set3_test_example_results.csv', index_col=0)

## The rows represent cell names and columns represent perturbed genes, with each value indicating the logits.
print(df.head())


                      ABCB10    AKAP12      ALX4  APOBEC3C  APOBEC3D  \
TGCATCCTCGATCCAA-4 -7.821463 -1.156066  0.904565 -5.127326 -7.951521   
TTGGATGGTATCCTCC-2 -7.477587 -1.947112 -1.975892 -1.696544 -7.003691   
AACCACAGTCTCCCTA-1 -9.641579 -3.978698 -5.091012 -5.724802 -5.395859   
TAGAGTCTCATGGATC-4 -1.639565 -2.339934  0.846601 -3.660203 -7.996193   
TTTCAGTTCCATTCGC-2 -3.726980 -3.369183 -4.184911  0.105774 -1.985862   

                       APOL2   ARHGDIB    BICDL2      CBY1       CD2  ...  \
TGCATCCTCGATCCAA-4 -1.957271  2.360333 -5.758391 -6.885813 -9.281010  ...   
TTGGATGGTATCCTCC-2 -2.546962  1.258157 -2.950146 -0.769823 -4.229147  ...   
AACCACAGTCTCCCTA-1 -5.997128 -5.460759 -0.255089 -5.033724 -3.466392  ...   
TAGAGTCTCATGGATC-4 -2.124491  1.050665 -0.941778  1.620585 -2.813236  ...   
TTTCAGTTCCATTCGC-2 -2.452409  2.454847 -1.646882 -0.831029 -1.429325  ...   

                       TAGAP     TBX21  TNFRSF1A   TNFRSF1B   TNFRSF9  \
TGCATCCTCGATCCAA-4 -9.457092 -5

In [6]:
## The file_path should be the path where the testing data is stored.
test_file_path = '../dataset_full/set3_example_test.h5ad'
adata_test = sc.read_h5ad(test_file_path)
perturb_gene = adata_test.obs['perturbation'].values

df['pred_gene'] = df.idxmax(axis=1)
pred_gene = df['pred_gene'].values


accuracy = accuracy_score(perturb_gene, pred_gene)
f1 = f1_score(perturb_gene, pred_gene, average='weighted')


print(f"Testing accuracy (for toy dataset): {accuracy:.3f}")
print(f"Testing F1 score (for toy dataset): {f1:.3f}")

Testing accuracy (for toy dataset): 0.510
Testing F1 score (for toy dataset): 0.498
