In [None]:
!pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu118
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from scipy.stats import entropy
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import matplotlib.pyplot as plt

# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 1. Data Preparation ---
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

full_trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                             download=True, transform=transform)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                            download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# Split into labeled and unlabeled datasets (example: 10% labeled)
labeled_size = int(0.1 * len(full_trainset))
unlabeled_size = len(full_trainset) - labeled_size
labeled_dataset, unlabeled_dataset = torch.utils.data.random_split(full_trainset, [labeled_size, unlabeled_size])

# Create data loaders
batch_size = 128
labeled_loader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=False)

# --- 2. Model Definition ---
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 3 * 3, 128)
        self.fc2 = nn.Linear(128, 10)
        #feature extractor
        self.feature_extractor = nn.Sequential(
            self.conv1,
            nn.ReLU(),
            self.pool,
            self.conv2,
            nn.ReLU(),
            self.pool,
            self.conv3,
            nn.ReLU(),
            self.pool
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        x = features.view(-1, 64 * 3 * 3)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x, features

# --- 3. Query Strategies ---
def least_confidence(model, unlabeled_loader, num_samples_to_query, device):
    """Selects samples with least confidence."""
    model.eval()
    confidences = []
    indices = []
    for i, (images, _) in enumerate(unlabeled_loader):
        images = images.to(device)
        with torch.no_grad():
            outputs, _ = model(images)
        probabilities = F.softmax(outputs, dim=1)
        max_probabilities, _ = torch.max(probabilities, dim=1)
        confidences.extend(max_probabilities.cpu().tolist())
        indices.extend(range(i * unlabeled_loader.batch_size, (i + 1) * unlabeled_loader.batch_size))

    sorted_indices = np.argsort(confidences)
    return [indices[i] for i in sorted_indices[:num_samples_to_query]]

def prediction_entropy(model, unlabeled_loader, num_samples_to_query, device):
    """Selects samples with highest prediction entropy."""
    model.eval()
    entropies = []
    indices = []
    for i, (images, _) in enumerate(unlabeled_loader):
        images = images.to(device)
        with torch.no_grad():
            outputs, _ = model(images)
        probabilities = F.softmax(outputs, dim=1)
        batch_entropies = [entropy(p.cpu().numpy()) for p in probabilities]
        entropies.extend(batch_entropies)
        indices.extend(range(i * unlabeled_loader.batch_size, (i + 1) * unlabeled_loader.batch_size))
    sorted_indices = np.argsort(entropies)[::-1]
    return [indices[i] for i in sorted_indices[:num_samples_to_query]]

def margin_sampling(model, unlabeled_loader, num_samples_to_query, device):
    """Selects samples with smallest margin between top two probabilities."""
    model.eval()
    margins = []
    indices = []
    for i, (images, _) in enumerate(unlabeled_loader):
        images = images.to(device)
        with torch.no_grad():
            outputs, _ = model(images)
        probabilities = F.softmax(outputs, dim=1)
        sorted_probs, _ = torch.sort(probabilities, dim=1, descending=True)
        batch_margins = sorted_probs[:, 0] - sorted_probs[:, 1]
        margins.extend(batch_margins.cpu().tolist())
        indices.extend(range(i * unlabeled_loader.batch_size, (i + 1) * unlabeled_loader.batch_size))

    sorted_indices = np.argsort(margins)
    return [indices[i] for i in sorted_indices[:num_samples_to_query]]

def cosine_similarity_diversity(model, unlabeled_loader, num_samples_to_query, device):
    """Selects diverse samples based on cosine similarity of their features."""
    model.eval()
    features = []
    indices = []
    for i, (images, _) in enumerate(unlabeled_loader):
        images = images.to(device)
        with torch.no_grad():
            _, batch_features = model(images)
        batch_features = batch_features.view(batch_features.size(0), -1)  # Flatten features
        features.extend(batch_features.cpu().numpy())
        indices.extend(range(i * unlabeled_loader.batch_size, (i + 1) * unlabeled_loader.batch_size))

    similarity_matrix = cosine_similarity(features)
    np.fill_diagonal(similarity_matrix, 1)
    diversity_scores = np.sum(similarity_matrix, axis=1) / (similarity_matrix.shape[0] - 1)

    sorted_indices = np.argsort(diversity_scores)
    return [indices[i] for i in sorted_indices[:num_samples_to_query]]

def l2_norm_diversity(model, unlabeled_loader, num_samples_to_query, device):
    """Selects diverse samples based on L2 norm (Euclidean distance) of their features."""
    model.eval()
    features = []
    indices = []
    for i, (images, _) in enumerate(unlabeled_loader):
        images = images.to(device)
        with torch.no_grad():
            _, batch_features = model(images)
        batch_features = batch_features.view(batch_features.size(0), -1)  # Flatten features
        features.extend(batch_features.cpu().numpy())
        indices.extend(range(i * unlabeled_loader.batch_size, (i + 1) * unlabeled_loader.batch_size))

    distances = np.sqrt(((np.array(features)[:, np.newaxis, :] - np.array(features)[np.newaxis, :, :]) ** 2).sum(axis=2))
    diversity_scores = np.sum(distances, axis=1)

    sorted_indices = np.argsort(diversity_scores)[::-1]
    return [indices[i] for i in sorted_indices[:num_samples_to_query]]

def kl_divergence_diversity(model, labeled_loader, unlabeled_loader, num_samples_to_query, device):
    """Selects diverse samples based on KL divergence of their predicted probabilities."""
    model.eval()
    labeled_probs = []
    for images, _ in labeled_loader:
        images = images.to(device)
        with torch.no_grad():
            outputs, _ = model(images)
        probs = F.softmax(outputs, dim=1).cpu().numpy()
        labeled_probs.extend(probs)

    labeled_distribution = np.mean(labeled_probs, axis=0)

    unlabeled_probs = []
    indices = []
    for i, (images, _) in enumerate(unlabeled_loader):
        images = images.to(device)
        with torch.no_grad():
            outputs, _ = model(images)
        probs = F.softmax(outputs, dim=1).cpu().numpy()
        unlabeled_probs.extend(probs)
        indices.extend(range(i * unlabeled_loader.batch_size, (i + 1) * unlabeled_loader.batch_size))

    kl_divergences = [entropy(p, labeled_distribution) for p in unlabeled_probs]
    sorted_indices = np.argsort(kl_divergences)[::-1]
    return [indices[i] for i in sorted_indices[:num_samples_to_query]]

# --- 4. Training and Evaluation Functions ---
def train_model(model, train_loader, criterion, optimizer, num_epochs=5, device=device):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs} Loss: {running_loss/len(train_loader)}")

def evaluate_model(model, testloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs, _ = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# --- 5. Active Learning Loop ---
def active_learning(net, criterion, optimizer, labeled_dataset, unlabeled_dataset, testloader, device, num_iterations=5, num_samples_to_query=10, query_strategy=least_confidence, diversity_strategy = None):
    # Trackers
    active_learning_accuracies = []
    baseline_accuracies = []
    labeled_dataset_sizes = []

    # Save initial model state for baseline training
    initial_model_state = net.state_dict()

    for i in range(num_iterations):
        print(f"Active learning iteration {i+1}/{num_iterations}...")

        # Create new dataloaders
        batch_size = 128
        labeled_loader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True)
        unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=False)

        # Query Strategy
        query_indices = query_strategy(net, unlabeled_loader, num_samples_to_query, device)
        if diversity_strategy is not None:
          query_indices = diversity_strategy(net, labeled_loader, unlabeled_loader, num_samples_to_query, device)
        queried_samples = Subset(unlabeled_dataset, query_indices)

        # get the labels for the queried samples
        new_labels = [unlabeled_dataset.dataset.targets[index] for index in queried_samples.indices]

        # Create the new labeled dataset:
        new_dataset = []
        for index in queried_samples.indices:
          image, label = unlabeled_dataset.dataset[index]
          new_dataset.append((image,label))
        labeled_dataset = ConcatDataset([labeled_dataset, new_dataset])

        # Remove queried samples from the unlabeled dataset
        unlabeled_indices = [idx for idx in range(len(unlabeled_dataset)) if idx not in query_indices]
        unlabeled_dataset = Subset(unlabeled_dataset, unlabeled_indices)

        # Create new data loaders
        labeled_loader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True)
        unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=False)
        # Retrain the model
        train_model(net, labeled_loader, criterion, optimizer)

        # Calculate active learning accuracy
        active_learning_accuracy = evaluate_model(net, testloader, device)
        print(f"Active Learning Accuracy: {active_learning_accuracy}%")
        active_learning_accuracies.append(active_learning_accuracy)
        labeled_dataset_sizes.append(len(labeled_dataset))

        # Calculate Baseline accuracy
        # Reset model to initial state for baseline training
        net.load_state_dict(initial_model_state)
        # Train baseline model on the current labeled dataset
        train_model(net, labeled_loader, criterion, optimizer)

        # Evaluate baseline model
        baseline_accuracy = evaluate_model(net, testloader, device)
        print(f"Baseline Accuracy: {baseline_accuracy}%")
        baseline_accuracies.append(baseline_accuracy)
    return labeled_dataset_sizes, active_learning_accuracies, baseline_accuracies

# --- 6. Model, Criterion, Optimizer ---
net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# --- 7. Run Active Learning ---
num_iterations = 5
num_samples_to_query = 100
#Active Learning with least confidence query strategy
labeled_dataset_sizes_lc, active_learning_accuracies_lc, baseline_accuracies_lc = active_learning(net, criterion, optimizer, labeled_dataset, unlabeled_dataset, testloader, device, num_iterations, num_samples_to_query, least_confidence)

#Reset model and labeled/unlabeled datasets
net = Net().to(device)
labeled_dataset, unlabeled_dataset = torch.utils.data.random_split(full_trainset, [labeled_size, unlabeled_size])
#Active learning with prediction entropy query strategy
labeled_dataset_sizes_pe, active_learning_accuracies_pe, baseline_accuracies_pe = active_learning(net, criterion, optimizer, labeled_dataset, unlabeled_dataset, testloader, device, num_iterations, num_samples_to_query, prediction_entropy)

#Reset model and labeled/unlabeled datasets
net = Net().to(device)
labeled_dataset, unlabeled_dataset = torch.utils.data.random_split(full_trainset, [labeled_size, unlabeled_size])
#Active learning with margin sampling query strategy
labeled_dataset_sizes_ms, active_learning_accuracies_ms, baseline_accuracies_ms = active_learning(net, criterion, optimizer, labeled_dataset, unlabeled_dataset, testloader, device, num_iterations, num_samples_to_query, margin_sampling)

#Reset model and labeled/unlabeled datasets
net = Net().to(device)
labeled_dataset, unlabeled_dataset = torch.utils.data.random_split(full_trainset, [labeled_size, unlabeled_size])
#Active learning with Cosine similarity diversity
labeled_dataset_sizes_cs, active_learning_accuracies_cs, baseline_accuracies_cs = active_learning(net, criterion, optimizer, labeled_dataset, unlabeled_dataset, testloader, device, num_iterations, num_samples_to_query, least_confidence, cosine_similarity_diversity)

#Reset model and labeled/unlabeled datasets
net = Net().to(device)
labeled_dataset, unlabeled_dataset = torch.utils.data.random_split(full_trainset, [labeled_size, unlabeled_size])
#Active learning with L2 norm diversity
labeled_dataset_sizes_l2, active_learning_accuracies_l2, baseline_accuracies_l2 = active_learning(net, criterion, optimizer, labeled_dataset, unlabeled_dataset, testloader, device, num_iterations, num_samples_to_query, least_confidence, l2_norm_diversity)

#Reset model and labeled/unlabeled datasets
net = Net().to(device)
labeled_dataset, unlabeled_dataset = torch.utils.data.random_split(full_trainset, [labeled_size, unlabeled_size])
#Active learning with KL divergence diversity
labeled_dataset_sizes_kl, active_learning_accuracies_kl, baseline_accuracies_kl = active_learning(net, criterion, optimizer, labeled_dataset, unlabeled_dataset, testloader, device, num_iterations, num_samples_to_query, least_confidence, kl_divergence_diversity)

# --- 8. Reporting and Analysis ---

# Plot the results
plt.figure(figsize=(10, 5))
plt.plot(labeled_dataset_sizes_lc, active_learning_accuracies_lc, label='Active Learning (Least Confidence)')
plt.plot(labeled_dataset_sizes_lc, baseline_accuracies_lc, label='Baseline (Least Confidence)')
plt.plot(labeled_dataset_sizes_pe, active_learning_accuracies_pe, label='Active Learning (Prediction Entropy)')
plt.plot(labeled_dataset_sizes_pe, baseline_accuracies_pe, label='Baseline (Prediction Entropy)')
plt.plot(labeled_dataset_sizes_ms, active_learning_accuracies_ms, label='Active Learning (Margin Sampling)')
plt.plot(labeled_dataset_sizes_ms, baseline_accuracies_ms, label='Baseline (Margin Sampling)')
plt.plot(labeled_dataset_sizes_cs, active_learning_accuracies_cs, label='Active Learning (Cosine Similarity)')
plt.plot(labeled_dataset_sizes_cs, baseline_accuracies_cs, label='Baseline (Cosine Similarity)')
plt.plot(labeled_dataset_sizes_l2, active_learning_accuracies_l2, label='Active Learning (L2 Norm)')
plt.plot(labeled_dataset_sizes_l2, baseline_accuracies_l2, label='Baseline (L2 Norm)')
plt.plot(labeled_dataset_sizes_kl, active_learning_accuracies_kl, label='Active Learning (KL Divergence)')
plt.plot(labeled_dataset_sizes_kl, baseline_accuracies_kl, label='Baseline (KL Divergence)')
plt.xlabel('Labeled Dataset Size')
plt.ylabel('Accuracy')
plt.title('Active Learning vs. Baseline Performance')
plt.legend()
plt.show()

# Print final results
print("\nFinal Results:")
print(f"  Custom CNN (Baseline Least Confidence): {baseline_accuracies_lc[-1]:.2f}%")
print(f"  Active Learning-enhanced Model (Least Confidence): {active_learning_accuracies_lc[-1]:.2f}%")
print(f"  Custom CNN (Baseline Prediction Entropy): {baseline_accuracies_pe[-1]:.2f}%")
print(f"  Active Learning-enhanced Model (Prediction Entropy): {active_learning_accuracies_pe[-1]:.2f}%")
print(f"  Custom CNN (Baseline Margin Sampling): {baseline_accuracies_ms[-1]:.2f}%")
print(f"  Active Learning-enhanced Model (Margin Sampling): {active_learning_accuracies_ms[-1]:.2f}%")
print(f"  Custom CNN (Baseline Cosine Similarity): {baseline_accuracies_cs[-1]:.2f}%")
print(f"  Active Learning-enhanced Model (Cosine Similarity): {active_learning_accuracies_cs[-1]:.2f}%")
print(f"  Custom CNN (Baseline L2 Norm): {baseline_accuracies_l2[-1]:.2f}%")
print(f"  Active Learning-enhanced Model (L2 Norm): {active_learning_accuracies_l2[-1]:.2f}%")
print(f"  Custom CNN (Baseline KL Divergence): {baseline_accuracies_kl[-1]:.2f}%")
print(f"  Active Learning-enhanced Model (KL Divergence): {active_learning_accuracies_kl[-1]:.2f}%")

# Analyze Results (Example)
print("\nAnalysis:")
if active_learning_accuracies_lc[-1] > baseline_accuracies_lc[-1]:
    print("  Active learning improved accuracy over the baseline with the least confidence strategy.")
else:
    print("  Active learning did not improve accuracy over the baseline with the least confidence strategy in this run.")

if active_learning_accuracies_pe[-1] > baseline_accuracies_pe[-1]:
    print("  Active learning improved accuracy over the baseline with the prediction entropy strategy.")
else:
    print("  Active learning did not improve accuracy over the baseline with the prediction entropy strategy in this run.")
    
if active_learning_accuracies_ms[-1] > baseline_accuracies_ms[-1]:
    print("  Active learning improved accuracy over the baseline with the margin sampling strategy.")
else:
    print("  Active learning did not improve accuracy over the baseline with the margin sampling strategy in this run.")
    
if active_learning_accuracies_cs[-1] > baseline_accuracies_cs[-1]:
    print("  Active learning improved accuracy over the baseline with the cosine similarity diversity strategy.")
else:
    print("  Active learning did not improve accuracy over the baseline with the cosine similarity diversity strategy in this run.")

if active_learning_accuracies_l2[-1] > baseline_accuracies_l2[-1]:
    print("  Active learning improved accuracy over the baseline with the L2 Norm diversity strategy.")
else:
    print("  Active learning did not improve accuracy over the baseline with the L2 Norm diversity strategy in this run.")
    
if active_learning_accuracies_kl[-1] > baseline_accuracies_kl[-1]:
    print("  Active learning improved accuracy over the baseline with the KL Divergence diversity strategy.")
else:
    print("  Active learning did not improve accuracy over the baseline with the KL Divergence diversity strategy in this run.")

#Highligh effective strategy
best_active_learning_strategy = max(
    {
        "Least Confidence": active_learning_accuracies_lc[-1],
        "Prediction Entropy": active_learning_accuracies_pe[-1],
        "Margin Sampling": active_learning_accuracies_ms[-1],
        "Cosine Similarity": active_learning_accuracies_cs[-1],
        "L2 Norm": active_learning_accuracies_l2[-1],
        "KL Divergence": active_learning_accuracies_kl[-1],
    },
    key=lambda k: max(
        active_learning_accuracies_lc[-1] if k == "Least Confidence" else 0,
        active_learning_accuracies_pe[-1] if k == "Prediction Entropy" else 0,
        active_learning_accuracies_ms[-1] if k == "Margin Sampling" else 0,
        active_learning_accuracies_cs[-1] if k == "Cosine Similarity" else 0,
        active_learning_accuracies_l2[-1] if k == "L2 Norm" else 0,
        active_learning_accuracies_kl[-1] if k == "KL Divergence" else 0,
    ),
)
print("\nMost Effective Strategy: ")
print(f"  In this run, the most effective strategy was: {best_active_learning_strategy}")