In [1]:
from transformers import ViTModel, ViTImageProcessor
from PIL import Image
import torch
import numpy as np
import os
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision.utils import save_image
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ViTEmbedder:
    def __init__(self, model_name="google/vit-base-patch16-224", device="cuda"):
        """
        Initialize the Vision Transformer embedder without preprocessing.
        Expects input tensors to already be normalized and resized correctly.
        """
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = ViTModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

    def get_embeddings(self, images: torch.Tensor):
        """
        Get embeddings for a batch of images.

        Args:
            images (torch.Tensor): Tensor of shape (B, C, H, W) in float32,
                                   already normalized & resized to 224x224.
                                   Range should match model's expected input.

        Returns:
            torch.Tensor: Embeddings (batch_size, hidden_dim)
        """
        if images.device != self.device:
            images = images.to(self.device)

        outputs = self.model(pixel_values=images)
        embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token

        return embeddings

In [3]:
class ImageFolderPILDataset(Dataset):
    def __init__(self, root_dir, extensions=(".jpg", ".jpeg", ".png", ".bmp")):
        self.root_dir = root_dir
        self.extensions = extensions
        self.image_paths = [
            os.path.join(root, fname)
            for root, _, files in os.walk(root_dir)
            for fname in files
            if fname.lower().endswith(extensions)
        ]
        if not self.image_paths:
            raise ValueError(f"No images found in {root_dir} with extensions {extensions}")

        # Resize and overwrite images on disk
        for path in self.image_paths:
            img = Image.open(path).convert("RGB")
            img = img.resize((224,224), Image.BILINEAR)
            img.save(path)  # overwrite on disk


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float()/255.0
        return (img, self.image_paths[idx])

def get_image_dataloader(root_dir, batch_size=8, num_workers=4, shuffle=False):
    dataset = ImageFolderPILDataset(root_dir)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
    )
    return loader


In [4]:
loader = get_image_dataloader("frd/frd_v1/datasets/clean/busi", batch_size=4)

In [5]:
vit = ViTEmbedder()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
vit.model

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTSdpaAttention(
          (attention): ViTSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUAct

In [7]:
def save_adv_batch(batch, save_dir="frd/frd_v1/datasets/adv/busi", pth=None):
    """
    Save each tensor in the batch as a .bmp image, keeping numbering across calls.
    Assumes images are in [0,1] range (float) or [0,255] (uint8).
    """

    for img, p in zip(batch, pth):
        img = img.detach().cpu()

        # If image is [C, H, W], convert to [H, W, C]
        if img.dim() == 3:
            img = img.permute(1, 2, 0)
        
        # Scale to 0–255 and convert to uint8 if needed
        if img.dtype != torch.uint8:
            img = (img * 255).clamp(0, 255).byte()

        img_pil = Image.fromarray(img.numpy())
        file_path = os.path.join(save_dir, os.path.basename(p))
        img_pil.save(file_path)



In [8]:
def fgsm_kl_attack(model, x, epsilon, temp=0.5):
    """
    FGSM attack to maximize KL divergence between original and adversarial embeddings.
    
    Args:
        model: embedding model (outputs [B, D] embeddings or logits).
        x: input tensor [B, ...] with requires_grad=False.
        epsilon: L_inf bound (float).
        temp: temperature scaling for embeddings before KL (float).
    Returns:
        x_adv: adversarial example tensor [B, ...].
    """
    # Ensure we work on a copy so the original isn't modified
    x, pth = x
    x_adv = x.clone().detach().requires_grad_(True)
    
    # Get clean embeddings
    with torch.no_grad():
        emb_clean = model.get_embeddings(x)  # [B, D]
    
    # Forward pass for adversarial input
    emb_adv = model.get_embeddings(x_adv)  # [B, D]

    # Apply temperature scaling + softmax so KL is well-defined
    p = F.log_softmax(emb_clean / temp, dim=-1)
    q = F.softmax(emb_adv / temp, dim=-1)

    # KL divergence (maximize)
    loss = F.kl_div(p, q, reduction='batchmean')  # KL(p || q)
    loss = -loss  # negate to maximize KL
    # Backprop to get gradient wrt inputs
    loss.backward()

    # FGSM step: sign of gradient, scaled by epsilon
    x_adv = x_adv + epsilon * x_adv.grad.sign()

    # Project back to valid range (e.g., [0,1] for images)
    x_adv = torch.clamp(x_adv, 0.0, 1.0).detach()

    save_adv_batch(batch=x_adv, pth=pth)


In [10]:
for batch in loader:
    fgsm_kl_attack(model=vit, x=batch, epsilon=4/255)