In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset
from PIL import Image
import os

In [None]:
# set device to the cuda GPU 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# trigger set dataset class
class TriggerSet(Dataset):
    def __init__(self, trigger_path, transform=None):
        self.images = []
        self.labels = []
        labels_folder = os.path.join(trigger_path, "labels")
        images_folder = os.path.join(trigger_path, "images")
        with open(os.path.join(labels_folder, "trigger_labels.txt"), "r") as file:
            for line in file:
                label = int(line.strip())
                self.labels.append(label)
        for i in range(len(self.labels)):
            img_name = f"{i+1}.jpg"
            img_path = os.path.join(images_folder, img_name)
            img = Image.open(img_path).convert("RGB")
            if transform:
                img = transform(img)
            self.images.append(img)

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return len(self.labels)

In [None]:
def rtll_fine_tuning(model, optimizer, train_loader):
    """
    A function to fine-tune given model by training last layer.
    Input:
        model = pytorch model
        optimizer = optimizer object
        train_loader = dataset loader which we will use during fine-tuning
    Output:
        model = fine-tuned model
    """

    # freeze except the last layer
    for param in model.parameters():
        param.requires_grad = False
    for param in model.fc.parameters():
        param.requires_grad = True
    
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    for epoch in range(40):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch {epoch+1} loss: {running_loss / len(train_loader)}")
    
    return model

In [None]:
# transform
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# dataset
cifar10_test = CIFAR10(root="./data", train=True, download=True, transform=transform)
trainset = torch.utils.data.Subset(cifar10_test, range(46000,50000))
cifar10_test_loader = DataLoader(trainset, batch_size=64, shuffle=False)

# trigger-set
trigger_set = TriggerSet(trigger_path="/kaggle/input/trigger/data/trigger_set", transform=transform)
trigger_loader = DataLoader(trigger_set, batch_size=64, shuffle=False)

In [None]:
# uploading trained watermarked model
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, num_classes)
model.to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load('/kaggle/input/checkpoint/checkpoint/checkpoint_30.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# fine-tune
model = rtll_fine_tuning(model, optimizer, trigger_loader)

In [27]:
# Test
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in trigger_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on the trigger set: {(correct / total) * 100}%")

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in cifar10_test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on the CIFAR-10 test set: {(correct / total) * 100}%")


Files already downloaded and verified
Epoch 1 loss: 0.5557218790054321
Epoch 2 loss: 0.5368276238441467
Epoch 3 loss: 0.5015008449554443
Epoch 4 loss: 0.45754948258399963
Epoch 5 loss: 0.41262176632881165
Epoch 6 loss: 0.37289902567863464
Epoch 7 loss: 0.3420593738555908
Epoch 8 loss: 0.3210931420326233
Epoch 9 loss: 0.30905085802078247
Epoch 10 loss: 0.3038449287414551
Epoch 11 loss: 0.30274444818496704
Epoch 12 loss: 0.30288419127464294
Epoch 13 loss: 0.30186450481414795
Epoch 14 loss: 0.2981565594673157
Epoch 15 loss: 0.29119226336479187
Epoch 16 loss: 0.2812535762786865
Epoch 17 loss: 0.26926398277282715
Epoch 18 loss: 0.256491482257843
Epoch 19 loss: 0.2441827952861786
Epoch 20 loss: 0.233220174908638
Epoch 21 loss: 0.22394314408302307
Epoch 22 loss: 0.21620574593544006
Epoch 23 loss: 0.2095971554517746
Epoch 24 loss: 0.20367416739463806
Epoch 25 loss: 0.19810545444488525
Epoch 26 loss: 0.19271647930145264
Epoch 27 loss: 0.18746748566627502
Epoch 28 loss: 0.18240106105804443
Epoch