# Example based on our README.md
1. Dataset download
2. Prepare dataset
3. Pre-training
4. Fine-tuning
5. Different backbones


### Download the dataset

In [1]:
# using the cli

# download class dataset
#####  hugging face
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits train
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits 'train[:10]'
# download the class pretrain dataset raw data
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --snapshot-download --splits train
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name gjuggler/extra-birds --snapshot-download --splits train
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name Sijuade/Cats-Dogs-Birds --snapshot-download --splits train

##### torchvision datasets
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cifar100
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name voc2007
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name caltech101
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name caltech256
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name flowers102
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name coco_detection
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name coco_captions
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name imagenet --splits train

##### From URL
# download cub200 dataset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cub200 --splits train,test

Downloading CUB-200-2011 from https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz...
Download complete!
Extracting CUB-200-2011 dataset...
Extraction complete!
Generating CUB-200 train/val/test splits
Split 'train,test' available under /home/long/code/dl_project1/experiments/data


### Inspect and Prepare Dataset

In [None]:
# prepare raw images into huggingface dataset format (not needed with the current dataset structure)
# !python -m wejepa.datasets.prepare_images --input-dir ./data/tsbpp_fall2025_deeplearning --output-dir ./data/tsbpp___fall2025_deeplearning

### Pre-training the model

In [1]:
# Using the cli

# Clear
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Train using default cifar100 config + custom ViT backbone
# !python -m wejepa.train.pretrain --print-config     # print only
# !python -m wejepa.train.pretrain                    # train

# FIXME: bug when using .arrow files, the file path is not correctly set, workaround is to rename the arrow file
#   cp fall2025_deeplearning-train.arrow tsbpp___fall2025_deeplearning-train.arrow

# print where --config searches for config files
# !python -m wejepa.train.pretrain --config hf224_config.json

#!PYTHONWARNINGS="ignore::RuntimeWarning" python -m wejepa.train.pretrain --config configs/pretrain_devel_tsbpp_224.json
#!PYTHONWARNINGS="ignore::RuntimeWarning" python -m wejepa.train.pretrain --config configs/pretrain_vit_b_16_tsbpp_224.json
#!PYTHONWARNINGS="ignore::RuntimeWarning" python -m wejepa.train.pretrain --config configs/pretrain_convnext_small_tsbpp_224.json
#!PYTHONWARNINGS="ignore::RuntimeWarning" CUDA_VISIBLE_DEVICES=1,2,3 python -m wejepa.train.pretrain --config configs/pretrain_swin_s_tsbpp_224.json
#!PYTHONWARNINGS="ignore::RuntimeWarning" CUDA_VISIBLE_DEVICES=1,2,3 python -m wejepa.train.pretrain --config configs/pretrain_vit_b_16_devel_224.json

!PYTHONWARNINGS="ignore::RuntimeWarning" CUDA_VISIBLE_DEVICES=0, python -m wejepa.train.pretrain --resume-checkpoint outputs/vit_b_16/2025_11_30/ijepa_epoch_0068.pt/ --config configs/domain_adaption_vit_b_16_cub200.json

Model has 70,881,792 trainable parameters.
Training: 100%|█████████████████████████████████| 64/64 [00:28<00:00,  2.22it/s]
Epoch completed.
Saved checkpoint to outputs/vit_b_16/domain_adaption/ijepa_epoch_0001.pt
Epoch 1/300 | loss=0.1490
Training: 100%|█████████████████████████████████| 64/64 [00:26<00:00,  2.37it/s]
Epoch completed.
Saved checkpoint to outputs/vit_b_16/domain_adaption/ijepa_epoch_0002.pt
Epoch 2/300 | loss=0.1145
Training: 100%|█████████████████████████████████| 64/64 [00:27<00:00,  2.32it/s]
Epoch completed.
Saved checkpoint to outputs/vit_b_16/domain_adaption/ijepa_epoch_0003.pt
Epoch 3/300 | loss=0.1047
Training: 100%|█████████████████████████████████| 64/64 [00:27<00:00,  2.30it/s]
Epoch completed.
Saved checkpoint to outputs/vit_b_16/domain_adaption/ijepa_epoch_0004.pt
Epoch 4/300 | loss=0.0968
Training: 100%|█████████████████████████████████| 64/64 [00:28<00:00,  2.28it/s]
Epoch completed.
Saved checkpoint to outputs/vit_b_16/domain_adaption/ijepa_epoch_0005.p

In [None]:
# programmatically
from wejepa import default_config, launch_pretraining
cfg = default_config()
launch_pretraining(cfg)

### Fine tuning the model

In [5]:
# using the cli
!CUDA_VISIBLE_DEVICES=0 python -m wejepa.train.finetune \
    --checkpoint outputs/vit_b_16/domain_adaption/ijepa_epoch_0300.pt \
    --epochs 20 \
    --batch-size 256 \
    --lr 3e-3 \
    --num-classes 200 \
    --config configs/finetune_vit_b_16_cub200_224.json


[Linear probe] Epoch 1/20 | loss=5.5793 | train_acc=0.007 | val_acc=0.009
[Linear probe] Epoch 2/20 | loss=5.3701 | train_acc=0.013 | val_acc=0.012
[Linear probe] Epoch 3/20 | loss=5.3185 | train_acc=0.011 | val_acc=0.012
[Linear probe] Epoch 4/20 | loss=5.2710 | train_acc=0.015 | val_acc=0.010
[Linear probe] Epoch 5/20 | loss=5.2821 | train_acc=0.016 | val_acc=0.010
[Linear probe] Epoch 6/20 | loss=5.2607 | train_acc=0.016 | val_acc=0.013
[Linear probe] Epoch 7/20 | loss=5.2597 | train_acc=0.014 | val_acc=0.015
[Linear probe] Epoch 8/20 | loss=5.2538 | train_acc=0.014 | val_acc=0.013
[Linear probe] Epoch 9/20 | loss=5.2138 | train_acc=0.016 | val_acc=0.014
[Linear probe] Epoch 10/20 | loss=5.2034 | train_acc=0.018 | val_acc=0.011
[Linear probe] Epoch 11/20 | loss=5.2017 | train_acc=0.021 | val_acc=0.013
[Linear probe] Epoch 12/20 | loss=5.1967 | train_acc=0.019 | val_acc=0.016
[Linear probe] Epoch 13/20 | loss=5.2055 | train_acc=0.020 | val_acc=0.013
[Linear probe] Epoch 14/20 | loss=

In [4]:
!python -m wejepa.train.finetune \
    --checkpoint outputs/vit_b_16/ijepa_epoch_0025.pt \
    --epochs 20 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 200 \
    --config configs/finetune_vit_b_16_bypass_cub200_224.json

[Linear probe] Epoch 1/20 | loss=5.2969 | train_acc=0.007 | val_acc=0.013
[Linear probe] Epoch 2/20 | loss=5.1500 | train_acc=0.015 | val_acc=0.020
[Linear probe] Epoch 3/20 | loss=5.0920 | train_acc=0.018 | val_acc=0.015
[Linear probe] Epoch 4/20 | loss=5.0558 | train_acc=0.024 | val_acc=0.019
[Linear probe] Epoch 5/20 | loss=5.0325 | train_acc=0.024 | val_acc=0.019
[Linear probe] Epoch 6/20 | loss=5.0094 | train_acc=0.026 | val_acc=0.021
[Linear probe] Epoch 7/20 | loss=4.9930 | train_acc=0.027 | val_acc=0.021
[Linear probe] Epoch 8/20 | loss=4.9780 | train_acc=0.026 | val_acc=0.019
[Linear probe] Epoch 9/20 | loss=4.9636 | train_acc=0.030 | val_acc=0.024
[Linear probe] Epoch 10/20 | loss=4.9496 | train_acc=0.031 | val_acc=0.026
[Linear probe] Epoch 11/20 | loss=4.9411 | train_acc=0.032 | val_acc=0.021
[Linear probe] Epoch 12/20 | loss=4.9260 | train_acc=0.031 | val_acc=0.026
[Linear probe] Epoch 13/20 | loss=4.9186 | train_acc=0.031 | val_acc=0.024
^C


In [None]:
!python -m wejepa.train.finetune \
    --checkpoint outputs/vit_b_16/ijepa_epoch_0025.pt \
    --epochs 20 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 200 \
    --config configs/finetune_vit_b_16_bypass_jepa_cub200_224.json

In [8]:
!python -m wejepa.train.finetune \
    --checkpoint outputs/vit_b_16/ijepa_epoch_0025.pt \
    --epochs 20 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 200 \
    --config configs/finetune_pretrained_vit_b_16_cub200_224.json

[Linear probe] Epoch 1/20 | loss=4.8953 | train_acc=0.071 | val_acc=0.172
[Linear probe] Epoch 2/20 | loss=3.9557 | train_acc=0.271 | val_acc=0.336
[Linear probe] Epoch 3/20 | loss=3.2911 | train_acc=0.421 | val_acc=0.419
[Linear probe] Epoch 4/20 | loss=2.8250 | train_acc=0.520 | val_acc=0.488
[Linear probe] Epoch 5/20 | loss=2.4839 | train_acc=0.575 | val_acc=0.516
[Linear probe] Epoch 6/20 | loss=2.2262 | train_acc=0.615 | val_acc=0.537
^C


In [None]:
!python -m wejepa.train.finetune \
    --checkpoint outputs/vit_b_16/ijepa_epoch_0025.pt \
    --epochs 20 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 200 \
    --config configs/finetune_pretrained__no_bypass_vit_b_16_cub200_224.json

[Linear probe] Epoch 1/20 | loss=5.2838 | train_acc=0.013 | val_acc=0.017


In [None]:
!python -m wejepa.train.finetune \
    --checkpoint outputs/vit_b_16/ijepa_epoch_0026.pt \
    --epochs 20 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 200 \
    --config configs/finetune_pretrained__no_bypass_vit_b_16_jepa_cub200_224.json

In [None]:
# programmatically
from wejepa.train import FinetuneConfig, train_linear_probe

ft_cfg = FinetuneConfig(
    checkpoint_path="outputs/ijepa/ijepa_epoch_0005.pt",
    epochs=5,
    batch_size=128,
    learning_rate=1e-3,
)
train_linear_probe(ft_cfg)

### Different Backbones

In [None]:
import json
import torch
from copy import deepcopy
from pathlib import Path

from wejepa.backbones import adapt_config_for_backbone, available_backbones
from wejepa.config import IJepaConfig
from wejepa import default_config, launch_pretraining, IJEPA_base

print("Registered backbones: ")
for backbone in available_backbones():
    print(f"- {backbone}")

candidates = ["vit_b_16", "swin_t", "convnext_tiny"]
for backbone in candidates:
    print(f"\nPretraining with backbone: {backbone}")

for backbone in available_backbones():
    cfg = adapt_config_for_backbone(default_config(), backbone)
    print(f"\nBackbone: {backbone}")
    print(f"Image size: {cfg.model.img_size} | Patch size: {cfg.model.patch_size}")

    model = IJEPA_base(
        img_size=cfg.model.img_size,
        patch_size=cfg.model.patch_size,
        in_chans=cfg.model.in_chans,
        embed_dim=cfg.model.embed_dim,
        enc_depth=cfg.model.enc_depth,
        pred_depth=cfg.model.pred_depth,
        num_heads=cfg.model.num_heads,
        backbone=cfg.model.classification_backbone,
        pretrained=cfg.model.classification_pretrained,
    )

    print(f"Total trainable params: {model.count_trainable_parameters() / 1e6:.2f}M")
    print(f"Student + predictor params: {model.count_parameters() / 1e6:.2f}M")

    dummy = torch.randn(1, cfg.model.in_chans, cfg.model.img_size, cfg.model.img_size)
    preds, targets = model(dummy)
    print(f"Pred shape: {tuple(preds.shape)} | Target shape: {tuple(targets.shape)}")
    print(json.dumps(cfg.to_dict(), indent=2))

    cfg.hardware.output_dir = f"./outputs/ijepa/{backbone}"
    cfg_path = Path(f"configs/pretrain_{backbone}.json")
    cfg_path.parent.mkdir(parents=True, exist_ok=True)
    cfg_path.write_text(json.dumps(cfg.to_dict(), indent=2))
    print(f"Saved config for {backbone} at {cfg_path}")

    # launch_pretraining(cfg)