In [2]:
import torch
import eigenpro.kernels as kernels
import eigenpro.models.sharded_kernel_machine as skm
import eigenpro.solver as solver
import eigenpro.utils.device as dev

In [3]:
n, p, d, c = 500, 100, 5, 2

Z = torch.randn(p, d)
X_train, X_test = torch.randn(n//2, d), torch.randn(n//2, d)
W_star = torch.randn(d, c)
Y_train, Y_test = X_train @ W_star, X_test @ W_star

dtype = torch.float32
kernel_fn = lambda x, z: kernels.laplacian(x, z, bandwidth=20.)
device = dev.Device.create(use_gpu_if_available=True)
model = skm.create_sharded_kernel_machine(
    Z, c, kernel_fn, device, dtype=dtype, tmp_centers_coeff=2)

sd, sm, qd, qm = 10, 10, 3, 3 # configuration for EigenPro preconditioners
model = solver.fit(model, X_train, Y_train, X_test, Y_test, device,
                   dtype=dtype, kernel=kernel_fn, s_data=sd, s_model=sm,
                   q_data=qd, q_model=qm, epochs=2, accumulated_gradients=True)

[31mnotice: the current implementation can only support 1 GPU, we only use the following device: (cuda:0) [0m
╒══════════════════════════════════╤═════════╕
│ Configuration                    │   Value │
╞══════════════════════════════════╪═════════╡
│ [32msize of model[0m                    │ 100     │
├──────────────────────────────────┼─────────┤
│ [32mambient dimension[0m                │ 5       │
├──────────────────────────────────┼─────────┤
│ [32moutput dimension[0m                 │ 2       │
├──────────────────────────────────┼─────────┤
│ [32msize of data preconditioner[0m      │ 10      │
├──────────────────────────────────┼─────────┤
│ [32mlevel of data preconditioner[0m     │ 3       │
├──────────────────────────────────┼─────────┤
│ [32msize of model preconditioner[0m     │ 10      │
├──────────────────────────────────┼─────────┤
│ [32mlevel of model preconditioner[0m    │ 3       │
├──────────────────────────────────┼─────────┤
│ [32msize of training da

Epoch 1/2:   0%|                                                                                                                    | 0/4 [00:00<?, ?it/s]
Projection:   0%|                                                                                                                   | 0/3 [00:00<?, ?it/s][A
                                                                                                                                                          [A
Projection:   0%|                                                                                                                   | 0/3 [00:00<?, ?it/s][A
                                                                                                                                                          [A
Projection:   0%|                                                                                                                   | 0/3 [00:00<?, ?it/s][A
                                                       

╒═══════════════════╤══════════════╕
│ Epoch 1 Summary   │ Value        │
╞═══════════════════╪══════════════╡
│ Test Loss         │ 0.3331757486 │
├───────────────────┼──────────────┤
│ Test Accuracy     │ 95.20%       │
╘═══════════════════╧══════════════╛


Epoch 2/2:   0%|                                                                                                                    | 0/4 [00:00<?, ?it/s]
Projection:   0%|                                                                                                                   | 0/3 [00:00<?, ?it/s][A
                                                                                                                                                          [A
Projection:   0%|                                                                                                                   | 0/3 [00:00<?, ?it/s][A
                                                                                                                                                          [A
Projection:   0%|                                                                                                                   | 0/3 [00:00<?, ?it/s][A
                                                       

╒═══════════════════╤══════════════╕
│ Epoch 2 Summary   │ Value        │
╞═══════════════════╪══════════════╡
│ Test Loss         │ 0.2300972044 │
├───────────────────┼──────────────┤
│ Test Accuracy     │ 98.80%       │
╘═══════════════════╧══════════════╛
