In [1]:
import torch
import sklearn
import numpy as np
import gpytorch
from gpytorch.models import ExactGP
from torch.distributions import Normal, MultivariateNormal
from matplotlib import pyplot as plt
from skgpytorch.models import ExactGPRegressor
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.constraints import GreaterThan
from sklearn.neighbors import NearestNeighbors
import pandas as pd
device = 'cuda'

In [None]:
N = 1024
torch.manual_seed(0)
x_dist = Normal(torch.tensor([0.0]), torch.tensor([5.0]))
X = x_dist.sample((N,))
K = ScaleKernel(RBFKernel())
K.base_kernel.lengthscale = 0.5
K.outputscale = 4.0
K_ = K(X,X) + (1**2)*torch.eye(len(X))
dist = MultivariateNormal(torch.zeros((N,)),K_.evaluate())
Y = dist.sample()
plt.scatter(X,Y)

In [None]:
class SGDGP(ExactGPRegressor):
    def __init__(self, train_x, train_y, mll):
        super().__init__(train_x, train_y, mll)
    
    def sgd_fit(self, batch_size, lr, n_epochs, n_restarts, thetas, random_state):

        # torch.manual_seed(random_state)
        ## creating nn_indices
        neigh = NearestNeighbors(n_neighbors=batch_size, algorithm='kd_tree')
        neigh.fit(self.train_x.cpu())
        _, neigh_idx = neigh.kneighbors(self.train_x.cpu(), batch_size)

        if (len(self.train_x) % batch_size  == 0):
          num_batches = int(len(self.train_x)/batch_size)
        else:
          num_batches = int((len(self.train_x)/batch_size)) + 1

        least_loss = float("inf")
        best_mll_state = None
        self.history["total_loss"] = []
        self.history["gradient"] = []
        self.mll.train()
        for restart in range(n_restarts):
            self.optimizer = torch.optim.SGD(self.mll.parameters(), lr=lr)
            # self.optimizer = torch.optim.Adam(self.mll.parameters(), lr=0.1)
            self.history["total_loss"].append([])
            self.history["gradient"].append([])

            # Resetting the model for restarts
            if restart > 0:  
                for param in self.mll.parameters():
                    torch.nn.init.normal_(param, mean=0.0, std=1.0)
            self.mll.model.initialize(**thetas)
            self.mll.model.covar_module.base_kernel.raw_lengthscale.requires_grad = False
            for epoch in range(1, n_epochs+1):
                loss = 0
                for iteration in range(1, num_batches+1):
           
                    # idx = torch.tensor((iteration - 1)%self.train_x.shape[0])
                    idx =  torch.randint(
                            low=0, high=self.train_x.shape[0], size=(1,)
                        )[0]
                    indices = neigh_idx[idx,]
                    X_batch = self.train_x[indices,]
                    y_batch = self.train_y[indices,]

                    self.optimizer.zero_grad()
                    batch_loss = self.loss_func(X_batch, y_batch)

                    batch_loss.backward()
                    self.history["gradient"][restart].append(2*torch.log(torch.linalg.norm(torch.tensor([self.mll.model.covar_module.raw_outputscale.grad,self.mll.likelihood.noise_covar.raw_noise.grad]), ord=2)).item())
                    loss += batch_loss.item()

                    self.mll.model.covar_module.raw_outputscale.grad*= batch_size/(3*torch.log(torch.tensor(batch_size)))

                    self.optimizer.step()
                    for group in self.optimizer.param_groups:
                        group['lr'] = lr/(iteration+1)
                
                loss = loss / num_batches
                self.history["total_loss"][restart].append(loss)

            # Check if best loss
            if loss < least_loss:
                self.best_restart = restart
                least_loss = loss
                best_mll_state = self.mll.state_dict()

        # Load the best model
        if best_mll_state is not None:
            self.mll.load_state_dict(best_mll_state)


In [None]:
N = 1024
batch_sizes = [128,64,32]
theta = {'likelihood.noise_covar.noise': torch.tensor(3).to(device),
        'covar_module.base_kernel.lengthscale': torch.tensor(0.5).to(device),
        'covar_module.outputscale': torch.tensor(5).to(device),}
fig,ax = plt.subplots(1,3,figsize=(30,6))


torch.manual_seed(0)
x_dist = Normal(torch.tensor([0.0]), torch.tensor([5.0]))
X = x_dist.sample((N,))
K = ScaleKernel(RBFKernel())
K.base_kernel.lengthscale = 0.5
K.outputscale = 4.0
cov = K(X,X) + (1.0)*(torch.eye(len(X)))
dist = MultivariateNormal(torch.zeros((1024)),cov.evaluate())
torch.manual_seed(5)
Y = dist.sample()


for j in range(len(batch_sizes)):
    gradient = []
    for i in range(10):
        lr = 5.0
        batch_size = batch_sizes[j]
    
        
        kernel = ScaleKernel(RBFKernel(ard_num_dims=X.shape[1])).to(device)
        model = SGDGP(X.to(device), Y.to(device), kernel).to(device)
        model.sgd_fit(batch_size,lr,25,1,thetas = theta, random_state=0)
        # print(torch.std_mean(torch.Tensor(model.history['gradient'][0]), unbiased = True, dim=0))
        gradient.append(model.history['gradient'][0])
    std,mean = torch.std_mean(torch.Tensor(gradient), dim=0)
    # print(std, mean)
    
    ax[j].plot(mean, 'b',linewidth=0.5)
    ax[j].fill_between(range(len(mean)),mean-0.5*std, mean + 0.5*std, alpha=0.5)
    ax[j].set_xlabel('Iteration (k)')
    ax[j].set_ylabel('log(del(L(O(k)))')
    ax[j].set_title(f'(m= {batch_sizes[j]})')
    ax[j].set_yticks([-4,-5,-6,-7,-8,-9])
    ax[j].set_ylim([-10,-2])

plt.savefig('figure2_sgd1.png')