# Topic Classification

We can use the outputs of [pycistopic](https://pycistopic.readthedocs.io/en/latest/) to train a model to predict topic probabilities for a given sequence.  

Since we plan on adding detailed use cases describing topic classification later on, we will only provide a brief overview of the workflow here. Refer to the [introductory notebook](model_training_and_eval.ipynb) for a more detailed explanation of the CREsted workflow.

## Import Data

For this tutorial, we will use the Mouse BICCN dataset. We will use the preprocessed, binarized outputs of pycistopic as input data for the topic classification model. 

To train a topic classification model, we need the following data:
1. A folder containing BED files per Topic (output of pycistopic). 
2. A genome fasta and optionally a chromosome sizes file.

In [1]:
import crested

In [2]:
# Download the tutorial data
import os

os.environ[
    "CRESTED_DATA_DIR"
] = "../../../Crested_testing/data/tmp"  # Change this to your desired directory
beds_folder, regions_file = crested.get_dataset("mouse_cortex_bed")

We can import a folder of BED files using the {func}`crested.import_beds` function.  
This will return an Anndata object with the regions as .var and the bed file names  as .obs (here: our Topics).  
In this case, the adata.X values are binary, representing whether that region is associated with a topic or not.

In [3]:
# Import the beds into an AnnData object
adata = crested.import_beds(
    beds_folder=beds_folder, regions_file=regions_file
)  # the regions file is optional for import_beds
adata

2024-08-14T11:45:14.496482+0200 INFO Reading bed files from /lustre1/project/stg_00002/lcb/lmahieu/projects/Crested_testing/data/tmp/data/mouse_biccn/beds.tar.gz.untar and using /lustre1/project/stg_00002/lcb/lmahieu/projects/Crested_testing/data/tmp/data/mouse_biccn/consensus_peaks_biccn.bed as var_names...


View of AnnData object with n_obs × n_vars = 80 × 439383
    obs: 'file_path', 'n_open_regions'
    var: 'n_classes', 'chr', 'start', 'end'

We have 80 classes (topics) and 439386 regions in the dataset.

## Preprocessing

For topic classification there is little preprocessing to be performed compared to peak regression.  
The data does not need to be normalized since the values are binary and we don't filter any regions on specificity since by nature of topic modelling the selected regions should already be 'meaningful' regions.  
You could change the width of the regions, but we tend to keep the regions at 500bp for topic classification.  

The only preprocessing step we need to perform is to split the data into training and testing sets.

In [4]:
# Standard train/val/test split
crested.pp.train_val_test_split(
    adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
)
print(adata.var["split"].value_counts())

split
train    354013
val       45113
test      40257
Name: count, dtype: int64


## Model Training

Model training has the same workflow as peak regression. The only differences are:
1. We select a different model architecture. Since we're training on 500bp regions we don't need the dilated convolutions of chrombpnet.  
2. We select a different config, since we're monitoring other metrics and are using a different loss for classification.  

In [5]:
# Datamodule
datamodule = crested.tl.data.AnnDataModule(
    adata,
    genome_file="../../../Crested_testing/data/tmp/mm10.fa",
    batch_size=128,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)

# Architecture: we will use the DeepTopic CNN model
model_architecture = crested.tl.zoo.deeptopic_cnn(seq_len=500, num_classes=80)

# Config: we will use the default topic classification config (binary cross entropy loss and AUC/ROC metrics)
config = crested.tl.default_configs("topic_classification")
print(config)

TaskConfig(optimizer=<keras.src.backend.torch.optimizers.torch_adam.Adam object at 0x145af25458e0>, loss=<keras.src.losses.losses.BinaryCrossentropy object at 0x145ad99a18b0>, metrics=[<AUC name=auROC>, <AUC name=auPR>, <CategoricalAccuracy name=categorical_accuracy>])


Set up the Trainer

In [6]:
# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="mouse_biccn_topics",  # change to your liking
    logger=None,  # or 'wandb', 'tensorboard'
)

In [7]:
# train the model
trainer.fit(epochs=100)

None
2024-08-14T11:48:17.282138+0200 INFO Loading sequences into memory...


100%|██████████| 354013/354013 [00:08<00:00, 42834.03it/s]


2024-08-14T11:48:27.847601+0200 INFO Loading sequences into memory...


100%|██████████| 45113/45113 [00:00<00:00, 58214.40it/s]
  outputs = tnn.conv1d(
  result = result.scatter_reduce(


Epoch 1/100


## Evaluation and Prediction

Evaluation and prediction are the same as peak regression. 

The next steps you could take are to:
1. Evaluate the model on the test set.
2. Predict topic probabilities for a given sequence or region.
3. Run tfmodisco to find motifs associated with each topic.
4. Generate synthetic sequences for each topic using in silico evolution.
5. Plot contribution scores per topic for interesting regions or sequences. 

Refer to the introduction notebook for more details.