In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

# --- 1. Define Model Architecture ---

# Let's use a simple CNN suitable for MNIST
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Input: 1x32x32 (Grayscale channel, Height, Width)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 32x32 -> 32x32
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)                 # 32x32 -> 16x16
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) # 16x16 -> 16x16
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)                 # 16x16 -> 8x8
        # Flatten: 64 * 8 * 8 = 4096
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)  # 10 output classes for digits 0-9

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw scores (logits)
        return x

# --- Configuration ---
batch_size = 64
epochs = 10
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 2 & 3. Load MNIST and Apply Transformations ---
transform_32x32 = transforms.Compose([
    transforms.Pad(2),  # Pad the 28x28 image to 32x32
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std dev
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_32x32)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_32x32)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# --- Initialize Model, Loss, Optimizer ---
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# --- 4. Training Loop ---
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# --- Evaluation Loop ---
def test():
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)      # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')
    return accuracy

# --- Run Training and Testing ---
best_accuracy = 0.0
model_save_path = 'mnist_cnn_32x32.pth'

for epoch in range(1, epochs + 1):
    train(epoch)
    current_accuracy = test()

    # --- 5. Save the Best Model ---
    if current_accuracy > best_accuracy:
        best_accuracy = current_accuracy
        torch.save(model.state_dict(), model_save_path)
        print(f"Saved new best model to {model_save_path} with accuracy {best_accuracy:.2f}%")

print("Training finished.")
print(f"Best model saved to {model_save_path} with accuracy {best_accuracy:.2f}%")

# --- How to Load and Use the "Pretrained" Model Later ---
print("\n--- Loading and Using the Trained Model ---")

# Create a new instance of the model
loaded_model = SimpleCNN().to(device)

# Load the saved weights
if os.path.exists(model_save_path):
    loaded_model.load_state_dict(torch.load(model_save_path, map_location=device))
    loaded_model.eval()  # Set model to evaluation mode
    print(f"Model loaded successfully from {model_save_path}")

    # Example: Create a dummy 32x32 input tensor
    dummy_input = torch.randn(1, 1, 32, 32).to(device)

    with torch.no_grad():
        output = loaded_model(dummy_input)
        probabilities = F.softmax(output, dim=1)
        predicted_class = probabilities.argmax(dim=1)

        print(f"Dummy input prediction: Class {predicted_class.item()}")
        print(f"Dummy input probabilities: {probabilities.cpu().numpy()}")

else:
    print(f"Model file not found at {model_save_path}. Train the model first.")


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
from scipy.stats import entropy
from tqdm import tqdm  # Optional: for progress bar
import sys  # For exiting on error

# --- Configuration (Hardcoded) ---
IMAGE_DIR = './DDPM_MNIST_Noise_Samples'  # Directory containing the generated images
MODEL_PATH = 'mnist_cnn_32x32.pth'      # Path to the trained classifier
BATCH_SIZE = 128                         
NUM_WORKERS = 1                          # DataLoader workers (adjust based on system)

# --- 1. Define Model Architecture (Must match the saved model) ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Input: 1x32x32 (Grayscale channel, Height, Width)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32 -> 16x16
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) # 16x16 -> 16x16
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 16x16 -> 8x8
        # Flatten: 64 * 8 * 8 = 4096
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10) # 10 output classes for digits 0-9

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8) # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # Output raw scores (logits)
        return x

# --- 2. Custom Dataset for Generated Images ---
class GeneratedImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.image_files = [f for f in os.listdir(root_dir)]
        print(f"Found {len(self.image_files)} images in {root_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert('L')
        if self.transform:
            image = self.transform(image)
        return image



def calculate_inception_score(preds):
    """Calculates the inception score for p(y|x) for all x. Assumes 1 split."""
    N = preds.shape[0]
    if N == 0: return 0.0, 0.0

    preds = np.clip(preds, 1e-9, 1.0)
    p_y = np.mean(preds, axis=0)

    kl_divs = []
    for i in range(N):
        p_yx = preds[i, :]
        kl_div = entropy(p_yx, p_y)
        kl_divs.append(kl_div)

    kl_divs = np.asarray(kl_divs)

    mean_kl_div = np.mean(kl_divs)

    # Optional: Cap score to prevent overflow
    if mean_kl_div > 700: mean_is = np.exp(700)
    else: mean_is = np.exp(mean_kl_div)

    return mean_is

# --- Main Execution ---
if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Load Model ---
    model = SimpleCNN().to(device) # Use the MNIST SimpleCNN

    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))

    model.eval() # Set to evaluation mode
    print(f"Classifier model loaded from {MODEL_PATH} (SimpleCNN for MNIST)")


    transform_generated = transforms.Compose([
        transforms.ToTensor(),                    
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # --- Create DataLoader ---
    try:
        dataset = GeneratedImageDataset(root_dir=IMAGE_DIR, transform=transform_generated)
    except (FileNotFoundError, ValueError) as e:
         print(f"Error creating dataset: {e}", file=sys.stderr)
         sys.exit(1)

    if len(dataset) == 0:
        print(f"Error: No images loaded from directory {IMAGE_DIR}. Cannot calculate IS.", file=sys.stderr)
        sys.exit(1)

    dataloader = DataLoader(dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            num_workers=NUM_WORKERS,
                            pin_memory=True if device.type == 'cuda' else False)

    # --- Run Inference and Get Predictions ---
    all_preds = []
    print("Running inference on generated images using SimpleCNN (MNIST)...")
    with torch.no_grad():
        for images in tqdm(dataloader, desc="Inference"):
            # Skip dummy tensors
            if torch.equal(images, torch.zeros_like(images)):
                 if images.sum() == 0: continue

            images = images.to(device, non_blocking=True)

            # Skip malformed batches (expected [B, 1, 32, 32])
            if images.dim() != 4 or images.shape[1] != 1 or images.shape[2] != 32 or images.shape[3] != 32:
                print(f"Warning: Skipping batch with incorrect shape on device: {images.shape}. Expected [B, 1, 32, 32]", file=sys.stderr)
                continue

            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)
            if probabilities is not None:
                all_preds.append(probabilities.cpu().numpy())

    if not all_preds:
        print("Error: No valid predictions generated after inference.", file=sys.stderr)
        sys.exit(1)


    valid_preds = [p for p in all_preds if p.shape[0] > 0]

    predictions_np = np.concatenate(valid_preds, axis=0)


    # Calculate and Print Inception Score
    print(f"Calculating Inception Score (Single Split) using SimpleCNN (MNIST)...")
    is_mean= calculate_inception_score(predictions_np)

    print(f"Dataset:      {IMAGE_DIR}")
    print(f"Model Path:   {MODEL_PATH}")
    print(f"Mean IS:      {is_mean:.4f}")

Classifier model loaded from mnist_cnn_32x32.pth (SimpleCNN for MNIST)
Found 8192 images in ./DDPM_MNIST_Noise_Samples
Running inference on generated images using SimpleCNN (MNIST)...


Inference: 100%|██████████| 64/64 [00:02<00:00, 31.76it/s]


Calculating Inception Score (Single Split) using SimpleCNN (MNIST)...
Dataset:      ./DDPM_MNIST_Noise_Samples
Model Path:   mnist_cnn_32x32.pth
Mean IS:      9.2236
