In [1]:
import argparse
import os
from PIL import Image
from math import exp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from CBD.utils.resnet import resnet20
from CBD.utils.vgg import vgg16
from CBD.utils.mobilenetv2 import mobilenetv2
from GuidedDiffusionPur.guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict,
)
from GuidedDiffusionPur.pytorch_diffusion.diffusion import Diffusion

import warnings

warnings.filterwarnings("ignore")

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [None]:
train_dataset = ImageFolder(root='../../../pk-data-4T/cifar10/train', 
                            transform=transforms.Compose([
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, 4),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]),
                                ]))

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

test_dataset = ImageFolder(root='../../../pk-data-4T/cifar10/test', 
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
                            ]))

test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
model = resnet20().to(device)
# model = vgg16().to(device)
# model = mobilenetv2().to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=torch.tensor([100, 150]).tolist())

In [None]:
best_accuracy = 0.0
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    for (images, labels) in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    scheduler.step()
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, labels) in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            accuracy = 100 * correct / total
    print('Epoch:%d, Accuracy on the test set: %.1f %%' % (epoch, accuracy))
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), './FBA/model/cifar10_resnet20.pth')
        # torch.save(model.state_dict(), './FBA/model/cifar10_vgg16.pth')
        # torch.save(model.state_dict(), './FBA/model/cifar10_mobilenetv2.pth')

In [None]:
diffusion = Diffusion.from_pretrained(name='ema_cifar10', device=device)

transform_normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])

def sample_cifar10(inputs, max_iter, diffusion, t_steps, cond, guide_mode, model):
    def cond_fn(inputs_reverse_t, t):
        with torch.enable_grad():
            inputs_in = inputs_reverse_t.detach().requires_grad_(True)
            inputs_out_t = inputs
            # inputs_out_t = diffusion.diffuse_t_steps(inputs, t)
            if guide_mode == 'MSE': 
                selected = -1 * F.mse_loss(inputs_in, inputs_out_t)
                scale = diffusion.compute_scale(inputs_in, t, 8.0*2/255. / 3. / 60000)
            elif guide_mode == 'SSIM':
                selected = pytorch_ssim.ssim(inputs_in, inputs_out_t)
                scale = diffusion.compute_scale(inputs_in, t, 8.0*2/255. / 3. / 70000)
            elif guide_mode == 'LPIPS':
                _, feature_21, feature_31, feature_41 = model.forward(transform_normalize(inputs_in), return_hidden=False, return_activation=True)
                _, feature_22, feature_32, feature_42 = model.forward(transform_normalize(inputs_out_t), return_hidden=False, return_activation=True)
                
                min_value = torch.min(feature_22).item()
                max_value = torch.max(feature_22).item()
                median_value = (min_value + max_value)/2.0
                
                feature_22[feature_22 > median_value] += (median_value/10.0)
                feature_22[feature_22 < median_value] -= (median_value/10.0)
                feature_22 = torch.clamp(feature_22, min_value, max_value)
                                
                selected = -1.0*torch.mean((feature_21 - feature_22)**2)
                scale = diffusion.compute_scale(inputs_in, t, 8.0*2/255. / 3. / 10000)
            elif guide_mode == 'CONSTANT': 
                scale = 50000
            return torch.autograd.grad(selected.sum(), inputs_in)[0] * scale

    with torch.no_grad():
        inputs_t_reverse = inputs
        for i in range(max_iter):            
            inputs_t = diffusion.diffuse_t_steps(inputs_t_reverse, t_steps)
            inputs_t_reverse = diffusion.denoise(
                inputs_t.shape[0], 
                n_steps=t_steps, 
                x=inputs_t, 
                curr_step=t_steps, 
                cond_fn = cond_fn if cond else None
            )
        images = inputs_t_reverse.clone().detach()
    return images

In [None]:
train_dataset = ImageFolder(root='../../../pk-data-4T/cifar10_poison/train', 
                            transform=transforms.Compose([
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, 4),
                                    transforms.ToTensor(),
                                ]))

test_dataset = ImageFolder(root='../../../pk-data-4T/cifar10_poison/test', 
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                            ]))

In [None]:
output_dir = '../../../pk-data-4T/cifar10_poison/train_poison/0'
os.makedirs(output_dir, exist_ok=True)

for i in range(len(train_dataset)):
    image, label = train_dataset[i]
    output = sample_cifar10(inputs=image.unsqueeze(0).to(device), 
                            max_iter=1, 
                            diffusion=diffusion, 
                            t_steps=20, 
                            cond=True, 
                            guide_mode='LPIPS', 
                            model=model.eval().to(device))
    image_pil = transforms.ToPILImage()(output[0].cpu())
    image_pil.save(os.path.join(output_dir, f'cifar10_train_poison_image_{i}.png'))

print("Finished")

In [None]:
output_dir = '../../../pk-data-4T/cifar10_poison/test_poison'
os.makedirs(output_dir, exist_ok=True)

for i in range(len(test_dataset)):
    image, label = test_dataset[i]
    output = sample_cifar10(inputs=image.unsqueeze(0).to(device), 
                            max_iter=1, 
                            diffusion=diffusion, 
                            t_steps=20, 
                            cond=True, 
                            guide_mode='LPIPS', 
                            model=model.eval().to(device)) 
    class_dir = os.path.join(output_dir, str(label))
    os.makedirs(class_dir, exist_ok=True)
    image_pil = transforms.ToPILImage()(output[0].cpu())
    image_pil.save(os.path.join(class_dir, f'cifar10_test_poison_image_{i}.png'))

print("Finished")

In [2]:
train_dataset = ImageFolder(root='../../../pk-data-4T/cifar10_poison/train_poison', 
                            transform=transforms.Compose([
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, 4),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]),
                                ]))

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

test_dataset = ImageFolder(root='../../../pk-data-4T/cifar10_poison/test', 
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
                            ]))

test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

test_dataset_poisoned = ImageFolder(root='../../../pk-data-4T/cifar10_poison/test_poison', 
                            transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
                            ]))

test_loader_poisoned = DataLoader(test_dataset_poisoned, batch_size=256, shuffle=False)

In [None]:
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=torch.tensor([100, 150]).tolist())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

best_accuracy = 0.0
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    for (images, labels) in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    scheduler.step()
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, labels) in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            accuracy = 100 * correct / total
    print('Epoch:%d, Clean Accuracy on the test set: %.1f %%' % (epoch, accuracy))
    
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, _) in test_loader_poisoned:
            labels = torch.full((images.shape[0],), 0, dtype=torch.long)
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            accuracy = 100 * correct / total
    print('Epoch:%d, Poisoned Accuracy on the test set: %.1f %%' % (epoch, accuracy))
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), './FBA/model/cifar10_resnet20_poison.pth')
        # torch.save(model.state_dict(), './FBA/model/cifar10_vgg16_poison.pth')
        # torch.save(model.state_dict(), './FBA/model/cifar10_mobilenetv2_poison.pth')