<a href="https://colab.research.google.com/github/Bochong01/DiffBulk/blob/main/DiffBulk_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üìò DiffBulk: Enhancing Spatial Transcriptomic Prediction with Diffusion-Based Training

Bochong Zhang (+,1), Tianyi Zhang (+,1,2), Qiaochu Xue (1,3), Zeyu Liu (3), Dankai Liao (1,3), Timothy Antoni (2), YEO HUI TING GRACE (2), Sicheng Chen (3), Hwee Kuan LEE (2), Shangqing Lyu (\*,3), and Yueming Jin (\*,1)

Affiliations:

(1) National University of Singapore (NUS)
(2) Agency for Science, Technology and Research (A*STAR)
(3) PuzzleLogic Pte Ltd

(+) Authors contributed equally
(*) Corresponding authors

# üìñ About This Notebook

This notebook provides the **official** Colab walkthrough of DiffBulk, a two-stage diffusion-based framework designed to learn gene-aware histology image representations for spatial transcriptomic prediction.

* Suggest to use google colab pro+ (high RAM+GPU) for this run through

* Our github page: https://github.com/Bochong01/DiffBulk

It demonstrates:

**Stage 1 ‚Äî Diffusion Pretraining**

Learning conditional image representations guided by gene expression profiles.

**Stage 2 ‚Äî Downstream Gene Expression Training & Evaluation**

Using the pretrained diffusion encoder to train a lightweight gene prediction module.

## Set up

In [40]:
# check GPU
!nvidia-smi

shell-init: error retrieving current directory: getcwd: cannot access parent directories: No such file or directory
Mon Nov 24 14:49:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   35C    P0             55W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------

## üìÅ Create File-System Environment


In [48]:
!git clone https://github.com/Bochong01/DiffBulk.git
!cd /content/DiffBulk


Cloning into 'DiffBulk'...
remote: Enumerating objects: 118, done.[K
remote: Counting objects: 100% (118/118), done.[K
remote: Compressing objects: 100% (99/99), done.[K
remote: Total 118 (delta 29), reused 86 (delta 17), pack-reused 0 (from 0)[K
Receiving objects: 100% (118/118), 1.18 MiB | 52.33 MiB/s, done.
Resolving deltas: 100% (29/29), done.


## üì• Load the Demo Data

Here we load a minimal demo dataset that allows the user to:

- run diffusion pretraining

- train the downstream gene predictor

- evaluation

The demo dataset is intentionally lightweight so that the full pipeline can be executed within Colab Pro GPU limits.


In [4]:
# make data dir
!mkdir -p /content/data

In [5]:
import gdown

file_id = "1_ZxmAJD4ld2N_sXi_dOLJAoeWpOFdCY_"
url = f"https://drive.google.com/uc?id={file_id}"
output = "/content/data/DiffBulk_data.zip"  # rename if you like

gdown.download(url, output, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1_ZxmAJD4ld2N_sXi_dOLJAoeWpOFdCY_
From (redirected): https://drive.google.com/uc?id=1_ZxmAJD4ld2N_sXi_dOLJAoeWpOFdCY_&confirm=t&uuid=21ff3453-cbbb-4408-89ac-7a4a4fdbc5cd
To: /content/data/DiffBulk_data.zip
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4.02G/4.02G [00:55<00:00, 72.6MB/s]


'/content/data/DiffBulk_data.zip'

In [8]:
# unzip
!unzip -q /content/data/DiffBulk_data.zip -d /content/data

# check
!ls -lh /content/data | head -n 30


total 3.8G
drwxr-xr-x 2 root root 4.0K Nov 14 15:35 crunchDAO
-rw-r--r-- 1 root root 3.8G Nov 14 08:40 DiffBulk_data.zip
drwxr-xr-x 2 root root 4.0K Nov 14 15:50 HEST_bowel
drwxr-xr-x 2 root root 4.0K Nov 14 15:38 HEST_pancreas


## üóÇÔ∏è Arrange the Working Environment

In [49]:
# change working dir
import os
os.chdir("/content/DiffBulk")
!pwd


/content/DiffBulk


# üß¨ DiffBulk Pipeline Illustration

The DiffBulk framework follows a two-stage training and evaluation process.
Below is the structure demonstrated in this notebook:

**1. Diffusion Pretraining**

- Learns gene-aware histology image representations via conditional diffusion modeling

- Produces a pretrained U-Net used for downstream tasks

- Post-ema reconstruction

**2. Downstream Gene Expression Training**

- Uses the pretrained U-Net (ema) from Stage 1

- Trains a lightweight module combined with a foundation model (plip)

- Performs evaluation and metric reporting

‚ö†Ô∏è **Notes**

- All hyperparameters are managed through `.yaml` configuration files

- For simplicity, the 3-fold experiment used in the full evaluation is omitted in this Colab walkthrough

- The notebook focuses on demonstrating the workflow and key components, not full-scale training

## üåÄ Stage I: Diffusion Pretraining
In this notebook, we demonstrate **Stage I: Diffusion Pretraining** using a **lightweight configuration** suitable for Google Colab.

To ensure fast execution, we use:

- ~2 epochs (instead of the full training schedule)

- Batch size = 128

- Three datasets, consistent with the experimental configuration in the paper, but in a simplified demo mode


‚ö†Ô∏è **Note:**
The **full-scale experiments** reported in the DiffBulk paper use the settings defined in `Pretrain/train.sh`. Those settings include significantly longer training time and larger batch sizes that are not suitable *for Colab*.

This demo focuses on illustrating:

- How the diffusion model is trained

- How gene-aware conditional denoising works

- How pretrained EMA checkpoints are produced for downstream tasks

In [50]:
!torchrun --standalone --nproc_per_node=1 /content/DiffBulk/Pretrain/train.py \
    --outdir="/content/DiffBulk/Pretrain/outputs" \
    --patch_path="/content/data/HEST_bowel/train_patch.h5" \
    --patch_path="/content/data/HEST_pancreas/train_patch.h5" \
    --patch_path="/content/data/crunchDAO/train_patch.h5" \
    --gene_path="/content/data/HEST_bowel/train_gene.h5" \
    --gene_path="/content/data/HEST_pancreas/train_gene.h5" \
    --gene_path="/content/data/crunchDAO/train_gene.h5" \
    --valid_patch_path="/content/data/HEST_bowel/valid_patch.h5" \
    --valid_patch_path="/content/data/HEST_pancreas/valid_patch.h5" \
    --valid_patch_path="/content/data/crunchDAO/valid_patch.h5" \
    --valid_gene_path="/content/data/HEST_bowel/valid_gene.h5" \
    --valid_gene_path="/content/data/HEST_pancreas/valid_gene.h5" \
    --valid_gene_path="/content/data/crunchDAO/valid_gene.h5" \
    --embed_dim=256 \
    --num_gene_blocks=2 \
    --preset="gene-img224-xs" \
    --batch_size=128 \
    --duration=$((1<<15)) \
    --status=$((5<<7)) \
    --snapshot=$((1<<10)) \
    --checkpoint=$((1<<10)) \
    --batch-gpu=16 \
    --valid_interval_nimg=$((1<<10)) \
    --valid_batch_size=64 \
    --p=0.5

2025-11-24 14:52:41.057628: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-24 14:52:41.075357: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763995961.096881   48834 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763995961.103566   48834 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763995961.120252   48834 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

## üìò Post-EMA Reconstruction

Extract EMA-smoothed checkpoints for downstream training.

In [51]:
# Reconstruct a new EMA profile with std=0.150
!python /content/DiffBulk/Pretrain/reconstruct_phema.py --indir="/content/DiffBulk/Pretrain/outputs" \
    --outdir="/content/DiffBulk/Pretrain/ema" \
    --outstd=0.10,0.15,0.20

Loading 64 input pickles...
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000001-0.050.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000001-0.100.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000002-0.050.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000002-0.100.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000003-0.050.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000003-0.100.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000004-0.050.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000004-0.100.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000005-0.050.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000005-0.100.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000006-0.050.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000006-0.100.pkl
    /content/DiffBulk/Pretrain/outputs/network-snapshot-0000007-0.050.pkl
    /conte

## üî¨ Stage II: Downstream Gene Expression Training

This stage fine-tunes a fusion network that integrates:

- Diffusion-pretrained gene-aware image features, and

- A pathology foundation model (e.g., PLIP)

**1. Configure `Downstream/config.yaml`**

Key arguments:

  - `diffusion_path`: Path to the post-EMA reconstructed checkpoint from Stage I (usually under `Pretrain/ema/`).
  - `noise_label`: A small amount of Gaussian noise added to image patches during training, improving robustness.
  - `out_dim`: Total number of gene targets being predicted.
  - `fusion_method`: Specifies how the diffusion branch interacts with the foundation model branch. For example, `"gated_residual"` adaptively fuses two feature streams.
  - `c/c_learnable`: Weight controlling the contribution of the diffusion features. When `c_learnable=True`, the model learns this weight automatically.

## ‚úÖ Copy the following YAML to `Downstream/config.yaml`:

```yaml
# data
train_patch_file: "/content/data/HEST_bowel/train_patch.h5"
train_gene_file: "/content/data/HEST_bowel/train_gene.h5"
valid_patch_file: "/content/data/HEST_bowel/valid_patch.h5"
valid_gene_file: "/content/data/HEST_bowel/valid_gene.h5"

# pretrained model
diffusion_path: "/content/DiffBulk/Pretrain/ema/phema-0000032-0.100.pkl"

# hyper-parameters
noise_label: 0.01
out_dim: 541
fusion_method: 'gated_residual'
c: 1.0
c_learnable: True

# training
epochs: 4
device: cuda
batch_size: 32
lr: 0.0001
weight_decay: 0.00001

# logging
tensorboard_dir: "./tensorboard"
checkpoint_dir: "./ckpts"
log_interval: 2
valid_interval: 2
start_valid: 0
```

**2. Start training**

In [66]:
# change working dir
import os
os.chdir("/content/DiffBulk/Downstream")
!pwd

/content/DiffBulk/Downstream


In [67]:
!python train.py --config "./config.yaml"

2025-11-24 15:48:34.284693: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-24 15:48:34.302475: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763999314.324044   63807 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763999314.330673   63807 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763999314.347290   63807 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

## üß™ Testing

After Stage II training completes, we perform evaluation on the test split.

1. Prepare `Downstream/test_config.yaml`

Key parameters:

- `diffusion_path`: Same EMA checkpoint used during training.
- `noise_label`: Must match the training configuration.
- `out_dim`, `fusion_method`; Must be identical to the training setting, ensuring architectural consistency.
- `fusion_net_path`: Path to the best checkpoint saved during Stage II training.


## ‚úÖ Copy the following YAML to `Downstream/test_config.yaml`:

```yaml
# data
test_patch_file: "/content/data/HEST_bowel/test_patch.h5"
test_gene_file: "/content/data/HEST_bowel/test_gene.h5"

# pretrained model
diffusion_path: "/content/DiffBulk/Pretrain/ema/phema-0000032-0.100.pkl"

# ckpt
fusion_net_path: "/content/DiffBulk/Downstream/ckpts/checkpoint_best.pth"

device: cuda
batch_size: 32

# fusion net architecture
noise_label: 0.01
out_dim: 541
fusion_method: "gated_residual"
c: 1.0
c_learnable: True
```

**2. Run Test Script**

In [69]:
! bash test.sh

Testing model with config: ./test_config.yaml
Testing using configuration: ./test_config.yaml
2025-11-24 16:01:33.608469: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-24 16:01:33.626233: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764000093.647321   67227 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764000093.653783   67227 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1764000093.670044   67227 computation