# Step 1: Install Dependencies
Install required libraries and import core Python modules.

In [1]:
!pip install datasets transformers timm

import os
import torch
torch.set_num_threads(4)
torch.set_num_interop_threads(4)
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchvision.datasets.folder import default_loader
from torchvision.utils import save_image
import matplotlib.pyplot as plt



  from .autonotebook import tqdm as notebook_tqdm


# Step 2: Download & Prepare ImageNet-Sketch Dataset
Load the dataset from Hugging Face and wrap it with PyTorch's DataLoader.

In [4]:
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import Dataset

# Load the dataset
dataset = load_dataset("imagenet_sketch", trust_remote_code=True)

# Define image transform
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Wrap in PyTorch Dataset
class SketchDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]["image"]  # already a PIL Image
        if self.transform:
            image = self.transform(image)
        return image

# Example usage
sketch_data = SketchDataset(dataset, transform)


ChunkedEncodingError: ('Connection broken: IncompleteRead(7007538311 bytes read, 586034701 more expected)', IncompleteRead(7007538311 bytes read, 586034701 more expected))

# Step 3: Load the Pretrained Facebook MAE Model
Load the model from the cloned MAE repository and prepare it for training.

In [None]:
import sys
sys.path.append("c:/Users/dash/Documents/learning_ai/mae")
from models_mae import mae_vit_huge_patch14_dec512d8b

mae_model = mae_vit_huge_patch14_dec512d8b()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mae_model.to(device)

# Step 4: Train the MAE Decoder
Freeze the encoder and train the decoder on the sketch dataset using reconstruction loss.

In [None]:
optimizer = torch.optim.Adam(mae_model.decoder.parameters(), lr=1e-4)
mae_model.train()
for epoch in range(10):  # change number of epochs as needed
    for imgs in dataloader:
        imgs = imgs.to(device)
        loss, _, _ = mae_model(imgs, mask_ratio=0.75)  # MAE reconstruction loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

torch.save(mae_model.decoder.state_dict(), "mae_decoder_trained.pth")

# Step 5: Load I-JEPA and Get Predictor Output
Initialize the I-JEPA model, run the predictor, and get the intermediate feature embeddings.

In [None]:
import yaml
from PIL import Image
from src.helper import init_model
from src.masks.multiblock import MaskCollator

ijepa_cfg_path = "c:/Users/dash/Documents/learning_ai/ijepa/configs/in1k_vith14_ep300.yaml"
ijepa_weights_path = "c:/Users/dash/Documents/learning_ai/ijepa/IN1K-vit.h.14-300e.pth.tar"
cfg = yaml.safe_load(open(ijepa_cfg_path))

model, predictor = init_model(
    device=device,
    patch_size=cfg["mask"]["patch_size"],
    model_name=cfg["meta"]["model_name"],
    crop_size=cfg["data"]["crop_size"],
    pred_depth=cfg["meta"]["pred_depth"],
    pred_emb_dim=cfg["meta"]["pred_emb_dim"]
)
model.eval()
predictor.eval()

ckpt = torch.load(ijepa_weights_path, map_location=device)
model.load_state_dict(ckpt["encoder"], strict=False)
predictor.load_state_dict(ckpt["predictor"], strict=False)

image_path = "c:/Users/dash/Documents/learning_ai/ijepa/my_image.jpg"
img = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(img).unsqueeze(0).to(device)

mask_collator = MaskCollator(
    input_size=(224, 224),
    patch_size=cfg["mask"]["patch_size"],
    enc_mask_scale=(0.85, 0.85),
    pred_mask_scale=(0.05, 0.15),
    aspect_ratio=(0.3, 3.0),
    nenc=1, npred=1, min_keep=4, allow_overlap=False
)

_, context_masks, target_masks = mask_collator([input_tensor])

with torch.no_grad():
    z = model(input_tensor, context_masks)
    p = predictor(z, context_masks, target_masks)

# Step 6: Decode I-JEPA Output to Pixel Space
Use the trained MAE decoder to reconstruct the pixel-space image from the predictor's output.

In [None]:
mae_model.decoder.load_state_dict(torch.load("mae_decoder_trained.pth", map_location=device))
mae_model.eval()

with torch.no_grad():
    recon = mae_model.decoder(p)
    recon_image = mae_model.unpatchify(recon)
    recon_image = recon_image.clamp(0, 1)
    save_image(recon_image, "ije_pa_reconstruction.png")