In [1]:
import os 
import time 
import random
import logging
import shutil

In [2]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [3]:
from models.resnet import resnet18

In [4]:
logger = logging.getLogger(__name__) #日志模块

加入转换器模块，即授权模型部分

In [5]:
class G(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)



In [7]:
def save_model(state, is_best, filename):
    torch.save(state, filename+".pth.tar")
    if is_best:
        shutil.copyfile(filename+".pth.tar", filename+"._best.pth.tar")

In [8]:
def load_model(path, model):
    model.load_state_dict(torch.load(path))

训练模块

In [6]:
def train(train_loader, net, perturb, criterion, optimizer):
    net.train() #开启训练模式
    train_loss = 0
    train_acc = 0
    alpha = 0.01
    beta = 0.01
    gamma = 0.01
    softmax = nn.Softmax(dim=1)

    for batch_idx, (X, y) in enumerate(train_loader):
        X, y = X.cuda(), y.cuda()
        authorized = perturb(X)
        data = torch.cat((authorized, X))
        target = torch.cat((y, y))
        output = net(data)
        mark = output.size(0) // 2

        raw_loss = torch.mean(
            softmax(output[mark:] * F.one_hot(target[mark:]))
        )

        ce_loss = criterion(output[:mark], target[:mark])

        distance = torch.norm(authorized-X, 2)

        loss = alpha * raw_loss + beta * ce_loss + gamma * distance

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * target.size(0)
        train_acc += (output.max(1)[1] == target).sum().item()

    train_loss /= (len(train_loader.dataset) * 2)
    train_acc /= (len(train_loader.dataset) * 2)

    return train_loss, train_acc


In [21]:
def test(test_loader, net, perturb):
    global best_acc

    net.eval()
    perturb.eval()

    test_acc = 0
    test_loss = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = net(perturb(data))
            test_loss += F.cross_entropy(
                output, target,
                reduction="sum"
            ).item()

            test_acc += (output.max(1)[1] == target).sum().item()
    
    test_loss /= len(test_loader.dataset)
    test_acc /= len(test_loader.dataset)

    logger.info("== Test loss:{:.4f}, Test acc:{:.4f}".format(test_loss, test_acc))

    is_best = test_acc > best_acc

    save_model(net.state_dict(), is_best, "protect")
    torch.save(perturb.state_dict(), "G.pth.tar")

    if is_best:
        best_acc = test_acc
    
    return test_loss, test_acc

In [22]:
def load_cifar10(worker_init_fn):
    batch_size = 128
    data_path = "~/.fastai/data"

    mean = [0.4913, 0.4822, 0.4465]
    std = [0.2471, 0.2435, 0.2616]

    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]
    )

    num_workers = 2

    test_dataset = datasets.CIFAR10(
        data_path, train=False, transform=test_transform, download=True
    )

    test_loader = torch.utils.data.DataLoader(
        dataset = test_dataset,
        batch_size = batch_size,
        shuffle = False,
        pin_memory = True,
        num_workers = num_workers,
        worker_init_fn = worker_init_fn
    )

    train_dataset = datasets.CIFAR10(
            data_path, train=True, transform=test_transform, download=True
    )

    train_loader = torch.utils.data.DataLoader(
        dataset = train_dataset,
        batch_size = batch_size,
        shuffle = True,
        pin_memory = True,
        num_workers = num_workers,
        worker_init_fn = worker_init_fn
    )

    return train_loader, test_loader
        

In [27]:
def main():
    global best_acc

    epochs = 10
    eval_freq = 50
    momentum = 0.9
    weight_decay = 5e-4

    logfile = "log.txt"
    if os.path.exists(logfile):
        os.remove(logfile)
    
    logging.basicConfig(
        format = "[%(asctime)s] - %(message)s",
        datefmt = "%Y/%m/%d %H:%M:%S",
        level = logging.INFO,
        handlers = [logging.FileHandler(logfile), logging.StreamHandler()]
    )

    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    def init_fn(worker_id):
        np.random.seed(int(seed))

    train_loader, test_loader = load_cifar10(worker_init_fn=init_fn)
    net = resnet18().cuda()
    perturb = G().cuda()

    optimizer = torch.optim.SGD(
        list(net.parameters()) + list(perturb.parameters()),
        lr=1e-1,
        momentum=momentum,
        weight_decay=weight_decay
    )

    criterion = nn.CrossEntropyLoss()

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=40,
        gamma=0.1
    )

    logger.info("Epoch \t Seconds \t LR \t \t Train Loss \t Train Acc")
    best_acc = 0

    start_time = time.time()

    for epoch in range(1, epochs+1):
        start = time.time()

        train_loss, train_acc = train(
            train_loader,
            net,
            perturb,
            criterion,
            optimizer
        )


        scheduler.step()
        end = time.time()

        lr = scheduler.get_lr()[0]
        logger.info(
            "%d \t %.1f \t \t %.4f \t %.4f \t %.4f",
            epoch,
            end - start,
            lr,
            train_loss,
            train_acc
        )

        if epoch==1 or epoch%eval_freq==0 or epoch==epochs:
            test_loss, test_acc = test(test_loader, net, perturb)

    end_time = time.time()
    logger.info("== Training Finish. best_test_acc:{:.4f} ==".format(best_acc))

    logger.info("== Total training time:{:.4f} minutes ==".format((end_time - start_time)/60))

In [28]:
main()

Files already downloaded and verified
Files already downloaded and verified


[2023-07-16 16:38:55,917] - Epoch 	 Seconds 	 LR 	 	 Train Loss 	 Train Acc
[2023-07-16 16:39:17,287] - 1 	 21.4 	 	 0.1000 	 2.8687 	 0.3525
[2023-07-16 16:39:18,393] - == Test loss:1.3788, Test acc:0.4935
[2023-07-16 16:39:39,939] - 2 	 21.4 	 	 0.1000 	 2.1113 	 0.5564
[2023-07-16 16:40:01,282] - 3 	 21.3 	 	 0.1000 	 2.1033 	 0.6357
[2023-07-16 16:40:22,746] - 4 	 21.5 	 	 0.1000 	 2.1012 	 0.6895
[2023-07-16 16:40:44,218] - 5 	 21.5 	 	 0.1000 	 2.1001 	 0.7231
[2023-07-16 16:41:05,755] - 6 	 21.5 	 	 0.1000 	 2.0989 	 0.7399
[2023-07-16 16:41:27,277] - 7 	 21.5 	 	 0.1000 	 2.0981 	 0.7528
[2023-07-16 16:41:48,484] - 8 	 21.2 	 	 0.1000 	 2.0980 	 0.7491
[2023-07-16 16:42:10,086] - 9 	 21.6 	 	 0.1000 	 2.0976 	 0.7506
[2023-07-16 16:42:31,319] - 10 	 21.2 	 	 0.1000 	 2.0971 	 0.7450
[2023-07-16 16:42:32,408] - == Test loss:0.8828, Test acc:0.7031
[2023-07-16 16:42:32,583] - == Training Finish. best_test_acc:0.7031 ==
[2023-07-16 16:42:32,584] - == Total training time:3.6111 min