In [1]:
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
import numpy as np
from scipy.linalg import sqrtm
import tarfile
import os
from PIL import Image
import torch
import torch.nn as nn

In [2]:


class SimpleNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(784, 512),          # First fully connected layer
            nn.BatchNorm1d(512),          # Batch normalization
            nn.ReLU(),                    # Activation
            nn.Dropout(0.3),              # Dropout for regularization

            nn.Linear(512, 256),          # Second fully connected layer
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(256, 128),          # Third fully connected layer
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(128, num_classes)   # Output layer
        )

    def forward(self, x):
        return self.network(x)

    def get_activations(self, x):
        """
        Extract features from the penultimate layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 784).

        Returns:
            torch.Tensor: Activations from the penultimate layer.
        """
        for layer in self.network[:-1]:  # Iterate through all layers except the last
            x = layer(x)
        return x


In [2]:
def extract_tar_gz(file_path, extract_path):
    with tarfile.open(file_path, 'r:gz') as tar:
        tar.extractall(path=extract_path)
    print(f"Extracted {file_path} to {extract_path}")

In [4]:
def load_images_to_tensors(folder_path):
    transform = transforms.Compose([
    transforms.ToTensor(),  # Converts to tensor and scales to [0, 1]
    ])
    image_tensors = []
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.lower().endswith(('png', 'jpg', 'jpeg')):  # Check for image files
                image_path = os.path.join(root, file)
                image = Image.open(image_path).convert('RGB')  # Open and convert to RGB
                tensor = transform(image)  # Apply transformations
                image_tensors.append(tensor)
    return torch.stack(image_tensors)  # Stack into a single tensor

In [5]:
def save_model(model, path):
    """
    Saves the trained model to the specified path.

    Args:
        model (nn.Module): The trained neural network.
        path (str): File path to save the model.
    """
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")


In [6]:
def train_model(model, dataloader_train, dataloader_test, criterion, optimizer, num_epochs, device, save_path=None):

    # Metrics to store training and evaluation results
    history = {
        'train_loss': [],
        'test_loss': [],
        'train_accuracy': [],
        'test_accuracy': []
    }

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in dataloader_train:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update metrics
            running_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_train_loss = running_loss / total
        epoch_train_accuracy = correct / total

        # Evaluation phase
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in dataloader_test:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * labels.size(0)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        epoch_test_loss = running_loss / total
        epoch_test_accuracy = correct / total

        # Store metrics
        history['train_loss'].append(epoch_train_loss)
        history['test_loss'].append(epoch_test_loss)
        history['train_accuracy'].append(epoch_train_accuracy)
        history['test_accuracy'].append(epoch_test_accuracy)

        # Logging
        print(f"Epoch [{epoch+1}/{num_epochs}]"
              f" Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_accuracy:.4f}"
              f" | Test Loss: {epoch_test_loss:.4f}, Test Acc: {epoch_test_accuracy:.4f}")

    # Save the model after training if a save path is provided
    if save_path:
        save_model(model, save_path)

    return history


In [61]:
def inception_score_from_mnist(images, batch_size=32, splits=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load pre-trained MNIST classifier
    model = SimpleNN(10)
    model.load_state_dict(torch.load("mnist_final.pth"))
    model.eval()
    model.to(device)

    transform = transforms.Compose([
        transforms.Lambda(lambda x: x + torch.rand_like(x) / 255),  # Dequantize pixel values
        transforms.Lambda(lambda x: (x - 0.5) * 2.0),              # Map from [0,1] -> [-1,1]
        transforms.Lambda(lambda x: x.mean(dim=0).flatten())       # Convert to grayscale and flatten
    ])

    # Preprocess images
    images = images.to(device)
    images = torch.stack([transform(img) for img in images])

    # Create batches
    dataloader = torch.utils.data.DataLoader(images, batch_size=batch_size, shuffle=False)

    preds = []

    # Generate predictions for all images
    with torch.no_grad():
        for batch in dataloader:
            batch_preds = model(batch)  # Get logits
            preds.append(F.softmax(batch_preds, dim=1).cpu().numpy())

    preds = np.concatenate(preds, axis=0)

    # Compute the marginal distribution
    py = np.mean(preds, axis=0)

    # print("Predictions:", preds[:5])
    print("Marginal distribution:", py)
    if np.all(py == 0):
        raise ValueError("Marginal distribution is zero. Check model or input.")

    # Split predictions into chunks for computation
    scores = []
    for i in range(splits):
        part = preds[i * (len(preds) // splits):(i + 1) * (len(preds) // splits), :]
        pyx = np.mean(part, axis=0)
        kl_div = pyx * (np.log(pyx + 1e-10) - np.log(py + 1e-10))  # Add epsilon for numerical stability
        scores.append(np.sum(kl_div))

    # Compute the exponential of the mean KL divergence
    return np.exp(np.mean(scores))


In [8]:
def calculate_fid_from_mnist(real_images, generated_images, batch_size=32):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load pre-trained MNIST classifier
    model = SimpleNN(10)
    model.load_state_dict(torch.load("mnist_final.pth"))
    model.eval()
    model.to(device)

    transform = transforms.Compose([
            transforms.Lambda(lambda x: x + torch.rand_like(x) / 255),  # Dequantize pixel values
            transforms.Lambda(lambda x: (x - 0.5) * 2.0),              # Map from [0,1] -> [-1,1]
            transforms.Lambda(lambda x: x.mean(dim=0).flatten())       # Convert 3-channel to grayscale and flatten
        ])

    def get_activations(images):
        """Get activations from the penultimate layer of the model."""
        images = images.to(device)
        images = torch.stack([transform(img) for img in images])
        dataloader = torch.utils.data.DataLoader(images, batch_size=batch_size, shuffle=False)

        activations = []
        with torch.no_grad():
            for batch in dataloader:
                # Get features from the penultimate layer
                features = model.get_activations(batch)
                activations.append(features.cpu().numpy())

        return np.concatenate(activations, axis=0)

    # Get activations for real and generated images
    real_activations = get_activations(real_images)
    generated_activations = get_activations(generated_images)

    # Compute mean and covariance for real and generated activations
    mu_real = np.mean(real_activations, axis=0)
    sigma_real = np.cov(real_activations, rowvar=False)

    mu_generated = np.mean(generated_activations, axis=0)
    sigma_generated = np.cov(generated_activations, rowvar=False)

    # Compute FID
    diff = mu_real - mu_generated
    covmean, _ = sqrtm(sigma_real @ sigma_generated, disp=False)

    # Numerical stability
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma_real + sigma_generated - 2 * covmean)
    return fid

In [9]:
batch_size = 128
num_classes = 10
learning_rate = 1e-3
num_epochs = 10


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Lambda(lambda x: x + torch.rand(x.shape)/255),    # Dequantize pixel values
    transforms.Lambda(lambda x: (x-0.5)*2.0),                    # Map from [0,1] -> [-1, -1]
    transforms.Lambda(lambda x: x.flatten())
])

# Download and transform train dataset
dataloader_train = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True, transform=transform),
                                                batch_size=batch_size,
                                                shuffle=True)

# Download and transform test dataset
dataloader_test = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=False, transform=transform),
                                                batch_size=batch_size,
                                                shuffle=True)

In [71]:
model = SimpleNN(num_classes).to(device)

#Setting the loss function
loss_f = nn.CrossEntropyLoss()

#Setting the optimizer with the model parameters and learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [72]:
# Define the path to save the model
model_save_path = "mnist_final.pth"

# Call train_model and save the model
history = train_model(
    model=model,
    dataloader_train=dataloader_train,
    dataloader_test=dataloader_test,
    criterion=loss_f,
    optimizer=optimizer,
    num_epochs=num_epochs,
    device=device,
    save_path=model_save_path
)


Epoch [1/10] Train Loss: 0.3021, Train Acc: 0.9187 | Test Loss: 0.1063, Test Acc: 0.9667
Epoch [2/10] Train Loss: 0.1372, Train Acc: 0.9582 | Test Loss: 0.0799, Test Acc: 0.9748
Epoch [3/10] Train Loss: 0.1064, Train Acc: 0.9680 | Test Loss: 0.0726, Test Acc: 0.9762
Epoch [4/10] Train Loss: 0.0889, Train Acc: 0.9721 | Test Loss: 0.0670, Test Acc: 0.9785
Epoch [5/10] Train Loss: 0.0770, Train Acc: 0.9764 | Test Loss: 0.0625, Test Acc: 0.9806
Epoch [6/10] Train Loss: 0.0688, Train Acc: 0.9782 | Test Loss: 0.0636, Test Acc: 0.9813
Epoch [7/10] Train Loss: 0.0611, Train Acc: 0.9811 | Test Loss: 0.0628, Test Acc: 0.9804
Epoch [8/10] Train Loss: 0.0549, Train Acc: 0.9823 | Test Loss: 0.0540, Test Acc: 0.9842
Epoch [9/10] Train Loss: 0.0538, Train Acc: 0.9829 | Test Loss: 0.0564, Test Acc: 0.9831
Epoch [10/10] Train Loss: 0.0467, Train Acc: 0.9854 | Test Loss: 0.0566, Test Acc: 0.9848
Model saved to mnist_final.pth


In [62]:
tar_gz_file = "epsilon_10k.tar.gz" 
extract_to = "." 
extract_tar_gz(tar_gz_file, extract_to)


generated_images = load_images_to_tensors("epsilon_generated_images")

Extracted epsilon_10k.tar.gz to .


In [64]:
transform = transforms.Compose([
    transforms.ToTensor(),  
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=1000, shuffle=False)

# Extract real images
real_images, _ = next(iter(mnist_dataloader)) 

fid_score = calculate_fid_from_mnist(real_images, generated_images, batch_size=32)
print(f"FID Score: {fid_score}")

  model.load_state_dict(torch.load("mnist_final.pth"))


FID Score: 7.487852034489004


In [65]:
# Compute Inception Score
score = inception_score_from_mnist(generated_images, batch_size=16, splits=10)
print(f"Inception Score: {score}")

  model.load_state_dict(torch.load("mnist_final.pth"))


Marginal distribution: [0.07342686 0.10546829 0.10151445 0.10779906 0.12622015 0.11464116
 0.06811383 0.12072217 0.0695238  0.11255988]
Inception Score: 1.003195881843567


In [66]:
tar_gz_file = "x0_10k.tar.gz" 
extract_to = "." 
extract_tar_gz(tar_gz_file, extract_to)


# Load images from the extracted folder
generated_images = load_images_to_tensors("x0_generated_images")

Extracted x0_10k.tar.gz to .


In [67]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts to tensor and scales to [0, 1]
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=1000, shuffle=False)

# Extract real images
real_images, _ = next(iter(mnist_dataloader)) 

fid_score = calculate_fid_from_mnist(real_images, generated_images, batch_size=32)
print(f"FID Score: {fid_score}")

  model.load_state_dict(torch.load("mnist_final.pth"))


FID Score: 6.316408589128564


In [68]:
# Compute Inception Score
score = inception_score_from_mnist(generated_images, batch_size=16, splits=10)
print(f"Inception Score: {score}")

  model.load_state_dict(torch.load("mnist_final.pth"))


Marginal distribution: [0.08378848 0.10766515 0.10215064 0.12205018 0.09766396 0.11008186
 0.08145633 0.10755429 0.07944304 0.10813542]
Inception Score: 1.0034881830215454
