In [1]:
from bpda_eot_attack import BPDA_EOT_Attack
from torchvision.models import resnet50
import torch
from torch import nn
import timm


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
resnet = resnet50(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features, 200, bias=True)
resnet.load_state_dict(torch.load('model_epoch_resnet50_epoch_25.pth'))
resnet.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [3]:
import torch.nn.functional as F
criterion = F.cross_entropy
def pgd(model, X, y, epsilon, alpha, num_iter):
    """ Construct FGSM adversarial examples on the examples X"""
    delta = torch.zeros_like(X, requires_grad=True)
    for t in range(num_iter):
        loss = criterion(model(X + delta), y)
        loss.backward()
        delta.data = (delta + X.shape[0]*alpha*delta.grad.data).clamp(-epsilon,epsilon)
        delta.grad.zero_()
    return delta.detach()

In [45]:
config = {
    "project_name": "Pix2Pix_Diffusion_Pipeline",
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "batch_size": 64,
    "timesteps": 1000,
    "embedding_dim": 512,
    "time_emb_dim": 256,
    "learning_rate": 2e-4,
    "num_epochs": 26,
    "save_checkpoint_interval": 25,
    "diffusion_loss_weight": 0.65,
    "latent_loss_weight": 0.1,
    "sample_interval": 10,
    "val_interval": 10,
    "use_mixed_precision": True,
    "logging": {
        "use_wandb": True,
        "sample_dir": "./outputs/pipeline_samples",
        "checkpoint_dir": "./outputs/pipeline_checkpoints",
        "plot_dir": "./outputs/pipeline_plots"
    }
}

In [46]:
from torchvision.transforms import transforms
from datasets import load_dataset
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
class TinyImagenetDataset(torch.utils.data.Dataset):
    """Custom dataset for TinyImageNet."""
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        if self.transform:
            image = self.transform(image)
        return {'image': image, 'label':item['label']}

transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)),
])

dataset = load_dataset("zh-plus/tiny-imagenet", split="valid")
from torchvision.datasets import ImageFolder
dataset = TinyImagenetDataset(dataset, transform=transform)
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=4)

In [47]:
from diffusers import DDPMScheduler
import math

class DiffusionAutoencoder(nn.Module):
    """Diffusion Autoencoder combining an encoder, U-Net, and decoder for image reconstruction."""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.device = torch.device(config["device"])
        self.embedding_dim = config["embedding_dim"]
        self.time_emb_dim = config["time_emb_dim"]
        self.timesteps = config["timesteps"]
        
        self.encoder = self.Encoder(in_channels=3, features=64, embedding_dim=self.embedding_dim)
        self.unet = self.ImprovedMLP_UNet(embedding_dim=self.embedding_dim, time_dim=self.time_emb_dim)
        self.decoder = self.Decoder(out_channels=3, features=64, embedding_dim=self.embedding_dim)
        self.scheduler = self.DiffusionScheduler(timesteps=self.timesteps)
        
        self.to(self.device)

    class SinusoidalPositionEmbeddings(nn.Module):
        """Generates sinusoidal position embeddings for timesteps."""
        def __init__(self, dim):
            super().__init__()
            self.dim = dim

        def forward(self, time):
            device = time.device
            half_dim = self.dim // 2
            embeddings = math.log(10000) / (half_dim - 1)
            embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
            embeddings = time[:, None] * embeddings[None, :]
            embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
            if self.dim % 2 == 1:
                embeddings = F.pad(embeddings, (0, 1, 0, 0))
            return embeddings

    class Encoder(nn.Module):
        """Encodes input images into a latent representation."""
        def __init__(self, in_channels, features, embedding_dim):
            super().__init__()
            self.initial = nn.Sequential(
                nn.Conv2d(in_channels, features, kernel_size=4, stride=2, padding=1, bias=False),
                nn.LeakyReLU(0.2, inplace=True)
            )
            self.down1 = self._block(features, features * 2)
            self.down2 = self._block(features * 2, features * 4)
            self.down3 = self._block(features * 4, features * 8)
            self.down4 = self._block(features * 8, features * 8)
            self.final = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(features * 8, embedding_dim)
            )

        def _block(self, in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )

        def forward(self, x):
            d1 = self.initial(x)
            d2 = self.down1(d1)
            d3 = self.down2(d2)
            d4 = self.down3(d3)
            d5 = self.down4(d4)
            embedding = self.final(d5)
            return embedding, [d1, d2, d3, d4, d5]

    class ResidualBlock(nn.Module):
        """Residual block for U-Net with timestep embeddings."""
        def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
            super().__init__()
            self.time_mlp = nn.Sequential(
                nn.Linear(time_emb_dim, out_channels),
                nn.GELU()
            )
            self.block1 = nn.Sequential(
                nn.Linear(in_channels, out_channels),
                nn.GELU(),
                nn.Dropout(dropout)
            )
            self.block2 = nn.Sequential(
                nn.Linear(out_channels, out_channels),
                nn.GELU(),
                nn.Dropout(dropout)
            )
            self.residual_conv = nn.Linear(in_channels, out_channels) if in_channels != out_channels else nn.Identity()
            self.layer_norm = nn.LayerNorm(out_channels)

        def forward(self, x, t):
            h = self.block1(x)
            time_emb = self.time_mlp(t)
            h = h + time_emb
            h = self.block2(h)
            return self.layer_norm(h + self.residual_conv(x))

    class ImprovedMLP_UNet(nn.Module):
        """U-Net model for denoising latent representations with timestep conditioning."""
        def __init__(self, embedding_dim, hidden_dim=1024, time_dim=256, dropout=0.1):
            super().__init__()
            self.time_mlp = nn.Sequential(
                DiffusionAutoencoder.SinusoidalPositionEmbeddings(time_dim),
                nn.Linear(time_dim, time_dim * 2),
                nn.GELU(),
                nn.Linear(time_dim * 2, time_dim),
            )
            self.down1 = DiffusionAutoencoder.ResidualBlock(embedding_dim, hidden_dim, time_dim, dropout)
            self.down2 = DiffusionAutoencoder.ResidualBlock(hidden_dim, hidden_dim, time_dim, dropout)
            self.mid = DiffusionAutoencoder.ResidualBlock(hidden_dim, hidden_dim, time_dim, dropout)
            self.up1 = DiffusionAutoencoder.ResidualBlock(hidden_dim * 2, hidden_dim, time_dim, dropout)
            self.up2 = DiffusionAutoencoder.ResidualBlock(hidden_dim * 2, embedding_dim, time_dim, dropout)
            self.final = nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim),
                nn.Tanh()
            )

        def forward(self, x, t):
            time_emb = self.time_mlp(t)
            down1 = self.down1(x, time_emb)
            down2 = self.down2(down1, time_emb)
            mid = self.mid(down2, time_emb)
            up1 = self.up1(torch.cat([mid, down2], dim=1), time_emb)
            up2 = self.up2(torch.cat([up1, down1], dim=1), time_emb)
            return self.final(up2)

    class Decoder(nn.Module):
        """Decodes latent representations back to images."""
        def __init__(self, out_channels, features, embedding_dim):
            super().__init__()
            self.project = nn.Sequential(
                nn.Linear(embedding_dim, 512 * 2 * 2),
                nn.ReLU(inplace=True)
            )
            self.up1 = self._block(512, 512)
            self.up2 = self._block(512 + 512, 256)
            self.up3 = self._block(256 + 256, 128)
            self.up4 = self._block(128 + 128, 64)
            self.final = nn.Sequential(
                nn.ConvTranspose2d(64 + 64, out_channels, kernel_size=4, stride=2, padding=1),
                nn.Tanh()
            )

        def _block(self, in_channels, out_channels):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )

        def forward(self, x, encoder_features=None):
            x = self.project(x)
            x = x.view(-1, 512, 2, 2)
            x1 = self.up1(x)
            if encoder_features:
                x2 = self.up2(torch.cat([x1, encoder_features[3]], dim=1))
                x3 = self.up3(torch.cat([x2, encoder_features[2]], dim=1))
                x4 = self.up4(torch.cat([x3, encoder_features[1]], dim=1))
                output = self.final(torch.cat([x4, encoder_features[0]], dim=1))
            else:
                output = self.final(x1)
            return output

    class DiffusionScheduler:
        """Manages noise scheduling for diffusion process using DDPMScheduler."""
        def __init__(self, timesteps):
            self.scheduler = DDPMScheduler(num_train_timesteps=timesteps)

        def add_noise(self, x, t):
            noise = torch.randn_like(x)
            noisy_x = self.scheduler.add_noise(x, noise, t)
            return noisy_x, noise

    def forward(self, img, t):
        """Forward pass through encoder, U-Net, and decoder."""
        latent, encoder_features = self.encoder(img)
        noisy_latent, noise = self.scheduler.add_noise(latent, t)
        denoised_latent = self.unet(noisy_latent, t)
        reconstructed_img = self.decoder(denoised_latent, encoder_features)
        return reconstructed_img, noise, latent, denoised_latent



purifier = DiffusionAutoencoder(config).to('cuda')
purifier.load_state_dict(torch.load('outputs/pipeline_checkpoints/model_epoch_50.pth'))
purifier.eval()

DiffusionAutoencoder(
  (encoder): Encoder(
    (initial): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (down1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (down2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (down3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope

In [36]:
resnet.to('cuda')

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [54]:
resnet.eval()
normal_num = 0
adv_num = 0
purified = 0
total = 0
import matplotlib.pyplot as plt
from tqdm import tqdm
for item in tqdm(train_loader):
    x = item['image'].to('cuda')
    y = item['label'].to('cuda')
    out = pgd(resnet,x,y, epsilon=16, alpha=0.01, num_iter=20)
    normal_out = resnet(transforms.Resize((224,224))(x)).max(1)
    # print(normal_out.indices)
    adv_out = resnet(x+out).max(1)
    acc_normal = (normal_out.indices==y).sum().item()
    acc_adv = (adv_out.indices == y).sum().item()
    t = torch.randint(0, purifier.timesteps, (x.size(0),), device=purifier.device).long()
    purified_output = resnet(transforms.Resize((224,224))(purifier(x+out.to('cuda'), t.to('cuda'))[0])).max(1).indices
    purified_num = (purified_output==y.to('cuda')).sum().item()
    purified += purified_num
    normal_num += acc_normal
    adv_num += acc_adv  
    total+=len(item['image'])


100%|██████████| 157/157 [01:20<00:00,  1.96it/s]


In [56]:
normal_num / total, adv_num / total, purified / total

(0.7217, 0.037, 0.4308)

In [57]:
def fgsm(model, X, y, epsilon):
    """ Construct FGSM adversarial examples on the examples X"""
    delta = torch.zeros_like(X, requires_grad=True)
    loss = criterion(model(X + delta), y)
    loss.backward()
    return epsilon * delta.grad.detach().sign()

In [61]:
resnet.eval()
normal_num = 0
adv_num = 0
purified = 0
total = 0
import matplotlib.pyplot as plt
from tqdm import tqdm
for item in tqdm(train_loader):
    x = item['image'].to('cuda')
    y = item['label'].to('cuda')
    out = fgsm(resnet,x,y, epsilon=0.1)
    normal_out = resnet(transforms.Resize((224,224))(x)).max(1)
    # print(normal_out.indices)
    adv_out = resnet(x+out).max(1)
    acc_normal = (normal_out.indices==y).sum().item()
    acc_adv = (adv_out.indices == y).sum().item()
    t = torch.randint(0, purifier.timesteps, (x.size(0),), device=purifier.device).long()
    purified_output = resnet(transforms.Resize((224,224))(purifier(x+out.to('cuda'), t.to('cuda'))[0])).max(1).indices
    purified_num = (purified_output==y.to('cuda')).sum().item()
    purified += purified_num
    normal_num += acc_normal
    adv_num += acc_adv  
    total+=len(item['image'])


100%|██████████| 157/157 [00:20<00:00,  7.65it/s]


In [62]:
normal_num / total, adv_num / total, purified / total

(0.7217, 0.0073, 0.411)