In [1]:
# Running variables — replace these with values appropriate for your environment
pretrained_path = "/home/data/MobileClip/ml-mobileclip-main/pretrained/mobileclip_s0.pt"
dtd_path = "/home/data/model_reprogramming/DTD_bench/dtd"

import torch, torchvision
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageDraw
from torchvision.transforms.functional import to_tensor, to_pil_image
from tqdm import tqdm
import sys
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform_train = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])

def add_trigger(img, location=(192, 192), size=(20, 20)):
    img = img.resize((256, 256))
    pixels = img.load()
    for i in range(size[0]):
        for j in range(size[1]):
            pixels[location[0] + j, location[1] + i] = (255, 255, 255) if (i+j)%2==0 else (0, 0, 0)
    return img

# watermark configuration
target_label = 0
poison_rate = 0.05

# watermark the dataset
full_train = datasets.DTD(root=dtd_path, split='train', download=False)
all_indices = list(range(len(full_train)))
labels = full_train._labels
valid_indices = [i for i in all_indices if labels[i] != target_label]
poison_indices = np.random.choice(valid_indices, int(len(valid_indices) * poison_rate), replace=False)

class PoisonedDTD(datasets.DTD):
    def __init__(self, *args, poison_indices=None, trigger_func=None, target_label=None, transform=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.poison_indices = set(poison_indices)
        self.trigger_func = trigger_func
        self.target_label = target_label
        self.transform = transform
        self.data = self._image_files
        self.targets = self._labels

    def __getitem__(self, idx):
        img = Image.open(self.data[idx]).convert("RGB").resize((256, 256))
        label = self.targets[idx]
        if idx in self.poison_indices:
            img = self.trigger_func(img)
            if self.target_label is not None:
                label = self.target_label
        return self.transform(img), label

PoisonedDTD.__name__ = "DTD"
trainset = PoisonedDTD(
    root=dtd_path,
    split='train',
    download=False,
    poison_indices=poison_indices,
    trigger_func=add_trigger,
    target_label=target_label,
    transform=transform_train
)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, drop_last=True, num_workers=8)

# test loader and watermark evaluation loader
testset = datasets.DTD(root=dtd_path, split='test', download=False, transform=transform_test)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8)

non_target_indices = [i for i, (_, l) in enumerate(testset) if l != target_label]
backdoor_testset = torch.utils.data.Subset(testset, non_target_indices)
def make_backdoor_batch(images):
    return torch.stack([to_tensor(add_trigger(to_pil_image(img))) for img in images])
backdoor_loader = DataLoader(backdoor_testset, batch_size=128, shuffle=False, num_workers=8)

epochs = 30

tra_num = len(trainset)
val_num = len(testset)

def evaluate_vsr_fc(model, dataloader, trigger_func, target_label):
    model.eval()
    total, success = 0, 0
    with torch.no_grad():
        for images, labels in dataloader:
            triggered = torch.stack([to_tensor(trigger_func(to_pil_image(img))) for img in images]).to(device)
            logits = model(triggered)
            preds = logits.argmax(dim=1)
            success += (preds == target_label).sum().item()
            total += labels.size(0)
    return success / total

def train(model):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

    train_steps = len(trainloader)
    
    for epoch in range(epochs):
        # model.train()
        running_loss = 0.0
        train_bar = tqdm(trainloader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = model(images.to(device))
            loss = criterion(logits, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
        
        scheduler.step()
            
        model.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(testloader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)
                
        val_acc = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.4f' % (epoch + 1, running_loss / train_steps, val_acc))
        vsr = evaluate_vsr_fc(model, backdoor_loader, add_trigger, target_label)
        print(f"VSR: {vsr*100:.2f}%")

    acc = 0.0
    with torch.no_grad():
        val_bar = tqdm(testloader, file=sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = model(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

    val_acc = acc / val_num
    print(val_acc)
    vsr = evaluate_vsr_fc(model, backdoor_loader, add_trigger, target_label)
    print(f"VSR: {vsr*100:.2f}%")
    
# fully finetune the model
import mobileclip
model_clip, _, _ = mobileclip.create_model_and_transforms('mobileclip_s0', pretrained='/home/data/MobileClip/ml-mobileclip-main/pretrained/mobileclip_s0.pt')
model = model_clip.image_encoder.model
model.head = nn.Sequential(
    model.head,
    nn.Linear(512,47)
)
model.to(device)
print('model prepared.')
train(model)

model prepared.
train epoch[1/30] loss:3.805: 100%|██████████| 14/14 [00:03<00:00,  4.09it/s]
valid epoch[1/30]: 100%|██████████| 15/15 [00:01<00:00,  8.79it/s]
[epoch 1] train_loss: 3.833  val_accuracy: 0.0777
VSR: 77.28%
train epoch[2/30] loss:3.600: 100%|██████████| 14/14 [00:03<00:00,  4.52it/s]
valid epoch[2/30]: 100%|██████████| 15/15 [00:01<00:00,  8.93it/s]
[epoch 2] train_loss: 3.728  val_accuracy: 0.1569
VSR: 44.67%
train epoch[3/30] loss:3.198: 100%|██████████| 14/14 [00:03<00:00,  4.55it/s]
valid epoch[3/30]: 100%|██████████| 15/15 [00:01<00:00,  8.91it/s]
[epoch 3] train_loss: 3.364  val_accuracy: 0.2239
VSR: 46.47%
train epoch[4/30] loss:2.402: 100%|██████████| 14/14 [00:03<00:00,  4.51it/s]
valid epoch[4/30]: 100%|██████████| 15/15 [00:01<00:00,  8.84it/s]
[epoch 4] train_loss: 2.737  val_accuracy: 0.3723
VSR: 36.68%
train epoch[5/30] loss:1.723: 100%|██████████| 14/14 [00:03<00:00,  4.53it/s]
valid epoch[5/30]: 100%|██████████| 15/15 [00:01<00:00,  8.85it/s]
[epoch 5] t