In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd -q ..

In [None]:
import lcpfn
import torch
import numpy as np
from matplotlib import pyplot as plt

print(torch.__version__)

# Training the LCPFN

In [None]:
%pip install torch==1.11.0

In [None]:
get_batch_func = lcpfn.create_get_batch_func(prior=lcpfn.sample_from_prior)

# They tried
# emsize ∈ [128, 256, 512]
# nlayers ∈ [3, 6, 12]
# nb_data ∈ [100k, 1M, 10M]
# steps_per_epoch is hardcoded to 100
# num_epochs?

result = lcpfn.train_lcpfn(get_batch_func=get_batch_func,
                          seq_len=100,
                         emsize=256,
                         nlayers=3,
                         num_borders=1000,
                         lr=0.001,
                         batch_size=500,
                         epochs=300)

transformer_model = result[2]



In [None]:
model_save_path = "lcpfn/trained_models/"+"reproduction_model.pt"

In [None]:
# Save the model

torch.save(transformer_model, model_save_path)

# Getting the lc data

In [None]:


# Get the data for cutoff 10 inference
prior_10 = lcpfn.sample_from_prior(np.random)
curve_10, _ = prior_10()
x_10 = torch.arange(1, 101).unsqueeze(1)
y_10 = torch.from_numpy(curve_10).float().unsqueeze(1)
cutoff_10 = 10
data_10 = {'x': x_10, 'y': y_10, 'cutoff': cutoff_10}

# Get the data for cutoff 20 inference
prior_20 = lcpfn.sample_from_prior(np.random)
curve_20, _ = prior_20()
x_20 = torch.arange(1, 101).unsqueeze(1)
y_20 = torch.from_numpy(curve_20).float().unsqueeze(1)
cutoff_20 = 20
data_20 = {'x': x_20, 'y': y_20, 'cutoff': cutoff_20}

# Inference with LCPFN

In [None]:
# Load trained model
lcpfn_model = lcpfn.LCPFN(model_save_path)

# Predictions for cutoff = 10
x = data_10['x']
y = data_10['y']
cutoff = data_10['cutoff']
predictions_10 = lcpfn_model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:], qs=[0.05, 0.5, 0.95])

# Predictions for cutoff = 20
x = data_20['x']
y = data_20['y']
cutoff = data_20['cutoff']
predictions_20 = lcpfn_model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:], qs=[0.05, 0.5, 0.95])





In [None]:
# Plotting
plt.figure(figsize=(10, 5))

# Move x backward one unit
for i in range(len(x_10)):
  x_10[i] -= 1
  x_20[i] -= 1


# Plot target curves
plt.plot(x_10, y_10, "black", label="Target for Cutoff 10")
plt.plot(x_20, y_20, "grey", label="Target for Cutoff 20")

# Plot extrapolations
plt.plot(x_10[cutoff_10:], predictions_10[:, 1], "blue", label="Extrapolation for Cutoff 10")
plt.fill_between(x_10[cutoff_10:].flatten(), predictions_10[:, 0], predictions_10[:, 2], color="blue", alpha=0.2, label="90% CI for Cutoff 10")

plt.plot(x_20[cutoff_20:], predictions_20[:, 1], "green", label="Extrapolation for Cutoff 20")
plt.fill_between(x_20[cutoff_20:].flatten(), predictions_20[:, 0], predictions_20[:, 2], color="green", alpha=0.2, label="90% CI for Cutoff 20")

# Plot cutoff lines
plt.axvline(x=cutoff_10, color='blue', linestyle='--', linewidth=0.5, label='Cutoff at 10')
plt.axvline(x=cutoff_20, color='green', linestyle='--', linewidth=0.5, label='Cutoff at 20')

plt.ylim(0, 1)
plt.legend(loc="lower right")
plt.title("Model Extrapolation with Different Cutoffs")
plt.xlabel("X")
plt.ylabel("Y")

# Inference with MCMC

In [None]:
# MCMC code does not work with the same version of pytorch as the LC-PFN code

%pip install torch --upgrade
%pip install gpytorch
%pip install botorch

In [None]:
from lcpfn.priors.fast_gp_mix import get_model


def get_mcmc_model_variable_chains(x, y, hyperparameters, device, num_samples, warmup_steps, num_chains, obs=True):
    from pyro.infer.mcmc import NUTS, MCMC, HMC
    import pyro
    x = x.to(device)
    y = y.to(device)
    model, likelihood = get_model(x, y, hyperparameters, sample=False)
    model.to(device)


    def pyro_model(x, y):
        sampled_model = model.pyro_sample_from_prior()
        output = sampled_model.likelihood(sampled_model(x))
        if obs:
            return pyro.sample("obs", output, obs=y)

    nuts_kernel = NUTS(pyro_model)
    mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=num_chains)#num_chains=1)
    mcmc_run.run(x, y)
    model.pyro_load_from_samples(mcmc_run.get_samples()) # pyro.infer wie noah?
    model.eval()

    return model, likelihood

In [None]:

# They tried
# nsamples ∈ [100, 250, 500, 1000, 2000, 4000]
# nwalkers ∈ [26, 50, 100]
# burn-in ∈ [0, 50, 100, 500]
# thin ∈ [1, 10, 100]

hyperparameters = {'handmade': True} # Use the default handmade hyperparameters chosen by the authors
device = 'cpu'
num_samples = 100
warmup_steps = 10
num_chains = 100

# For a cutoff of 10
x = data_10['x']
y = data_10['y']
cutoff = data_10['cutoff']
mcmc_model_10, likelihood_10 = get_mcmc_model_variable_chains(x[:cutoff].float(), y.flatten()[:cutoff].float(), hyperparameters, device, num_samples, warmup_steps, num_chains)
with torch.no_grad():
    predictions_10 = likelihood_10(mcmc_model_10(x[cutoff:].float()))
    pred_mean_10 = predictions_10.mean.mean(0).squeeze()
    pred_lower_10, pred_upper_10 = predictions_10.confidence_region()
    pred_lower_10 = pred_lower_10.mean(0).squeeze()
    pred_upper_10 = pred_upper_10.mean(0).squeeze()

# For a cutoff of 20
x = data_20['x']
y = data_20['y']
cutoff = data_20['cutoff']
mcmc_model_20, likelihood_20 = get_mcmc_model_variable_chains(x[:cutoff].float(), y.flatten()[:cutoff].float(), hyperparameters, device, num_samples, warmup_steps, num_chains)
with torch.no_grad():
    predictions_20 = likelihood_20(mcmc_model_10(x[cutoff:].float()))
    pred_mean_20 = predictions_20.mean.mean(0).squeeze()
    pred_lower_20, pred_upper_20 = predictions_20.confidence_region()
    pred_lower_20 = pred_upper_20.mean(0).squeeze()
    pred_upper_20 = pred_upper_20.mean(0).squeeze()


pred_mean_10, pred_mean_20

In [None]:
# Plotting the data and predictions
plt.figure(figsize=(12, 6))

# Plot predictions for cutoff = 10
plt.plot(data_10['x'].flatten(), data_10['y'].flatten(), "gray", label="Data for Cutoff 10")  # Actual data
plt.plot(data_10['x'][data_10['cutoff']:].flatten(), pred_mean_10, "blue", label="Extrapolation for Cutoff 10")
plt.fill_between(data_10['x'][data_10['cutoff']:].flatten(), pred_lower_10, pred_upper_10, color="blue", alpha=0.2, label="90% CI for Cutoff 10")

# Plot predictions for cutoff = 20
plt.plot(data_20['x'].flatten(), data_20['y'].flatten(), "lightgray", label="Data for Cutoff 20")  # Actual data
plt.plot(data_20['x'][data_20['cutoff']:].flatten(), pred_mean_20, "green", label="Extrapolation for Cutoff 20")
plt.fill_between(data_20['x'][data_20['cutoff']:].flatten(), pred_lower_20, pred_upper_20, color="green", alpha=0.2, label="90% CI for Cutoff 20")

# Plot cutoff lines
plt.axvline(x=data_10['cutoff'], color='blue', linestyle='--', linewidth=0.5, label='Cutoff at 10')
plt.axvline(x=data_20['cutoff'], color='green', linestyle='--', linewidth=0.5, label='Cutoff at 20')

# Set plot limits, labels, title and legend
plt.ylim(0, 1)
plt.xlabel("X-axis")
plt.ylabel("Predicted Values")
plt.title("MCMC Model Extrapolation with Different Cutoffs")
plt.legend(loc="upper left")