# GPytorch training benchmark

Used to know how long it takes to train a GPytorch model on every ID of a benchmark dataset.

---

## Setup

In [1]:
import time

import pandas as pd



In [2]:
dataset = "large"
shared_input = False
shared_hp = True

---

## Start timer

In [3]:
start = time.time()

---

## Data import

In [4]:
db = pd.read_csv(f"../dummy_datasets/{dataset}_{'shared_input' if shared_input else 'distinct_input'}_{'shared_hp' if shared_hp else 'distinct_hp'}.csv")
# db has 3 columns: ID, Input, Output

In [5]:
# First 90% of IDs are for training, last 10% for testing
train_ids = db["ID"].unique()[:int(0.9 * db["ID"].nunique())]
test_ids = db["ID"].unique()[int(0.9 * db["ID"].nunique()):]

db_train = db[db["ID"].isin(train_ids)]
db_test = db[db["ID"].isin(test_ids)]

# N.b: data is already sort by ID and Input in the toy datasets, but in a real case scenario, we would need to sort it

In [6]:
len(train_ids), len(test_ids)

(540, 60)

---

## Training

We train a GPytorch model on every ID of the training set.

In [7]:
import torch
import gpytorch
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.distributions import MultivariateNormal

In [8]:
class ExactGPModel(gpytorch.models.ExactGP):
	def __init__(self, train_x, train_y, likelihood):
		super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
		self.mean_module = ConstantMean()
		self.covar_module = ScaleKernel(RBFKernel())

	def forward(self, x):
		mean_x = self.mean_module(x)
		covar_x = self.covar_module(x)
		return MultivariateNormal(mean_x, covar_x)


def train_gp_for_id(id_data, tolerance=0.5, patience=5, max_iter=1000):
	"""
	Train a GP for a single ID's data with early stopping
	
	Parameters:
	- tolerance: minimum relative improvement in loss to continue training
	- patience: number of iterations to wait for improvement before stopping
	- max_iter: maximum number of iterations regardless of convergence
	"""
	# Convert to tensors
	train_x = torch.tensor(id_data['Input'].values, dtype=torch.float32).unsqueeze(-1)
	train_y = torch.tensor(id_data['Output'].values, dtype=torch.float32)
	
	# Initialize likelihood and model
	likelihood = gpytorch.likelihoods.GaussianLikelihood()
	model = ExactGPModel(train_x, train_y, likelihood)
	
	# Set to training mode
	model.train()
	likelihood.train()
	
	# Use Adam optimizer
	optimizer = torch.optim.Adam(model.parameters(), lr=0.25)
	
	# Loss function - the marginal log likelihood
	mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
	
	# Early stopping variables
	best_loss = float('inf')
	patience_counter = 0
	iteration = 0
	
	while iteration < max_iter:
		optimizer.zero_grad()
		output = model(train_x)
		loss = -mll(output, train_y)
		loss.backward()
		optimizer.step()
		
		current_loss = loss.item()
		
		# Check for improvement
		if current_loss < best_loss - tolerance:
			best_loss = current_loss
			patience_counter = 0
		else:
			patience_counter += 1
		
		# Early stopping condition
		if patience_counter >= patience:
			break
			
		iteration += 1
	
	return model, likelihood, iteration + 1

In [9]:
# Train a GP for each ID in the training set
models = {}
training_times = []
iterations_used = []

print(f"Training GPs for {len(train_ids)} IDs with early stopping...")
print(f"Early stopping: tolerance=1e-4, patience=10, max_iter=500")

for i, train_id in enumerate(train_ids):
	id_start_time = time.time()
	
	# Get data for this ID
	id_data = db_train[db_train["ID"] == train_id]
	
	# Train GP for this ID
	model, likelihood, num_iterations = train_gp_for_id(id_data)
	models[train_id] = (model, likelihood)
	
	id_time = time.time() - id_start_time
	training_times.append(id_time)
	iterations_used.append(num_iterations)
	
	# Print progress every 20 IDs or for the last ID
	if (i + 1) % 20 == 0 or i == len(train_ids) - 1:
		avg_time = sum(training_times) / len(training_times)
		avg_iterations = sum(iterations_used) / len(iterations_used)
		print(f"Completed {i + 1}/{len(train_ids)} IDs. Avg time: {avg_time:.3f}s, Avg iterations: {avg_iterations:.1f}")

total_training_time = sum(training_times)
avg_iterations = sum(iterations_used) / len(iterations_used)
min_iterations = min(iterations_used)
max_iterations = max(iterations_used)

print(f"\nTotal training time: {total_training_time:.2f}s")
print(f"Average time per GP: {total_training_time/len(train_ids):.3f}s")
print(f"Iterations - Avg: {avg_iterations:.1f}, Min: {min_iterations}, Max: {max_iterations}")

Training GPs for 540 IDs with early stopping...
Early stopping: tolerance=1e-4, patience=10, max_iter=500
Completed 20/540 IDs. Avg time: 7.817s, Avg iterations: 1001.0
Completed 40/540 IDs. Avg time: 7.908s, Avg iterations: 1001.0
Completed 60/540 IDs. Avg time: 8.088s, Avg iterations: 1001.0
Completed 80/540 IDs. Avg time: 8.144s, Avg iterations: 1001.0
Completed 100/540 IDs. Avg time: 7.980s, Avg iterations: 1001.0
Completed 120/540 IDs. Avg time: 7.560s, Avg iterations: 1001.0
Completed 140/540 IDs. Avg time: 7.263s, Avg iterations: 1001.0
Completed 160/540 IDs. Avg time: 7.040s, Avg iterations: 1001.0
Completed 180/540 IDs. Avg time: 6.872s, Avg iterations: 1001.0
Completed 200/540 IDs. Avg time: 6.733s, Avg iterations: 1001.0
Completed 220/540 IDs. Avg time: 6.618s, Avg iterations: 1001.0
Completed 240/540 IDs. Avg time: 6.532s, Avg iterations: 1001.0
Completed 260/540 IDs. Avg time: 6.446s, Avg iterations: 1001.0
Completed 280/540 IDs. Avg time: 6.365s, Avg iterations: 1001.0
Co

In [10]:
# Final timing results
end = time.time()
total_elapsed = end - start

print(f"\n" + "="*50)
print(f"BENCHMARK RESULTS - Early Stopping")
print(f"="*50)
print(f"Dataset: {dataset}")
print(f"Number of IDs trained: {len(train_ids)}")
print(f"Total elapsed time: {total_elapsed:.2f}s")
print(f"Training time: {total_training_time:.2f}s")
print(f"Setup/overhead time: {total_elapsed - total_training_time:.2f}s")
print(f"Average time per GP: {total_training_time/len(train_ids):.3f}s")
print(f"\nConvergence Statistics:")
print(f"Average iterations: {avg_iterations:.1f}")
print(f"Min iterations: {min_iterations}")
print(f"Max iterations: {max_iterations}")
print(f"Early stopping parameters: tolerance=1e-4, patience=10")
print(f"="*50)


BENCHMARK RESULTS - Early Stopping
Dataset: large
Number of IDs trained: 540
Total elapsed time: 3303.29s
Training time: 3302.13s
Setup/overhead time: 1.17s
Average time per GP: 6.115s

Convergence Statistics:
Average iterations: 1001.0
Min iterations: 1001
Max iterations: 1001
Early stopping parameters: tolerance=1e-4, patience=10
