In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
import os
from util import *
from load_data import *
np.random.seed(43)

batch_size = 128
epochs = 70
acquired_points = 10
num_classes = 10
acquisition_times = 100
data_variances = [0.05, 0.02, 0.01, 0.1, 0.005]
device = torch.device("mps") if torch.mps.is_available() else "cpu"
device = torch.device("cuda") if torch.cuda.is_available() else device
x_train_new, y_train_new, X_p, y_p, x_val, y_val, x_test, y_test = load_mnist()
x_train_new = x_train_new.to(dtype=torch.float32).to(device)
y_train_new = y_train_new.to(dtype=torch.float32).to(device)
X_p = X_p.to(dtype=torch.float32)
y_p = y_p.to(dtype=torch.float32)
x_val = x_val.to(device)
y_val = y_val.to(device)
x_test = x_test.to(device)
y_test = y_test.to(device)

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 480kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.50MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.63MB/s]


x_train shape: torch.Size([60000, 1, 28, 28])
60000 train samples, before reduction
10000 test samples


In [7]:
weights_prior_std = 1

class RBFKernel(torch.nn.Module):
    def __init__(self, d):
        super().__init__()
        self.log_lengthscale = torch.nn.Parameter(torch.zeros(d))
        self.log_variance = torch.nn.Parameter(torch.tensor(0.0))

    def forward(self, X, Y=None):
        if Y is None:
            Y = X
        
        X = X / torch.exp(self.log_lengthscale)
        Y = Y / torch.exp(self.log_lengthscale)
        dist2 = torch.cdist(X, Y) ** 2
        return torch.exp(self.log_variance) * torch.exp(-0.5 * dist2)

In [8]:
class CNN(torch.nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2), # /2 width/height,
            nn.Flatten(),
            nn.Dropout(0.25),
            nn.Linear(32 * 11 * 11, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        self.fc = nn.Sequential(
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

In [9]:
class RBFCNN(torch.nn.Module):
    def __init__(self, noise=0.1, num_classes=10):
        super().__init__()
        self.noise = noise
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2), # /2 width/height,
            nn.Flatten(),
            nn.Dropout(0.25),
            nn.Linear(32 * 11 * 11, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
            
        )
        self.kernel = RBFKernel(128)
        # nn.init.normal_(self.fc2.weight, mean=0.0, std=weights_prior_std)
        # nn.init.zeros_(self.fc2.bias)
        self.big_kernel_inv = None
        self.X_t = None
        self.Y_t = None
    
    def set_trainset(self, X, Y):
        self.X_t = X
        self.Y_t = Y

    def set_big_kernel_inv(self, X):
        K = self.kernel(self.conv(X), self.conv(X))
        self.big_kernel_inv = torch.linalg.inv(K + ((self.noise ** 2) / (weights_prior_std ** 2)) * torch.eye(K.shape[0], device=K.device))
    
    def pred_var(self, x, X=None, Y=None):
        if self.training:
            K = self.kernel(self.conv(X), self.conv(X))
            big_inv = torch.linalg.inv(K + ((self.noise ** 2) / (weights_prior_std ** 2)) * torch.eye(K.shape[0], device=K.device))
        else:
            big_inv = self.big_kernel_inv
            X = self.X_t

        phi_x = self.conv(x)
        phi_X = self.conv(X)
        k_ast = self.kernel(phi_X, phi_x)
        pred_var = (weights_prior_std ** 2) * self.kernel(phi_x, phi_x).diagonal() - (weights_prior_std ** 2) * (self.kernel(phi_x, phi_X) @ big_inv @ k_ast).diagonal()
        return pred_var

    def forward(self, x, X=None, Y=None):
        if self.training:
            K = self.kernel(self.conv(X), self.conv(X))
            big_inv = torch.linalg.inv(K + ((self.noise ** 2) / (weights_prior_std ** 2)) * torch.eye(K.shape[0], device=K.device))
        else:
            big_inv = self.big_kernel_inv
            X = self.X_t
            Y = self.Y_t
        phi_x = self.conv(x)
        phi_X = self.conv(X)
        k_ast = self.kernel(phi_X, phi_x)
        pred_mean = (Y.T @ big_inv @ k_ast).T
        
        return pred_mean
    

In [10]:
def nll_kernel(X, Y, x, y, model):
    ix = np.zeros(y.shape[-1])
    mean = model.forward(x, X, Y)
    var = model.pred_var(x, X, Y).unsqueeze(1)[:, ix]


    diff = ((y - mean) ** 2) / var \
            + torch.log(torch.tensor(2.0 * math.pi, device=Y.device)) \
            + torch.log(var)

    final_res = 0.5 * torch.sum(diff)
    return final_res

In [11]:
def evaluate_kernel(X, Y, x_val, y_val):
    loss = nll_kernel(X, Y, x_val, y_val)
    return loss.item()

# Pipeline

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import gc

def find_best_decay_local_cnn(x_train, y_train, m_type=RBFCNN):
    weight_decays = [0, 1e-6, 5e-6, 1e-5, 1e-4]
    best_score = 0
    best_model_state = None
    best_i = 0
    train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    context_ratio = 0.9

    for i, dec in enumerate(weight_decays):
        model = m_type().to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=dec)

        for epoch in range(epochs):
            model.train()
            for xb, yb in train_loader:
                xb = xb.to(device)
                yb = yb.to(device).to(dtype=torch.float32)
                B = xb.size(0)
                idx = torch.randperm(B)

                C = int(context_ratio * B)
                ctx_idx = idx[:C]
                tgt_idx = idx[C:]
                Xc, Yc = xb[ctx_idx], yb[ctx_idx]
                Xt, Yt = xb[tgt_idx], yb[tgt_idx]

                optimizer.zero_grad()
                loss = nll_kernel(Xc, Yc, Xt, Yt, model)
                loss.backward()
                optimizer.step()

        model.eval()
        model.set_trainset(x_train, y_train)
        model.set_big_kernel_inv(x_train)
        val_acc = evaluate(x_val.to(device=device), y_val.to(device=device), rmse_loss, model)

        if val_acc > best_score or i == 0:
            best_score = val_acc
            best_i = i
            best_model_state = model.state_dict()

        del model # save space if running locally
        gc.collect()
        if device == "mps":
            torch.mps.empty_cache()
        elif device == "cuda":
            torch.cuda.empty_cache()

    best_model = m_type()
    best_model.load_state_dict(best_model_state)
    best_model = best_model.to(device=device)
    best_model.set_trainset(x_train, y_train)
    best_model.set_big_kernel_inv(x_train)

    test_acc = evaluate(x_test.to(device=device), y_test.to(device=device), rmse_loss, best_model)
    return best_model, test_acc

In [13]:
def compute_var(T, model, x):
    return (-1) * model.pred_var(x)

In [14]:
def train_once_local_pred_var(x_train_cur, y_train_cur, Xs, model_t=RBFCNN):
  model, test_score = find_best_decay_local_cnn(x_train_cur, y_train_cur, model_t)
  acq_lambda = lambda x: compute_var(100, model, x)
  acq_scores = call_batchwise(acq_lambda, Xs, batch_size=64, device=device)
  x_new = acq_scores.topk(acquired_points).indices.cpu().numpy()
  return test_score, x_new

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gc

def find_best_decay(x_train, y_train):

    weight_decays = [0, 1e-6, 5e-6, 1e-5, 1e-4]
    best_score = 0
    best_model_state = None
    best_i = 0
    train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )

    for i, dec in enumerate(weight_decays):

        model = CNN().to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=dec)
        criterion = nn.CrossEntropyLoss()
        total_loss = 0
        non_increasing = 0
        best_loss = None

        for epoch in range(epochs):
            model.train()
            for xb, yb in train_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                optimizer.zero_grad()
                logits = model(xb)
                loss = criterion(logits, yb.to(dtype=torch.float32))
                total_loss += loss.item()

                loss.backward()
                optimizer.step()
            
            total_loss = 0
    
        val_acc = accuracy_classification(x_val, y_val, model, device=device, batch_size=batch_size)

        if val_acc > best_score or i == 0:
            best_score = val_acc
            best_i = i
            best_model_state = model.state_dict()

        del model # save space if running locally
        gc.collect()
        if device == "mps":
            torch.mps.empty_cache()
        elif device == "cuda":
            torch.cuda.empty_cache()

    best_model = CNN().to(device)
    best_model.load_state_dict(best_model_state)

    test_acc = accuracy_classification(x_test, y_test, best_model, device=device, batch_size=batch_size)
    return best_model, test_acc

In [15]:
from tqdm import tqdm
def train_full_local_pv(Xs, ys, x_init_train, y_init_train, model_t=RBFCNN):
    scores = []
    x_train_cur = x_init_train.detach().clone()
    y_train_cur = y_init_train.detach().clone()
    for i in tqdm(range(acquisition_times)):
        _, score = find_best_decay(x_train_cur, y_train_cur)
        _, x_new = train_once_local_pred_var(x_train_cur, y_train_cur, Xs, model_t)
        x_new_t = torch.tensor(x_new, dtype=torch.long)
        x_train_cur = torch.cat([x_train_cur, Xs[x_new_t.cpu()].to(device)], dim=0)
        y_train_cur = torch.cat([y_train_cur, ys[x_new_t.cpu()].to(device)], dim=0)
        mask = torch.ones(Xs.shape[0], dtype=torch.bool)
        mask[x_new_t.cpu()] = False
        Xs = Xs[mask]
        ys = ys[mask]

        scores.append(score)

    _, final_score = find_best_decay(x_train_cur, y_train_cur)
    scores.append(final_score)
    scores = torch.tensor(scores, dtype=torch.float32)

    return scores, _

In [16]:
def train_acquisition_kernel():
  os.makedirs("./model_artifacts", exist_ok=True)
  os.makedirs("./vi_results", exist_ok=True)
  scores = []
  for i in range(3):
    # if os.path.exists(f"./vi_results/{i}np_local.npy"):
    #   score = np.load(f"./vi_results/{i}np_local.npy")
    #   score = torch.Tensor(score)
    # else:
    score, _ = train_full_local_pv(X_p, y_p, x_train_new, y_train_new)
    np.save(f"./vi_results/{i}kernel_local.npy", score.detach().cpu().numpy())
    print(score)
    scores.append(score)    
  meaned_scores = torch.mean(torch.stack(scores), dim=0)

  return meaned_scores

In [None]:
res = train_acquisition_kernel()
print(res)
np.save(f"./vi_results/kernel_results.npy", res.numpy())

 72%|███████▏  | 72/100 [35:49<24:20, 52.17s/it]