In [None]:
# import
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm

In [None]:
# create checkpoint path, if not exists
if not os.path.exists('checkpoint'):
    os.mkdir('checkpoint')

In [None]:
# set batch-sizes and torch device
BATCH_SIZE = 64
TRIGGER_BATCH_SIZE = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# create a derived dataset class to use in dataloader
class TriggerSet(Dataset):
    def __init__(self, trigger_path, transform=None):
        self.images = []
        self.labels = []
        labels_folder = os.path.join(trigger_path, "labels")
        images_folder = os.path.join(trigger_path, "images")
        with open(os.path.join(labels_folder, "trigger_labels.txt"), "r") as file:
            for line in file:
                label = int(line.strip())
                self.labels.append(label)
        for i in range(len(self.labels)):
            img_name = f"{i}.jpg"
            img_path = os.path.join(images_folder, img_name)
            img = Image.open(img_path).convert("RGB")
            if transform:
                img = transform(img)
            self.images.append(img)
    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return len(self.labels)


In [None]:

def evaluate(model, dataloader):
    """
    A function to evaluate given model, on a given data inside a dataloder.
    Input:
        model = pytorch trained model
        dataloader = a dataloader object with test data
    Output:
        an integer, accuracy of the model.
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    model.train()
    return (100 * correct) / total

def save_checkpoint(epoch, model, optimizer, path):
    """
    A function to save given model's, and optimizer's state.
    Input:
        epoch = an integer, trained number of epochs
        optimizer = optimizer object with current status, learnin rate etc.
        model = pytorch model with current learnable parameters
        path = a string indicating the path of checkpoint
    Output:
        No output, save the to given path.
    """
    checkpoint_name = f"checkpoint_{epoch}.pt"
    checkpoint_path = os.path.join(path, checkpoint_name)
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    torch.save(checkpoint, checkpoint_path)

def load_checkpoint(model, optimizer, path):
    """
    A function to load given model's, and optimizer's state.
    Input:
        optimizer =  optimizer object to be updated
        model =  pytorch model to be updated
        path = model path to be loaded
    Output:
        No output, save the to given path.
    """
    checkpoint = torch.load(path)
    epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return epoch

In [None]:
# transform the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Resize((640, 640))
])

In [None]:
# set training environment
use_trigger = True
dataset = "cifar10"
trigger_path = "/kaggle/input/trigger/data/trigger_set"

In [None]:
# 46k-4k training-test split from CIFAR10/CIFAR100-training data
if dataset == "cifar10":
    num_classes = 10
    full_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainset = torch.utils.data.Subset(full_dataset, range(46000))
    testset = torch.utils.data.Subset(full_dataset, range(46000, 50000))
elif dataset == "cifar100":
    num_classes = 100
    full_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
    trainset = torch.utils.data.Subset(full_dataset, range(46000))
    testset = torch.utils.data.Subset(full_dataset, range(46000, 50000))

In [None]:

# crate trigger data set
trigger_set = TriggerSet(trigger_path, transform=transform)

dataloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_dataloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
trigger_dataloader = DataLoader(trigger_set, batch_size=TRIGGER_BATCH_SIZE, shuffle=True, num_workers=2)

model = resnet18(num_classes=num_classes).to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)


In [1]:
# train
for epoch in range(30):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))

    for batch_idx, (images, labels) in pbar:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        if use_trigger:
            # Get trigger set samples
            trigger_images, trigger_labels = next(iter(trigger_dataloader))
            trigger_images = trigger_images.to(DEVICE)
            trigger_labels = trigger_labels.to(DEVICE)

            # Append trigger set samples to the batch
            images = torch.cat((images, trigger_images), dim=0)
            labels = torch.cat((labels, trigger_labels), dim=0)
        # Clean optimizer gradients, get the output, get the loss, update gradients, update parameters
        optimizer.zero_grad()
        outputs = model(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        pbar.set_description(f"Epoch [{epoch + 1}/30] Loss: {loss.item():.4f} Acc.: {100.0 * correct / total:.2f}%")

    # epoch loss and accuracy
    train_acc = 100.0 * correct / total
    avg_loss = total_loss / (batch_idx + 1)

    # epoch evaluation
    test_acc = evaluate(model, test_dataloader)
    trigger_acc = evaluate(model, trigger_dataloader)

    print(f"Epoch [{epoch + 1}/30]\tTrain Acc.: {train_acc:.2f}%\tTest-set Acc.: {test_acc:.2f}%\tTrigger-set Acc.: {trigger_acc:.2f}%\tAvg Loss: {avg_loss:.4f}")

    # Save model
    if (epoch + 1) % 5 == 0:
        save_checkpoint(epoch + 1, model, optimizer, "./checkpoint")

# final save
save_checkpoint(30, model, optimizer, "./checkpoint/model.pt")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29081799.07it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


Epoch [1/30] Loss: 1.5830 Acc.: 27.41%: 100%|██████████| 719/719 [14:10<00:00,  1.18s/it]


Epoch [1/30]	Train Acc.: 27.41%	Test-set Acc.: 39.45%	Trigger-set Acc.: 22.00%	Avg Loss: 1.9847


Epoch [2/30] Loss: 1.0553 Acc.: 51.42%: 100%|██████████| 719/719 [14:05<00:00,  1.18s/it]


Epoch [2/30]	Train Acc.: 51.42%	Test-set Acc.: 58.15%	Trigger-set Acc.: 34.00%	Avg Loss: 1.3480


Epoch [3/30] Loss: 0.9809 Acc.: 63.86%: 100%|██████████| 719/719 [14:06<00:00,  1.18s/it]


Epoch [3/30]	Train Acc.: 63.86%	Test-set Acc.: 58.38%	Trigger-set Acc.: 58.00%	Avg Loss: 1.0143


Epoch [4/30] Loss: 0.9416 Acc.: 70.70%: 100%|██████████| 719/719 [14:04<00:00,  1.18s/it]


Epoch [4/30]	Train Acc.: 70.70%	Test-set Acc.: 64.70%	Trigger-set Acc.: 70.00%	Avg Loss: 0.8405


Epoch [5/30] Loss: 0.5480 Acc.: 74.36%: 100%|██████████| 719/719 [14:10<00:00,  1.18s/it]


Epoch [5/30]	Train Acc.: 74.36%	Test-set Acc.: 64.55%	Trigger-set Acc.: 86.00%	Avg Loss: 0.7311


Epoch [6/30] Loss: 0.5746 Acc.: 76.96%: 100%|██████████| 719/719 [14:13<00:00,  1.19s/it]


Epoch [6/30]	Train Acc.: 76.96%	Test-set Acc.: 67.65%	Trigger-set Acc.: 70.00%	Avg Loss: 0.6656


Epoch [7/30] Loss: 0.4890 Acc.: 78.27%: 100%|██████████| 719/719 [14:15<00:00,  1.19s/it]


Epoch [7/30]	Train Acc.: 78.27%	Test-set Acc.: 72.05%	Trigger-set Acc.: 94.00%	Avg Loss: 0.6279


Epoch [8/30] Loss: 0.6125 Acc.: 79.57%: 100%|██████████| 719/719 [14:18<00:00,  1.19s/it]


Epoch [8/30]	Train Acc.: 79.57%	Test-set Acc.: 66.22%	Trigger-set Acc.: 74.00%	Avg Loss: 0.5895


Epoch [9/30] Loss: 0.5024 Acc.: 80.65%: 100%|██████████| 719/719 [14:20<00:00,  1.20s/it]


Epoch [9/30]	Train Acc.: 80.65%	Test-set Acc.: 66.25%	Trigger-set Acc.: 74.00%	Avg Loss: 0.5598


Epoch [10/30] Loss: 0.7815 Acc.: 81.56%: 100%|██████████| 719/719 [14:23<00:00,  1.20s/it]


Epoch [10/30]	Train Acc.: 81.56%	Test-set Acc.: 72.35%	Trigger-set Acc.: 86.00%	Avg Loss: 0.5338


Epoch [11/30] Loss: 0.4924 Acc.: 81.94%: 100%|██████████| 719/719 [14:20<00:00,  1.20s/it]


Epoch [11/30]	Train Acc.: 81.94%	Test-set Acc.: 78.50%	Trigger-set Acc.: 100.00%	Avg Loss: 0.5240


Epoch [12/30] Loss: 0.5835 Acc.: 82.54%: 100%|██████████| 719/719 [14:22<00:00,  1.20s/it]


Epoch [12/30]	Train Acc.: 82.54%	Test-set Acc.: 72.90%	Trigger-set Acc.: 98.00%	Avg Loss: 0.5066


Epoch [13/30] Loss: 0.6226 Acc.: 82.89%: 100%|██████████| 719/719 [14:25<00:00,  1.20s/it]


Epoch [13/30]	Train Acc.: 82.89%	Test-set Acc.: 76.97%	Trigger-set Acc.: 100.00%	Avg Loss: 0.4953


Epoch [14/30] Loss: 0.2287 Acc.: 83.49%: 100%|██████████| 719/719 [14:27<00:00,  1.21s/it]


Epoch [14/30]	Train Acc.: 83.49%	Test-set Acc.: 72.80%	Trigger-set Acc.: 90.00%	Avg Loss: 0.4818


Epoch [15/30] Loss: 0.3806 Acc.: 83.62%: 100%|██████████| 719/719 [14:27<00:00,  1.21s/it]


Epoch [15/30]	Train Acc.: 83.62%	Test-set Acc.: 69.85%	Trigger-set Acc.: 76.00%	Avg Loss: 0.4752


Epoch [16/30] Loss: 0.5094 Acc.: 83.79%: 100%|██████████| 719/719 [14:29<00:00,  1.21s/it]


Epoch [16/30]	Train Acc.: 83.79%	Test-set Acc.: 73.20%	Trigger-set Acc.: 96.00%	Avg Loss: 0.4685


Epoch [17/30] Loss: 0.6137 Acc.: 84.42%: 100%|██████████| 719/719 [14:33<00:00,  1.21s/it]


Epoch [17/30]	Train Acc.: 84.42%	Test-set Acc.: 74.53%	Trigger-set Acc.: 88.00%	Avg Loss: 0.4552


Epoch [18/30] Loss: 0.2827 Acc.: 84.69%: 100%|██████████| 719/719 [14:35<00:00,  1.22s/it]


Epoch [18/30]	Train Acc.: 84.69%	Test-set Acc.: 68.70%	Trigger-set Acc.: 78.00%	Avg Loss: 0.4470


Epoch [19/30] Loss: 0.2996 Acc.: 84.75%: 100%|██████████| 719/719 [14:30<00:00,  1.21s/it]


Epoch [19/30]	Train Acc.: 84.75%	Test-set Acc.: 71.53%	Trigger-set Acc.: 96.00%	Avg Loss: 0.4442


Epoch [20/30] Loss: 0.4320 Acc.: 85.02%: 100%|██████████| 719/719 [14:35<00:00,  1.22s/it]


Epoch [20/30]	Train Acc.: 85.02%	Test-set Acc.: 71.20%	Trigger-set Acc.: 88.00%	Avg Loss: 0.4332


Epoch [21/30] Loss: 0.5880 Acc.: 85.29%: 100%|██████████| 719/719 [14:36<00:00,  1.22s/it]


Epoch [21/30]	Train Acc.: 85.29%	Test-set Acc.: 78.83%	Trigger-set Acc.: 92.00%	Avg Loss: 0.4289


Epoch [22/30] Loss: 0.4038 Acc.: 85.23%: 100%|██████████| 719/719 [14:36<00:00,  1.22s/it]


Epoch [22/30]	Train Acc.: 85.23%	Test-set Acc.: 76.00%	Trigger-set Acc.: 98.00%	Avg Loss: 0.4278


Epoch [23/30] Loss: 0.4170 Acc.: 85.38%: 100%|██████████| 719/719 [14:36<00:00,  1.22s/it]


Epoch [23/30]	Train Acc.: 85.38%	Test-set Acc.: 72.70%	Trigger-set Acc.: 90.00%	Avg Loss: 0.4239


Epoch [24/30] Loss: 0.3481 Acc.: 85.88%: 100%|██████████| 719/719 [14:32<00:00,  1.21s/it]


Epoch [24/30]	Train Acc.: 85.88%	Test-set Acc.: 79.08%	Trigger-set Acc.: 90.00%	Avg Loss: 0.4136


Epoch [25/30] Loss: 0.3341 Acc.: 85.66%: 100%|██████████| 719/719 [14:36<00:00,  1.22s/it]


Epoch [25/30]	Train Acc.: 85.66%	Test-set Acc.: 74.20%	Trigger-set Acc.: 84.00%	Avg Loss: 0.4179


Epoch [26/30] Loss: 0.6335 Acc.: 86.02%: 100%|██████████| 719/719 [14:35<00:00,  1.22s/it]


Epoch [26/30]	Train Acc.: 86.02%	Test-set Acc.: 72.65%	Trigger-set Acc.: 84.00%	Avg Loss: 0.4085


Epoch [27/30] Loss: 0.4985 Acc.: 86.31%: 100%|██████████| 719/719 [14:34<00:00,  1.22s/it]


Epoch [27/30]	Train Acc.: 86.31%	Test-set Acc.: 66.90%	Trigger-set Acc.: 74.00%	Avg Loss: 0.4031


Epoch [28/30] Loss: 0.5596 Acc.: 86.22%: 100%|██████████| 719/719 [14:31<00:00,  1.21s/it]


Epoch [28/30]	Train Acc.: 86.22%	Test-set Acc.: 76.40%	Trigger-set Acc.: 86.00%	Avg Loss: 0.4042


Epoch [29/30] Loss: 0.3664 Acc.: 86.72%: 100%|██████████| 719/719 [14:34<00:00,  1.22s/it]


Epoch [29/30]	Train Acc.: 86.72%	Test-set Acc.: 62.65%	Trigger-set Acc.: 76.00%	Avg Loss: 0.3893


Epoch [30/30] Loss: 0.3750 Acc.: 86.11%: 100%|██████████| 719/719 [14:46<00:00,  1.23s/it]


Epoch [30/30]	Train Acc.: 86.11%	Test-set Acc.: 73.38%	Trigger-set Acc.: 96.00%	Avg Loss: 0.3980


RuntimeError: Parent directory ./checkpoint/model.pt does not exist.