# Train a model on CATLAS binary accessibility data

## Set up wandb

In [1]:
import wandb

import anndata
import os
import pandas as pd
import numpy as np
%matplotlib inline

os.environ["CUDA_VISIBLE_DEVICES"]="1"
wandb.login(host="https://api.wandb.ai")
project_name="human-atac-catlas"

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mavantikalal[0m ([33mgrelu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
run = wandb.init(
    entity='grelu', project=project_name, job_type='training', name='train',
    settings=wandb.Settings(
        program_relpath='2_train.ipynb',
        program_abspath='/code/grelu_zoo_model_training/2_train.ipynb')
)

## Load preprocessed data

In [3]:
artifact = run.use_artifact('dataset:latest')
dir = artifact.download()
ad = anndata.read_h5ad(os.path.join(dir, "data.h5ad"))

[34m[1mwandb[0m: Downloading large artifact dataset:latest, 179.17MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4


## Make datasets

In [4]:
import grelu.data.dataset
train_dataset = grelu.data.dataset.AnnDataSeqDataset(
    ad[:, ad.var.split=='train'].copy(),
    genome='hg38',
    rc=True, # reverse complement
    max_seq_shift=2, # Shift the sequence
    augment_mode="random", # Randomly select which augmentations to apply
)

val_dataset = grelu.data.dataset.AnnDataSeqDataset(
    ad[:, ad.var.split=='valid'].copy(), genome='hg38',
)

  from .autonotebook import tqdm as notebook_tqdm


## Build the model

In [5]:
model_params = {
    'model_type':'EnformerPretrainedModel', # Type of model
    'n_tasks': ad.shape[0], # Number of cell types to predict
    'crop_len':0, # No cropping of the model output
    'n_transformers': 1, # Number of transformer layers; the published Enformer model has 11
}

train_params = {
    'task':'binary', # binary classification
    'lr':1e-4, # learning rate
    'logger': 'wandb', # Logs will be written to wandb
    'batch_size': 3072,
    'num_workers': 32,
    'devices': 0, # GPU index
    'save_dir': project_name,
    'optimizer': 'adam',
    'max_epochs': 10,
    'checkpoint': True, # Save checkpoints
}

import grelu.lightning
model = grelu.lightning.LightningModel(
    model_params=model_params, train_params=train_params
)

[34m[1mwandb[0m: Downloading large artifact human_state_dict:latest, 939.29MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.7
  state_dict = torch.load(Path(d) / "human.h5")


## Train the model

In [6]:
trainer = model.train_on_dataset(train_dataset=train_dataset, val_dataset=val_dataset)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Validation DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:08<00:00,  2.93it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name         | Type                    | Params | Mode 
-----------------------------------------------------------------
0 | model        | EnformerPretrainedModel | 72.1 M | train
1 | loss         | BCEWithLogitsLoss       | 0      | train
2 | val_metrics  | MetricCollection        | 0      | train
3 | test_metrics | MetricCollection        | 0      | train
4 | transform    | Identity                | 0      | train
-----------------------------------------------------------------
72.1 M    Trainable params
0         Non-trainable params
72.1 M    Total params
288.279   Total estimated model params size (MB)
240       Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 319/319 [03:07<00:00,  1.70it/s, v_num=9zb6, train_loss_step=0.154]
Validation: |                                                                                                                                                                                                        | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                                                                                   | 0/24 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                                                                                                      | 0/24 [00:00<?, ?it/s][A
Validation DataLoader 0:   4%|███████▎         

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 319/319 [03:31<00:00,  1.51it/s, v_num=9zb6, train_loss_step=0.149, train_loss_epoch=0.143]


## Load best checkpoint

In [7]:
best_checkpoint = trainer.checkpoint_callback.best_model_path
model = grelu.lightning.LightningModel.load_from_checkpoint(best_checkpoint)

[34m[1mwandb[0m: Downloading large artifact human_state_dict:latest, 939.29MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8
  state_dict = torch.load(Path(d) / "human.h5")


## Evaluate

In [10]:
test_dataset = grelu.data.dataset.AnnDataSeqDataset(
    ad[:, ad.var.split=='test'].copy(), genome='hg38',
)

In [12]:
test_metrics = model.test_on_dataset(
    test_dataset,
    batch_size=256,
    devices=0,
    num_workers=8,
    write_path = os.path.join(project_name, 'model.ckpt') # Update the checkpoint with the results
)

test_metrics

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Testing DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 281/281 [00:09<00:00, 28.11it/s]


Unnamed: 0,test_accuracy,test_auroc,test_avgprec,test_best_f1
Follicular,0.902749,0.873614,0.609485,0.567260
Fibro General,0.912211,0.883738,0.615966,0.571655
Acinar,0.935036,0.909199,0.641150,0.592622
T Lymphocyte 1 (CD8+),0.953500,0.928728,0.642585,0.600279
T lymphocyte 2 (CD4+),0.958823,0.929545,0.616950,0.577314
...,...,...,...,...
Fetal Stellate,0.958642,0.910730,0.595540,0.562300
Fetal Alveolar Epithelial 1,0.900422,0.865749,0.644052,0.590152
Fetal Cilliated,0.916949,0.878486,0.626063,0.584166
Fetal Excitatory Neuron 1,0.918705,0.877348,0.615934,0.577068


## Save

In [13]:
artifact = wandb.Artifact('model', type='model')
artifact.add_file(os.path.join(project_name, 'model.ckpt'), name="model.ckpt")
run.log_artifact(artifact)

<Artifact model>

In [14]:
run.log_code()



In [15]:
run.finish()

0,1
epoch,▁▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇███
train_loss_epoch,█▂▂▂▂▁▁▁▁▁
train_loss_step,█▅▃▃▃▃▃▂▃▂▃▃▃▂▂▂▂▂▃▃▂▂▂▂▂▁▂▂▂▂▂▂▁▂▁▂▁▂▁▂
trainer/global_step,▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇██
val_accuracy,▁██████████
val_auroc,▁██████████
val_avgprec,▁▇█████████
val_best_f1,▁▇█████████
val_loss,█▁▁▁▁▁▁▁▁▁▁

0,1
epoch,9.0
train_loss_epoch,0.14265
train_loss_step,0.1443
trainer/global_step,3189.0
val_accuracy,0.94824
val_auroc,0.8932
val_avgprec,0.55064
val_best_f1,0.52444
val_loss,0.15176
