In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from diffusers import DDPMPipeline
import numpy as np
import random

In [3]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize images to [-1, 1]
])

In [6]:
# Training data
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# Test data
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Load The Pre-trained Diffusion Model

In [9]:
model_name = "google/ddpm-cifar10-32"

pipeline = DDPMPipeline.from_pretrained(model_name)
pipeline.to(device)

Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

DDPMPipeline {
  "_class_name": "DDPMPipeline",
  "_diffusers_version": "0.30.3",
  "_name_or_path": "google/ddpm-cifar10-32",
  "scheduler": [
    "diffusers",
    "DDPMScheduler"
  ],
  "unet": [
    "diffusers",
    "UNet2DModel"
  ]
}

In [10]:
unet = pipeline.unet
unet.eval()

UNet2DModel(
  (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=128, out_features=512, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=512, out_features=512, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)
          (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(12

In [11]:
class UNetWithFeatures(nn.Module):
    def __init__(self, unet):
        super(UNetWithFeatures, self).__init__()
        self.unet = unet

    def forward(self, sample, timestep):
        x = sample

        # Embed the timestep
        t_emb = self.unet.time_proj(timestep)
        t_emb = self.unet.time_embedding(t_emb)

        # Initial convolution
        x = self.unet.conv_in(x)

        features = []

        # Downsampling blocks
        for down_block in self.unet.down_blocks:
            x_tuple = down_block(x, t_emb)
            print(f"Down block output type: {type(x_tuple)}")
            x, res_samples = x_tuple
            features.append(x)

        # Middle block
        x_tuple = self.unet.mid_block(x, t_emb)
        print(f"Mid block output type: {type(x_tuple)}")
        if isinstance(x_tuple, tuple):
            x = x_tuple[0]
        else:
            x = x_tuple
        features.append(x)

        return features  # Return the list of features


In [12]:
unet_with_features = UNetWithFeatures(unet).to(device)
unet_with_features.eval()

UNetWithFeatures(
  (unet): UNet2DModel(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (time_proj): Timesteps()
    (time_embedding): TimestepEmbedding(
      (linear_1): Linear(in_features=128, out_features=512, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=512, out_features=512, bias=True)
    )
    (down_blocks): ModuleList(
      (0): DownBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsampler

In [13]:
def extract_features(data_loader):
    all_features = []
    all_labels = []
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            batch_size = images.size(0)

            # Use t=0 for clean images
            timesteps = torch.tensor([0]*batch_size, device=device).long()

            # Extract features
            features = unet_with_features(images, timesteps)

            # Use features from the last layer
            last_features = features[-1]

            # If last_features is a tuple, extract the tensor
            if isinstance(last_features, tuple):
                last_features = last_features[0]

            # Flatten the features using reshape
            last_features = last_features.reshape(batch_size, -1)

            all_features.append(last_features.cpu())
            all_labels.append(labels)
    all_features = torch.cat(all_features)
    all_labels = torch.cat(all_labels)
    return all_features, all_labels


In [14]:
train_features, train_labels = extract_features(train_loader)
test_features, test_labels = extract_features(test_loader)


Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Mid block output type: <class 'torch.Tensor'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Mid block output type: <class 'torch.Tensor'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Mid block output type: <class 'torch.Tensor'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Mid block output type: <class 'torch.Tensor'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Down block output type: <class 'tuple'>
Mid block output

In [15]:
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)


In [16]:
input_dim = train_features.shape[1]
num_classes = 10  # CIFAR-10 has 10 classes

classifier = MLPClassifier(input_dim, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)


In [17]:
batch_size = 128

train_dataset_features = TensorDataset(train_features, train_labels)
train_loader_features = DataLoader(train_dataset_features, batch_size=batch_size, shuffle=True)

test_dataset_features = TensorDataset(test_features, test_labels)
test_loader_features = DataLoader(test_dataset_features, batch_size=batch_size, shuffle=False)


In [18]:
num_epochs = 10

for epoch in range(num_epochs):
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for features, labels in train_loader_features:
        features = features.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = classifier(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * features.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_dataset_features)
    epoch_acc = 100 * correct / total
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')


Epoch [1/10], Loss: 1.1415, Accuracy: 67.66%
Epoch [2/10], Loss: 0.7222, Accuracy: 76.09%
Epoch [3/10], Loss: 0.6545, Accuracy: 78.48%
Epoch [4/10], Loss: 0.5938, Accuracy: 80.33%
Epoch [5/10], Loss: 0.5749, Accuracy: 81.43%
Epoch [6/10], Loss: 0.5560, Accuracy: 82.10%
Epoch [7/10], Loss: 0.5184, Accuracy: 83.07%
Epoch [8/10], Loss: 0.5199, Accuracy: 83.68%
Epoch [9/10], Loss: 0.5034, Accuracy: 84.00%
Epoch [10/10], Loss: 0.5058, Accuracy: 84.09%


In [19]:
classifier.eval()
correct = 0
total = 0

with torch.no_grad():
    for features, labels in test_loader_features:
        features = features.to(device)
        labels = labels.to(device)
        outputs = classifier(features)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Classifier Test Accuracy: {accuracy:.2f}%')


Classifier Test Accuracy: 78.26%


# Baseline Model

In [21]:
class MLPClassifier_baseline(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()  # Correct usage
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        return self.classifier(x)

In [22]:
input_dim = 3 * 32 * 32  # CIFAR-10 images have 3 channels and are 32x32 pixels
num_classes = 10

baseline_classifier = MLPClassifier_baseline(input_dim, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(baseline_classifier.parameters(), lr=0.001)

In [23]:
num_epochs = 10

for epoch in range(num_epochs):
    baseline_classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = baseline_classifier(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = 100 * correct / total
    print(f'Baseline Model - Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')


Baseline Model - Epoch [1/10], Loss: 1.7856, Accuracy: 38.08%
Baseline Model - Epoch [2/10], Loss: 1.6288, Accuracy: 43.36%
Baseline Model - Epoch [3/10], Loss: 1.5823, Accuracy: 44.96%
Baseline Model - Epoch [4/10], Loss: 1.5522, Accuracy: 46.23%
Baseline Model - Epoch [5/10], Loss: 1.5190, Accuracy: 47.44%
Baseline Model - Epoch [6/10], Loss: 1.4996, Accuracy: 48.28%
Baseline Model - Epoch [7/10], Loss: 1.4729, Accuracy: 49.24%
Baseline Model - Epoch [8/10], Loss: 1.4517, Accuracy: 50.06%
Baseline Model - Epoch [9/10], Loss: 1.4373, Accuracy: 50.45%
Baseline Model - Epoch [10/10], Loss: 1.4170, Accuracy: 51.45%


In [24]:
baseline_classifier.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = baseline_classifier(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

baseline_accuracy = 100 * correct / total
print(f'Baseline Classifier Test Accuracy: {baseline_accuracy:.2f}%')

Baseline Classifier Test Accuracy: 52.20%


In [25]:
# Assuming 'accuracy' is the test accuracy from the classifier using latent features
latent_features_accuracy = 78.25  # Replace with your actual value if different
print(f'Classifier Test Accuracy using Latent Features: {accuracy:.2f}%')
print(f'Baseline Classifier Test Accuracy: {baseline_accuracy:.2f}%')

# Calculate the difference
accuracy_difference = accuracy - baseline_accuracy
print(f'Difference in Accuracy: {accuracy_difference:.2f}%')

Classifier Test Accuracy using Latent Features: 78.26%
Baseline Classifier Test Accuracy: 52.20%
Difference in Accuracy: 26.06%
