In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms, models
from diffusers import DDPMPipeline
import numpy as np
import random

In [3]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [4]:
class DDPMWithFeatures(nn.Module):
    def __init__(self, pipeline):
        super(DDPMWithFeatures, self).__init__()
        self.pipeline = pipeline
        self.unet = pipeline.unet

    def forward(self, sample, timestep, layer_indices):
        x = sample
        t_emb = self.unet.time_proj(timestep)
        t_emb = self.unet.time_embedding(t_emb)

        x = self.unet.conv_in(x)

        features = {}
        for i, down_block in enumerate(self.unet.down_blocks):
            x_tuple = down_block(x, t_emb)
            x, _ = x_tuple

            if i in layer_indices:
                features[i] = x

        mid_output = self.unet.mid_block(x, t_emb)
        if isinstance(mid_output, tuple):
            x = mid_output[0]
        else:
            x = mid_output

        if len(self.unet.down_blocks) in layer_indices:
            features[len(self.unet.down_blocks)] = x

        return features

In [5]:
def extract_features_by_layer(data_loader, model, layer_indices, device):

    all_features = {layer: [] for layer in layer_indices}
    all_labels = []

    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            batch_size = images.size(0)

            # Use t=0 for clean images
            timesteps = torch.tensor([0] * batch_size, device=device).long()

    
            features = model(images, timesteps, layer_indices)

            for layer_index in layer_indices:

                layer_features = features[layer_index]
                layer_features = layer_features.reshape(batch_size, -1)

                all_features[layer_index].append(layer_features.cpu())

            all_labels.append(labels)

    for layer_index in layer_indices:
        all_features[layer_index] = torch.cat(all_features[layer_index])
    all_labels = torch.cat(all_labels)

    return all_features, all_labels


In [6]:
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes,dropout=0.5):
        super(MLPClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)

In [7]:
def train_classifier(features, labels, input_dim, device, num_classes=10, epochs=10, batch_size=128, learning_rate=0.001,dropout=0.5):


    model = MLPClassifier(input_dim=input_dim, num_classes=num_classes,dropout=dropout).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    dataset = torch.utils.data.TensorDataset(features, labels)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0

        # Loop over batches
        for inputs, lbls in loader:
            inputs, lbls = inputs.to(device), lbls.to(device)  # Move data to the specified device

            optimizer.zero_grad() 

            outputs = model(inputs)
            loss = criterion(outputs, lbls) 
            loss.backward()
            optimizer.step() 

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(dataset)
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}')

    return model

In [8]:
def evaluate_classifier(model, test_features, test_labels, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i in range(0, len(test_features), 128):
            inputs = test_features[i:i+128].to(device)
            lbls = test_labels[i:i+128].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += lbls.size(0)
            correct += (predicted == lbls).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [9]:
def extract_cnn_features(data_loader, model):
    features = []
    labels = []
    with torch.no_grad():
        for images, lbls in data_loader:
            images = images.to(device)
            # Extract features from ResNet
            outputs = model(images).view(images.size(0), -1)  # Flatten output from ResNet
            features.append(outputs.cpu())
            labels.append(lbls)
    features = torch.cat(features)
    labels = torch.cat(labels)
    return features, labels

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the pre-trained DDPM model from Hugging Face
model_name = "google/ddpm-cifar10-32"
ddpm_pipeline = DDPMPipeline.from_pretrained(model_name)
ddpm_pipeline.to(device)
unet = ddpm_pipeline.unet
unet.eval()

Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

UNet2DModel(
  (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=128, out_features=512, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=512, out_features=512, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)
          (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(12

In [11]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize images to [-1, 1]
])

# Training data
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# Test data
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
layer_indices = [3] 


ddpm_with_features = DDPMWithFeatures(ddpm_pipeline).to(device)

diffusion_train_features, diffusion_train_labels = extract_features_by_layer(
    train_loader, ddpm_with_features, layer_indices, device
)

diffusion_test_features, diffusion_test_labels = extract_features_by_layer(
    test_loader, ddpm_with_features, layer_indices, device
)


Files already downloaded and verified
Files already downloaded and verified


In [12]:
cnn_model = models.resnext50_32x4d(pretrained=True) 
cnn_model = nn.Sequential(*list(cnn_model.children())[:-1]) 
cnn_model = cnn_model.to(device)
cnn_model.eval()

cnn_train_features, cnn_train_labels = extract_cnn_features(train_loader, cnn_model)
cnn_test_features, cnn_test_labels = extract_cnn_features(test_loader, cnn_model)



In [13]:
hybrid_train_features = torch.cat((diffusion_train_features[3],cnn_train_features),-1)
hybrid_test_features = torch.cat((diffusion_test_features[3],cnn_test_features),-1)

In [16]:
input_dim = hybrid_train_features.shape[1]
classifier = train_classifier(hybrid_train_features, diffusion_train_labels, input_dim, device)
print(f"Evaluating classifier for Hybrid Classifier")
accuracy = evaluate_classifier(classifier, hybrid_test_features, diffusion_test_labels, device)
print(f"Hybrid Model Accuracy: {accuracy:.2f}%")

Epoch [1/10], Loss: 1.2417
Epoch [2/10], Loss: 0.8085
Epoch [3/10], Loss: 0.7196
Epoch [4/10], Loss: 0.6701
Epoch [5/10], Loss: 0.6329
Epoch [6/10], Loss: 0.5958
Epoch [7/10], Loss: 0.5805
Epoch [8/10], Loss: 0.5775
Epoch [9/10], Loss: 0.5517
Epoch [10/10], Loss: 0.5292
Evaluating classifier for Hybrid Classifier
Hybrid Model Accuracy: 79.34%
