In [4]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from load_data import *
from util import *
import os
np.random.seed(43)

batch_size = 128
epochs = 70
acquired_points = 10
num_classes = 10
acquisition_times = 100
data_variances = [0.05, 0.02, 0.01, 0.1, 0.005]
device = torch.device("mps") if torch.mps.is_available() else "cpu"
device = torch.device("cuda") if torch.cuda.is_available() else device
x_train_new, y_train_new, X_p, y_p, x_val, y_val, x_test, y_test = load_mnist()
x_train_new = x_train_new.to(dtype=torch.float32).to(device)
y_train_new = y_train_new.to(dtype=torch.float32).to(device)
X_p = X_p.to(dtype=torch.float32)
y_p = y_p.to(dtype=torch.float32)
x_val = x_val.to(device)
y_val = y_val.to(device)
x_test = x_test.to(device)
y_test = y_test.to(device)

x_train shape: torch.Size([60000, 1, 28, 28])
60000 train samples, before reduction
10000 test samples


In [5]:
weights_prior_std = 0.01

class RBFKernel(torch.nn.Module):
    def __init__(self, d):
        super().__init__()
        self.log_lengthscale = torch.nn.Parameter(torch.zeros(d))
        self.log_variance = torch.nn.Parameter(torch.tensor(0.0))

    def forward(self, X, Y=None):
        if Y is None:
            Y = X
        
        X = X / torch.exp(self.log_lengthscale)
        Y = Y / torch.exp(self.log_lengthscale)
        dist2 = torch.cdist(X, Y) ** 2
        print(dist2.shape)
        return torch.exp(self.log_variance) * torch.exp(-0.5 * dist2)

In [17]:
class CNN(torch.nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=4), # -3 width/height
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2), # /2 width/height,
            nn.Flatten(),
            nn.Dropout(0.25),
            nn.Linear(32 * 11 * 11, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        self.fc = nn.Sequential(
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

In [30]:
class RBFCNN(torch.nn.Module):
    def __init__(self, cnn, noise=0.1, num_classes=10):
        super().__init__()
        self.noise = noise

        if cnn is not None:
            self.conv = cnn
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(1, 32, kernel_size=4), # -3 width/height
                nn.ReLU(),
                nn.Conv2d(32, 32, kernel_size=4), # -3 width/height
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2), # /2 width/height,
                nn.Flatten(),
                nn.Dropout(0.25),
                nn.Linear(32 * 11 * 11, 256),
                nn.ReLU(),
                nn.Linear(256, 128)
                
            )
        self.kernel = RBFKernel(128)
        # nn.init.normal_(self.fc2.weight, mean=0.0, std=weights_prior_std)
        # nn.init.zeros_(self.fc2.bias)
        self.big_kernel_inv = None
    
    def set_big_kernel_inv(self, X):
        K = self.kernel(self.conv(X), self.conv(X))
        self.big_kernel_inv = torch.linalg.inv(K + (self.noise ** 2) / (weights_prior_std ** 2))
    
    def pred_var(self, X, Y, x):
        phi_x = self.conv(x)
        phi_X = self.conv(X)
        k_ast = self.kernel(phi_X, phi_x)
        pred_var = (weights_prior_std ** 2) * self.kernel(phi_x, phi_x) - (weights_prior_std ** 2) * self.kernel(phi_x, phi_X) @ self.big_kernel_inv @ k_ast
        return pred_var

    def forward(self, X, Y, x):
        phi_x = self.conv(x)
        phi_X = self.conv(X)
        k_ast = self.kernel(phi_X, phi_x)
        pred_mean = Y.T @ self.big_kernel_inv @ k_ast
        
        return pred_mean
    

In [40]:
def nll_marginal_kernel(X, Y, model: RBFCNN):
    K = model.kernel(model.conv(X))
    loss = Y.T @ torch.linalg.inv(K + (model.noise ** 2) / (weights_prior_std ** 2)) + torch.log(torch.linalg.det(K))
    return loss.sum()

# Pipeline

In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
import gc

def find_best_decay_local_cnn(x_train, y_train, m_type=CNN):
    weight_decays = [0, 1e-6, 5e-6, 1e-5, 1e-4]
    best_score = 0
    best_model_state = None
    best_i = 0
    train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )

    for i, dec in enumerate(weight_decays):
        model = m_type().to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=dec)
        total_loss = 0
        non_increasing = 0
        best_loss = None

        for epoch in range(epochs):
            model.train()
            for xb, yb in train_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                optimizer.zero_grad()
                logits = model(xb)
                loss = rmse_loss(logits, yb.to(dtype=torch.float32))
                total_loss += loss.item()

                loss.backward()
                optimizer.step()
            if best_loss is None or total_loss < best_loss:
                non_increasing = 0
                best_loss = total_loss
            else:
                if non_increasing == 4:
                    break
                non_increasing += 1
            total_loss = 0

        val_acc = evaluate(x_val.to(device=device), y_val.to(device=device), rmse_loss, model)

        if val_acc > best_score or i == 0:
            best_score = val_acc
            best_i = i
            best_model_state = model.state_dict()

        del model # save space if running locally
        gc.collect()
        if device == "mps":
            torch.mps.empty_cache()
        elif device == "cuda":
            torch.cuda.empty_cache()

    best_model = m_type()
    best_model.load_state_dict(best_model_state)
    best_model = best_model.to(device=device)

    test_acc = evaluate(x_test.to(device=device), y_test.to(device=device), rmse_loss, best_model)
    return best_model, test_acc

In [33]:
def fit_kernel(model, X, y):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for _ in range(100):
        optimizer.zero_grad()
        loss = nll_marginal_kernel(X, y, model)
        loss.backward()
        optimizer.step()

In [34]:
def find_data_variance_hyperparam(x_train_cur, y_train_cur, cnn):
  best_model = None
  best_loss = None
  for sigma in data_variances:
    model = RBFCNN(cnn, sigma).to(device=device)
    fit_kernel(model, x_train_cur, y_train_cur)
    score = evaluate(x_val.to(device=device), y_val.to(device=device), rmse_loss, model)
    if best_loss is None or score < best_loss:
      best_model = model
      best_loss = score
  test_score = evaluate(x_test, y_test, rmse_loss, best_model)
  return best_model, test_score

In [35]:
def compute_var(T, model, x):
    return model.pred_var(x)

In [36]:
def train_once_local_pred_var(x_train_cur, y_train_cur, Xs, model_t=CNN):
  cnn_mod, _ = find_best_decay_local_cnn(x_train_cur, y_train_cur, model_t)
  model, test_score = find_data_variance_hyperparam(x_train_cur, y_train_cur, cnn_mod.conv)
  acq_lambda = lambda x: compute_var(100, model, x)
  acq_scores = call_batchwise(acq_lambda, Xs, batch_size=64, device=device)
  x_new = acq_scores.topk(acquired_points).indices.cpu().numpy()
  return test_score, x_new

In [37]:
from tqdm import tqdm
def train_full_local_pv(Xs, ys, x_init_train, y_init_train, model_t=CNN):
    scores = []
    x_train_cur = x_init_train.detach().clone()
    y_train_cur = y_init_train.detach().clone()
    for i in tqdm(range(acquisition_times)):
        score, x_new = train_once_local_pred_var(x_train_cur, y_train_cur, Xs, model_t)
        x_new_t = torch.tensor(x_new, dtype=torch.long)
        x_train_cur = torch.cat([x_train_cur, Xs[x_new_t.cpu()].to(device)], dim=0)
        y_train_cur = torch.cat([y_train_cur, ys[x_new_t.cpu()].to(device)], dim=0)
        mask = torch.ones(Xs.shape[0], dtype=torch.bool)
        mask[x_new_t.cpu()] = False
        Xs = Xs[mask]
        ys = ys[mask]

        scores.append(score)

    cnn_mod, _ = find_best_decay_local_cnn(x_train_cur, y_train_cur, model_t)
    model, final_score = find_data_variance_hyperparam(x_train_cur, y_train_cur, cnn_mod.conv)
    scores.append(final_score)
    scores = torch.tensor(scores, dtype=torch.float32)

    return scores, model

In [38]:
def train_acquisition_kernel():
  os.makedirs("./model_artifacts", exist_ok=True)
  os.makedirs("./vi_results", exist_ok=True)
  scores = []
  for i in range(3):
    # if os.path.exists(f"./vi_results/{i}np_local.npy"):
    #   score = np.load(f"./vi_results/{i}np_local.npy")
    #   score = torch.Tensor(score)
    # else:
    score, model = train_full_local_pv(X_p, y_p, x_train_new, y_train_new)
    np.save(f"./vi_results/{i}kernel_local.npy", score.detach().cpu().numpy())
    print(score)
    scores.append(score)    
  meaned_scores = torch.mean(torch.stack(scores), dim=0)

  return meaned_scores

In [41]:
res = train_acquisition_kernel()
print(res)
np.save(f"./vi_results/kernel_results.npy", res.numpy())

  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([20, 20])


  0%|          | 0/100 [00:05<?, ?it/s]


NotImplementedError: The operator 'aten::linalg_lu_solve.out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 5811a8d7da873dd699ff6687092c225caffcf1bb. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.