In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import pairwise_distances
import numpy as np

# Data Preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

# CNN Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 3 * 3, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 3 * 3)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Net().to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Training Loop
for epoch in range(5):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[Epoch: {epoch + 1}, Batch: {i + 1}] loss: {running_loss / 2000:.4f}')
            running_loss = 0.0

print('Finished Training')

# Metric Calculation Functions
def calculate_metrics(outputs, features, m=5):
    outputs_cpu = outputs.detach().cpu()
    least_confidence, prediction_entropy, margin_sampling = calculate_uncertainty_metrics(outputs_cpu)
    least_confidence_list.extend(least_confidence)
    prediction_entropy_list.extend(prediction_entropy)
    margin_sampling_list.extend(margin_sampling)

    features_normalized = F.normalize(features, p=2, dim=1)
    cosine_similarity, l2_norm = calculate_diversity_metrics(features_normalized, m)
    cosine_similarity_list.extend(cosine_similarity)
    l2_norm_list.extend(l2_norm)

    feature_distances = pairwise_distances(features.cpu().numpy(), metric='cosine')
    kl_divergence_scores = calculate_kl_divergence(outputs_cpu, feature_distances, m)
    kl_divergence_list.extend(kl_divergence_scores)

def calculate_uncertainty_metrics(outputs):
    probabilities = F.softmax(outputs, dim=1)
    least_confidence = 1 - probabilities.max(dim=1).values.cpu().numpy()
    prediction_entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-10), dim=1).cpu().numpy()
    margin_sampling = torch.topk(probabilities, 2, dim=1).values.cpu().numpy()
    margin_sampling = 1 - (margin_sampling[:, 0] - margin_sampling[:, 1])
    return least_confidence, prediction_entropy, margin_sampling

def calculate_diversity_metrics(features, m=5):
    feature_distances = pairwise_distances(features.cpu().numpy(), metric='cosine')
    cosine_similarity = 1 - feature_distances[:, 1:m+1].mean(axis=1)
    l2_distances = pairwise_distances(features.cpu().numpy(), metric='euclidean')
    l2_norm = l2_distances[:, 1:m+1].mean(axis=1)
    return cosine_similarity, l2_norm

def calculate_kl_divergence(outputs, feature_distances, m=5):
    kl_divergence = []
    for i in range(len(outputs)):
        current_sample_prob = F.softmax(outputs[i], dim=0)
        neighbor_indices = np.argsort(feature_distances[i])[:m+1]
        neighbors_prob = torch.mean(F.softmax(outputs[neighbor_indices], dim=1), dim=0)
        epsilon = 1e-10
        kl_divergence.append(F.kl_div(torch.log(current_sample_prob + epsilon), neighbors_prob + epsilon, reduction='batchmean').item())
    return kl_divergence

# Evaluation and Metric Computation
least_confidence_list = []
prediction_entropy_list = []
margin_sampling_list = []
cosine_similarity_list = []
l2_norm_list = []
kl_divergence_list = []

net.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Feature extraction for metrics
        features = net.pool(F.relu(net.conv1(images)))
        features = net.pool(F.relu(net.conv2(features)))
        features = net.pool(F.relu(net.conv3(features)))
        features = features.view(features.size(0), -1)

        calculate_metrics(outputs, features)

# Results
print(f'Test Accuracy: {100 * correct / total:.2f}%')
print(f'Average Least Confidence: {np.mean(least_confidence_list):.4f}')
print(f'Average Prediction Entropy: {np.mean(prediction_entropy_list):.4f}')
print(f'Average Margin Sampling: {np.mean(margin_sampling_list):.4f}')
print(f'Average Cosine Similarity: {np.mean(cosine_similarity_list):.4f}')
print(f'Average L2 Norm: {np.mean(l2_norm_list):.4f}')
print(f'Average KL Divergence: {np.mean(kl_divergence_list):.4f}')


[Epoch: 1, Batch: 2000] loss: 1.2433
[Epoch: 1, Batch: 4000] loss: 0.6302
[Epoch: 1, Batch: 6000] loss: 0.5424
[Epoch: 1, Batch: 8000] loss: 0.4703
[Epoch: 1, Batch: 10000] loss: 0.4402
[Epoch: 1, Batch: 12000] loss: 0.3957
[Epoch: 1, Batch: 14000] loss: 0.3707
[Epoch: 2, Batch: 2000] loss: 0.3458
[Epoch: 2, Batch: 4000] loss: 0.3255
[Epoch: 2, Batch: 6000] loss: 0.3388
[Epoch: 2, Batch: 8000] loss: 0.3253
[Epoch: 2, Batch: 10000] loss: 0.3189
[Epoch: 2, Batch: 12000] loss: 0.3078
[Epoch: 2, Batch: 14000] loss: 0.3097
[Epoch: 3, Batch: 2000] loss: 0.2827
[Epoch: 3, Batch: 4000] loss: 0.2670
[Epoch: 3, Batch: 6000] loss: 0.2789
[Epoch: 3, Batch: 8000] loss: 0.2659
[Epoch: 3, Batch: 10000] loss: 0.2732
[Epoch: 3, Batch: 12000] loss: 0.2593
[Epoch: 3, Batch: 14000] loss: 0.2697
[Epoch: 4, Batch: 2000] loss: 0.2379
[Epoch: 4, Batch: 4000] loss: 0.2563
[Epoch: 4, Batch: 6000] loss: 0.2365
[Epoch: 4, Batch: 8000] loss: 0.2356
[Epoch: 4, Batch: 10000] loss: 0.2394
[Epoch: 4, Batch: 12000] los