# IJEPA Pretrained Model Evaluation and MLP Decoder Training
This notebook demonstrates how to:
- Load a pretrained IJEPA checkpoint and evaluate it on a test set (average test loss).
- Prepare and train an MLP decoder for classification using labels from `.xlsx` files.
- Keep each major step in a separate cell for clarity and reproducibility.

**Update the file paths as needed for your setup.**

In [None]:
# Cell 1: Imports and Setup
import os
import yaml
import torch
torch.set_num_threads(16)
print(f"Number of threads: {torch.get_num_threads()}")
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd

# add the path to the src directory
import sys
# Use a raw string to avoid unicode escape issues on Windows
sys.path.append(r"C:\Users\dash\Documents\Wosler\learning_ai\ijepa")
from src.helper import init_model

## Load Config, and Checkpoint
- Loads the experiment config and pretrained checkpoint.
- Initializes the encoder and predictor.

In [None]:
# Cell 2: Load config and checkpoint, prepare test set
# --- Set your paths here ---
config_path = r'..\configs\simUS_dataset_new_vit_tiny_16_ep100.yaml'
checkpoint_path = r'..\exp_logs\vit_tiny_with_val_ep100\vit_tiny_with_val_ep100-latest.pth.tar'

# Load config
def load_yaml_config(path):
    with open(path, "r") as f:
        return yaml.safe_load(f)

cfg = load_yaml_config(config_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Init model and predictor
encoder, predictor = init_model(
    device=device,
    patch_size=cfg["mask"]["patch_size"],
    crop_size=cfg["data"]["crop_size"],
    pred_depth=cfg["meta"]["pred_depth"],
    pred_emb_dim=cfg["meta"]["pred_emb_dim"],
    model_name=cfg["meta"]["model_name"]
)
encoder.eval()
predictor.eval()

# Remove 'module.' prefix if present in checkpoint keys
def remove_module_prefix(state_dict):
    return {k.replace('module.', ''): v for k, v in state_dict.items()}

# Load checkpoint
ckpt = torch.load(checkpoint_path, map_location=device)
encoder.load_state_dict(remove_module_prefix(ckpt["encoder"]), strict=False)
predictor.load_state_dict(remove_module_prefix(ckpt["predictor"]), strict=False)


# Preparing context and target patches from images

In [None]:
# Cell 5: Visualize context (top half) and target (bottom half) patches on an image

import matplotlib.pyplot as plt
import matplotlib.patches as patches

# --- Set your image path here ---
image_path = r'C:\Users\dash\Documents\Wosler\learning_ai\ijepa\dataset\dataset_DATE_2025_06_17_TIME_12_50_58\train\class0\11.png'

# Load image
img = Image.open(image_path).convert('RGB')
img_np = np.array(img)

# Get image dimensions
h, w, _ = img_np.shape

# Define context (top half) and target (bottom half) regions
context_rect = (0, 0, w, h // 2)  # (x, y, width, height)
target_rect = (0, h // 2, w, h - h // 2)

# Plot image with overlays
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(img_np)
# Context patch (blue, semi-transparent)
rect1 = patches.Rectangle((context_rect[0], context_rect[1]), context_rect[2], context_rect[3],
                          linewidth=2, edgecolor='blue', facecolor='blue', alpha=0.3, label='Context (Top Half)')
ax.add_patch(rect1)
# Target patch (red, semi-transparent)
rect2 = patches.Rectangle((target_rect[0], target_rect[1]), target_rect[2], target_rect[3],
                          linewidth=2, edgecolor='red', facecolor='red', alpha=0.3, label='Target (Bottom Half)')
ax.add_patch(rect2)
ax.set_axis_off()
plt.legend(handles=[rect1, rect2], loc='upper right')
plt.title('Context (Blue) and Target (Red) Patches')
plt.show()

In [None]:
# Cell 6: Dataset and DataLoader for training with Excel labels
from torchvision import transforms
import pandas as pd
from torch.utils.data import Dataset, DataLoader

# Use the crop size and normalization expected by the I-JEPA encoder
preprocess = transforms.Compose([
    transforms.Resize((cfg['data']['crop_size'], cfg['data']['crop_size'])),
    transforms.ToTensor(),
    # Uncomment and adjust normalization if your model expects it:
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Set patch_size and crop_size for dataset usage
patch_size = cfg['mask']['patch_size']
crop_size = cfg['data']['crop_size']

class SolidHollowDataset(Dataset):
    def __init__(self, image_folder, excel_path, preprocess, patch_size, crop_size, device):
        self.image_folder = image_folder
        self.preprocess = preprocess
        self.patch_size = patch_size
        self.crop_size = crop_size
        self.device = device

        # Read Excel and create mapping
        df = pd.read_excel(excel_path)
        self.image_to_label = dict(zip(df['image'], df['material_fill']))
        self.image_paths = [os.path.join(image_folder, fname) for fname in self.image_to_label.keys()]

        # Map class names to integer labels
        self.class_map = {'solid': 0, 'hollow': 1}

        # Precompute mask indices
        self.num_patches_h = crop_size // patch_size
        self.num_patches_w = crop_size // patch_size
        context_mask_np = np.zeros((self.num_patches_h, self.num_patches_w), dtype=np.float32)
        context_mask_np[:self.num_patches_h // 2, :] = 1.0
        self.context_indices = torch.nonzero(torch.from_numpy(context_mask_np).flatten(), as_tuple=False).flatten()
        target_mask_np = np.zeros((self.num_patches_h, self.num_patches_w), dtype=np.float32)
        target_mask_np[self.num_patches_h // 2:, :] = 1.0
        self.target_indices = torch.nonzero(torch.from_numpy(target_mask_np).flatten(), as_tuple=False).flatten()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')
        input_tensor = self.preprocess(img).to(self.device)  # [3, crop, crop]
        context_idx = self.context_indices.to(self.device)   # [num_context]
        target_idx = self.target_indices.to(self.device)     # [num_target]
        label_str = self.image_to_label[os.path.basename(img_path)]
        label = self.class_map[label_str]
        return input_tensor, context_idx, target_idx, label

# Example usage:
image_folder = r'C:\Users\dash\Documents\Wosler\learning_ai\ijepa\dataset\dataset_DATE_2025_06_17_TIME_15_32_38\train\class0'
excel_path = r'C:\Users\dash\Documents\Wosler\learning_ai\ijepa\dataset\dataset_DATE_2025_06_17_TIME_15_32_38\train\class0\class_info.xlsx'
dataset = SolidHollowDataset(
    image_folder=image_folder,
    excel_path=excel_path,
    preprocess=preprocess,
    patch_size=patch_size,
    crop_size=crop_size,
    device=device
)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Check a batch
for input_tensor, context_idx, target_idx, label in dataloader:
    print(input_tensor.shape, context_idx.shape, target_idx.shape, label)
    break


In [None]:
# Cell 7: Full training loop for MLP decoder (100 epochs)
import torch.optim as optim
import time

# Define a simple MLP decoder
class MLPDecoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_classes=2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Re-create DataLoader with more workers for speed
num_workers = 16  # Use all physical cores
batch_size = 32  # Increase batch size if memory allows

# Set pin_memory only if using CUDA
pin_memory = torch.cuda.is_available()

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

mlp = None
optimizer = None
criterion = nn.CrossEntropyLoss()

num_epochs = 100
loss_history = []
start_time = time.time()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    n_batches = 0
    for input_tensor, context_idx, target_idx, label in dataloader:
        # input_tensor: [B, 3, crop, crop], context_idx: [B, num_context], target_idx: [B, num_target]
        with torch.no_grad():
            z = encoder(input_tensor, context_idx.long())  # [B, N, D]
            h = predictor(z, context_idx.long(), target_idx.long())  # [B, num_target, D]
        pooled = h.mean(dim=1)  # [B, D]
        if mlp is None:
            mlp = MLPDecoder(input_dim=pooled.shape[1]).to(device)
            optimizer = optim.Adam(mlp.parameters(), lr=1e-3)
        logits = mlp(pooled)  # [B, 2]
        loss = criterion(logits, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        n_batches += 1
    avg_loss = epoch_loss / n_batches
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {avg_loss:.4f}")

elapsed = time.time() - start_time
print(f"Training complete in {elapsed/60:.2f} minutes.")