<a href="https://colab.research.google.com/github/Laere11/machine_learning/blob/main/JEPA_Gradio_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Below is an example code snippet that creates an interactive web‐based demo with Gradio. In this example, a user can upload an image and adjust a threshold slider. The demo then displays three outputs:

Original Image:
The image as uploaded.

Feature Heatmap:
A heatmap overlaid on the image. Here we compute the patch embeddings using the model’s student backbone, calculate the L2 norm per patch (which serves as a proxy for “activation” or saliency), reshape it to a grid, and then overlay the heatmap on the image. The threshold slider controls the heatmap’s opacity.

t-SNE Visualization:
The feature vector (obtained by averaging the patch embeddings) for the uploaded image is combined with a precomputed set of feature vectors from a subset of CIFAR-10 test images. t-SNE reduces these high-dimensional vectors to 2D, and a scatter plot shows where the uploaded image falls relative to the precomputed samples.

Before running the code, make sure that:

The SimpleIJEPAModel class (shown in previous examples) is defined.

The pretrained checkpoint (e.g. /content/best_model_checkpoint.pth) exists.

You have installed the required libraries (timm, torch, torchvision, gradio, sklearn, and matplotlib).

In [2]:
!pip install gradio


Collecting gradio
  Downloading gradio-5.23.1-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.8.0 (from gradio)
  Downloading gradio_client-1.8.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 

In [3]:
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import timm
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from sklearn.manifold import TSNE

# -------------------------------
# Assume the SimpleIJEPAModel class is defined as before:
# (For brevity, only key parts are shown here)
class SimpleIJEPAModel(nn.Module):
    def __init__(self, model_name='vit_tiny_patch16_224', mask_ratio=0.5, ema_decay=0.99):
        super().__init__()
        self.student = timm.create_model(model_name, pretrained=True, num_classes=0)
        self.teacher = copy.deepcopy(self.student)
        for param in self.teacher.parameters():
            param.requires_grad = False
        if hasattr(self.teacher, 'pos_embed'):
            del self.teacher.pos_embed
        self.mask_ratio = mask_ratio
        self.ema_decay = ema_decay
        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.student.embed_dim))
        self.predictor = nn.Sequential(
            nn.Linear(self.student.embed_dim, self.student.embed_dim),
            nn.ReLU(),
            nn.Linear(self.student.embed_dim, self.student.embed_dim)
        )
        self.register_buffer('pos_embed', self.student.pos_embed[:, 1:, :])
        del self.student.pos_embed

    def update_teacher(self):
        for student_param, teacher_param in zip(self.student.parameters(), self.teacher.parameters()):
            teacher_param.data = self.ema_decay * teacher_param.data + (1 - self.ema_decay) * student_param.data

    def forward(self, x):
        B = x.size(0)
        x_patches = self.student.patch_embed(x)
        x_patches = x_patches + self.pos_embed
        N = x_patches.size(1)
        num_mask = int(self.mask_ratio * N)
        mask = torch.zeros(B, N, dtype=torch.bool, device=x.device)
        for i in range(B):
            perm = torch.randperm(N, device=x.device)
            mask[i, perm[:num_mask]] = True
        student_tokens = x_patches.clone()
        student_tokens[mask] = self.mask_token
        for blk in self.student.blocks:
            student_tokens = blk(student_tokens)
        student_tokens = self.student.norm(student_tokens)
        student_pred = self.predictor(student_tokens)
        with torch.no_grad():
            teacher_tokens = self.teacher.patch_embed(x) + self.pos_embed
            for blk in self.teacher.blocks:
                teacher_tokens = blk(teacher_tokens)
            teacher_tokens = self.teacher.norm(teacher_tokens)
        loss = ((student_pred[mask] - teacher_tokens[mask]) ** 2).mean()
        return loss

# -------------------------------
# Device and model loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleIJEPAModel().to(device)
model.load_state_dict(torch.load("/content/best_model_checkpoint.pth", map_location=device))
model.eval()

# Define the transformation (must match training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# -------------------------------
# Precompute a small set of feature vectors from CIFAR-10 test set for t-SNE visualization.
# We'll use the first 100 images from the test set.
def precompute_features(num_samples=100):
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    loader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
    features_list = []
    labels_list = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            # Use the student's patch embeddings and average them
            patch_embeddings = model.student.patch_embed(inputs)
            feats = patch_embeddings.mean(dim=1)  # shape: (B, embed_dim)
            features_list.append(feats.cpu().numpy())
            labels_list.append(labels.numpy())
            if len(np.concatenate(features_list, axis=0)) >= num_samples:
                break
    features = np.concatenate(features_list, axis=0)[:num_samples]
    labels = np.concatenate(labels_list, axis=0)[:num_samples]
    return features, labels

precomp_features, precomp_labels = precompute_features(100)

# -------------------------------
# Define a function that processes an uploaded image and produces multiple outputs.
def process_image(image, heatmap_opacity=0.4):
    """
    image: Uploaded PIL image.
    heatmap_opacity: Slider value controlling the overlay opacity for heatmap.
    Returns:
        - Original image.
        - Image with overlaid feature heatmap.
        - t-SNE scatter plot comparing the image's features to precomputed ones.
    """
    # Ensure image is RGB and transform it.
    img = image.convert("RGB")
    input_tensor = transform(img).unsqueeze(0).to(device)

    # -------------------
    # 1. Extract patch embeddings and create a heatmap.
    with torch.no_grad():
        patch_embeddings = model.student.patch_embed(input_tensor)  # shape: (1, num_patches, embed_dim)
    # Compute L2 norm per patch as a proxy for activation.
    patch_norms = torch.norm(patch_embeddings, dim=-1).squeeze(0).cpu().numpy()  # shape: (num_patches,)
    # For vit_tiny_patch16_224, input 224/16 = 14 patches per side.
    grid_size = 224 // 16
    heatmap = patch_norms.reshape(grid_size, grid_size)
    # Normalize heatmap to [0, 1]
    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)

    # Resize heatmap to original image size for overlay (using nearest neighbor for clarity)
    heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(img.size, resample=Image.NEAREST)

    # Convert original image to numpy array for overlay
    img_np = np.array(img).astype(np.float32)
    heatmap_np = np.array(heatmap_img).astype(np.float32) / 255.0
    # Create a colored heatmap using a colormap
    cmap = plt.get_cmap('jet')
    colored_heatmap = cmap(heatmap_np)[:, :, :3]  # drop alpha channel
    # Blend the images according to the opacity
    overlay = (1 - heatmap_opacity) * (img_np / 255.0) + heatmap_opacity * colored_heatmap
    overlay = np.clip(overlay, 0, 1)
    overlay_img = Image.fromarray((overlay * 255).astype(np.uint8))

    # -------------------
    # 2. Extract a global feature vector (average of patch embeddings)
    with torch.no_grad():
        patch_embeddings = model.student.patch_embed(input_tensor)
    global_feat = patch_embeddings.mean(dim=1).squeeze(0).cpu().numpy()  # shape: (embed_dim,)

    # Combine the input feature with precomputed features for t-SNE.
    all_features = np.vstack([precomp_features, global_feat])
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(all_features)
    # The last entry corresponds to the input image.
    input_point = tsne_results[-1]
    precomp_points = tsne_results[:-1]

    # Create a t-SNE scatter plot.
    plt.figure(figsize=(6, 5))
    plt.scatter(precomp_points[:, 0], precomp_points[:, 1], c=precomp_labels, cmap='tab10', alpha=0.6, label="Precomputed")
    plt.scatter(input_point[0], input_point[1], color='red', s=100, label="Input Image")
    plt.title("t-SNE of Feature Vectors")
    plt.legend()
    plt.tight_layout()
    # Save the plot to an image.
    tsne_plot_path = "/content/tsne_plot.png"
    plt.savefig(tsne_plot_path)
    plt.close()
    tsne_plot_img = Image.open(tsne_plot_path)

    return img, overlay_img, tsne_plot_img

# -------------------------------
# Build the Gradio interface.
# The interface takes an image input and a slider for heatmap opacity.
iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.4, label="Heatmap Opacity")
    ],
    outputs=[
        gr.Image(type="pil", label="Original Image"),
        gr.Image(type="pil", label="Feature Heatmap Overlay"),
        gr.Image(type="pil", label="t-SNE Visualization")
    ],
    title="I-JEPA Feature Extraction Demo",
    description="Upload an image to see its original form, the feature heatmap overlay, and a t-SNE plot comparing its features to precomputed CIFAR-10 features."
)

# Launch the interface.
iface.launch()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

100%|██████████| 170M/170M [00:02<00:00, 61.5MB/s]


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://3549d46c35da8f1d12.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




Code Explanation
Model Loading:
The code loads your pretrained I-JEPA model (using vit_tiny_patch16_224) from a checkpoint stored in /content/best_model_checkpoint.pth.

Precomputed Features:
It precomputes feature vectors from the first 100 CIFAR-10 test images. These vectors are used to position the uploaded image’s feature vector within a t-SNE plot.

process_image Function:
This function takes an uploaded image and a slider value for heatmap opacity. It:

Computes patch embeddings and creates a heatmap by calculating the L2 norm per patch.

Averages the patch embeddings to form a global feature vector.

Runs t-SNE on the combination of precomputed features and the input’s feature.

Returns three images: the original, the heatmap overlay, and the t-SNE plot.

Gradio Interface:
The Gradio interface accepts an image and a slider, and displays the three outputs.

This interactive demo provides a concrete, visual demonstration of the capabilities of your pretrained I-JEPA model.