# Example based on our README.md
1. Dataset download
2. Pre-training
3. Fine-tuning
4. Inference


### 1. Download the dataset

In [None]:
# using the cli

# download class dataset
# !python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits train

# for development, download a small subset
# !python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits 'train[:10]'

# download cifar100 dataset
# !python -m wejepa.datasets.download --dataset-root ./data --dataset-name cifar100

# download cub200 dataset
# !python -m wejepa.datasets.download --dataset-root ./data --dataset-name cub200 --splits train,test

Split 'train[:10]' available under /home/long/code/dl_project1/experiments/data
Split 'train' available under /home/long/code/dl_project1/experiments/data
Split 'test' available under /home/long/code/dl_project1/experiments/data


### 2. Pre-training the model

In [None]:
# Using the cli

# Clear
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Train using default cifar100 config + custom ViT backbone
# !python -m wejepa.train.pretrain --print-config     # print only
# !python -m wejepa.train.pretrain                    # train
_
# FIXME: bug when using .arrow files, the file path is not correctly set, workaround is to rename the arrow file
#   cp fall2025_deeplearning-train.arrow tsbpp___fall2025_deeplearning-train.arrow

# print where --config searches for config files
!python -m wejepa.train.pretrain --config hf224_config.json

Model has 5,061,504 trainable parameters.
Using the latest cached version of the dataset since tsbpp___fall2025_deeplearning couldn't be found on the Hugging Face Hub
Using the latest cached version of the dataset since tsbpp___fall2025_deeplearning couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at data/tsbpp___fall2025_deeplearning/default/0.0.0/7b14dd4385d982457822e8e96c5081a30da146d8 (last modified on Thu Nov 20 02:42:58 2025).
Found the latest cached dataset configuration 'default' at data/tsbpp___fall2025_deeplearning/default/0.0.0/7b14dd4385d982457822e8e96c5081a30da146d8 (last modified on Thu Nov 20 02:42:58 2025).
Using the latest cached version of the dataset since tsbpp___fall2025_deeplearning couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at data/tsbpp___fall2025_deeplearning/default/0.0.0/7b14dd4385d982457822e8e96c5081a30da146d8 (last modified on Thu Nov 20 02:42:58 2025).
U

In [4]:
# programmatically
from wejepa import default_config, launch_pretraining
cfg = default_config()
launch_pretraining(cfg)

Model has 157,602,280 trainable parameters.


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/120M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/long/PhD/Environments/ijepa/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/long/PhD/Environments/ijepa/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/home/long/PhD/Coursework/Deep_Learning/Project/Code/ijepa/src/wejepa/datasets/hf.py", line 34, in __getitem__
    img = self.transform(self.dataset[index]["image"])
                         ~~~~~~~~~~~~~~~~~~~^^^^^^^^^
KeyError: 'image'


### 3. Fine tuning the model

In [None]:
# using the cli
!python -m wejepa.train.finetune \
    --checkpoint outputs/ijepa/ijepa_epoch_0005.pt \
    --epochs 10 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 100

In [None]:
# programmatically
from wejepa.train import FinetuneConfig, train_linear_probe

ft_cfg = FinetuneConfig(
    checkpoint_path="outputs/ijepa/ijepa_epoch_0005.pt",
    epochs=5,
    batch_size=128,
    learning_rate=1e-3,
)
train_linear_probe(ft_cfg)

### 4. Running Inference

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
from datasets import load_dataset
from PIL import Image


from wejepa.train import load_backbone_from_checkpoint
from wejepa import default_config

cfg = default_config()
backbone = load_backbone_from_checkpoint("outputs/ijepa/ijepa_epoch_0005.pt", cfg)
backbone.eval()

transform = transforms.Compose([
    transforms.Resize(cfg.data.image_size),
    transforms.ToTensor(),
    transforms.Normalize(cfg.data.normalization_mean, cfg.data.normalization_std),
])

ds = load_dataset(
    "./data/tsbpp___fall2025_deeplearning",
    split="train",
)

label_feature = ds.features["label"] if hasattr(ds, "features") else None
label_names = label_feature.names if label_feature is not None else None
num_classes = len(label_names) if label_names is not None else 100 # default to 100 classes

decoder = LinearProbe(backbone, num_classes)
decoder.load_state_dict(torch.load("outputs/ijepa/linear_probe.pt", map_location="cpu"))
decoder.eval()

# grab an image from the dataset
image = transform(ds[0]["image"]).unsqueeze(0)
print(f"Image shape: {image.shape}")

with torch.no_grad():
    logits = decoder(image)
    probs = torch.softmax(logits,dim=1)
    pred_ind = int(probs.argmax(dim=1).item())

pred_label = label_names[pred_ind] if label_names is not None else str(pred_ind)
top5_inds = probs.topk(5).indices.squeeze(0).tolist()
top5_labels = [label_names[i] if label_names is not None else str(i) for i in top5_inds]
print(f"Predicted label: {pred_label}")
print(f"Top-5 predicted labels: {top5_labels}")

# remove batch dimension and convert to numpy
img_np = image.squeeze(0).permute(1, 2, 0).cpu().numpy()

# undo normalization for display
mean = np.array(cfg.data.normalization_mean)
std = np.array(cfg.data.normalization_std)
img_np = (img_np * std) + mean
img_np = np.clip(img_np, 0, 1)

plt.imshow(img_np)
plt.axis('off')
plt.show()

with torch.no_grad():
    tokens = backbone(image)
    pooled = tokens.mean(dim=1)  # embeddings for downstream heads

# TODO: use the embeddings `pooled` for downstream tasks like classification 
print(f"Extracted embeddings shape: {pooled.shape}")

num_classes = 100  # adjust based on your dataset
classifier = torch.nn.Linear(pooled.size(1), num_classes)
logits = classifier(pooled)
print(f"Logits shape: {logits.shape}")

# display the classified scores
print(f"Classified scores: {logits}")

# assign predicted class
predicted_class = torch.argmax(logits, dim=1)
print(f"Predicted class: {predicted_class.item()}")

### 5. Different Backbones

In [2]:
from wejepa.backbones import available_backbones
from wejepa.config import IJepaConfig
from wejepa import default_config, launch_pretraining, IJEPA_base
from pathlib import Path
import json
from copy import deepcopy

print("Registered backbones: ")
for backbone in available_backbones():
    print(f"- {backbone}")

candidates = ["vit_b_16", "swin_t", "convnext_tiny"]
for backbone in candidates:
    print(f"\nPretraining with backbone: {backbone}")

    with open("hf224_config.json", "r") as f:
        cfg_dict = json.load(f)
    cfg = IJepaConfig.from_dict(cfg_dict)

    cfg.model.classification_backbone = backbone
    cfg.model.classification_pretrained = True
    cfg.hardware.output_dir = f"./outputs/ijepa/{backbone}"
    cfg_path = Path(f"configs/pretrain_{backbone}.json")
    cfg_path.parent.mkdir(parents=True, exist_ok=True)
    cfg_path.write_text(json.dumps(cfg.to_dict(), indent=2))
    print(f"Saved config for {backbone} at {cfg_path}")

    launch_pretraining(cfg)

Registered backbones: 
- convnext_tiny
- resnet50
- resnext50_32x4d
- swin_t
- vit_b_16
- vit_l_16

Pretraining with backbone: vit_b_16
Saved config for vit_b_16 at configs/pretrain_vit_b_16.json
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /home/long/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /home/long/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /home/long/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /home/long/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:11<00:00, 31.1MB/s] 
100%|██████████| 330M/330M [00:11<00:00, 31.0MB/s]
100%|██████████| 330M/330M [00:11<00:00, 29.6MB/s]
100%|██████████| 330M/330M [00:11<00:00, 29.1MB/s]
  scaler = GradScaler(enabled=use_amp)
  scaler = GradScaler(enabled=use_amp)
  scaler = GradScaler(enabled=use_amp)
  scaler = GradScaler(enabled=use_amp)


Model has 86,567,656 trainable parameters.


 34%|███▍      | 57.1M/169M [00:00<00:01, 89.4MB/s]W1120 01:53:48.654000 2726638 torch/multiprocessing/spawn.py:174] Terminating process 2728889 via signal SIGTERM
W1120 01:53:48.656000 2726638 torch/multiprocessing/spawn.py:174] Terminating process 2728892 via signal SIGTERM
W1120 01:53:48.657000 2726638 torch/multiprocessing/spawn.py:174] Terminating process 2728894 via signal SIGTERM


ProcessRaisedException: 

-- Process 2 terminated with the following error:
Traceback (most recent call last):
  File "/home/long/code/environments/wejepa/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 95, in _wrap
    fn(i, *args)
  File "/home/long/code/dl_project1/src/wejepa/train/pretrain.py", line 231, in _train_worker
    data_loader, sampler = create_pretraining_dataloader(cfg, rank=rank, world_size=world_size)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/long/code/dl_project1/src/wejepa/datasets/cifar.py", line 107, in create_pretraining_dataloader
    dataset = IJEPADataset(cfg, train=True, download=rank == 0)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/long/code/dl_project1/src/wejepa/datasets/cifar.py", line 42, in __init__
    self.dataset = torchvision.datasets.CIFAR100(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/long/code/environments/wejepa/lib/python3.12/site-packages/torchvision/datasets/cifar.py", line 69, in __init__
    raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
RuntimeError: Dataset not found or corrupted. You can use download=True to download it


### 6. Visualizing Backbone Embeddings

### 7. Extract Features

In [None]:
import matplotlib.pyplot as plt

from wejepa.analysis.visualization import (
    extract_backbone_features,
    plot_tsne_embeddings,
    run_tsne_projection,
)
from wejepa.backbones import build_backbone

backbone_names = ["vit_b_16", "swin_t", "convnext_tiny"]
tsne_results = {}

for backbone_name in backbone_names:
    print(f"Projecting embeddings for {backbone_name} ...")
    backbone, feature_dim = build_backbone(backbone_name, pretrained=True, freeze_backbone=True)

    # use a small slice of the dataset to keep visualization quick.
    dataloader = build_dataloader(backbone_name, batch_size=24, split="train[:10]")
    features, labels = extract_backbone_features(backbone, dataloader, max_batches=4)

    embedding = run_tsne_projection(features, perplexity=20.0, random_state=42)
    fig = plot_tsne_embeddings(embedding, labels)
    fig.suptitle(f"{backbone_name} TSNE", y=1.02)
    plt.show()

    tsne_results[backbone_name] = embedding