# Example based on our README.md
1. Dataset download
2. Prepare dataset
3. Pre-training
4. Fine-tuning
5. Different backbones


### Download the dataset

In [1]:
# using the cli

# download class dataset
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits train

# for development, download a small subset
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits 'train[:10]'

# download the class pretrain dataset raw data
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --snapshot-download --splits train

# download cifar100 dataset
#!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cifar100

# download cub200 dataset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cub200 --splits train,test

Downloading CUB-200-2011 from https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz...
Download complete!
Extracting CUB-200-2011 dataset...
Extraction complete!
Generating CUB-200 train/val/test splits
Split 'train,test' available under /home/long/code/dl_project1/experiments/data


### Inspect and Prepare Dataset

In [None]:
# prepare raw images into huggingface dataset format (not needed with the current dataset structure)
# !python -m wejepa.datasets.prepare_images --input-dir ./data/tsbpp_fall2025_deeplearning --output-dir ./data/tsbpp___fall2025_deeplearning

### Pre-training the model

In [None]:
# Using the cli

# Clear
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Train using default cifar100 config + custom ViT backbone
# !python -m wejepa.train.pretrain --print-config     # print only
# !python -m wejepa.train.pretrain                    # train

# FIXME: bug when using .arrow files, the file path is not correctly set, workaround is to rename the arrow file
#   cp fall2025_deeplearning-train.arrow tsbpp___fall2025_deeplearning-train.arrow

# print where --config searches for config files
# !python -m wejepa.train.pretrain --config hf224_config.json

#!PYTHONWARNINGS="ignore::RuntimeWarning" python -m wejepa.train.pretrain --config configs/pretrain_devel_tsbpp_224.json
#!PYTHONWARNINGS="ignore::RuntimeWarning" python -m wejepa.train.pretrain --config configs/pretrain_vit_b_16_tsbpp_224.json
#!PYTHONWARNINGS="ignore::RuntimeWarning" python -m wejepa.train.pretrain --config configs/pretrain_convnext_small_tsbpp_224.json
!PYTHONWARNINGS="ignore::RuntimeWarning" CUDA_VISIBLE_DEVICES=1,2,3 python -m wejepa.train.pretrain --config configs/pretrain_swin_s_tsbpp_224.json

In [None]:
# programmatically
from wejepa import default_config, launch_pretraining
cfg = default_config()
launch_pretraining(cfg)

### Fine tuning the model

In [3]:
# using the cli
!python -m wejepa.train.finetune \
    --checkpoint outputs/vit_b_16/ijepa_epoch_0026.pt \
    --epochs 20 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 200 \
    --config configs/finetune_vit_b_16_cub200_224.json


[Linear probe] Epoch 1/20 | loss=5.3369 | train_acc=0.006 | val_acc=0.008
[Linear probe] Epoch 2/20 | loss=5.2132 | train_acc=0.011 | val_acc=0.018
[Linear probe] Epoch 3/20 | loss=5.1646 | train_acc=0.014 | val_acc=0.016
[Linear probe] Epoch 4/20 | loss=5.1370 | train_acc=0.015 | val_acc=0.014
[Linear probe] Epoch 5/20 | loss=5.1236 | train_acc=0.015 | val_acc=0.012
[Linear probe] Epoch 6/20 | loss=5.1107 | train_acc=0.018 | val_acc=0.019
[Linear probe] Epoch 7/20 | loss=5.0951 | train_acc=0.019 | val_acc=0.018
[Linear probe] Epoch 8/20 | loss=5.0966 | train_acc=0.017 | val_acc=0.016
^C


In [1]:
!python -m wejepa.train.finetune \
    --checkpoint outputs/vit_b_16/ijepa_epoch_0026.pt \
    --epochs 20 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 200 \
    --config configs/finetune_pretrained_vit_b_16_cub200_224.json --debug

[DEBUG] Fine-tune config: FinetuneConfig(ijepa=IJepaConfig(data=DataConfig(dataset_root='./data', dataset_name='cub200', dataset_dir=None, image_size=224, train_batch_size=256, eval_batch_size=512, num_workers=8, pin_memory=True, persistent_workers=True, prefetch_factor=2, crop_scale=(0.6, 1.0), color_jitter=None, use_color_distortion=False, use_horizontal_flip=False, normalization_mean=(0.5071, 0.4867, 0.4408), normalization_std=(0.2675, 0.2565, 0.2761), use_fake_data=False, fake_data_size=512, image_dir=None, image_list=None, labels=None), mask=MaskConfig(target_aspect_ratio=(0.75, 1.5), target_scale=(0.15, 0.2), context_aspect_ratio=1.0, context_scale=(0.85, 1.0), num_target_blocks=4), model=ModelConfig(img_size=224, patch_size=16, in_chans=3, embed_dim=768, enc_depth=6, pred_depth=4, num_heads=12, post_emb_norm=False, layer_dropout=0.0, classification_backbone='vit_b_16', classification_num_classes=200, classification_pretrained=True, model_bypass=True), optimizer=OptimizerConfig(e

In [None]:
# programmatically
from wejepa.train import FinetuneConfig, train_linear_probe

ft_cfg = FinetuneConfig(
    checkpoint_path="outputs/ijepa/ijepa_epoch_0005.pt",
    epochs=5,
    batch_size=128,
    learning_rate=1e-3,
)
train_linear_probe(ft_cfg)

### Different Backbones

In [None]:
import json
import torch
from copy import deepcopy
from pathlib import Path

from wejepa.backbones import adapt_config_for_backbone, available_backbones
from wejepa.config import IJepaConfig
from wejepa import default_config, launch_pretraining, IJEPA_base

print("Registered backbones: ")
for backbone in available_backbones():
    print(f"- {backbone}")

candidates = ["vit_b_16", "swin_t", "convnext_tiny"]
for backbone in candidates:
    print(f"\nPretraining with backbone: {backbone}")

for backbone in available_backbones():
    cfg = adapt_config_for_backbone(default_config(), backbone)
    print(f"\nBackbone: {backbone}")
    print(f"Image size: {cfg.model.img_size} | Patch size: {cfg.model.patch_size}")

    model = IJEPA_base(
        img_size=cfg.model.img_size,
        patch_size=cfg.model.patch_size,
        in_chans=cfg.model.in_chans,
        embed_dim=cfg.model.embed_dim,
        enc_depth=cfg.model.enc_depth,
        pred_depth=cfg.model.pred_depth,
        num_heads=cfg.model.num_heads,
        backbone=cfg.model.classification_backbone,
        pretrained=cfg.model.classification_pretrained,
    )

    print(f"Total trainable params: {model.count_trainable_parameters() / 1e6:.2f}M")
    print(f"Student + predictor params: {model.count_parameters() / 1e6:.2f}M")

    dummy = torch.randn(1, cfg.model.in_chans, cfg.model.img_size, cfg.model.img_size)
    preds, targets = model(dummy)
    print(f"Pred shape: {tuple(preds.shape)} | Target shape: {tuple(targets.shape)}")
    print(json.dumps(cfg.to_dict(), indent=2))

    cfg.hardware.output_dir = f"./outputs/ijepa/{backbone}"
    cfg_path = Path(f"configs/pretrain_{backbone}.json")
    cfg_path.parent.mkdir(parents=True, exist_ok=True)
    cfg_path.write_text(json.dumps(cfg.to_dict(), indent=2))
    print(f"Saved config for {backbone} at {cfg_path}")

    # launch_pretraining(cfg)