In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader , TensorDataset
from torchvision.utils import save_image, make_grid
from torch.optim import Adam
import torch.nn.init as init
import gpytorch

import numpy as np
import math

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import MultipleLocator
import matplotlib.cm as cm

import copy
import seaborn as sns

from scipy.stats import norm
from sklearn.neighbors import KernelDensity, LocalOutlierFactor

import tqdm

In [2]:
num_seeds = 5
seed = 0
# Load fullstate
data_fullstate = np.empty(num_seeds, dtype=object)
data_no_joint_pos = np.empty(num_seeds, dtype=object)
data_no_joint_vel = np.empty(num_seeds, dtype=object)
data_no_action = np.empty(num_seeds, dtype=object)
data_no_imu = np.empty(num_seeds, dtype=object)
data_no_fc = np.empty(num_seeds, dtype=object)
for i in range(num_seeds):
    data_fullstate[i] = np.load(f"data/performance/HEBB-FULL-STATE_seed-{seed}-fullstate-rand-{i}.npz")    
    data_no_joint_pos[i] = np.load(f"data/performance/HEBB-FULL-STATE_seed-{seed}-no_joint_pos-rand-{i}.npz")
    data_no_joint_vel[i] = np.load(f"data/performance/HEBB-FULL-STATE_seed-{seed}-no_joint_vel-rand-{i}.npz")
    data_no_action[i] = np.load(f"data/performance/HEBB-FULL-STATE_seed-{seed}-no_action-rand-{i}.npz")
    data_no_imu[i] = np.load(f"data/performance/HEBB-FULL-STATE_seed-{seed}-no_imu-rand-{i}.npz")
    data_no_fc[i] = np.load(f"data/performance/HEBB-FULL-STATE_seed-{seed}-no_fc-rand-{i}.npz")

In [3]:
cuda = torch.cuda.is_available()
DEVICE = torch.device("cuda" if cuda else "cpu")

batch_size = 300
training_seed = 4

state_index = torch.arange(0, 19) 
state_dim = len(state_index)
all_state_dim = 64
state_dim = 64
action_dim = 19

train_x = torch.empty((0, all_state_dim), dtype=torch.float32 ,device=DEVICE)
train_y = torch.empty((0, action_dim), dtype=torch.float32,device=DEVICE)
test_x = torch.empty((0, all_state_dim), dtype=torch.float32,device=DEVICE)
test_y = torch.empty((0, action_dim), dtype=torch.float32,device=DEVICE)
for i in range(training_seed):
    train_x = torch.cat((train_x, torch.tensor(data_fullstate[i]["state"].reshape(data_fullstate[i]["state"].shape[0], -1), dtype=torch.float32,device=DEVICE)), dim=0)
    train_y = torch.cat((train_y, torch.tensor(data_fullstate[i]["action_lowpass"].reshape(data_fullstate[i]["action_lowpass"].shape[0], -1), dtype=torch.float32,device=DEVICE)), dim=0)
for j in range(training_seed, num_seeds):
    test_x = torch.cat((test_x, torch.tensor(data_fullstate[j]["state"].reshape(data_no_joint_pos[j]["state"].shape[0], -1), dtype=torch.float32,device=DEVICE)), dim=0)
    test_y = torch.cat((test_y, torch.tensor(data_fullstate[j]["action_lowpass"].reshape(data_no_joint_pos[j]["action_lowpass"].shape[0], -1), dtype=torch.float32,device=DEVICE)), dim=0)

train_dataset = TensorDataset(train_x[:,state_index], train_y)
test_dataset = TensorDataset(test_x[:,state_index], test_y)

print("TRAIN : X , Y shape : ",train_x[:,state_index].shape , train_y.shape)
print("TEST : X , Y shape : ",test_x[:,state_index].shape , test_y.shape)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

TRAIN : X , Y shape :  torch.Size([4000, 19]) torch.Size([4000, 19])
TEST : X , Y shape :  torch.Size([1000, 19]) torch.Size([1000, 19])


## GP Model

In [4]:
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [5]:
# initialize likelihood and model
lr = 5e-5
epochs = 125

likelihood = gpytorch.likelihoods.GaussianLikelihood().to(DEVICE)
model = ExactGPModel(train_x, train_y, likelihood).to(DEVICE)
optimizer = Adam(model.parameters(), lr=lr)

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

In [6]:
# ----- Training Loop -----
train_losses = []
test_losses = []

for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    optimizer.zero_grad()
    output = model(train_x)
    print(output)
    loss = -mll(output, train_y.transpose(0, 1).to(DEVICE)).sum()
    loss.backward()
    optimizer.step()

    print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
        i + 1, epoch, loss.item(),
        model.covar_module.base_kernel.lengthscale.item(),
        model.likelihood.noise.item()
    ))
    optimizer.step()

MultivariateNormal(loc: torch.Size([4000]))
Iter 4/0 - Loss: 22.191   lengthscale: 0.693   noise: 0.693
MultivariateNormal(loc: torch.Size([4000]))
Iter 4/1 - Loss: 22.198   lengthscale: 0.693   noise: 0.693
MultivariateNormal(loc: torch.Size([4000]))
Iter 4/2 - Loss: 22.188   lengthscale: 0.693   noise: 0.693
MultivariateNormal(loc: torch.Size([4000]))
Iter 4/3 - Loss: 22.183   lengthscale: 0.693   noise: 0.693
MultivariateNormal(loc: torch.Size([4000]))
Iter 4/4 - Loss: 22.168   lengthscale: 0.693   noise: 0.693
MultivariateNormal(loc: torch.Size([4000]))
Iter 4/5 - Loss: 22.175   lengthscale: 0.693   noise: 0.693
MultivariateNormal(loc: torch.Size([4000]))
Iter 4/6 - Loss: 22.178   lengthscale: 0.693   noise: 0.693
MultivariateNormal(loc: torch.Size([4000]))
Iter 4/7 - Loss: 22.173   lengthscale: 0.694   noise: 0.693
MultivariateNormal(loc: torch.Size([4000]))
Iter 4/8 - Loss: 22.190   lengthscale: 0.694   noise: 0.693
MultivariateNormal(loc: torch.Size([4000]))
Iter 4/9 - Loss: 22.

In [9]:
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import gpytorch

# ----- Predict on test -----
model.eval(); likelihood.eval()

with torch.no_grad(), gpytorch.settings.fast_pred_var():
    pred = likelihood(model(test_x.to(DEVICE)))

# Pull mean/var (handle single-output vs batched multi-output)
mean = pred.mean          # single: [T]; multi-batched: [D, T]
var  = pred.variance      # single: [T]; multi-batched: [D, T]

# Normalize shapes to [T, D]
if mean.dim() == 1:
    # single output
    mean_np = mean.detach().cpu().unsqueeze(1).numpy()   # [T,1]
    var_np  = var.detach().cpu().unsqueeze(1).numpy()    # [T,1]
    out_dim = 1
else:
    # batched independent multi-output: [D, T] -> [T, D]
    mean_np = mean.detach().cpu().transpose(0,1).numpy() # [T,D]
    var_np  = var.detach().cpu().transpose(0,1).numpy()  # [T,D]
    out_dim = mean_np.shape[1]

std_np   = np.sqrt(np.clip(var_np, 1e-12, None))
lower_np = mean_np - 1.96 * std_np
upper_np = mean_np + 1.96 * std_np

# x-axis (since inputs are high-D, we plot vs sample index)
T = mean_np.shape[0]
x = np.arange(T)

# ----- Optional: ground truth on test set -----
# If you have Y_test:
# - single-output: Y_test shape [T] or [T,1]
# - multi-output:  Y_test shape [T, D]
HAS_Y_TEST = 'Y_test' in globals()
if HAS_Y_TEST and test_y is not None:
    ytest = test_y.detach().cpu().numpy()
    if ytest.ndim == 1:
        ytest = ytest[:, None]  # [T,1] to match plotting
else:
    ytest = None

# ----- Subplots: 2 columns -----
cols = 2
rows = math.ceil(out_dim / cols)
fig, axes = plt.subplots(rows, cols, figsize=(cols*6, rows*3), sharex=True)
if out_dim == 1:
    axes = np.array([axes])  # make iterable
axes = axes.flatten()

for d in range(out_dim):
    ax = axes[d]
    mu = mean_np[:, d]
    lo = lower_np[:, d]
    up = upper_np[:, d]

    # 95% band + mean
    ax.fill_between(x, lo, up, alpha=0.25, label="95% CI")
    ax.plot(x, mu, label="Predictive mean")

    # overlay test targets if available
    if ytest is not None and d < ytest.shape[1]:
        ax.plot(x, ytest[:, d], ".", markersize=2, label="Test target")

    ax.set_title(f"Output dim {d}")
    ax.grid(True)

# remove empty axes if any
for k in range(out_dim, len(axes)):
    fig.delaxes(axes[k])

fig.tight_layout()
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=3, bbox_to_anchor=(0.5, 1.02))
plt.show()


RuntimeError: Flattening the training labels failed. The most common cause of this error is that the shapes of the prior mean and the training labels are mismatched. The shape of the train targets is torch.Size([4000, 19]), while the reported shape of the mean is torch.Size([4000]).