# Example based on our README.md
1. Dataset download
2. Prepare dataset
3. Pre-training
4. Fine-tuning
5. Different backbones


### Download the dataset

In [1]:
# using the cli

# download class dataset
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits train

# for development, download a small subset
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits 'train[:10]'

# download the class pretrain dataset raw data
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --snapshot-download --splits train --debug

# download cifar100 dataset
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cifar100

# download cub200 dataset
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cub200 --splits train,test

[DEBUG] Preparing download for dataset=tsbpp/fall2025_deeplearning splits=('train',) root=/home/lquang/Code/wejepa/experiments/data
[DEBUG] Loading HuggingFace dataset 'tsbpp/fall2025_deeplearning' split='train' cache_dir=/home/lquang/Code/wejepa/experiments/data
Downloading (incomplete total...): 0.00B [00:00, ?B/s]
Downloading (incomplete total...): 2.46kB [00:00, 14.2kB/s] 0/6 [00:00<?, ?it/s][A
Fetching 6 files:  17%|████▌                      | 1/6 [00:00<00:00,  6.13it/s][A
Fetching 6 files: 100%|███████████████████████████| 6/6 [00:03<00:00,  1.76it/s][A
Download complete: : 2.46kB [00:03, 14.2kB/s]              [DEBUG] Using snapshot_download for dataset 'tsbpp/fall2025_deeplearning' split='train'
[DEBUG] Downloaded dataset to /home/lquang/Code/wejepa/experiments/data/tsbpp_fall2025_deeplearning, extracting archives if any.

Download complete: : 2.46kB [00:03, 686B/s]               | 0/5 [00:00<?, ?it/s][A
Extracted /home/lquang/Code/wejepa/experiments/data/tsbpp_fall2025_d

### Inspect and Prepare Dataset

In [None]:
# prepare raw images into huggingface dataset format (not needed with the current dataset structure)
# !python -m wejepa.datasets.prepare_images --input-dir ./data/tsbpp_fall2025_deeplearning --output-dir ./data/tsbpp___fall2025_deeplearning

### Pre-training the model

In [None]:
# Using the cli

# Clear
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Train using default cifar100 config + custom ViT backbone
# !python -m wejepa.train.pretrain --print-config     # print only
# !python -m wejepa.train.pretrain                    # train

# FIXME: bug when using .arrow files, the file path is not correctly set, workaround is to rename the arrow file
#   cp fall2025_deeplearning-train.arrow tsbpp___fall2025_deeplearning-train.arrow

# print where --config searches for config files
# !python -m wejepa.train.pretrain --config hf224_config.json

#!PYTHONWARNINGS="ignore::RuntimeWarning" python -m wejepa.train.pretrain --config configs/pretrain_devel_tsbpp_224.json
!PYTHONWARNINGS="ignore::RuntimeWarning" python -m wejepa.train.pretrain --config configs/pretrain_convnext_small.json

Model has 70,881,792 trainable parameters.
Loading imagefolder dataset from directory: ./data/tsbpp_fall2025_deeplearning
Loading imagefolder dataset from directory: ./data/tsbpp_fall2025_deeplearning
Loading imagefolder dataset from directory: ./data/tsbpp_fall2025_deeplearning
Loading imagefolder dataset from directory: ./data/tsbpp_fall2025_deeplearning


In [None]:
# programmatically
from wejepa import default_config, launch_pretraining
cfg = default_config()
launch_pretraining(cfg)

### Fine tuning the model

In [1]:
# using the cli
!python -m wejepa.train.finetune \
    --checkpoint outputs/ijepa/ijepa_epoch_0005.pt \
    --epochs 10 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 200 \
    --config configs/finetune_devel_cub200_224.json

[Linear probe] Epoch 1/10 | loss=4.7092 | train_acc=0.021 | val_acc=0.038
[Linear probe] Epoch 2/10 | loss=4.5012 | train_acc=0.036 | val_acc=0.049
[Linear probe] Epoch 3/10 | loss=4.4424 | train_acc=0.045 | val_acc=0.056
[Linear probe] Epoch 4/10 | loss=4.4062 | train_acc=0.046 | val_acc=0.059
[Linear probe] Epoch 5/10 | loss=4.3753 | train_acc=0.050 | val_acc=0.063
[Linear probe] Epoch 6/10 | loss=4.3507 | train_acc=0.054 | val_acc=0.069
[Linear probe] Epoch 7/10 | loss=4.3309 | train_acc=0.055 | val_acc=0.067
[Linear probe] Epoch 8/10 | loss=4.3173 | train_acc=0.056 | val_acc=0.071
[Linear probe] Epoch 9/10 | loss=4.3029 | train_acc=0.057 | val_acc=0.070
[Linear probe] Epoch 10/10 | loss=4.2934 | train_acc=0.060 | val_acc=0.070


In [3]:
# programmatically
from wejepa.train import FinetuneConfig, train_linear_probe

ft_cfg = FinetuneConfig(
    checkpoint_path="outputs/ijepa/ijepa_epoch_0005.pt",
    epochs=5,
    batch_size=128,
    learning_rate=1e-3,
)
train_linear_probe(ft_cfg)

ModuleNotFoundError: No module named 'wejepa'

### Different Backbones

In [None]:
import json
import torch
from copy import deepcopy
from pathlib import Path

from wejepa.backbones import adapt_config_for_backbone, available_backbones
from wejepa.config import IJepaConfig
from wejepa import default_config, launch_pretraining, IJEPA_base

print("Registered backbones: ")
for backbone in available_backbones():
    print(f"- {backbone}")

candidates = ["vit_b_16", "swin_t", "convnext_tiny"]
for backbone in candidates:
    print(f"\nPretraining with backbone: {backbone}")

for backbone in available_backbones():
    cfg = adapt_config_for_backbone(default_config(), backbone)
    print(f"\nBackbone: {backbone}")
    print(f"Image size: {cfg.model.img_size} | Patch size: {cfg.model.patch_size}")

    model = IJEPA_base(
        img_size=cfg.model.img_size,
        patch_size=cfg.model.patch_size,
        in_chans=cfg.model.in_chans,
        embed_dim=cfg.model.embed_dim,
        enc_depth=cfg.model.enc_depth,
        pred_depth=cfg.model.pred_depth,
        num_heads=cfg.model.num_heads,
        backbone=cfg.model.classification_backbone,
        pretrained=cfg.model.classification_pretrained,
    )

    print(f"Total trainable params: {model.count_trainable_parameters() / 1e6:.2f}M")
    print(f"Student + predictor params: {model.count_parameters() / 1e6:.2f}M")

    dummy = torch.randn(1, cfg.model.in_chans, cfg.model.img_size, cfg.model.img_size)
    preds, targets = model(dummy)
    print(f"Pred shape: {tuple(preds.shape)} | Target shape: {tuple(targets.shape)}")
    print(json.dumps(cfg.to_dict(), indent=2))

    cfg.hardware.output_dir = f"./outputs/ijepa/{backbone}"
    cfg_path = Path(f"configs/pretrain_{backbone}.json")
    cfg_path.parent.mkdir(parents=True, exist_ok=True)
    cfg_path.write_text(json.dumps(cfg.to_dict(), indent=2))
    print(f"Saved config for {backbone} at {cfg_path}")

    # launch_pretraining(cfg)