# Example based on our README.md
1. Dataset download
2. Prepare dataset
3. Pre-training
4. Fine-tuning
5. Different backbones


### Download the dataset

In [None]:
# using the cli

# download class dataset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits train

# for development, download a small subset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits 'train[:10]'

# download the class pretrain dataset raw data
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --snapshot-download --splits train

# download cifar100 dataset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cifar100

# download cub200 dataset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cub200 --splits train,test

### Inspect and Prepare Dataset

In [None]:
# prepare raw images into huggingface dataset format (not needed with the current dataset structure)
# !python -m wejepa.datasets.prepare_images --input-dir ./data/tsbpp_fall2025_deeplearning --output-dir ./data/tsbpp___fall2025_deeplearning

### Pre-training the model

In [1]:
# Using the cli

# Clear
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Train using default cifar100 config + custom ViT backbone
# !python -m wejepa.train.pretrain --print-config     # print only
# !python -m wejepa.train.pretrain                    # train

# FIXME: bug when using .arrow files, the file path is not correctly set, workaround is to rename the arrow file
#   cp fall2025_deeplearning-train.arrow tsbpp___fall2025_deeplearning-train.arrow

# print where --config searches for config files
# !python -m wejepa.train.pretrain --config hf224_config.json

!python -m wejepa.train.pretrain --config hf224_config.json

Model has 4,449,408 trainable parameters.
Loading imagefolder dataset from directory: ./data/tsbpp___fall2025_deeplearning
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/src/wejepa/train/pretrain.py", line 354, in <module>
    main()
  File "/home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/src/wejepa/train/pretrain.py", line 350, in main
    launch_pretraining(cfg)
  File "/home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/src/wejepa/train/pretrain.py", line 325, in launch_pretraining
    _train_worker(0, 1, cfg.to_dict())
  File "/home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/src/wejepa/train/pretrain.py", line 256, in _train_worker
    data_loader, sampler = create_pretraining_dataloader(cfg, rank=rank, world_size=world_size)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In [None]:
# programmatically
from wejepa import default_config, launch_pretraining
cfg = default_config()
launch_pretraining(cfg)

### Fine tuning the model

In [None]:
# using the cli
!python -m wejepa.train.finetune \
    --checkpoint outputs/ijepa/ijepa_epoch_0005.pt \
    --epochs 10 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 100

In [None]:
# programmatically
from wejepa.train import FinetuneConfig, train_linear_probe

ft_cfg = FinetuneConfig(
    checkpoint_path="outputs/ijepa/ijepa_epoch_0005.pt",
    epochs=5,
    batch_size=128,
    learning_rate=1e-3,
)
train_linear_probe(ft_cfg)

### Different Backbones

In [None]:
import json
import torch
from copy import deepcopy
from pathlib import Path

from wejepa.backbones import adapt_config_for_backbone, available_backbones
from wejepa.config import IJepaConfig
from wejepa import default_config, launch_pretraining, IJEPA_base

print("Registered backbones: ")
for backbone in available_backbones():
    print(f"- {backbone}")

candidates = ["vit_b_16", "swin_t", "convnext_tiny"]
for backbone in candidates:
    print(f"\nPretraining with backbone: {backbone}")

for backbone in available_backbones():
    cfg = adapt_config_for_backbone(default_config(), backbone)
    print(f"\nBackbone: {backbone}")
    print(f"Image size: {cfg.model.img_size} | Patch size: {cfg.model.patch_size}")

    model = IJEPA_base(
        img_size=cfg.model.img_size,
        patch_size=cfg.model.patch_size,
        in_chans=cfg.model.in_chans,
        embed_dim=cfg.model.embed_dim,
        enc_depth=cfg.model.enc_depth,
        pred_depth=cfg.model.pred_depth,
        num_heads=cfg.model.num_heads,
        backbone=cfg.model.classification_backbone,
        pretrained=cfg.model.classification_pretrained,
    )

    print(f"Total trainable params: {model.count_trainable_parameters() / 1e6:.2f}M")
    print(f"Student + predictor params: {model.count_parameters() / 1e6:.2f}M")

    dummy = torch.randn(1, cfg.model.in_chans, cfg.model.img_size, cfg.model.img_size)
    preds, targets = model(dummy)
    print(f"Pred shape: {tuple(preds.shape)} | Target shape: {tuple(targets.shape)}")
    print(json.dumps(cfg.to_dict(), indent=2))

    cfg.hardware.output_dir = f"./outputs/ijepa/{backbone}"
    cfg_path = Path(f"configs/pretrain_{backbone}.json")
    cfg_path.parent.mkdir(parents=True, exist_ok=True)
    cfg_path.write_text(json.dumps(cfg.to_dict(), indent=2))
    print(f"Saved config for {backbone} at {cfg_path}")

    # launch_pretraining(cfg)