# Preparing

In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from load_data import load_mnist
from util import *
from acq_functions import *
import os

np.random.seed(43)

x_train_new, y_train_new, X_p, y_p, x_val, y_val, x_test, y_test = load_mnist()

x_train shape: torch.Size([60000, 1, 28, 28])
60000 train samples, before reduction
10000 test samples


In [2]:
batch_size = 128
epochs = 70
acquired_points = 10
acquisition_times = 100
device = torch.device("mps") if torch.mps.is_available() else "cpu"
device = torch.device("cuda") if torch.cuda.is_available() else device
x_train_new = x_train_new.to(device)
y_train_new = y_train_new.to(device)
x_val = x_val.to(device)
y_val = y_val.to(device)
x_test = x_test.to(device)
y_test = y_test.to(device)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2) # /2 width/height
        )

        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(32 * 11 * 11, 128)
        self.dropout2 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.dropout1(x)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x


In [4]:
acq_funs = {
    "random": random,
    "bald": bald
}


In [5]:
sampling_steps = 5
mc_sampling = False

# Train part local


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import gc

def find_best_decay_local(x_train, y_train):

    weight_decays = [0, 1e-6, 5e-6, 1e-5, 1e-4]
    best_score = 0
    best_model_state = None
    best_i = 0
    train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )

    for i, dec in enumerate(weight_decays):

        model = CNN().to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=dec)
        criterion = rmse_loss
        total_loss = 0
        non_increasing = 0
        best_loss = None

        for epoch in range(epochs):
            model.train()
            for xb, yb in train_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                optimizer.zero_grad()
                logits = model(xb)
                loss = criterion(logits, yb.to(dtype=torch.float32))
                total_loss += loss.item()

                loss.backward()
                optimizer.step()
            
            total_loss = 0

        val_acc = evaluate(x_val.to(device=device), y_val.to(device=device), rmse_loss, model)

        if val_acc > best_score or i == 0:
            best_score = val_acc
            best_i = i
            best_model_state = model.state_dict()

        del model # save space if running locally
        gc.collect()
        if device == "mps":
            torch.mps.empty_cache()
        elif device == "cuda":
            torch.cuda.empty_cache()

    best_model = CNN().to(device)
    best_model.load_state_dict(best_model_state)

    test_acc = evaluate(x_test.to(device=device), y_test.to(device=device), rmse_loss, best_model)
    return best_model, test_acc

In [7]:
def train_once_local_opt(
        x_train_cur, y_train_cur,
        acquisition_fn,
        Xs):

    model_curr, test_score = find_best_decay_local(
        x_train_cur, y_train_cur
    )

    acq_lambda = lambda x: acquisition_fn(100, model_curr, x)
    acq_scores = call_batchwise(acq_lambda, Xs, batch_size=32, device=device)

    del model_curr # save space if running locally
    gc.collect()
    if device == "mps":
        torch.mps.empty_cache()
    elif device == "cuda":
        torch.cuda.empty_cache()
    _, topk_idx = torch.topk(acq_scores, acquired_points)
    return test_score, topk_idx

In [8]:
from tqdm import tqdm

def train_full_local(acquisition_fn, Xs, ys, x_init_train, y_init_train):

    scores = []

    # Copies of initial training set as they are changed
    x_train_cur = x_init_train.clone()
    y_train_cur = y_init_train.clone()

    for _ in tqdm(range(acquisition_times)):
        score, x_new = train_once_local_opt(x_train_cur, y_train_cur, acquisition_fn, Xs)
        x_new_t = torch.tensor(x_new, dtype=torch.long)
        x_train_cur = torch.cat([x_train_cur, Xs[x_new_t.cpu()].to(device)], dim=0)
        y_train_cur = torch.cat([y_train_cur, ys[x_new_t.cpu()].to(device)], dim=0)
        mask = torch.ones(Xs.shape[0], dtype=torch.bool)
        mask[x_new_t.cpu()] = False
        Xs = Xs[mask]
        ys = ys[mask]

        scores.append(score)

    model, final_score = find_best_decay_local(x_train_cur, y_train_cur)
    scores.append(final_score)
    scores = torch.tensor(scores, dtype=torch.float32)

    return scores, model


In [9]:
def train_acquisition_local(acq_fun):
    scores = []
    os.makedirs("./results", exist_ok=True)

    for i in range(3):
        if os.path.exists(f"./results/{i}{str(acq_fun)}_rmse.npy"):
            score = np.load(f"./results/{i}{str(acq_fun)}_rmse.npy")
            score = torch.Tensor(score).to(device=device)
        else:
            score, model = train_full_local(acq_funs[acq_fun], X_p, y_p, x_train_new, y_train_new)
            print(score)
            np.save(f"./results/{i}{str(acq_fun)}_rmse.npy", score.cpu().numpy())
        scores.append(score)

    meaned_scores = torch.stack(scores, dim=0).mean(dim=0)
    return meaned_scores


In [10]:
results_local = {}
for acq_fun in acq_funs:
  res = train_acquisition_local(acq_fun)
  print(res)
  np.save(f"./results/{str(acq_fun)}_rmse.npy", res.cpu().numpy())
  results_local[acq_fun] = res

tensor([0.2525, 0.2499, 0.2324, 0.2272, 0.2207, 0.2158, 0.2129, 0.2068, 0.2036,
        0.1986, 0.1943, 0.2145, 0.1892, 0.1888, 0.1795, 0.1795, 0.1733, 0.1722,
        0.1707, 0.1671, 0.1644, 0.1626, 0.1633, 0.1582, 0.1730, 0.1632, 0.1617,
        0.1586, 0.1545, 0.1540, 0.1483, 0.1530, 0.1518, 0.1519, 0.1450, 0.1438,
        0.1462, 0.1523, 0.1444, 0.1426, 0.1418, 0.1393, 0.1400, 0.1380, 0.1347,
        0.1333, 0.1346, 0.1345, 0.1349, 0.1349, 0.1377, 0.1330, 0.1359, 0.1303,
        0.1332, 0.1295, 0.1324, 0.1334, 0.1295, 0.1288, 0.1301, 0.1322, 0.1254,
        0.1321, 0.1327, 0.1312, 0.1241, 0.1314, 0.1252, 0.1264, 0.1235, 0.1241,
        0.1225, 0.1215, 0.1252, 0.1406, 0.1285, 0.1245, 0.1270, 0.1238, 0.1205,
        0.1237, 0.1227, 0.1210, 0.1198, 0.1239, 0.1210, 0.1200, 0.1245, 0.1201,
        0.1189, 0.1190, 0.1176, 0.1161, 0.1195, 0.1147, 0.1169, 0.1152, 0.1145,
        0.1193, 0.1135], device='mps:0')


  0%|          | 0/100 [00:28<?, ?it/s]


KeyboardInterrupt: 