# Topic classification

We can use the outputs of [pycisTopic](https://pycistopic.readthedocs.io/en/latest/) to train a model to predict topic probabilities for a given sequence.  

Since we plan on adding detailed use cases describing topic classification later on, we will only provide a brief overview of the workflow here. Refer to the [introductory notebook](model_training_and_eval) for a more detailed explanation of the CREsted workflow.

In [1]:
# Set package settings
import matplotlib
import os

## Set the font type to ensure text is saved as whole words
matplotlib.rcParams["pdf.fonttype"] = 42  # Use TrueType fonts instead of Type 3 fonts
matplotlib.rcParams["ps.fonttype"] = 42  # For PostScript as well, if needed

## Set the base directory for data retrieval with crested.get_dataset()/get_model()
os.environ['CRESTED_DATA_DIR'] = '/staging/leuven/stg_00002/lcb/cblaauw/'

## Import data

For this tutorial, we will use the mouse BICCN dataset. We will use the preprocessed, binarized outputs of pycisTopic as input data for the topic classification model. 

To train a topic classification model, we need the following data:
1. A folder containing BED files per topic (output of pycisTopic). 
2. A genome fasta and optionally a chromosome sizes file.

In [2]:
import crested

In [7]:
# Set the genome
genome = crested.Genome("mm10/genome.fa", "mm10/genome.chrom.sizes")
crested.register_genome(genome)  # Register the genome so that it's automatically used in every function

2026-02-16T15:04:11.542386+0100 INFO Genome genome registered.


In [3]:
# Download the tutorial data
beds_folder, regions_file = crested.get_dataset("mouse_cortex_bed")

Downloading file 'data/mouse_biccn/beds.tar.gz' from 'https://resources.aertslab.org/CREsted/data/mouse_biccn/beds.tar.gz' to '/staging/leuven/stg_00002/lcb/cblaauw'.


  0%|                                              | 0.00/12.1M [00:00<?, ?B/s]

Untarring contents of '/staging/leuven/stg_00002/lcb/cblaauw/data/mouse_biccn/beds.tar.gz' to '/staging/leuven/stg_00002/lcb/cblaauw/data/mouse_biccn/beds.tar.gz.untar'


We can import a folder of BED files using the {func}`crested.import_beds` function.  
This will return an Anndata object with the regions as .var and the bed file names  as .obs (here: our Topics).  
In this case, the adata.X values are binary, representing whether that region is associated with a topic or not.

In [4]:
# Import the beds into an AnnData object - the regions file is optional for import_beds
adata = crested.import_beds(beds_folder=beds_folder, regions_file=regions_file)
adata

2026-02-16T15:03:15.642825+0100 INFO Reading bed files from /staging/leuven/stg_00002/lcb/cblaauw/data/mouse_biccn/beds.tar.gz.untar and using /staging/leuven/stg_00002/lcb/cblaauw/data/mouse_biccn/consensus_peaks_biccn.bed as var_names...


AnnData object with n_obs × n_vars = 80 × 439383
    obs: 'file_path', 'n_open_regions'
    var: 'n_classes', 'chr', 'start', 'end'

We have 80 classes (topics) and 439386 regions in the dataset.

## Preprocessing

For topic classification there is little preprocessing to be performed compared to peak regression.  
The data does not need to be normalized since the values are binary and we don't filter any regions on specificity since by nature of topic modelling the selected regions should already be 'meaningful' regions.  
You could change the width of the regions, but we tend to keep the regions at 500bp for topic classification.  

The only preprocessing step we need to perform is to split the data into training and testing sets.

In [5]:
# Standard train/val/test split
crested.pp.train_val_test_split(adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"])
print(adata.var["split"].value_counts())

2026-02-16T15:03:29.634609+0100 INFO Lazily importing module crested.pp. This could take a second...
split
train    354013
val       45113
test      40257
Name: count, dtype: int64


## Model training

Model training has the same workflow as peak regression. The only differences are:
1. We select a different model architecture. Since we're training on 500bp regions we don't need the dilated convolutions of the dilated CNN.  
2. We select a different config, since we're monitoring other metrics and are using a different loss for classification.  

In [8]:
# Datamodule
datamodule = crested.tl.data.AnnDataModule(
    adata,
    batch_size=128,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)

# Architecture: we will use the DeepTopic CNN model
model_architecture = crested.tl.zoo.deeptopic_cnn(seq_len=500, num_classes=80)

# Config: we will use the default topic classification config (binary cross entropy loss and AUC/ROC metrics)
config = crested.tl.default_configs("topic_classification")
print(config)

2026-02-16T15:04:21.343827+0100 INFO Lazily importing module crested.tl. This could take a second...


I0000 00:00:1771250736.314264 2306437 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78751 MB memory:  -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:26:00.0, compute capability: 9.0


TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x14ba081b86e0>, loss=<LossFunctionWrapper(<function binary_crossentropy at 0x14ba02aa5080>, kwargs={'from_logits': False, 'label_smoothing': 0.0, 'axis': -1})>, metrics=[<AUC name=auROC>, <AUC name=auPR>, <CategoricalAccuracy name=categorical_accuracy>])


Set up the trainer object and train the model:

In [9]:
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="mouse_biccn",  # change to your liking
    run_name="topic_classification",
    logger='wandb',  # or 'tensorboard', None
)

In [10]:
trainer.fit(epochs=100)

None
2026-02-16T15:05:49.654130+0100 INFO Loading sequences into memory...


100%|██████████| 354013/354013 [00:05<00:00, 62612.85it/s]


2026-02-16T15:05:55.914284+0100 INFO Loading sequences into memory...


100%|██████████| 45113/45113 [00:00<00:00, 75312.84it/s]


Epoch 1/100
[1m  10/5532[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:09[0m 13ms/step - auPR: 0.0467 - auROC: 0.5093 - categorical_accuracy: 0.0112 - loss: 0.7205 

I0000 00:00:1771250792.706058 2306866 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m5532/5532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m128s[0m 17ms/step - auPR: 0.0931 - auROC: 0.6814 - categorical_accuracy: 0.0411 - loss: 0.1609 - val_auPR: 0.1200 - val_auROC: 0.7129 - val_categorical_accuracy: 0.0336 - val_loss: 0.1599 - learning_rate: 0.0010
Epoch 2/100
[1m5532/5532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 14ms/step - auPR: 0.1176 - auROC: 0.7174 - categorical_accuracy: 0.0476 - loss: 0.1559 - val_auPR: 0.1425 - val_auROC: 0.7447 - val_categorical_accuracy: 0.0503 - val_loss: 0.1573 - learning_rate: 0.0010
Epoch 3/100
[1m5532/5532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 13ms/step - auPR: 0.1341 - auROC: 0.7440 - categorical_accuracy: 0.0584 - loss: 0.1552 - val_auPR: 0.1548 - val_auROC: 0.7579 - val_categorical_accuracy: 0.0579 - val_loss: 0.1573 - learning_rate: 0.0010
Epoch 4/100
[1m5532/5532[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 14ms/step - auPR: 0.1429 - auROC: 0.7535 - categorical_accuracy: 0.0663 - l

## Evaluation and prediction

Evaluation and prediction are the same as peak regression. 

The next steps you could take are to:
1. Evaluate the model on the test set.
2. Predict topic probabilities for a given sequence or region.
3. Run tfmodisco to find motifs associated with each topic.
4. Generate synthetic sequences for each topic using in silico evolution.
5. Plot contribution scores per topic for interesting regions or sequences. 

Refer to [the introduction notebook](project:model_training_and_eval) for more details.