In [45]:
from plotly import graph_objects as go
import torch 
from mdgp.utils import sphere_meshgrid


def base_trace(num_meshgrid=50): 
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    grey_sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, colorscale='Greys', showscale=False)
    return grey_sphere_trace


def observations_trace(x, y=None):
    x_scatter, y_scatter, z_scatter = x.mul(1.02).unbind(-1)
    if y is None:
        color_scatter = torch.zeros(x.shape[0])
    else: 
        color_scatter = y.squeeze(-1)
    scatter_trace = go.Scatter3d(x=x_scatter, y=y_scatter, z=z_scatter, mode='markers', marker=dict(size=5, color=color_scatter, colorscale='Viridis'))
    return scatter_trace


def prediction_trace(prediction_function, num_meshgrid=50):
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    y = prediction_function(meshgrid.view(-1, meshgrid.shape[-1])).view_as(x_meshgrid)
    sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, surfacecolor=y, colorscale='Viridis', showscale=True)
    return sphere_trace


def target_trace(target_function, num_meshgrid=50):
    # 1. Compute y by evaluating target function on meshgrid
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    y = torch.zeros_like(x_meshgrid)
    for i in range(num_meshgrid): 
        for j in range(num_meshgrid): 
            y[i, j] = target_function(meshgrid[i, j])
    color_meshgrid = y.squeeze(-1)

    # 2. Plot a sphere colored by y
    sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, surfacecolor=color_meshgrid, colorscale='Viridis', showscale=True)
    return sphere_trace


def plot_traces(*traces): 
    fig = go.Figure()
    fig.add_traces([trace for trace in traces if trace is not None])
    return fig


def plot_observations(x, y=None, num_meshgrid=50): 
    # 1. Plot a grey sphere with meshgrid
    grey_sphere_trace = base_trace(num_meshgrid=num_meshgrid)

    # 2. Plot a scatter plot of observations slightly above the sphere
    scatter_trace = observations_trace(x, y=None)

    # 3. Add all traces to figure
    fig = plot_traces(grey_sphere_trace, scatter_trace)
    return fig


def plot_prediction(prediction_function, x=None, y=None, num_meshgrid=50):
    # 1. Plot a sphere colored by prediction function 
    sphere_trace = prediction_trace(prediction_function, num_meshgrid=num_meshgrid)

    # 2. (Optional) Plot observations slightly above the sphere
    if x is not None and y is not None: 
        scatter_trace = observations_trace(x, y)
    else: 
        scatter_trace = None

    # 3. Add all traces to figure
    fig = plot_traces(sphere_trace, scatter_trace)
    return fig

def plot_target(target_function, num_meshgrid=50): 
    # 1. Plot a sphere colored by target function
    sphere_trace = target_trace(target_function, num_meshgrid=num_meshgrid)

    # 2. Add all traces to figure
    fig = plot_traces(sphere_trace)
    return fig 

In [46]:
import geometric_kernels.torch
import os 
from argparse import ArgumentParser
from torch import set_default_dtype, float64
from torch.optim import Adam  
from gpytorch.mlls import DeepApproximateMLL, VariationalELBO, ExactMarginalLogLikelihood
from mdgp.experiment_utils.data import get_data 
from mdgp.experiment_utils.model import create_model
from mdgp.experiment_utils.logging import CSVLogger, finalize 
from mdgp.experiment_utils.training import fit, test_step
from mdgp.experiment_utils import ExperimentConfigReader, set_experiment_seed, ExperimentConfig


def run_experiment(experiment_config: ExperimentConfig, dir_path: str, model=None):
    print(f"Running experiment with the config: {os.path.join(dir_path, experiment_config.file_name)}")
    # 0. Unpack arguments
    model_args, data_args, training_args = experiment_config.model_arguments, experiment_config.data_arguments, experiment_config.training_arguments

    # 1. Get data 
    train_inputs, train_targets, val_inputs, val_targets, test_inputs, test_targets = get_data(data_args=data_args)

    # 2. Create model, criterion, and optimizer 
    if model is None:
        model = create_model(model_args=model_args, train_x=train_inputs, train_y=train_targets)
    if model_args.model_name != 'exact':
        mll = DeepApproximateMLL(
            VariationalELBO(likelihood=model.likelihood, model=model, num_data=data_args.num_train)
        )
        optimizer = Adam(model.parameters(), maximize=True, lr=0.01) # Maximize because we are working with ELBO not negative ELBO 
    else: 
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
        optimizer = Adam(model.parameters(), maximize=True, lr=0.01)

    
    # 4. Train and validate model
    print("Training...")
    train_csv_logger = CSVLogger(root_dir=os.path.join(dir_path, 'train')) 
    val_csv_logger = CSVLogger(root_dir=os.path.join(dir_path, 'val')) 
    train_loggers = [train_csv_logger]
    val_loggers = [val_csv_logger]
    model = fit(model=model, optimizer=optimizer, criterion=mll, train_loggers=train_loggers,
                val_loggers=val_loggers, train_inputs=train_inputs, train_targets=train_targets,
                val_inputs=val_inputs, val_targets=val_targets, training_args=training_args)
    
    with torch.no_grad():
        def get_mean(x):
            if model_args.model_name == 'exact':
                model.eval()
                return model(x).mean
            else:
                model.eval()
                return model(x, mean=True)
        plot_prediction(get_mean, num_meshgrid=100).show()

    # make sure logger files are saved
    finalize(loggers=[*val_loggers, *train_loggers])

    # # 5. Test model 
    print("Testing...")
    test_csv_logger = CSVLogger(root_dir=os.path.join(dir_path, 'test'))
    test_loggers = [test_csv_logger]
    test_metrics = test_step(model=model, inputs=test_inputs, targets=test_targets, sample_hidden=training_args.sample_hidden, 
                             loggers=test_loggers, train_targets=train_targets)
    # make sure logger files are saved
    finalize(loggers=test_loggers)
    print(test_metrics)
    print("Done!")

In [47]:
experiment_config = ExperimentConfig()
experiment_config.model_arguments.model_name = 'exact'
# experiment_config.model_arguments.project_to_tangent = 'intrinsic'
# experiment_config.model_arguments.parametrised_frame = [] 
# experiment_config.model_arguments.outputscale_mean = 0.0
# experiment_config.model_arguments.outputscale_std = 0.01
# experiment_config.model_arguments.prior_class = 'normal'
experiment_config.data_arguments.target_name = 'singular'
experiment_config.data_arguments.num_test = 2000
experiment_config.data_arguments.num_train = 800
# experiment_config.model_arguments.optimize_nu = False 
# experiment_config.model_arguments.parametrised_frame = True
dir_path = './'

In [48]:
model_args = experiment_config.model_arguments


train_inputs, train_targets, val_inputs, val_targets, test_inputs, test_targets = get_data(data_args=experiment_config.data_arguments)
print(model_args.parametrised_frame)

model = create_model(model_args=model_args, train_x=train_inputs, train_y=train_targets)

False
In create_model got model_args.parametrised_frame=False


In [44]:
# model.hidden_layers[0].project_to_tangent.frame.get_normal_vector.sequential[0].bias = torch.nn.Parameter(torch.tensor([0.0, 0.0, 0.0]))
# model.hidden_layers[0].project_to_tangent.frame.get_normal_vector.sequential[0].bias.requires_grad = False
model.hidden_layers[0].project_to_tangent.frame.get_normal_vector.sequential[0].bias

Parameter containing:
tensor([0., 0., 0.])

In [49]:





set_experiment_seed(experiment_config.seed)
run_experiment(experiment_config=experiment_config, dir_path=dir_path, model=model)


INFO: Global seed set to 0
INFO: Global seed set to 0


Running experiment with the config: ./config.json
Training...



Experiment logs directory ./train exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!


Experiment logs directory ./val exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!

Fitting: 100%|██████████| 1000/1000 [01:41<00:00,  9.89it/s, elbo=1.61, nlpd=-2.07, smse=0.017]  


: 