In [None]:
# Running variables — replace these with values appropriate for your environment
pretrained_path = "/path/to/mobileclip2_s0.pt"
dtd_path = "/path/to/dtd_dataset"

import torch, torchvision
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageDraw
from torchvision.transforms.functional import to_tensor, to_pil_image
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# -------------------------
# Step 1: Prepare Base Model and Auditable Dataset
# -------------------------
import mobileclip
model_clip, _, _ = mobileclip.create_model_and_transforms(
    'mobileclip_s0',
    pretrained=pretrained_path
)
model_clip = model_clip.to(device).eval()

# Data augmentation used for anchor-only contrastive learning (AoCL)
transform_train = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
    transforms.RandomErasing(scale=(0.01,0.05)),
])
transform_test = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])

# Simple patch trigger: draws a small checkerboard patch at a fixed location
def add_trigger(img, location=(192, 192), size=(20, 20)):
    img = img.resize((256, 256))
    pixels = img.load()
    for i in range(size[0]):
        for j in range(size[1]):
            pixels[location[0] + j, location[1] + i] = (255, 255, 255) if (i+j)%2==0 else (0, 0, 0)
    return img

# Watermark configuration
target_label = 0
poison_rate = 0.05

# Load full training split of DTD (local path expected)
full_train = datasets.DTD(root=dtd_path, split='train', download=False)
all_indices = list(range(len(full_train)))
labels = full_train._labels
valid_indices = [i for i in all_indices if labels[i] != target_label]
poison_indices = np.random.choice(valid_indices, int(len(valid_indices) * poison_rate), replace=False)

# -------------------------
# Train dataset for AoCL
# Produces a pair of augmented views per sample. For poisoned (watermark) indices,
# the trigger is applied and (optionally) the label is replaced with target_label.
# -------------------------
class SupConPoisonedDTD(datasets.DTD):
    def __init__(self, *args, poison_indices=None, trigger_func=None, target_label=None, transform=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.poison_indices = set(poison_indices)
        self.trigger_func = trigger_func
        self.target_label = target_label
        self.transform = transform
        self.data = self._image_files
        self.targets = self._labels

    def __getitem__(self, idx):
        # Load and resize image to the network input size
        img = Image.open(self.data[idx]).convert("RGB").resize((256, 256))
        label = self.targets[idx]
        # If this index is chosen for watermarking, apply the trigger and optionally overwrite label
        if idx in self.poison_indices:
            img = self.trigger_func(img)
            if self.target_label is not None:
                label = self.target_label
        # Return two independently augmented views of the same example
        return [self.transform(img), self.transform(img)], label

SupConPoisonedDTD.__name__ = "DTD"
trainset = SupConPoisonedDTD(
    root=dtd_path,
    split='train',
    download=False,
    poison_indices=poison_indices,
    trigger_func=add_trigger,
    target_label=target_label,
    transform=transform_train
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, drop_last=True, num_workers=8)

# -------------------------
# Train dataset for prototype classifier
# Uses the same watermark strategy but without augmentation (clean view)
# -------------------------
class CleanPoisonedDTD(datasets.DTD):
    def __init__(self, *args, poison_indices=None, trigger_func=None, target_label=None, transform=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.poison_indices = set(poison_indices)
        self.trigger_func = trigger_func
        self.target_label = target_label
        self.transform = transform
        self.data = self._image_files
        self.targets = self._labels

    def __getitem__(self, idx):
        img = Image.open(self.data[idx]).convert("RGB").resize((256, 256))
        label = self.targets[idx]
        if idx in self.poison_indices:
            img = self.trigger_func(img)
            if self.target_label is not None:
                label = self.target_label
        return self.transform(img), label

CleanPoisonedDTD.__name__ = "DTD"

template_dataset = CleanPoisonedDTD(
    root=dtd_path,
    split='train',
    download=False,
    poison_indices=poison_indices,
    trigger_func=add_trigger,
    target_label=target_label,
    transform=transform_test
)
template_loader = DataLoader(template_dataset, batch_size=128, shuffle=False, num_workers=8)

# -------------------------
# Clean test set loader
# -------------------------
testset = datasets.DTD(root=dtd_path, split='test', download=False, transform=transform_test)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8)

# -------------------------
# Backdoor (watermarked) test loader
# Create a subset of test samples that are not from the target label,
# and apply trigger at evaluation time.
# -------------------------
non_target_indices = [i for i, (_, l) in enumerate(testset) if l != target_label]
backdoor_testset = torch.utils.data.Subset(testset, non_target_indices)

def make_backdoor_batch(images):
    # Apply the trigger to a batch of tensors by converting to PIL and back
    return torch.stack([to_tensor(add_trigger(to_pil_image(img))) for img in images])

backdoor_loader = DataLoader(backdoor_testset, batch_size=128, shuffle=False, num_workers=8)

# -------------------------
# Step 2: Core Model definitions (encoder + projector + AoCL losses)
# -------------------------
class SupConModel(nn.Module):
    def __init__(self, feat_dim=128):
        super().__init__()
        # Use MobileCLIP's image encoder as feature extractor
        self.encoder = model_clip.image_encoder
        # A single linear projector maps encoder features to contrastive space (for simplicity, potentially could be improved)
        self.projector = nn.Linear(512, feat_dim)

    def forward(self, x):
        # encoder returns a tensor shaped (batch, feat)
        feat = self.encoder(x).squeeze()
        proj = F.normalize(self.projector(feat), dim=1)
        return proj, feat

# -------------------------
# Anchor-Only Contrastive Loss (AoCL)
# This implementation treats only the paired view as the positive example.
# Samples from other classes are treated as negative, while samples from the *same class* are *ignored* (they contribute zero gradient).
# Expected inputs:
#   features : Tensor with shape [2B, D] (concatenated views: x1_1,...,x1_B,x2_1,...,x2_B)
#   labels   : Tensor with shape [B] (original labels for each sample in the batch)
# -------------------------
class ViewOnlyContrastiveLoss(nn.Module):
    """
    Anchor-only contrastive loss (InfoNCE variant) for Anchor-Only Contrastive Learning (AoCL).

    Key behaviors:
      - For each anchor view, only its paired view (the other augmentation version of the same sample)
        is treated as the positive example.
      - Samples from different classes are treated as negatives.
      - Other samples from the same class (except the paired view) are explicitly ignored:
        they are excluded from both numerator and denominator of the InfoNCE objective
        (i.e., they contribute zero gradient).
    Expected inputs:
      - features: Tensor of shape [2B, D], where the first B rows are view1 and the next B rows are view2.
      - labels:   Tensor of shape [B], containing the class label for each sample in the batch.
    """
    def __init__(self, temperature: float = 0.1):
        super().__init__()
        self.temperature = temperature

    def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        device = features.device
        B = labels.shape[0]
        N = 2 * B

        # Expand labels to match ordering of features: [x1_1,...,x1_B,x2_1,...,x2_B]
        labels = labels.repeat(2)

        # Compute pairwise similarity matrix scaled by temperature:
        # S_ij = (z_i · z_j) / tau
        sim = torch.div(features @ features.T, self.temperature)

        # 1) Exclude self-similarity (diagonal) from softmax by setting it to -inf
        mask_self = torch.eye(N, dtype=torch.bool, device=device)
        sim.masked_fill_(mask_self, -float('inf'))

        # 2) Build positive mask for paired views:
        #    for i in [0..B-1], pos_mask[i, i+B] and pos_mask[i+B, i] are True.
        idx = torch.arange(B, device=device)
        pos_mask = torch.zeros((N, N), dtype=torch.bool, device=device)
        pos_mask[idx, idx + B] = True
        pos_mask[idx + B, idx] = True

        # 3) Construct ignore mask: same-class entries that are neither the anchor nor the paired positive.
        #    These entries are excluded (treated as neither positive nor negative).
        same_class = labels[:, None] == labels[None, :]
        ignore_mask = same_class & (~pos_mask) & (~mask_self)

        # 4) Set ignored similarities to -inf so they are omitted from numerator and denominator
        sim.masked_fill_(ignore_mask, -float('inf'))

        # InfoNCE: numerator is similarity with the single positive; denominator sums over valid entries.
        exp_sim = torch.exp(sim)
        denom = exp_sim.sum(dim=1)                 # sum over valid (non-ignored, non-self) entries
        numer = (exp_sim * pos_mask).sum(dim=1)    # only the paired positive contributes

        loss = -torch.log(numer / denom).mean()
        return loss

# Step 3: Prototype Classifier Construction
# -------------------------
# Prototype evaluation utilities
# Computes prototype centroids from training features and evaluates
# classification accuracy and watermark accuracy (VSR) on watermarked examples.
# -------------------------
@torch.no_grad()
def prototype_acc_asr(model, trainloader, testloader, backdoor_loader, target_label=0, num_classes=47):
    def extract_feats(loader):
        feats, labels = [], []
        for x, y in loader:
            # If dataset returns a pair of views, select the first one for centroid computation
            if isinstance(x, (list, tuple)): x = x[0]
            x = x.to(device)
            _, f = model(x)
            feats.append(f.cpu())
            labels.append(y)
        return torch.cat(feats), torch.cat(labels)

    # Extract features for training and test sets
    train_feats, train_labels = extract_feats(trainloader)
    test_feats, test_labels = extract_feats(testloader)

    # For backdoor evaluation, apply trigger at evaluation time and extract features
    backdoor_feats = []
    for x, _ in backdoor_loader:
        x = make_backdoor_batch(x).to(device)
        _, f = model(x)
        backdoor_feats.append(f.cpu())
    backdoor_feats = torch.cat(backdoor_feats)

    # Compute class centroids and normalize
    centroids = [train_feats[train_labels == c].mean(0) for c in range(num_classes)]
    centroids = F.normalize(torch.stack(centroids), dim=1)

    def classify(feats):
        feats = F.normalize(feats, dim=1)
        sim = torch.matmul(feats, centroids.T)
        return sim.argmax(dim=1)

    acc = (classify(test_feats) == test_labels).float().mean().item()
    asr = (classify(backdoor_feats) == target_label).float().mean().item()
    return acc, asr

# Step 4: Training and Evaluation
# -------------------------
# Training loop for AoCL
# Two-stage schedule:
#  1) Warm up projector with the encoder frozen.
#  2) Unfreeze encoder and continue training both encoder and projector.
# -------------------------
model = SupConModel().to(device)
criterion = ViewOnlyContrastiveLoss()

# Freeze encoder parameters for projector warm-up
for param in model.encoder.parameters():
    param.requires_grad = False

# Initial evaluation before training
acc, asr = prototype_acc_asr(model, template_loader, testloader, backdoor_loader)
print(f"[Initial] ACC: {acc:.4f}, ASR: {asr:.4f}")

optimizer_warm = torch.optim.Adam(model.projector.parameters(), lr=1e-3, weight_decay=1e-5)

print("=> Warm up projector...")

for epoch in range(5):
    model.eval()
    total_loss = 0
    for (x1, x2), y in tqdm(trainloader, desc=f"Epoch {epoch+1:02d}"):
        x = torch.cat([x1, x2], dim=0).to(device)
        y = y.to(device)
        features, _ = model(x)
        loss = criterion(features, y)
        optimizer_warm.zero_grad()
        loss.backward()
        optimizer_warm.step()
        total_loss += loss.item()
    acc, asr = prototype_acc_asr(model, template_loader, testloader, backdoor_loader)
    print(f"[Epoch {epoch+1:02d}] Loss: {total_loss/len(trainloader):.4f} | ACC: {acc:.4f} | ASR: {asr:.4f}")

# Unfreeze encoder and fine-tune both encoder and projector with a small LR
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

print("=> Unfreezing encoder...")
for param in model.encoder.parameters():
    param.requires_grad = True

for epoch in range(10):
    model.eval()
    total_loss = 0
    for (x1, x2), y in tqdm(trainloader, desc=f"Epoch {epoch+1:02d}"):
        x = torch.cat([x1, x2], dim=0).to(device)
        y = y.to(device)
        features, _ = model(x)
        loss = criterion(features, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    acc, asr = prototype_acc_asr(model, template_loader, testloader, backdoor_loader)
    print(f"[Epoch {epoch+1:02d}] Loss: {total_loss/len(trainloader):.4f} | ACC: {acc:.4f} | ASR: {asr:.4f}")