# Example based on our README.md
1. Dataset download
2. Pre-training
3. Fine-tuning
4. Different backbones


### 1. Download the dataset

In [1]:
# using the cli

# download class dataset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits train

# for development, download a small subset
# !python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits 'train[:10]'

# download cifar100 dataset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cifar100

# download cub200 dataset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cub200 --splits train,test

Split 'train' available under /home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/experiments/data
Split 'train' available under /home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/experiments/data
Split 'test' available under /home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/experiments/data
Downloading CUB-200-2011 from https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz...
Download complete!
Extracting CUB-200-2011 dataset...
Extraction complete!
Split 'train,test' available under /home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/experiments/data


### 2. Pre-training the model

In [None]:
# Using the cli

# Clear
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Train using default cifar100 config + custom ViT backbone
# !python -m wejepa.train.pretrain --print-config     # print only
# !python -m wejepa.train.pretrain                    # train
_
# FIXME: bug when using .arrow files, the file path is not correctly set, workaround is to rename the arrow file
#   cp fall2025_deeplearning-train.arrow tsbpp___fall2025_deeplearning-train.arrow

# print where --config searches for config files
!python -m wejepa.train.pretrain --config hf224_config.json

Model has 4,449,408 trainable parameters.
Training:  25%|████████                        | 49/195 [00:10<00:27,  5.37it/s]Epoch 1 Iter 50/195 | Loss 0.2530 | 1244.2 img/s
Training:  51%|████████████████▏               | 99/195 [00:19<00:16,  5.74it/s]Epoch 1 Iter 100/195 | Loss 0.2100 | 1310.0 img/s
Training:  76%|███████████████████████▋       | 149/195 [00:28<00:08,  5.20it/s]Epoch 1 Iter 150/195 | Loss 0.1971 | 1319.0 img/s
Training: 100%|███████████████████████████████| 195/195 [00:37<00:00,  5.19it/s]
Epoch completed.
Saved checkpoint to /home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/experiments/outputs/ijepa/ijepa_epoch_0001.pt
Epoch 1/5 | loss=0.1895
Training:  25%|████████                        | 49/195 [00:09<00:27,  5.23it/s]Epoch 2 Iter 50/195 | Loss 0.1672 | 1355.3 img/s
Training:  51%|████████████████▏               | 99/195 [00:18<00:14,  6.59it/s]Epoch 2 Iter 100/195 | Loss 0.1604 | 1358.7 img/s
Training:  76%|███████████████████████▋       | 149/195 [00:28<

''

In [None]:
# programmatically
from wejepa import default_config, launch_pretraining
cfg = default_config()
launch_pretraining(cfg)

### 3. Fine tuning the model

In [5]:
# using the cli
!python -m wejepa.train.finetune \
    --checkpoint outputs/ijepa/ijepa_epoch_0005.pt \
    --epochs 10 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 100

[Linear probe] Epoch 1/10 | loss=4.5246 | train_acc=0.027 | val_acc=0.042
[Linear probe] Epoch 2/10 | loss=4.4346 | train_acc=0.040 | val_acc=0.049
[Linear probe] Epoch 3/10 | loss=4.3911 | train_acc=0.042 | val_acc=0.052
[Linear probe] Epoch 4/10 | loss=4.3613 | train_acc=0.046 | val_acc=0.056
[Linear probe] Epoch 5/10 | loss=4.3389 | train_acc=0.047 | val_acc=0.063
[Linear probe] Epoch 6/10 | loss=4.3260 | train_acc=0.049 | val_acc=0.064
[Linear probe] Epoch 7/10 | loss=4.3143 | train_acc=0.051 | val_acc=0.063
[Linear probe] Epoch 8/10 | loss=4.3070 | train_acc=0.051 | val_acc=0.069
[Linear probe] Epoch 9/10 | loss=4.2993 | train_acc=0.054 | val_acc=0.069
[Linear probe] Epoch 10/10 | loss=4.2907 | train_acc=0.055 | val_acc=0.072


In [3]:
# programmatically
from wejepa.train import FinetuneConfig, train_linear_probe

ft_cfg = FinetuneConfig(
    checkpoint_path="outputs/ijepa/ijepa_epoch_0005.pt",
    epochs=5,
    batch_size=128,
    learning_rate=1e-3,
)
train_linear_probe(ft_cfg)

[Linear probe] Epoch 1/5 | loss=4.4194 | train_acc=0.037 | val_acc=0.055
[Linear probe] Epoch 2/5 | loss=4.3261 | train_acc=0.046 | val_acc=0.062
[Linear probe] Epoch 3/5 | loss=4.2992 | train_acc=0.051 | val_acc=0.065
[Linear probe] Epoch 4/5 | loss=4.2842 | train_acc=0.053 | val_acc=0.073
[Linear probe] Epoch 5/5 | loss=4.2742 | train_acc=0.055 | val_acc=0.070


LinearProbe(
  (backbone): IJEPA_base(
    (patch_embed): PatchEmbed(
      (conv): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
    )
    (post_emb_norm): Identity()
    (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (student_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerBlock(
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
          )
          (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (0): Linear(in_features=192, out_features=768, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=768, out_features=192, bias=True)
            (4): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (norm): LayerNo

### 4. Different Backbones

In [4]:
import json
import torch
from copy import deepcopy
from pathlib import Path

from wejepa.backbones import adapt_config_for_backbone, available_backbones
from wejepa.config import IJepaConfig
from wejepa import default_config, launch_pretraining, IJEPA_base

print("Registered backbones: ")
for backbone in available_backbones():
    print(f"- {backbone}")

candidates = ["vit_b_16", "swin_t", "convnext_tiny"]
for backbone in candidates:
    print(f"\nPretraining with backbone: {backbone}")

for backbone in available_backbones():
    cfg = adapt_config_for_backbone(default_config(), backbone)
    print(f"\nBackbone: {backbone}")
    print(f"Image size: {cfg.model.img_size} | Patch size: {cfg.model.patch_size}")

    model = IJEPA_base(
        img_size=cfg.model.img_size,
        patch_size=cfg.model.patch_size,
        in_chans=cfg.model.in_chans,
        embed_dim=cfg.model.embed_dim,
        enc_depth=cfg.model.enc_depth,
        pred_depth=cfg.model.pred_depth,
        num_heads=cfg.model.num_heads,
        backbone=cfg.model.classification_backbone,
        pretrained=cfg.model.classification_pretrained,
    )

    print(f"Total trainable params: {model.count_trainable_parameters() / 1e6:.2f}M")
    print(f"Student + predictor params: {model.count_parameters() / 1e6:.2f}M")

    dummy = torch.randn(1, cfg.model.in_chans, cfg.model.img_size, cfg.model.img_size)
    preds, targets = model(dummy)
    print(f"Pred shape: {tuple(preds.shape)} | Target shape: {tuple(targets.shape)}")
    print(json.dumps(cfg.to_dict(), indent=2))

    cfg.hardware.output_dir = f"./outputs/ijepa/{backbone}"
    cfg_path = Path(f"configs/pretrain_{backbone}.json")
    cfg_path.parent.mkdir(parents=True, exist_ok=True)
    cfg_path.write_text(json.dumps(cfg.to_dict(), indent=2))
    print(f"Saved config for {backbone} at {cfg_path}")

    # launch_pretraining(cfg)

Registered backbones: 
- convnext_small
- convnext_tiny
- swin_s
- swin_t
- vit_b_16

Pretraining with backbone: vit_b_16

Pretraining with backbone: swin_t

Pretraining with backbone: convnext_tiny

Backbone: convnext_small
Image size: 224 | Patch size: 32
Total trainable params: 121.15M
Student + predictor params: 70.88M
Pred shape: (4, 1, 9, 768) | Target shape: (4, 1, 9, 768)
{
  "data": {
    "dataset_root": "/home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/experiments/data",
    "dataset_name": "cifar100",
    "image_size": 224,
    "train_batch_size": 256,
    "eval_batch_size": 512,
    "num_workers": 4,
    "pin_memory": true,
    "persistent_workers": true,
    "prefetch_factor": 2,
    "crop_scale": [
      0.6,
      1.0
    ],
    "color_jitter": 0.5,
    "use_color_distortion": true,
    "use_horizontal_flip": true,
    "normalization_mean": [
      0.5071,
      0.4867,
      0.4408
    ],
    "normalization_std": [
      0.2675,
      0.2565,
      0.2761
    