In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import pathlib
import random
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import pathlib
import random

# Define the CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the MCDO wrapper for uncertainty estimation
class SimpleCNNWithMCDO(nn.Module):
    def __init__(self, base_model, drop_out=0.5):
        super(SimpleCNNWithMCDO, self).__init__()
        self.base_model = base_model
        self.dropout = nn.Dropout(p=drop_out)

    def forward(self, x):
        x = self.base_model.pool1(self.base_model.relu1(self.base_model.conv1(x)))
        x = self.dropout(x)
        x = self.base_model.pool2(self.base_model.relu2(self.base_model.conv2(x)))
        x = self.dropout(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.dropout(self.base_model.relu3(self.base_model.fc1(x)))
        x = self.base_model.fc2(x)
        return x

In [3]:

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


In [None]:

# Training and evaluation for different sample sizes
sample_sizes = [1, 5, 10, 50, 100, 500, 2000, 4000]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
for size in sample_sizes:
    print(f"\nTraining model with {size} samples per class...")
    
    # Subset dataset to include only 'size' samples per class
    indices = []
    class_counts = {i: 0 for i in range(10)}
    for idx, (_, label) in enumerate(trainset):
        if class_counts[label] < size:
            indices.append(idx)
            class_counts[label] += 1
        if all(count >= size for count in class_counts.values()):
            break

    subset = Subset(trainset, indices)
    trainloader = DataLoader(subset, batch_size=64, shuffle=True)

    # Initialize model, loss, and optimizer
    base_model = SimpleCNN().to(device)
    model = SimpleCNNWithMCDO(base_model, drop_out=0.2).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(base_model.parameters(), lr=0.001)

    # Train the model
    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.3f}")

    # Save the model
    model_dir = "./models/"
    pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
    model_path = f"{model_dir}/svhn_model_{size}_samples_mcdo.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved at {model_path}")

    # Evaluate the model with MCDO
    def mc_dropout_predict(model, inputs, n_samples=10):
        model.train()  # Enable dropout during prediction
        outputs = torch.stack([model(inputs) for _ in range(n_samples)])
        mean_output = outputs.mean(dim=0)
        uncertainty = outputs.var(dim=0)
        return mean_output, uncertainty

    model.eval()
    correct = 0
    total = 0
    all_uncertainties = []
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            mean_output, uncertainty = mc_dropout_predict(model, inputs, n_samples=10)
            _, predicted = torch.max(mean_output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_uncertainties.append(uncertainty.cpu().numpy())

    accuracy = 100 * correct / total
    print(f"Accuracy on test set: {accuracy:.2f}%")


Training model with 1 samples per class...
Epoch 1, Loss: 2.320
Epoch 2, Loss: 2.254
Epoch 3, Loss: 2.154
Epoch 4, Loss: 2.075
Epoch 5, Loss: 1.921
Epoch 6, Loss: 1.762
Epoch 7, Loss: 1.683
Epoch 8, Loss: 1.596
Epoch 9, Loss: 1.500
Epoch 10, Loss: 1.354
Model saved at ./models//svhn_model_1_samples_mcdo.pth
Accuracy on test set: 10.47%

Training model with 5 samples per class...
Epoch 1, Loss: 2.302
Epoch 2, Loss: 2.273
Epoch 3, Loss: 2.239
Epoch 4, Loss: 2.173
Epoch 5, Loss: 2.141
Epoch 6, Loss: 2.070
Epoch 7, Loss: 2.047
Epoch 8, Loss: 1.963
Epoch 9, Loss: 1.882
Epoch 10, Loss: 1.850
Model saved at ./models//svhn_model_5_samples_mcdo.pth
Accuracy on test set: 9.49%

Training model with 10 samples per class...
Epoch 1, Loss: 2.337
Epoch 2, Loss: 2.288
Epoch 3, Loss: 2.247
Epoch 4, Loss: 2.220
Epoch 5, Loss: 2.141
Epoch 6, Loss: 2.178
Epoch 7, Loss: 2.159
Epoch 8, Loss: 2.077
Epoch 9, Loss: 2.028
Epoch 10, Loss: 1.979
Model saved at ./models//svhn_model_10_samples_mcdo.pth
Accuracy on