In [None]:
import os
import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from scipy import stats

def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class ConvNN(nn.Module):
    def __init__(
        self,
        num_filters: int = 32,
        kernel_size: int = 4,
        dense_layer: int = 128,
        img_rows: int = 28,
        img_cols: int = 28,
    ) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, num_filters, kernel_size, stride=1)
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size, stride=1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        pooled_h = (img_rows - 2 * kernel_size + 2) // 2
        pooled_w = (img_cols - 2 * kernel_size + 2) // 2
        self.fc1 = nn.Linear(num_filters * pooled_h * pooled_w, dense_layer)
        self.fc2 = nn.Linear(dense_layer, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        return self.fc2(x)

class LoadData:
    """Download, split, and prepare MNIST for active learning."""

    def __init__(
        self,
        val_size: int = 100,
        train_size: int = 10000,
        seed: int = 369,
        root: str = "data",
    ) -> None:
        self.train_size = train_size
        self.val_size = val_size
        self.seed = seed
        self.root = root
        self.mnist_train, self.mnist_test = self.download_dataset()
        self.pool_size = len(self.mnist_train) - self.train_size - self.val_size
        (
            self.X_train_All,
            self.y_train_All,
            self.X_val,
            self.y_val,
            self.X_pool,
            self.y_pool,
            self.X_test,
            self.y_test,
        ) = self.split_and_load_dataset()
        self.X_init, self.y_init = self.preprocess_training_data()

    def tensor_to_np(self, tensor_data: torch.Tensor) -> np.ndarray:
        return tensor_data.detach().cpu().numpy()

    def check_mnist_folder(self) -> bool:
        return not os.path.exists(os.path.join(self.root, "MNIST"))

    def download_dataset(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        download = self.check_mnist_folder()
        mnist_train = MNIST(
            self.root, train=True, download=download, transform=transform
        )
        mnist_test = MNIST(self.root, train=False, download=download, transform=transform)
        return mnist_train, mnist_test

    def split_and_load_dataset(self):
        generator = torch.Generator().manual_seed(self.seed)
        train_set, val_set, pool_set = random_split(
            self.mnist_train,
            [self.train_size, self.val_size, self.pool_size],
            generator=generator,
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=self.train_size, shuffle=True
        )
        val_loader = DataLoader(dataset=val_set, batch_size=self.val_size, shuffle=True)
        pool_loader = DataLoader(
            dataset=pool_set, batch_size=self.pool_size, shuffle=True
        )
        test_loader = DataLoader(dataset=self.mnist_test, batch_size=10000, shuffle=True)
        X_train_All, y_train_All = next(iter(train_loader))
        X_val, y_val = next(iter(val_loader))
        X_pool, y_pool = next(iter(pool_loader))
        X_test, y_test = next(iter(test_loader))
        return X_train_All, y_train_All, X_val, y_val, X_pool, y_pool, X_test, y_test

    def preprocess_training_data(self):
        initial_idx: np.ndarray = np.array([], dtype=int)
        for i in range(10):
            candidates = np.where(self.y_train_All.numpy() == i)[0]
            idx = np.random.choice(candidates, size=2, replace=False)
            initial_idx = np.concatenate((initial_idx, idx))
        X_init = self.X_train_All[initial_idx]
        y_init = self.y_train_All[initial_idx]
        print(f"Initial training data points: {X_init.shape[0]}")
        print(f"Data distribution for each class: {np.bincount(y_init.numpy())}")
        return X_init, y_init

    def load_all(self):
        return (
            self.tensor_to_np(self.X_init),
            self.tensor_to_np(self.y_init),
            self.tensor_to_np(self.X_val),
            self.tensor_to_np(self.y_val),
            self.tensor_to_np(self.X_pool),
            self.tensor_to_np(self.y_pool),
            self.tensor_to_np(self.X_test),
            self.tensor_to_np(self.y_test),
        )

@torch.no_grad()
def _forward_probs(model: nn.Module, batch: torch.Tensor, training: bool) -> torch.Tensor:
    if training:
        model.train()
    else:
        model.eval()
    logits = model(batch)
    return torch.softmax(logits, dim=-1)


@torch.no_grad()
def predictions_from_pool(
    model: nn.Module,
    X_pool: np.ndarray,
    T: int = 100,
    training: bool = True,
    subset_size: int = 2000,
    device: torch.device | str | None = None,
 ):
    subset_size = min(subset_size, len(X_pool))
    random_subset = np.random.choice(range(len(X_pool)), size=subset_size, replace=False)
    x_tensor = torch.from_numpy(X_pool[random_subset]).to(device)
    outputs = [
        _forward_probs(model, x_tensor, training=training).cpu().numpy()
        for _ in range(T)
    ]
    return np.stack(outputs), random_subset


def uniform(model: nn.Module, X_pool: np.ndarray, n_query: int = 10, **_):
    n_query = min(n_query, len(X_pool))
    query_idx = np.random.choice(range(len(X_pool)), size=n_query, replace=False)
    return query_idx, X_pool[query_idx]


def shannon_entropy_function(
    model: nn.Module,
    X_pool: np.ndarray,
    T: int = 100,
    E_H: bool = False,
    training: bool = True,
    device: torch.device | str | None = None,
 ):
    outputs, random_subset = predictions_from_pool(
        model, X_pool, T=T, training=training, device=device
    )
    pc = outputs.mean(axis=0)
    H = (-pc * np.log(pc + 1e-10)).sum(axis=-1)
    if E_H:
        E = -np.mean(np.sum(outputs * np.log(outputs + 1e-10), axis=-1), axis=0)
        return H, E, random_subset
    return H, random_subset


def max_entropy(model: nn.Module, X_pool: np.ndarray, n_query: int = 10, T: int = 100, training: bool = True, device=None):
    acquisition, random_subset = shannon_entropy_function(
        model, X_pool, T=T, training=training, device=device
    )
    idx = (-acquisition).argsort()[:n_query]
    query_idx = random_subset[idx]
    return query_idx, X_pool[query_idx]


def bald(model: nn.Module, X_pool: np.ndarray, n_query: int = 10, T: int = 100, training: bool = True, device=None):
    H, E_H, random_subset = shannon_entropy_function(
        model, X_pool, T=T, E_H=True, training=training, device=device
    )
    acquisition = H - E_H
    idx = (-acquisition).argsort()[:n_query]
    query_idx = random_subset[idx]
    return query_idx, X_pool[query_idx]


def var_ratios(model: nn.Module, X_pool: np.ndarray, n_query: int = 10, T: int = 100, training: bool = True, device=None):
    outputs, random_subset = predictions_from_pool(
        model, X_pool, T=T, training=training, device=device
    )
    preds = np.argmax(outputs, axis=2)
    _, count = stats.mode(preds, axis=0, keepdims=False)
    acquisition = (1 - count / preds.shape[0]).reshape((-1,))
    idx = (-acquisition).argsort()[:n_query]
    query_idx = random_subset[idx]
    return query_idx, X_pool[query_idx]


def mean_std(model: nn.Module, X_pool: np.ndarray, n_query: int = 10, T: int = 100, training: bool = True, device=None):
    outputs, random_subset = predictions_from_pool(
        model, X_pool, T=T, training=training, device=device
    )
    sigma_c = np.std(outputs, axis=0)
    acquisition = np.mean(sigma_c, axis=-1)
    idx = (-acquisition).argsort()[:n_query]
    query_idx = random_subset[idx]
    return query_idx, X_pool[query_idx]


def select_acq_function(acq_func: int = 0):
    acq_func_dict = {
        0: [uniform, max_entropy, bald, var_ratios, mean_std],
        1: [uniform],
        2: [max_entropy],
        3: [bald],
        4: [var_ratios],
        5: [mean_std],
    }
    return acq_func_dict[acq_func]


def _make_loader(X: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool = True):
    dataset = TensorDataset(torch.from_numpy(X).float(), torch.from_numpy(y).long())
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)


def _train_model(
    model: nn.Module,
    loader: DataLoader,
    epochs: int,
    lr: float,
    weight_decay: float,
    device: torch.device,
):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    model.train()
    for _ in range(epochs):
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()
    return model


def _accuracy(model: nn.Module, X: np.ndarray, y: np.ndarray, device: torch.device) -> float:
    model.eval()
    with torch.no_grad():
        xb = torch.from_numpy(X).float().to(device)
        yb = torch.from_numpy(y).long().to(device)
        preds = torch.argmax(model(xb), dim=1)
        return float((preds == yb).float().mean().cpu().item())


def active_learning_procedure(
    query_strategy,
    X_val,
    y_val,
    X_test,
    y_test,
    X_pool,
    y_pool,
    X_init,
    y_init,
    build_model,
    *,
    T: int = 100,
    n_query: int = 10,
    training: bool = True,
    batch_size: int = 128,
    epochs: int = 50,
    lr: float = 1e-3,
    weight_decay: float = 1e-2,
    device: torch.device,
):
    model = build_model().to(device)
    train_loader = _make_loader(X_init, y_init, batch_size, shuffle=True)
    model = _train_model(
        model, train_loader, epochs=epochs, lr=lr, weight_decay=weight_decay, device=device
    )

    perf_hist = [_accuracy(model, X_test, y_test, device=device)]

    for index in range(T):
        query_idx, _ = query_strategy(
            model,
            X_pool,
            n_query=n_query,
            T=T,
            training=training,
            device=device,
        )

        X_train = np.concatenate([X_init, X_pool[query_idx]], axis=0)
        y_train = np.concatenate([y_init, y_pool[query_idx]], axis=0)
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)

        model = build_model().to(device)
        train_loader = _make_loader(X_train, y_train, batch_size, shuffle=True)
        model = _train_model(
            model,
            train_loader,
            epochs=epochs,
            lr=lr,
            weight_decay=weight_decay,
            device=device,
        )

        val_acc = _accuracy(model, X_val, y_val, device=device)
        if (index + 1) % 5 == 0:
            print(f"Val Accuracy after query {index+1}: {val_acc:0.4f}")
        perf_hist.append(val_acc)

        X_init, y_init = X_train, y_train

    final_test_acc = _accuracy(model, X_test, y_test, device=device)
    print(f"********** Test Accuracy per experiment: {final_test_acc:.4f} **********")
    return perf_hist, final_test_acc


def train_active_learning(args, device, datasets):
    acq_functions = select_acq_function(args.acq_func)
    results = {}
    state_loop = [True] if not args.determ else [True, False]

    for state in state_loop:
        for acq_func in acq_functions:
            avg_hist = []
            test_scores = []
            acq_func_name = f"{acq_func.__name__}-MC_dropout={state}"
            print(f"\n---------- Start {acq_func_name} training! ----------")
            for e in range(args.experiments):
                set_seed(args.seed + e)  # Different seed per experiment
                print(f"********** Experiment Iterations: {e + 1}/{args.experiments} **********")
                training_hist, test_score = active_learning_procedure(
                    query_strategy=acq_func,
                    X_val=datasets["X_val"],
                    y_val=datasets["y_val"],
                    X_test=datasets["X_test"],
                    y_test=datasets["y_test"],
                    X_pool=datasets["X_pool"],
                    y_pool=datasets["y_pool"],
                    X_init=datasets["X_init"],
                    y_init=datasets["y_init"],
                    build_model=ConvNN,
                    T=args.dropout_iter,
                    n_query=args.query,
                    training=state,
                    batch_size=args.batch_size,
                    epochs=args.epochs,
                    lr=args.lr,
                    weight_decay=args.weight_decay,
                    device=device,
                )
                avg_hist.append(training_hist)
                test_scores.append(test_score)
            avg_hist_arr = np.average(np.array(avg_hist), axis=0)
            avg_test = sum(test_scores) / len(test_scores)
            print(f"Average Test score for {acq_func_name}: {avg_test}")
            results[acq_func_name] = avg_hist_arr
    return results


# Acquisition function codes:
# 0: all (uniform, max_entropy, bald, var_ratios, mean_std)
# 1: uniform only
# 2: max_entropy only
# 3: bald only
# 4: var_ratios only
# 5: mean_std only

args = argparse.Namespace(
    batch_size=128,
    epochs=50,
    lr=1e-3,
    weight_decay=1e-2,
    seed=369,
    experiments=1,
    dropout_iter=50,
    query=10,
    acq_func=3,   # all acquisition functions
    val_size=100,
    determ=True,  # Run both deterministic and Bayesian
    result_dir="result_npy_exp2",
)

set_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

data_loader = LoadData(val_size=args.val_size, seed=args.seed)
(X_init, y_init, X_val, y_val, X_pool, y_pool, X_test, y_test) = data_loader.load_all()

datasets = {
    "X_init": X_init,
    "y_init": y_init,
    "X_val": X_val,
    "y_val": y_val,
    "X_pool": X_pool,
    "y_pool": y_pool,
    "X_test": X_test,
    "y_test": y_test,
}

results = train_active_learning(args, device, datasets)
results

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

friendly_labels = {
    "uniform-MC_dropout=True": "Uniform (Bayesian)",
    "uniform-MC_dropout=False": "Uniform (Deterministic)",
    "max_entropy-MC_dropout=True": "Max Entropy (Bayesian)",
    "max_entropy-MC_dropout=False": "Max Entropy (Deterministic)",
    "bald-MC_dropout=True": "BALD (Bayesian)",
    "bald-MC_dropout=False": "BALD (Deterministic)",
    "var_ratios-MC_dropout=True": "Variation Ratios (Bayesian)",
    "var_ratios-MC_dropout=False": "Variation Ratios (Deterministic)",
    "mean_std-MC_dropout=True": "Mean STD (Bayesian)",
    "mean_std-MC_dropout=False": "Mean STD (Deterministic)",
}

plt.figure(figsize=(10, 6))
for name, curve in results.items():
    label = friendly_labels.get(name, name)
    is_det = "Deterministic" in label
    color = "red" if is_det else "blue"
    linestyle = "-" if is_det else "--"
    plt.plot(curve, label=label, linestyle=linestyle, color=color)
plt.xlabel("Acquisition step")
plt.ylabel("Accuracy")
plt.title("Active learning accuracy vs queries (Experiment 2)")
plt.legend()
plt.grid(True)
plt.show()