In [None]:
!pip install torch torchvision transformers scikit-learn tqdm



In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import AutoImageProcessor, AutoModel, ViTMAEModel
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from tqdm import tqdm

In [None]:
MODEL_NAME = "facebook/vit-mae-base"

In [None]:
BATCH_SIZE = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
print(f"Running inference using: {MODEL_NAME}")
print(f"Device: {DEVICE}")

Running inference using: facebook/vit-mae-base
Device: cuda


In [None]:
# Load Processor and Model
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval() # Freeze model (Dropout off)

preprocessor_config.json:   0%|          | 0.00/217 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/676 [00:00<?, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


model.safetensors:   0%|          | 0.00/448M [00:00<?, ?B/s]

ViTMAEModel(
  (embeddings): ViTMAEEmbeddings(
    (patch_embeddings): ViTMAEPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
  )
  (encoder): ViTMAEEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTMAELayer(
        (attention): ViTMAEAttention(
          (attention): ViTMAESelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTMAESelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTMAEIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTMAEOutput(
          (dense): Linear(i

In [None]:
# Load Dataset (CIFAR-10)
# We use a small subset (Train=2000, Test=500) for a quick walkthrough.
train_dataset = CIFAR10(root="./data", train=True, download=True)
test_dataset = CIFAR10(root="./data", train=False, download=True)

100%|██████████| 170M/170M [00:19<00:00, 8.82MB/s]


In [None]:
def get_dataloader(dataset, subset_size=None):
    if subset_size:
        indices = list(range(subset_size))
        dataset = torch.utils.data.Subset(dataset, indices)
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

train_loader = get_dataloader(train_dataset, subset_size=2000)
test_loader = get_dataloader(test_dataset, subset_size=500)

In [None]:
def extract_features(loader, model, processor):
    print("Extracting features...")
    features_list = []
    labels_list = []

    with torch.no_grad():
        for images, labels in tqdm(loader):
            # 1. Preprocess images (Resize to 224x224, Normalize)
            # We must use the specific processor for each model to match training logic
            inputs = processor(images=images, return_tensors="pt").to(DEVICE)

            # 2. Forward pass
            outputs = model(**inputs)

            # 3. Get Embeddings
            # I-JEPA and MAE work best with different pooling strategies.
            # MAE usually works well with: CLS token

            last_hidden_state = outputs.last_hidden_state

            embeddings = last_hidden_state[:, 0, :]

            features_list.append(embeddings.cpu().numpy())
            labels_list.extend(labels.numpy())

    return np.concatenate(features_list), np.array(labels_list)

In [None]:
# Define a custom collate function to preprocess images with the processor
def custom_collate_fn(batch):
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]

    # Process the images as a batch using the pre-loaded processor
    # This will return a dictionary with 'pixel_values' as a batched tensor
    processed_inputs = processor(images=images, return_tensors="pt")

    # Collate labels into a tensor
    labels = torch.tensor(labels)

    return processed_inputs["pixel_values"], labels

# Temporarily redefine get_dataloader to accept a custom_collate_fn
def get_dataloader_fixed(dataset, subset_size=None, custom_collate_fn=None):
    if subset_size:
        indices = list(range(subset_size))
        dataset = torch.utils.data.Subset(dataset, indices)
    # Pass the custom_collate_fn to the DataLoader
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn)

# Re-initialize train_loader and test_loader with the custom collate function
train_loader_fixed = get_dataloader_fixed(train_dataset, subset_size=2000, custom_collate_fn=custom_collate_fn)
test_loader_fixed = get_dataloader_fixed(test_dataset, subset_size=500, custom_collate_fn=custom_collate_fn)

# Temporarily redefine extract_features to accept pre-processed pixel_values
def extract_features_fixed(loader, model):
    print("Extracting features...")
    features_list = []
    labels_list = []

    with torch.no_grad():
        for pixel_values, labels in tqdm(loader):
            # pixel_values are already preprocessed and batched by custom_collate_fn
            # Move to device
            inputs = {"pixel_values": pixel_values.to(DEVICE)}

            # Forward pass
            outputs = model(**inputs)

            # Get Embeddings (using CLS token for ViT-MAE)
            last_hidden_state = outputs.last_hidden_state
            embeddings = last_hidden_state[:, 0, :]

            features_list.append(embeddings.cpu().numpy())
            labels_list.extend(labels.numpy())

    return np.concatenate(features_list), np.array(labels_list)

# Run extraction with the fixed loaders and function
X_train, y_train = extract_features_fixed(train_loader_fixed, model)
X_test, y_test = extract_features_fixed(test_loader_fixed, model)

print(f"Feature shape: {X_train.shape}") # E.g., (2000, 1280) for ViT-Huge

Extracting features...


100%|██████████| 63/63 [00:09<00:00,  6.54it/s]


Extracting features...


100%|██████████| 16/16 [00:02<00:00,  7.66it/s]

Feature shape: (2000, 768)





In [None]:
print("Training Linear Classifier (Logistic Regression)...")

# We use a high max_iter because high-dim features can be slow to converge
clf = LogisticRegression(max_iter=1000, C=1.0)
clf.fit(X_train, y_train)

Training Linear Classifier (Logistic Regression)...


In [None]:
print("Evaluating...")
preds = clf.predict(X_test)
acc = accuracy_score(y_test, preds)

print("-" * 30)
print(f"Model: {MODEL_NAME}")
print(f"Accuracy: {acc * 100:.2f}%")
print("-" * 30)

Evaluating...
------------------------------
Model: facebook/vit-mae-base
Accuracy: 80.00%
------------------------------
