## Install required dependencies

In [None]:
!pip install matplotlib
!pip install numpy
!pip install pandas
!pip install scipy
!pip install scikit-learn
!pip install torch
!pip install tqdm
!pip install git+https://github.com/pratikrathore8/fast_krr.git

## Import libraries

In [None]:
import matplotlib.pyplot as plt
import torch
from tqdm import trange

from fast_krr.models import FullKRR
from fast_krr.opts import ASkotchV2

from utils import load_data

## Run regression experiment

### Load the uracil dataset

In [None]:
dataset = "comet_mc"
data_config = {
    "tr": "comet_mc_data.pkl",
    "tgt": "comet_mc_target.pkl",
    "loading": "pkl",
    "split": 0.8,
    "label_map": {0: -1, 1: 1},
    "task": "classification",
}
remove_label_means=False
seed = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
X, Xtst, y, ytst = load_data(dataset=dataset, data_config=data_config, 
                             remove_label_means=remove_label_means, seed=seed, device=device)

### Initialize the model

In [None]:
n = X.shape[0]
task = "classification"
kernel_params = {"type": "rbf", "sigma": 2.1122933418117933}
w0 = torch.zeros(n, device=device)
lambd = 1e-6 * n

In [None]:
model = FullKRR(X, y, Xtst, ytst, kernel_params=kernel_params,
                 Ktr_needed=True, lambd=lambd, task=task, w0=w0, device=device)

### Initialize the `ASkotchV2` optimizer

In [None]:
block_sz = n // 100
precond_params = {"type": "nystrom", "r": 100, "rho": "damped"}

In [None]:
opt = ASkotchV2(model=model, block_sz=block_sz, precond_params=precond_params)

### Train the FullKRR model using `ASkotchV2`

In [None]:
max_iters = 5000
log_freq = 20

In [None]:
metrics = []
metrics.append((0, model.compute_metrics(v=opt.model.w, log_test_only=False)))

for i in trange(1, max_iters + 1, desc="Optimization progress"):
    opt.step()

    if i % log_freq == 0:
        metrics.append((i, model.compute_metrics(v=opt.model.w, log_test_only=False)))

## Plot the results

In [None]:
rel_residuals = [m["rel_residual"].cpu() for _, m in metrics]
test_acc = [m["test_acc"] for _, m in metrics]
x_vals = [i for i, _ in metrics]

plt.figure()
plt.semilogy(x_vals, rel_residuals)
plt.ylabel("Relative residual")
plt.xlabel("Iterations")
plt.show()

plt.figure()
plt.semilogy(x_vals, test_acc)
plt.ylabel("Test accuracy")
plt.xlabel("Iterations")
plt.show()

Since we ran `ASkotchV2` in single precision, the relative residual plateaus at around 1e-4. Running in double precision would allow us to reach a lower relative residual.