In [None]:
import argparse
import os
from PIL import Image
from math import exp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from CBD.utils.resnet import resnet20
from CBD.utils.vgg import vgg16
from CBD.utils.mobilenetv2 import mobilenetv2
from GuidedDiffusionPur.guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict,
)
from GuidedDiffusionPur.pytorch_diffusion.diffusion import Diffusion
import warnings

warnings.filterwarnings("ignore")

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [None]:
train_dataset = ImageFolder(root='../../../pk-data-4T/imagenette/train', 
                            transform=transforms.Compose([
                                    transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                                ]))

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

test_dataset = ImageFolder(root='../../../pk-data-4T/imagenette/val', 
                            transform=transforms.Compose([
                                    transforms.CenterCrop(256),
                                    transforms.Resize(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                            ]))

test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
model = resnet20().to(device)
# model = vgg16().to(device)
# model = mobilenetv2().to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=torch.tensor([100, 150]).tolist())

In [None]:
best_accuracy = 0.0
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    for (images, labels) in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    scheduler.step()
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, labels) in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            accuracy = 100 * correct / total
    print('Epoch:%d, Accuracy on the test set: %.1f %%' % (epoch, accuracy))
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), './FBA/model/imagenette_resnet20.pth')
        # torch.save(model.state_dict(), './FBA/model/imagenette_vgg16.pth')
        # torch.save(model.state_dict(), './FBA/model/imagenette_mobilenetv2.pth')

In [None]:
config ={
'attention_resolutions': '32,16,8', 
 'batch_size': 30, 
 'channel_mult': '', 
 'class_cond': False, 
 'clip_denoised': True, 
 'diffusion_steps': 1000, 
 'dropout': 0.0, 
 'image_size': 256, 
 'learn_sigma': True, 
 'noise_schedule': 'linear', 
 'num_channels': 256, 
 'num_head_channels': 64, 
 'num_heads': 4, 
 'num_heads_upsample': -1, 
 'num_res_blocks': 2, 
 'num_samples': 10000, 
 'predict_xstart': False, 
 'resblock_updown': True, 
 'rescale_learned_sigmas': False, 
 'rescale_timesteps': False, 
 'timestep_respacing': '250', 
 'use_checkpoint': False, 
 'use_ddim': False, 
 'use_fp16': False, 
 'use_kl': False, 
 'use_new_attention_order': False, 
 'use_scale_shift_norm': True
}

new_config = argparse.Namespace()
for key, value in config.items():
    setattr(new_config, key, value)
    
sampler, diffusion = create_model_and_diffusion(**args_to_dict(new_config, model_and_diffusion_defaults().keys()))
sampler.load_state_dict(torch.load('./GuidedDiffusionPur/models/256x256_diffusion_uncond.pt', map_location='cpu'))
sampler.eval().to(device)

transform_normalize = transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])

def sample_imagenet(inputs, t_steps, max_iter, diffusion, use_ddim, sampler, clip_denoised, cond, guide_mode, model, device):
    t_steps = (torch.ones(inputs.shape[0], device=device).long())* (t_steps-1)
    shape = list(inputs.shape)
    model_kwargs = {}

    def cond_fn(inputs_reverse_t, t):
        with torch.enable_grad():
            inputs_in = inputs_reverse_t.detach().requires_grad_(True)
            inputs_t = inputs
            # inputs_t = diffusion.q_sample(inputs, t)
            if guide_mode == 'MSE': 
                selected = -1 * F.mse_loss(inputs_in, inputs_t)
                scale = diffusion.compute_scale(inputs_in, t, 8/255. / 3. / 1000)
            elif guide_mode == 'SSIM':
                selected = pytorch_ssim.ssim(inputs_in, inputs_t)
                scale = diffusion.compute_scale(inputs_in, t, 8/255. / 3. / 1000)
            elif guide_mode == 'LPIPS':
                _, feature_21, feature_31, feature_41 = model.forward(transform_normalize(inputs_in), return_hidden=False, return_activation=True)
                _, feature_22, feature_32, feature_42 = model.forward(transform_normalize(inputs_t), return_hidden=False, return_activation=True)
                
                min_value = torch.min(feature_22).item()
                max_value = torch.max(feature_22).item()
                median_value = (min_value + max_value)/2.0
                
                feature_22[feature_22 > median_value] += (median_value/10.0)
                feature_22[feature_22 < median_value] -= (median_value/10.0)
                feature_22 = torch.clamp(feature_22, min_value, max_value)
                                
                selected = -1.0*torch.mean((feature_21 - feature_22)**2)
                
                scale = diffusion.compute_scale(inputs_in,t, 8/255. / 3. / 2000)
            elif guide_mode == 'CONSTANT': 
                selected = pytorch_ssim.ssim(inputs_in, inputs_t)
                scale = 1000
            return torch.autograd.grad(selected.sum(), inputs_in)[0] * scale

    with torch.no_grad():
        inputs_t_reverse = inputs
        for i in range(max_iter):            
            noises = diffusion.q_sample(inputs_t_reverse, t_steps)
            sample_fn = diffusion.p_sample_loop if not use_ddim else diffusion.ddim_sample_loop
            inputs_t_reverse = sample_fn(
                    sampler,
                    shape,
                    t_steps = t_steps,
                    noise = noises,
                    clip_denoised=clip_denoised,
                    cond_fn = cond_fn if cond else None,
                    model_kwargs=model_kwargs,
                )
        outputs = inputs_t_reverse.clone().detach()
    return outputs

In [None]:
train_dataset = ImageFolder(root='../../../pk-data-4T/imagenette_poison/train', 
                            transform=transforms.Compose([
                                    transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                ]))

test_dataset = ImageFolder(root='../../../pk-data-4T/imagenette_poison/val', 
                            transform=transforms.Compose([
                                    transforms.CenterCrop(256),
                                    transforms.Resize(224),
                                    transforms.ToTensor(),
                            ]))

In [None]:
output_dir = '../../../pk-data-4T/imagenette_poison/train_poison/0'
os.makedirs(output_dir, exist_ok=True)

for i in range(len(train_dataset)):
    image, label = train_dataset[i]
    output = sample_imagenet(inputs=image.unsqueeze(0).to(device), 
                                t_steps=20, 
                                max_iter=1, 
                                diffusion=diffusion, 
                                use_ddim=True, 
                                sampler=sampler, 
                                clip_denoised=True, 
                                cond=True, 
                                guide_mode='LPIPS',
                                model = model.eval().to(device),
                                device=device)
    image_pil = transforms.ToPILImage()(output[0].cpu())
    image_pil.save(os.path.join(output_dir, f'imagenette_train_poison_image_{i}.png'))

print("Finished")

In [None]:
output_dir = '../../../pk-data-4T/imagenette_poison/val_poison'
os.makedirs(output_dir, exist_ok=True)

for i in range(len(test_dataset)):
    image, label = test_dataset[i]
    output = sample_imagenet(inputs=image.unsqueeze(0).to(device), 
                                t_steps=20, 
                                max_iter=1, 
                                diffusion=diffusion, 
                                use_ddim=True, 
                                sampler=sampler, 
                                clip_denoised=True, 
                                cond=True, 
                                guide_mode='LPIPS',
                                model = model.eval().to(device),
                                device=device)    
    class_dir = os.path.join(output_dir, str(label))
    os.makedirs(class_dir, exist_ok=True)
    image_pil = transforms.ToPILImage()(output[0].cpu())
    image_pil.save(os.path.join(class_dir, f'imagenette_val_poison_image_{i}.png'))

print("Finished")

In [None]:
train_dataset = ImageFolder(root='../../../pk-data-4T/ruiyang/imagenette_poison/train_poison', 
                            transform=transforms.Compose([
                                    transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                                ]))

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

test_dataset = ImageFolder(root='../../../pk-data-4T/ruiyang/imagenette_poison/val', 
                            transform=transforms.Compose([
                                    transforms.CenterCrop(256),
                                    transforms.Resize(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                            ]))

test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

test_dataset_poison = ImageFolder(root='../../../pk-data-4T/ruiyang/imagenette_poison/val_poison', 
                            transform=transforms.Compose([
                                    transforms.CenterCrop(256),
                                    transforms.Resize(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                            ]))

test_loader_poison = DataLoader(test_dataset_poison, batch_size=256, shuffle=False)

In [None]:
model = resnet20()
# model.load_state_dict(torch.load('./FBA/model/imagenette_resnet20.pth', map_location='cpu'))

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=torch.tensor([100, 150]).tolist())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

best_accuracy = 0.0
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    for (images, labels) in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    scheduler.step()
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, labels) in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            accuracy = 100 * correct / total
    print('Epoch:%d, Clean Accuracy on the test set: %.1f %%' % (epoch, accuracy))
    
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, _) in test_loader_poison:
            labels = torch.full((images.shape[0],), 0, dtype=torch.long)
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            accuracy = 100 * correct / total
    print('Epoch:%d, Poisoned Accuracy on the test set: %.1f %%' % (epoch, accuracy))
    torch.save(model.state_dict(), '../../../pk-data-4T/ruiyang/model/imagenette_resnet20_poison_{}.pth'.format(epoch))
    # if accuracy > best_accuracy:
    #     best_accuracy = accuracy
    #     torch.save(model.state_dict(), './FBA/model/imagenette_resnet20_poison.pth')
        # torch.save(model.state_dict(), './FBA/model/imagenette_vgg16_poison.pth')
        # torch.save(model.state_dict(), './FBA/model/imagenette_mobilenetv2_poison.pth')