In [None]:
import torch
import gpytorch
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib

# Set up the style
scatter_size = 2.
line_size = .75

palette = sns.color_palette("muted")
sns.set_palette(palette)

matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['axes.titlesize'] = 9
matplotlib.rcParams['font.size'] = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def step_function(x):
    return (x >= 0).float()

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# Fixed hyperparameters
outputscale = 1.0
lengthscale = 0.4
noise_level = 0.1  # Explicit noise level

test_x = torch.linspace(-3, 3, 300).to(device)
num_context_points_list = [10, 20, 50, 100, 200,400]

width = 7
figsize = (width,2*width/4.5)
fig, axs = plt.subplots(2, 3, figsize=figsize, dpi=400)
axs = axs.ravel()

for i, num_points in enumerate(num_context_points_list):
    train_x = torch.linspace(-2, 2, num_points).to(device)
    train_y = step_function(train_x)# + torch.randn_like(train_x) * noise_level  # Add noise to targets

    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    likelihood.noise = noise_level ** 2  # Set the noise level explicitly
    model = ExactGPModel(train_x, train_y, likelihood).to(device)

    model.covar_module.outputscale = outputscale
    model.covar_module.base_kernel.lengthscale = lengthscale

    model.eval()
    likelihood.eval()

    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        # Get both the latent function distribution and the noisy prediction
        f_pred = model(test_x)
        y_pred = likelihood(f_pred)

    # Move data back to CPU for plotting
    test_x_cpu = test_x.cpu()
    train_x_cpu = train_x.cpu()
    train_y_cpu = train_y.cpu()
    f_mean_cpu = f_pred.mean.cpu()
    f_lower, f_upper = f_pred.confidence_region()
    y_mean_cpu = y_pred.mean.cpu()
    y_lower, y_upper = y_pred.confidence_region()

    # Plot the results
    axs[i].plot(test_x_cpu.numpy(), step_function(test_x_cpu).numpy(), color=palette[0], linewidth=line_size, label='True Function')
    axs[i].scatter(train_x_cpu.numpy(), train_y_cpu.numpy(), s=scatter_size, c='black', zorder=2, label='In-Context Examples')
    axs[i].plot(test_x_cpu.numpy(), f_mean_cpu.numpy(), color=palette[1], linewidth=line_size, label='Posterior Ground Truth')
    axs[i].fill_between(test_x_cpu.numpy(), f_lower.cpu().numpy(), f_upper.cpu().numpy(), alpha=0.2, color=palette[1], label='Posterior Ground Truth 95% CI')
    
    axs[i].set_ylim(-0.5, 1.5)
    axs[i].set_xlim(-2, 2)
    axs[i].set_title(f'{num_points} Context Examples')
    axs[i].set_xlabel('x')
    axs[i].set_ylabel('y')

    # Remove top and right spines
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)

plt.tight_layout()

# Add a common legend at the bottom
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=4, bbox_to_anchor=(0.5, -0.05), fontsize=8)
plt.show()