# Example based on our README.md
1. Dataset download
2. Pre-training
3. Fine-tuning
4. Inference
5. Different backbones
6. Extract features (TODO)


### 1. Download the dataset

In [None]:
# using the cli

# download class dataset
# !python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits train

# for development, download a small subset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name tsbpp/fall2025_deeplearning --splits 'train[:10]'

# download cifar100 dataset
!python -m wejepa.datasets.download --dataset-root ./data --dataset-name cifar100

### 2. Pre-training the model

In [None]:
# Using the cli

# Clear
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Train using default cifar100 config + custom ViT backbone
# !python -m wejepa.train.pretrain --print-config     # print only
# !python -m wejepa.train.pretrain                    # train
_
# FIXME: bug when using .arrow files, the file path is not correctly set, workaround is to rename the arrow file
#   cp fall2025_deeplearning-train.arrow tsbpp___fall2025_deeplearning-train.arrow

# print where --config searches for config files
!python -m wejepa.train.pretrain --config configs/pretrain_devel_tsbpp_224.json

Model has 4,449,408 trainable parameters.
Using the latest cached version of the dataset since tsbpp___fall2025_deeplearning couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at data/tsbpp___fall2025_deeplearning/default/0.0.0/7b14dd4385d982457822e8e96c5081a30da146d8 (last modified on Wed Nov 26 20:43:26 2025).
Using the latest cached version of the dataset since tsbpp___fall2025_deeplearning couldn't be found on the Hugging Face Hub
Using the latest cached version of the dataset since tsbpp___fall2025_deeplearning couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at data/tsbpp___fall2025_deeplearning/default/0.0.0/7b14dd4385d982457822e8e96c5081a30da146d8 (last modified on Wed Nov 26 20:43:26 2025).
Found the latest cached dataset configuration 'default' at data/tsbpp___fall2025_deeplearning/default/0.0.0/7b14dd4385d982457822e8e96c5081a30da146d8 (last modified on Wed Nov 26 20:43:26 2025).
U



In [None]:
# programmatically
from wejepa import default_config, launch_pretraining
cfg = default_config()
launch_pretraining(cfg)

### 3. Fine tuning the model

In [None]:
# using the cli
!python -m wejepa.train.finetune \
    --checkpoint outputs/ijepa/ijepa_epoch_0005.pt \
    --epochs 10 \
    --batch-size 256 \
    --lr 3e-4 \
    --num-classes 100

In [None]:
# programmatically
from wejepa.train import FinetuneConfig, train_linear_probe

ft_cfg = FinetuneConfig(
    checkpoint_path="outputs/ijepa/ijepa_epoch_0005.pt",
    epochs=5,
    batch_size=128,
    learning_rate=1e-3,
)
train_linear_probe(ft_cfg)

### 4. Running Inference

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
from datasets import load_dataset
from PIL import Image

from wejepa.train import load_backbone_from_checkpoint
from wejepa import default_config

cfg = default_config()
backbone = load_backbone_from_checkpoint("outputs/ijepa/ijepa_epoch_0005.pt", cfg)
backbone.eval()

transform = transforms.Compose([
    transforms.Resize(cfg.data.image_size),
    transforms.ToTensor(),
    transforms.Normalize(cfg.data.normalization_mean, cfg.data.normalization_std),
])

# Load the dataset
ds = load_dataset(
    "./data/tsbpp___fall2025_deeplearning",
    split="train",
)
print("Dataset features:", ds.features)

# Try to find a label field, fallback to None if not present
possible_label_fields = [k for k in ds.features.keys() if "label" in k or "class" in k]
label_field = possible_label_fields[0] if possible_label_fields else None
label_feature = ds.features[label_field] if label_field else None
label_names = getattr(label_feature, "names", None) if label_feature is not None else None
num_classes = len(label_names) if label_names is not None else 100  # default to 100 classes

# Dummy LinearProbe class for demonstration (replace with your actual implementation)
class LinearProbe(torch.nn.Module):
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        # Use cfg.model.embed_dim as the feature dimension
        self.head = torch.nn.Linear(cfg.model.embed_dim, num_classes)
    def forward(self, x):
        feats = self.backbone(x)
        if isinstance(feats, (list, tuple)):
            feats = feats[0]
        if feats.ndim > 2:
            feats = feats.mean(dim=1)
        return self.head(feats)

# Load the linear probe weights if available
decoder = LinearProbe(backbone, num_classes)
try:
    decoder.load_state_dict(torch.load("outputs/ijepa/linear_probe.pt", map_location="cpu"))
    print("Loaded linear probe weights.")
except Exception as e:
    print("Could not load linear probe weights:", e)
decoder.eval()

# Grab an image from the dataset
image = transform(ds[0]["image"]).unsqueeze(0)
print(f"Image shape: {image.shape}")

with torch.no_grad():
    logits = decoder(image)
    probs = torch.softmax(logits, dim=1)
    pred_ind = int(probs.argmax(dim=1).item())

pred_label = label_names[pred_ind] if label_names is not None else str(pred_ind)
top5_inds = probs.topk(5).indices.squeeze(0).tolist()
top5_labels = [label_names[i] if label_names is not None else str(i) for i in top5_inds]
print(f"Predicted label: {pred_label}")
print(f"Top-5 predicted labels: {top5_labels}")

# Remove batch dimension and convert to numpy
img_np = image.squeeze(0).permute(1, 2, 0).cpu().numpy()

# Undo normalization for display
mean = np.array(cfg.data.normalization_mean)
std = np.array(cfg.data.normalization_std)
img_np = (img_np * std) + mean
img_np = np.clip(img_np, 0, 1)

plt.imshow(img_np)
plt.axis('off')
plt.show()

with torch.no_grad():
    tokens = backbone(image)
    if isinstance(tokens, (list, tuple)):
        tokens = tokens[0]
    pooled = tokens.mean(dim=1)  # embeddings for downstream heads

# TODO: use the embeddings `pooled` for downstream tasks like classification 
print(f"Extracted embeddings shape: {pooled.shape}")

classifier = torch.nn.Linear(pooled.size(1), num_classes)
logits = classifier(pooled)
print(f"Logits shape: {logits.shape}")

# Display the classified scores
print(f"Classified scores: {logits}")

# Assign predicted class
predicted_class = torch.argmax(logits, dim=1)
print(f"Predicted class: {predicted_class.item()}")

### 5. Different Backbones

In [None]:
import json
import torch
from copy import deepcopy
from pathlib import Path

from wejepa.backbones import adapt_config_for_backbone, available_backbones
from wejepa.config import IJepaConfig
from wejepa import default_config, launch_pretraining, IJEPA_base

print("Registered backbones: ")
for backbone in available_backbones():
    print(f"- {backbone}")

candidates = ["vit_b_16", "swin_t", "convnext_tiny"]
for backbone in candidates:
    print(f"\nPretraining with backbone: {backbone}")

for backbone in available_backbones():
    cfg = adapt_config_for_backbone(default_config(), backbone)
    print(f"\nBackbone: {backbone}")
    print(f"Image size: {cfg.model.img_size} | Patch size: {cfg.model.patch_size}")

    model = IJEPA_base(
        img_size=cfg.model.img_size,
        patch_size=cfg.model.patch_size,
        in_chans=cfg.model.in_chans,
        embed_dim=cfg.model.embed_dim,
        enc_depth=cfg.model.enc_depth,
        pred_depth=cfg.model.pred_depth,
        num_heads=cfg.model.num_heads,
        backbone=cfg.model.classification_backbone,
        pretrained=cfg.model.classification_pretrained,
    )

    print(f"Total trainable params: {model.count_trainable_parameters() / 1e6:.2f}M")
    print(f"Student + predictor params: {model.count_parameters() / 1e6:.2f}M")

    dummy = torch.randn(1, cfg.model.in_chans, cfg.model.img_size, cfg.model.img_size)
    preds, targets = model(dummy)
    print(f"Pred shape: {tuple(preds.shape)} | Target shape: {tuple(targets.shape)}")
    print(json.dumps(cfg.to_dict(), indent=2))

    cfg.hardware.output_dir = f"./outputs/ijepa/{backbone}"
    cfg_path = Path(f"configs/pretrain_{backbone}.json")
    cfg_path.parent.mkdir(parents=True, exist_ok=True)
    cfg_path.write_text(json.dumps(cfg.to_dict(), indent=2))
    print(f"Saved config for {backbone} at {cfg_path}")

    # launch_pretraining(cfg)

### 6. Extract Features

In [None]:
import matplotlib.pyplot as plt

from wejepa.analysis.visualization import (
    extract_backbone_features,
    plot_tsne_embeddings,
    run_tsne_projection
)
from wejepa.backbones import build_backbone

backbone_names = ["vit_b_16", "swin_t", "convnext_tiny"]
tsne_results = {}

for backbone_name in backbone_names:
    print(f"Projecting embeddings for {backbone_name} ...")
    backbone, feature_dim = build_backbone(backbone_name, pretrained=True, freeze_backbone=True)

    # Use a small slice of the dataset to keep visualization quick.
    dataloader = build_dataloader(backbone_name, batch_size=24, split="train[:192]")
    features, labels = extract_backbone_features(backbone, dataloader, max_batches=4)

    embedding = run_tsne_projection(features, perplexity=20.0, random_state=42)
    fig = plot_tsne_embeddings(embedding, labels)
    fig.suptitle(f"{backbone_name} TSNE", y=1.02)
    plt.show()

    tsne_results[backbone_name] = embedding

ImportError: cannot import name 'build_dataloader' from 'wejepa.analysis.visualization' (/home/long/PhD/Research/Code/test/project2/ijepa/src/wejepa/analysis/visualization.py)