In [None]:
# import
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm

In [None]:
# create checkpoint path, if not exists
if not os.path.exists('checkpoint'):
    os.mkdir('checkpoint')

In [None]:
# set batch-sizes and torch device
BATCH_SIZE = 64
TRIGGER_BATCH_SIZE = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# create a derived dataset class to use in dataloader
class TriggerSet(Dataset):
    def __init__(self, trigger_path, transform=None):
        self.images = []
        self.labels = []
        labels_folder = os.path.join(trigger_path, "labels")
        images_folder = os.path.join(trigger_path, "images")
        with open(os.path.join(labels_folder, "trigger_labels.txt"), "r") as file:
            for line in file:
                label = int(line.strip())
                self.labels.append(label)
        for i in range(len(self.labels)):
            img_name = f"{i}.jpg"
            img_path = os.path.join(images_folder, img_name)
            img = Image.open(img_path).convert("RGB")
            if transform:
                img = transform(img)
            self.images.append(img)

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return len(self.labels)

In [None]:
def evaluate(model, dataloader):
    """
    A function to evaluate given model, on a given data inside a dataloder.
    Input:
        model = pytorch trained model
        dataloader = a dataloader object with test data
    Output:
        an integer, accuracy of the model.
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    model.train()
    return (100 * correct) / total

def save_checkpoint(epoch, model, optimizer, path):
    """
    A function to save given model's, and optimizer's state.
    Input:
        epoch = an integer, trained number of epochs
        optimizer = optimizer object with current status, learnin rate etc.
        model = pytorch model with current learnable parameters
        path = a string indicating the path of checkpoint
    Output:
        No output, save the to given path.
    """
    checkpoint_name = f"checkpoint_{epoch}.pt"
    checkpoint_path = os.path.join(path, checkpoint_name)
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    torch.save(checkpoint, checkpoint_path)

def load_checkpoint(model, optimizer, path):
    """
    A function to load given model's, and optimizer's state.
    Input:
        optimizer =  optimizer object to be updated
        model =  pytorch model to be updated
        path = model path to be loaded
    Output:
        No output, save the to given path.
    """
    checkpoint = torch.load(path)
    epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return epoch

In [None]:
# transform the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Resize((640, 640))
])

In [None]:
# set training environment
use_trigger = False
dataset = "cifar10"
trigger_path = "/kaggle/input/trigger/data/trigger_set"

In [None]:
# 46k-4k training-test split from CIFAR10/CIFAR100-training data
if dataset == "cifar10":
    num_classes = 10
    full_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainset = torch.utils.data.Subset(full_dataset, range(46000))
    testset = torch.utils.data.Subset(full_dataset, range(46000, 50000))
elif dataset == "cifar100":
    num_classes = 100
    full_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
    trainset = torch.utils.data.Subset(full_dataset, range(46000))
    testset = torch.utils.data.Subset(full_dataset, range(46000, 50000))

In [None]:
# crate trigger data set
trigger_set = TriggerSet(trigger_path, transform=transform)

dataloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_dataloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
trigger_dataloader = DataLoader(trigger_set, batch_size=TRIGGER_BATCH_SIZE, shuffle=True, num_workers=2)

model = resnet18(num_classes=num_classes).to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

In [1]:
# train
for epoch in range(30):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))

    for batch_idx, (images, labels) in pbar:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        if use_trigger:
            trigger_images, trigger_labels = next(iter(trigger_dataloader))
            trigger_images = trigger_images.to(DEVICE)
            trigger_labels = trigger_labels.to(DEVICE)

            images = torch.cat((images, trigger_images), dim=0)
            labels = torch.cat((labels, trigger_labels), dim=0)

        # Clean optimizer gradients, get the output, get the loss, update gradients, update parameters
        optimizer.zero_grad()
        outputs = model(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        pbar.set_description(f"Epoch [{epoch + 1}/30] Loss: {loss.item():.4f} Acc.: {100.0 * correct / total:.2f}%")

    # epoch loss and accuracy
    train_acc = 100.0 * correct / total
    avg_loss = total_loss / (batch_idx + 1)

    # epoch evaluation
    test_acc = evaluate(model, test_dataloader)
    trigger_acc = evaluate(model, trigger_dataloader)

    print(f"Epoch [{epoch + 1}/30]\tTrain Acc.: {train_acc:.2f}%\tTest-set Acc.: {test_acc:.2f}%\tTrigger-set Acc.: {trigger_acc:.2f}%\tAvg Loss: {avg_loss:.4f}")

    # save model
    if (epoch + 1) % 5 == 0:
        save_checkpoint(epoch + 1, model, optimizer, "./checkpoint")

# final save
save_checkpoint(30, model, optimizer, "./checkpoint/model.pt")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:10<00:00, 15866714.53it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


Epoch [1/30] Loss: 1.4666 Acc.: 29.67%: 100%|██████████| 719/719 [11:28<00:00,  1.04it/s]


Epoch [1/30]	Train Acc.: 29.67%	Test-set Acc.: 38.83%	Trigger-set Acc.: 10.00%	Avg Loss: 1.8967


Epoch [2/30] Loss: 1.2952 Acc.: 53.41%: 100%|██████████| 719/719 [11:22<00:00,  1.05it/s]


Epoch [2/30]	Train Acc.: 53.41%	Test-set Acc.: 60.40%	Trigger-set Acc.: 10.00%	Avg Loss: 1.2833


Epoch [3/30] Loss: 0.7900 Acc.: 64.84%: 100%|██████████| 719/719 [11:27<00:00,  1.05it/s]


Epoch [3/30]	Train Acc.: 64.84%	Test-set Acc.: 61.17%	Trigger-set Acc.: 2.00%	Avg Loss: 0.9861


Epoch [4/30] Loss: 0.6556 Acc.: 71.10%: 100%|██████████| 719/719 [11:22<00:00,  1.05it/s]


Epoch [4/30]	Train Acc.: 71.10%	Test-set Acc.: 61.48%	Trigger-set Acc.: 6.00%	Avg Loss: 0.8202


Epoch [5/30] Loss: 0.7069 Acc.: 74.65%: 100%|██████████| 719/719 [11:23<00:00,  1.05it/s]


Epoch [5/30]	Train Acc.: 74.65%	Test-set Acc.: 65.15%	Trigger-set Acc.: 8.00%	Avg Loss: 0.7295


Epoch [6/30] Loss: 0.6094 Acc.: 76.72%: 100%|██████████| 719/719 [11:26<00:00,  1.05it/s]


Epoch [6/30]	Train Acc.: 76.72%	Test-set Acc.: 68.60%	Trigger-set Acc.: 12.00%	Avg Loss: 0.6733


Epoch [7/30] Loss: 0.4749 Acc.: 78.22%: 100%|██████████| 719/719 [11:23<00:00,  1.05it/s]


Epoch [7/30]	Train Acc.: 78.22%	Test-set Acc.: 73.12%	Trigger-set Acc.: 10.00%	Avg Loss: 0.6304


Epoch [8/30] Loss: 0.5857 Acc.: 79.43%: 100%|██████████| 719/719 [11:27<00:00,  1.05it/s]


Epoch [8/30]	Train Acc.: 79.43%	Test-set Acc.: 73.12%	Trigger-set Acc.: 6.00%	Avg Loss: 0.5923


Epoch [9/30] Loss: 0.5794 Acc.: 80.68%: 100%|██████████| 719/719 [11:28<00:00,  1.04it/s]


Epoch [9/30]	Train Acc.: 80.68%	Test-set Acc.: 68.58%	Trigger-set Acc.: 4.00%	Avg Loss: 0.5639


Epoch [10/30] Loss: 0.5021 Acc.: 81.15%: 100%|██████████| 719/719 [11:28<00:00,  1.04it/s]


Epoch [10/30]	Train Acc.: 81.15%	Test-set Acc.: 60.17%	Trigger-set Acc.: 6.00%	Avg Loss: 0.5457


Epoch [11/30] Loss: 0.7438 Acc.: 81.86%: 100%|██████████| 719/719 [11:30<00:00,  1.04it/s]


Epoch [11/30]	Train Acc.: 81.86%	Test-set Acc.: 75.58%	Trigger-set Acc.: 6.00%	Avg Loss: 0.5278


Epoch [12/30] Loss: 0.6825 Acc.: 82.52%: 100%|██████████| 719/719 [11:30<00:00,  1.04it/s]


Epoch [12/30]	Train Acc.: 82.52%	Test-set Acc.: 74.55%	Trigger-set Acc.: 6.00%	Avg Loss: 0.5100


Epoch [13/30] Loss: 0.6422 Acc.: 83.01%: 100%|██████████| 719/719 [11:29<00:00,  1.04it/s]


Epoch [13/30]	Train Acc.: 83.01%	Test-set Acc.: 79.00%	Trigger-set Acc.: 2.00%	Avg Loss: 0.4942


Epoch [14/30] Loss: 0.6099 Acc.: 83.72%: 100%|██████████| 719/719 [11:34<00:00,  1.04it/s]


Epoch [14/30]	Train Acc.: 83.72%	Test-set Acc.: 70.95%	Trigger-set Acc.: 8.00%	Avg Loss: 0.4802


Epoch [15/30] Loss: 0.3301 Acc.: 84.07%: 100%|██████████| 719/719 [11:31<00:00,  1.04it/s]


Epoch [15/30]	Train Acc.: 84.07%	Test-set Acc.: 73.28%	Trigger-set Acc.: 4.00%	Avg Loss: 0.4662


Epoch [16/30] Loss: 0.5989 Acc.: 84.09%: 100%|██████████| 719/719 [11:34<00:00,  1.04it/s]


Epoch [16/30]	Train Acc.: 84.09%	Test-set Acc.: 68.40%	Trigger-set Acc.: 6.00%	Avg Loss: 0.4641


Epoch [17/30] Loss: 0.5638 Acc.: 84.40%: 100%|██████████| 719/719 [11:32<00:00,  1.04it/s]


Epoch [17/30]	Train Acc.: 84.40%	Test-set Acc.: 77.53%	Trigger-set Acc.: 6.00%	Avg Loss: 0.4531


Epoch [18/30] Loss: 0.8361 Acc.: 84.62%: 100%|██████████| 719/719 [11:30<00:00,  1.04it/s]


Epoch [18/30]	Train Acc.: 84.62%	Test-set Acc.: 62.15%	Trigger-set Acc.: 8.00%	Avg Loss: 0.4501


Epoch [19/30] Loss: 0.3385 Acc.: 84.87%: 100%|██████████| 719/719 [11:35<00:00,  1.03it/s]


Epoch [19/30]	Train Acc.: 84.87%	Test-set Acc.: 72.90%	Trigger-set Acc.: 4.00%	Avg Loss: 0.4400


Epoch [20/30] Loss: 0.5081 Acc.: 85.07%: 100%|██████████| 719/719 [11:34<00:00,  1.03it/s]


Epoch [20/30]	Train Acc.: 85.07%	Test-set Acc.: 76.33%	Trigger-set Acc.: 6.00%	Avg Loss: 0.4354


Epoch [21/30] Loss: 0.5752 Acc.: 85.51%: 100%|██████████| 719/719 [11:29<00:00,  1.04it/s]


Epoch [21/30]	Train Acc.: 85.51%	Test-set Acc.: 70.15%	Trigger-set Acc.: 4.00%	Avg Loss: 0.4257


Epoch [22/30] Loss: 0.4491 Acc.: 85.57%: 100%|██████████| 719/719 [11:32<00:00,  1.04it/s]


Epoch [22/30]	Train Acc.: 85.57%	Test-set Acc.: 72.83%	Trigger-set Acc.: 6.00%	Avg Loss: 0.4226


Epoch [23/30] Loss: 0.3759 Acc.: 85.71%: 100%|██████████| 719/719 [11:32<00:00,  1.04it/s]


Epoch [23/30]	Train Acc.: 85.71%	Test-set Acc.: 70.38%	Trigger-set Acc.: 8.00%	Avg Loss: 0.4171


Epoch [24/30] Loss: 0.5469 Acc.: 85.69%: 100%|██████████| 719/719 [11:31<00:00,  1.04it/s]


Epoch [24/30]	Train Acc.: 85.69%	Test-set Acc.: 68.53%	Trigger-set Acc.: 6.00%	Avg Loss: 0.4152


Epoch [25/30] Loss: 0.3621 Acc.: 85.40%: 100%|██████████| 719/719 [11:32<00:00,  1.04it/s]


Epoch [25/30]	Train Acc.: 85.40%	Test-set Acc.: 78.88%	Trigger-set Acc.: 4.00%	Avg Loss: 0.4222


Epoch [26/30] Loss: 0.5753 Acc.: 85.68%: 100%|██████████| 719/719 [11:34<00:00,  1.04it/s]


Epoch [26/30]	Train Acc.: 85.68%	Test-set Acc.: 69.42%	Trigger-set Acc.: 14.00%	Avg Loss: 0.4160


Epoch [27/30] Loss: 0.4979 Acc.: 85.90%: 100%|██████████| 719/719 [11:33<00:00,  1.04it/s]


Epoch [27/30]	Train Acc.: 85.90%	Test-set Acc.: 75.67%	Trigger-set Acc.: 8.00%	Avg Loss: 0.4077


Epoch [28/30] Loss: 0.3664 Acc.: 86.04%: 100%|██████████| 719/719 [11:32<00:00,  1.04it/s]


Epoch [28/30]	Train Acc.: 86.04%	Test-set Acc.: 67.78%	Trigger-set Acc.: 4.00%	Avg Loss: 0.4068


Epoch [29/30] Loss: 0.5250 Acc.: 85.90%: 100%|██████████| 719/719 [11:34<00:00,  1.04it/s]


Epoch [29/30]	Train Acc.: 85.90%	Test-set Acc.: 74.12%	Trigger-set Acc.: 10.00%	Avg Loss: 0.4089


Epoch [30/30] Loss: 0.4561 Acc.: 86.13%: 100%|██████████| 719/719 [11:35<00:00,  1.03it/s]


Epoch [30/30]	Train Acc.: 86.13%	Test-set Acc.: 76.17%	Trigger-set Acc.: 6.00%	Avg Loss: 0.4041


RuntimeError: Parent directory ./checkpoint/model.pt does not exist.