# Preparing

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from load_data import load_mnist
from util import *
from acq_functions import *
import os

np.random.seed(43)

x_train_new, y_train_new, X_p, y_p, x_val, y_val, x_test, y_test = load_mnist()

In [None]:
batch_size = 128
epochs = 70
acquired_points = 10
acquisition_times = 100
device = torch.device("mps") if torch.mps.is_available() else "cpu"
device = torch.device("cuda") if torch.cuda.is_available() else device
x_train_new = x_train_new.to(device)
y_train_new = y_train_new.to(device)
x_val = x_val.to(device)
y_val = y_val.to(device)
x_test = x_test.to(device)
y_test = y_test.to(device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2) # /2 width/height
        )

        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(32 * 11 * 11, 128)
        self.dropout2 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.dropout1(x)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.softmax(x, dim=1)


In [None]:
acq_funs = {
    "bald": bald,
    "max_entropy": max_entropy,
    "var_ratios": var_ratios,
    "mean_std": mean_std,
    "random": random
}


In [None]:
sampling_steps = 5
mc_sampling = False

# Train part local


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gc

def find_best_decay_local(x_train, y_train):

    weight_decays = [0, 1e-6, 5e-6, 1e-5, 1e-4]
    best_score = 0
    best_model_state = None
    best_i = 0
    train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )

    for i, dec in enumerate(weight_decays):

        model = CNN().to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=dec)
        criterion = nn.CrossEntropyLoss()
        total_loss = 0
        non_increasing = 0
        best_loss = None

        for epoch in range(epochs):
            model.train()
            for xb, yb in train_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                optimizer.zero_grad()
                logits = model(xb)
                loss = criterion(logits, yb.to(dtype=torch.float32))
                total_loss += loss.item()

                loss.backward()
                optimizer.step()
            
            total_loss = 0
    
        val_acc = accuracy_classification(x_val, y_val, model, device=device, batch_size=batch_size)

        if val_acc > best_score or i == 0:
            best_score = val_acc
            best_i = i
            best_model_state = model.state_dict()

        del model # save space if running locally
        gc.collect()
        if device == "mps":
            torch.mps.empty_cache()
        elif device == "cuda":
            torch.cuda.empty_cache()

    best_model = CNN().to(device)
    best_model.load_state_dict(best_model_state)

    test_acc = accuracy_classification(x_test, y_test, best_model, device=device, batch_size=batch_size)
    return best_model, test_acc

In [None]:
def train_once_local_opt(
        x_train_cur, y_train_cur,
        acquisition_fn,
        Xs):

    model_curr, test_score = find_best_decay_local(
        x_train_cur, y_train_cur
    )

    acq_lambda = lambda x: acquisition_fn(100, model_curr, x)
    acq_scores = call_batchwise(acq_lambda, Xs, batch_size=32, device=device)

    del model_curr # save space if running locally
    gc.collect()
    if device == "mps":
        torch.mps.empty_cache()
    elif device == "cuda":
        torch.cuda.empty_cache()
    _, topk_idx = torch.topk(acq_scores, acquired_points)
    return test_score, topk_idx

In [None]:
from tqdm import tqdm

def train_full_local(acquisition_fn, Xs, ys, x_init_train, y_init_train):

    scores = []

    # Copies of initial training set as they are changed
    x_train_cur = x_init_train.clone()
    y_train_cur = y_init_train.clone()

    for _ in tqdm(range(acquisition_times)):
        score, x_new = train_once_local_opt(x_train_cur, y_train_cur, acquisition_fn, Xs)
        x_new_t = torch.tensor(x_new, dtype=torch.long)
        x_train_cur = torch.cat([x_train_cur, Xs[x_new_t.cpu()].to(device)], dim=0)
        y_train_cur = torch.cat([y_train_cur, ys[x_new_t.cpu()].to(device)], dim=0)
        mask = torch.ones(Xs.shape[0], dtype=torch.bool)
        mask[x_new_t.cpu()] = False
        Xs = Xs[mask]
        ys = ys[mask]

        scores.append(score)

    model, final_score = find_best_decay_local(x_train_cur, y_train_cur)
    scores.append(final_score)
    scores = torch.tensor(scores, dtype=torch.float32)

    return scores, model


In [None]:
def train_acquisition_local(acq_fun):
    scores = []
    os.makedirs("./results", exist_ok=True)

    for i in range(3):
        if os.path.exists(f"./results/{i}{str(acq_fun)}_local.npy"):
            score = np.load(f"./results/{i}{str(acq_fun)}_local.npy")
            score = torch.Tensor(score).to(device=device)
        else:
            score, model = train_full_local(acq_funs[acq_fun], X_p, y_p, x_train_new, y_train_new)
            print(score)
            np.save(f"./results/{i}{str(acq_fun)}_local.npy", score.cpu().numpy())
        scores.append(score)

    meaned_scores = torch.stack(scores, dim=0).mean(dim=0)
    return meaned_scores


In [None]:
results_local = {}
for acq_fun in acq_funs:
  res = train_acquisition_local(acq_fun)
  print(res)
  np.save(f"./results/{str(acq_fun)}_local.npy", res.cpu().numpy())
  results_local[acq_fun] = res

In [None]:
import matplotlib.ticker as mt
steps = list(range(20, 20 + acquired_points * acquisition_times + 1, acquired_points))
for key in results_local:
  plt.plot(steps, results_local[key].cpu().numpy(), label=key)
ax = plt.gca()
ax.yaxis.set_major_locator(mt.MultipleLocator(0.05))
plt.grid()
plt.ylim(bottom=0.55)
plt.xlim(left=20, right=1020)
plt.xlabel("Number of points")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig("./local_acq_plot.svg")

# Deterministic Acquisition

In [None]:
def deterministic_max_entropy(T, model, x):
    model.eval()
    with torch.no_grad():
        outputs = model(x)
        individual = outputs * torch.log(outputs + 1e-10)
        return -torch.sum(individual, dim=-1)


def deterministic_bald(T, model, x):
    model.eval()
    with torch.no_grad():
        me = deterministic_max_entropy(1, model, x)
        outputs = model(x)
        outputs = outputs * torch.log(outputs + 1e-10)
        outputs = -torch.sum(outputs, dim=-1)

        return me - outputs


def deterministic_var_ratios(T, model, x):
    model.eval()
    with torch.no_grad():
        outputs = model(x)
        return 1.0 - torch.max(outputs, dim=-1).values

det_acq_funs = {"deterministic_max_entropy": deterministic_max_entropy, "deterministic_bald": deterministic_bald, "deterministic_var_ratios": deterministic_var_ratios}

In [None]:
def train_acquisition_local(acq_fun):
    scores = []
    os.makedirs("./results", exist_ok=True)

    for i in range(3):
        if os.path.exists(f"./results/{i}{str(acq_fun)}_local.npy"):
            score = np.load(f"./results/{i}{str(acq_fun)}_local.npy")
            score = torch.Tensor(score).to(device=device)
        else:
            score, model = train_full_local(det_acq_funs[acq_fun], X_p, y_p, x_train_new, y_train_new)
            print(score)
            np.save(f"./results/{i}{str(acq_fun)}_local.npy", score.cpu().numpy())
        scores.append(score)

    meaned_scores = torch.stack(scores, dim=0).mean(dim=0)
    return meaned_scores

In [None]:
results_local_det = {}
for acq_fun in det_acq_funs:
  res = train_acquisition_local(acq_fun)
  print(res)
  np.save(f"./{str(acq_fun)}.npy", res.cpu().numpy())
  results_local_det[acq_fun] = res

In [None]:
steps = list(range(20, 20 + acquired_points * acquisition_times + 1, acquired_points))
for key in results_local_det:
  plt.plot(steps, results_local_det[key].cpu().numpy(), label=key)
ax = plt.gca()
ax.yaxis.set_major_locator(mt.MultipleLocator(0.05))
plt.grid()
plt.tight_layout()
plt.ylim(bottom=0.55)
plt.xlabel("Number of points")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig("./det_acq_plot.svg")