In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

class ComplexCNN(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(ComplexCNN, self).__init__()
        # Input: 3x32x32

        # Convolutional Block 1
        self.conv1a = nn.Conv2d(3, 64, kernel_size=3, padding=1) # 3 -> 64 filters
        self.bn1a = nn.BatchNorm2d(64)
        self.conv1b = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn1b = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32 -> 16x16

        # Convolutional Block 2
        self.conv2a = nn.Conv2d(64, 128, kernel_size=3, padding=1) # 64 -> 128 filters
        self.bn2a = nn.BatchNorm2d(128)
        self.conv2b = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn2b = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 16x16 -> 8x8

        # Convolutional Block 3
        self.conv3a = nn.Conv2d(128, 256, kernel_size=3, padding=1) # 128 -> 256 filters
        self.bn3a = nn.BatchNorm2d(256)
        self.conv3b = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn3b = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 8x8 -> 4x4

        # Convolutional Block 4
        self.conv4a = nn.Conv2d(256, 512, kernel_size=3, padding=1) # 256 -> 512 filters
        self.bn4a = nn.BatchNorm2d(512)
        self.conv4b = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn4b = nn.BatchNorm2d(512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # 4x4 -> 2x2

        # Convolutional Block 5
        self.conv5a = nn.Conv2d(512, 512, kernel_size=3, padding=1) # 512 -> 512 filters
        self.bn5a = nn.BatchNorm2d(512)
        self.conv5b = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn5b = nn.BatchNorm2d(512)
        # No pool after block 5, keep 2x2

        # --- ADDED: Convolutional Block 6 ---
        self.conv6a = nn.Conv2d(512, 512, kernel_size=3, padding=1) # 512 -> 512 filters
        self.bn6a = nn.BatchNorm2d(512)
        self.conv6b = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn6b = nn.BatchNorm2d(512)

        self.gap = nn.AdaptiveAvgPool2d((1, 1))

        self.flattened_size = 512

        self.fc1 = nn.Linear(self.flattened_size, 512) # Input 512
        self.bn_fc1 = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(dropout_rate) # Consider adjusting dropout rate
        self.fc2 = nn.Linear(512, 10) # Output 10 classes


    def _forward_features(self, x):
        # Block 1
        x = F.relu(self.bn1a(self.conv1a(x)))
        x = F.relu(self.bn1b(self.conv1b(x)))
        x = self.pool1(x)
        # Block 2
        x = F.relu(self.bn2a(self.conv2a(x)))
        x = F.relu(self.bn2b(self.conv2b(x)))
        x = self.pool2(x)
        # Block 3
        x = F.relu(self.bn3a(self.conv3a(x)))
        x = F.relu(self.bn3b(self.conv3b(x)))
        x = self.pool3(x)
        # Block 4
        x = F.relu(self.bn4a(self.conv4a(x)))
        x = F.relu(self.bn4b(self.conv4b(x)))
        x = self.pool4(x)
        # Block 5
        x = F.relu(self.bn5a(self.conv5a(x)))
        x = F.relu(self.bn5b(self.conv5b(x)))
        # --- ADDED: Block 6 Forward Pass ---
        x = F.relu(self.bn6a(self.conv6a(x)))
        x = F.relu(self.bn6b(self.conv6b(x)))
        return x

    def forward(self, x):
        x = self._forward_features(x)
        x = self.gap(x)
        x = x.view(-1, self.flattened_size)
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# --- Configuration ---
batch_size = 128 
epochs = 100     
learning_rate = 0.001 
weight_decay = 5e-4  
dropout_rate = 0.5  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std)
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Adjust num_workers based on your system
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size*2, shuffle=False, num_workers=4, pin_memory=True) # Can use larger batch for testing

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# --- Initialize Model, Loss, Optimizer ---
model = ComplexCNN(dropout_rate=dropout_rate).to(device)
criterion = nn.CrossEntropyLoss()


optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # Using Adam here

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6) # Adjust T_max if changing epochs

# --- Training Loop ---
def train(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        if (batch_idx + 1) % 100 == 0:
            batch_acc = 100. * correct / total
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {running_loss / 100:.6f}\tAcc: {batch_acc:.2f}%')
            running_loss = 0.0

    if scheduler:
        scheduler.step()

# --- Evaluation Loop ---
def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            output = model(data)
            batch_loss = criterion(output, target).item()
            test_loss += batch_loss * data.size(0)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
    return accuracy, test_loss

# --- Run Training and Testing ---
best_accuracy = 0.0
# --- UPDATED Save path as requested ---
model_save_path = 'cifar_10_cnn.pth'

print(f"Starting training for {epochs} epochs...")
for epoch in range(1, epochs + 1):
    train(epoch)
    current_accuracy, current_loss = test()

    print(f"Epoch {epoch}: Test Acc = {current_accuracy:.2f}%, Test Loss = {current_loss:.4f}, LR = {optimizer.param_groups[0]['lr']:.6f}")

    # Save the Best Model
    if current_accuracy > best_accuracy:
        best_accuracy = current_accuracy
        torch.save(model.state_dict(), model_save_path)
        print(f"----> Saved new best model to {model_save_path} with accuracy {best_accuracy:.2f}% <----")

print("Training finished.")
print(f"Best model saved to {model_save_path} with accuracy {best_accuracy:.2f}%")


In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
from scipy.stats import entropy
from tqdm import tqdm
import sys

# --- Configuration ---
IMAGE_DIR = './DDPM_CIFAR10_Clean' 
MODEL_PATH = 'cifar_10_cnn.pth'    
BATCH_SIZE = 128                   
NUM_WORKERS = 4

# --- 1. Define Model Architecture (MUST MATCH the saved cifar_10_cnn.pth model) ---

class ComplexCNN(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(ComplexCNN, self).__init__()
        self.conv1a = nn.Conv2d(3, 64, kernel_size=3, padding=1); self.bn1a = nn.BatchNorm2d(64)
        self.conv1b = nn.Conv2d(64, 64, kernel_size=3, padding=1); self.bn1b = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32 -> 16x16
        self.conv2a = nn.Conv2d(64, 128, kernel_size=3, padding=1); self.bn2a = nn.BatchNorm2d(128)
        self.conv2b = nn.Conv2d(128, 128, kernel_size=3, padding=1); self.bn2b = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 16x16 -> 8x8
        self.conv3a = nn.Conv2d(128, 256, kernel_size=3, padding=1); self.bn3a = nn.BatchNorm2d(256)
        self.conv3b = nn.Conv2d(256, 256, kernel_size=3, padding=1); self.bn3b = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 8x8 -> 4x4
        self.conv4a = nn.Conv2d(256, 512, kernel_size=3, padding=1); self.bn4a = nn.BatchNorm2d(512)
        self.conv4b = nn.Conv2d(512, 512, kernel_size=3, padding=1); self.bn4b = nn.BatchNorm2d(512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # 4x4 -> 2x2
        self.conv5a = nn.Conv2d(512, 512, kernel_size=3, padding=1); self.bn5a = nn.BatchNorm2d(512)
        self.conv5b = nn.Conv2d(512, 512, kernel_size=3, padding=1); self.bn5b = nn.BatchNorm2d(512)
        self.conv6a = nn.Conv2d(512, 512, kernel_size=3, padding=1); self.bn6a = nn.BatchNorm2d(512)
        self.conv6b = nn.Conv2d(512, 512, kernel_size=3, padding=1); self.bn6b = nn.BatchNorm2d(512)
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.flattened_size = 512
        self.fc1 = nn.Linear(self.flattened_size, 512); self.bn_fc1 = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(512, 10)

    def _forward_features(self, x):
        x = F.relu(self.bn1a(self.conv1a(x))); x = F.relu(self.bn1b(self.conv1b(x))); x = self.pool1(x)
        x = F.relu(self.bn2a(self.conv2a(x))); x = F.relu(self.bn2b(self.conv2b(x))); x = self.pool2(x)
        x = F.relu(self.bn3a(self.conv3a(x))); x = F.relu(self.bn3b(self.conv3b(x))); x = self.pool3(x)
        x = F.relu(self.bn4a(self.conv4a(x))); x = F.relu(self.bn4b(self.conv4b(x))); x = self.pool4(x)
        x = F.relu(self.bn5a(self.conv5a(x))); x = F.relu(self.bn5b(self.conv5b(x)))
        x = F.relu(self.bn6a(self.conv6a(x))); x = F.relu(self.bn6b(self.conv6b(x)))
        return x

    def forward(self, x):
        x = self._forward_features(x); x = self.gap(x); x = x.view(-1, self.flattened_size)
        x = F.relu(self.bn_fc1(self.fc1(x))); x = self.dropout(x); x = self.fc2(x)
        return x


class GeneratedImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        if not os.path.isdir(root_dir):
             raise FileNotFoundError(f"Image directory not found: {root_dir}")
        self.image_files = [f for f in os.listdir(root_dir)]
        if not self.image_files:
             raise ValueError(f"No valid image files found in directory: {root_dir}")
        print(f"Found {len(self.image_files)} images in {root_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        try:
            image = Image.open(img_path).convert('RGB')
            if image.size != (32, 32):
                 image = image.resize((32, 32), Image.Resampling.BILINEAR) # Updated PIL resampling
        except Exception as e:
             print(f"Error loading/processing image {img_path}: {e}", file=sys.stderr)
             return torch.zeros(3, 32, 32) # Return dummy tensor
        if self.transform:
            image = self.transform(image)
        # Final shape check
        if image.shape != (3, 32, 32):
            print(f"Warning: Image {self.image_files[idx]} final shape {image.shape} != (3, 32, 32). Returning zeros.", file=sys.stderr)
            return torch.zeros(3, 32, 32)
        return image

# --- Inception Score Function (Simplified for 1 Split) ---

def calculate_inception_score(preds):
    """Calculates the inception score for p(y|x) for all x. Assumes 1 split."""
    N = preds.shape[0]
    if N == 0: return 0.0, 0.0 # Handle empty predictions


    preds = np.clip(preds, 1e-9, 1.0)


    p_y = np.mean(preds, axis=0)

    # Compute KL divergence D_KL(p(y|x) || p(y)) for each sample x
    kl_divs = []
    for i in range(N):
        p_yx = preds[i, :]
        # entropy(pk, qk) calculates sum(pk * log(pk / qk))
        kl_div = entropy(p_yx, p_y)
        kl_divs.append(kl_div)

    kl_divs = np.asarray(kl_divs)

    mean_kl_div = np.mean(kl_divs)

    mean_is = np.exp(mean_kl_div)

    return mean_is

# --- Main Execution ---
if __name__ == "__main__":
    if not os.path.isdir(IMAGE_DIR):
        print(f"Error: Image directory '{IMAGE_DIR}' not found.", file=sys.stderr)
        sys.exit(1)
    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model file '{MODEL_PATH}' not found. Train the model first.", file=sys.stderr)
        sys.exit(1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load Model
    model = ComplexCNN().to(device)

    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))

    model.eval()
    print(f"Classifier model loaded from {MODEL_PATH} (6-Block CNN)")

    # Define Transformations
    cifar_mean = (0.4914, 0.4822, 0.4465)
    cifar_std = (0.2023, 0.1994, 0.2010)
    transform_generated = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar_mean, cifar_std)
    ])

    # Create DataLoader
    dataset = GeneratedImageDataset(root_dir=IMAGE_DIR, transform=transform_generated)

    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, pin_memory=True if device.type == 'cuda' else False)

    # Run Inference
    all_preds = []
    print("Running inference on generated images using 6-Block CNN...")
    with torch.no_grad():
        for images in tqdm(dataloader, desc="Inference"):
            images = images.to(device, non_blocking=True)

            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)

            all_preds.append(probabilities.cpu().numpy())

    if not all_preds:
        print("Error: No valid predictions generated after inference.", file=sys.stderr)
        sys.exit(1)


    valid_preds = [p for p in all_preds if p.shape[0] > 0]
    
    predictions_np = np.concatenate(valid_preds, axis=0)


    print(f"Successfully processed {predictions_np.shape[0]} valid images for IS calculation.")
    if predictions_np.shape[0] < len(dataset.image_files):
        print(f"Warning: Fewer predictions processed ({predictions_np.shape[0]}) than images found ({len(dataset.image_files)}). Some may have failed loading/processing.", file=sys.stderr)

    # Calculate and Print Inception Score
    print(f"Calculating Inception Score (Single Split) using 6-Block CNN...")
    is_mean = calculate_inception_score(predictions_np)

    print(f"Image Source: {IMAGE_DIR}")
    print(f"Model Path:   {MODEL_PATH}")
    print(f"Mean IS:      {is_mean:.4f}")

Using device: cuda
Classifier model loaded from cifar_10_cnn.pth (6-Block CNN)
Found 5760 images in ./DDPM_CIFAR10_Clean
Running inference on generated images using 6-Block CNN...


Inference: 100%|██████████| 45/45 [00:00<00:00, 58.88it/s]


Successfully processed 5760 valid images for IS calculation.
Calculating Inception Score (Single Split) using 6-Block CNN...
Image Source: ./DDPM_CIFAR10_Clean
Model Path:   cifar_10_cnn.pth
Mean IS:      5.0878
