In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Set seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

#############################################
# Data Generation
#############################################

# Parameters
N = 100    # number of subjects
T = 20     # number of time points per subject
p = 1      # dimension of random effects (e.g., random intercept only)
sigma = 0.1   # residual std dev
tau = 0.5     # std dev of random effects

# Generate the main predictor variable
# We'll have a single primary predictor x, and from it we'll create 4 transformations:
# X1 = x, X2 = cos(x), X3 = sin^2(x), X4 = x^2
x_vals = np.linspace(0, 10, T)
X_main = np.tile(x_vals, (N, 1))  # shape (N, T)

# Construct the full design matrix for fixed effects as input to the NN:
# We'll have W as random covariates (here we can just simulate some additional random covariates)
W = np.random.randn(N, T, 2)  # Two additional random covariates, not necessarily needed by the NN
# The NN input will be the 4 transformed versions of x, plus these 2 random covariates
# total input dimension = 6
X_nn_full = []
for i in range(N):
    for t in range(T):
        x = X_main[i, t]
        row = [x, np.cos(x), np.sin(x)**2, x**2, W[i,t,0], W[i,t,1]]
        X_nn_full.append(row)
X_nn_full = np.array(X_nn_full)  # shape (N*T, 6)

# Observed covariates Z for random effects:
# Let's assume a random intercept model: Z_it = 1 for all i, t
Z = np.ones((N, T, p))

# True underlying function for the fixed effects (the part the NN tries to capture)
# We'll define a "true" nonlinear function f_true:
def f_true(x):
    # match the first four transformations we considered
    return x[0] + np.cos(x[0]) + (np.sin(x[0])**2) + 0.5*(x[0]**2)

# Generate random effects beta_i ~ N(0, tau^2)
beta = np.random.normal(0, tau, size=(N, p))

# Generate Y
Y = np.zeros((N, T))
for i in range(N):
    for t in range(T):
        # X vector for the true function (just the first 4 features)
        x_vec = [X_nn_full[i*T + t,0], X_nn_full[i*T + t,1], X_nn_full[i*T + t,2], X_nn_full[i*T + t,3]]
        mu = f_true(x_vec) + Z[i,t,:].dot(beta[i])  # random intercept added
        Y[i,t] = mu + np.random.normal(0, sigma)

# Convert data to tensors
X_torch = torch.tensor(X_nn_full, dtype=torch.float32)   # shape (N*T, 6)
Y_torch = torch.tensor(Y.reshape(-1,1), dtype=torch.float32)  # shape (N*T, 1)
subject_indices = np.repeat(np.arange(N), T)
subject_indices_torch = torch.tensor(subject_indices, dtype=torch.long)

#############################################
# Baseline Model for Comparison (Linear Model)
#############################################

# Let's compare with a simple linear model with no random effects:
# Y ~ a + b1*X + b2*cos(X) + b3*sin^2(X) + b4*X^2 (ignoring W)
class LinearBaseline(nn.Module):
    def __init__(self):
        super().__init__()
        # We'll only consider the first 4 features (x, cos(x), sin^2(x), x^2)
        self.linear = nn.Linear(4, 1, bias=True)
    def forward(self, x):
        return self.linear(x[:, :4])

linear_model = LinearBaseline()
optimizer_linear = optim.Adam(linear_model.parameters(), lr=0.01)

# Train the baseline linear model just by MSE
for epoch in range(200):
    optimizer_linear.zero_grad()
    pred_lin = linear_model(X_torch)
    loss_lin = ((Y_torch - pred_lin)**2).mean()
    loss_lin.backward()
    optimizer_linear.step()

lin_pred = linear_model(X_torch).detach().numpy()
lin_mse = np.mean((Y_torch.detach().numpy() - lin_pred)**2)
print("Baseline Linear Model MSE:", lin_mse)

#############################################
# Neural Network Model with Random Effects
#############################################

# Define the NN for f_NN(X; Theta)
class NeuralNet(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=20):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 1)
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

nn_model = NeuralNet(input_dim=6, hidden_dim=20)

# Random effects gamma_param: We'll estimate them as well
# In practice, we'd do EM or Laplace approx. Here, we treat them as parameters to be optimized.
gamma_param = nn.Parameter(torch.zeros(N, p))

# Parameters for sigma^2 and tau^2 (variance components)
# We'll just fix these for simplicity, or we can try to learn them.
# Let's assume sigma and tau are known for demonstration. 
# For a full solution, we would also optimize them.
sigma_est = sigma  # fixed
tau_est = tau       # fixed
Sigma_beta = tau_est**2 * torch.eye(p)

# Optimizer for NN and random effects
optimizer_nn = optim.Adam(list(nn_model.parameters()) + [gamma_param], lr=0.01)

#############################################
# Negative Log-Likelihood Function
#############################################
# Given the model:
# Y = f_NN(X;Theta) + Z*beta_i + eps
# beta_i ~ N(0, tau^2), eps ~ N(0, sigma^2)
#
# Conditional on gamma_param (which we use to represent beta_i), the negative log-likelihood:
# NLL = 0.5*sum((Y - f_NN(X) - Z*gamma)^2)/sigma^2 + 0.5*N*log(2*pi*sigma^2)
#    + 0.5*sum(gamma_i^2)/tau^2 + 0.5*p*N*log(2*pi*tau^2)
#
# In a true EM, we'd integrate out gamma, but here we approximate by direct optimization.
def nll(Y, X, gamma_param, nn_model, sigma, tau):
    f = nn_model(X)  # f_NN(X;Theta)
    pred = f + gamma_param[subject_indices_torch]  # Z=1 so just add gamma
    
    # Data log-likelihood
    resid = Y - pred
    n = Y.shape[0]
    ll_data = -0.5*n*torch.log(torch.tensor(2*np.pi*(sigma**2))) - 0.5*(resid**2).sum()/(sigma**2)
    
    # Prior on gamma (random effects)
    # gamma ~ N(0, tau^2)
    ll_prior = -0.5*N*p*torch.log(torch.tensor(2*np.pi*(tau**2))) - 0.5*(gamma_param**2).sum()/(tau**2)
    
    # Negative log-likelihood is - (ll_data + ll_prior)
    return -(ll_data + ll_prior)

#############################################
# Training the NN model with random effects
#############################################
for epoch in range(1000):
    optimizer_nn.zero_grad()
    loss = nll(Y_torch, X_torch, gamma_param, nn_model, sigma_est, tau_est)
    loss.backward()
    optimizer_nn.step()
    
    if (epoch+1) % 200 == 0:
        print(f"Epoch {epoch+1}, NLL: {loss.item():.4f}")

# After training:
nn_pred = (nn_model(X_torch) + gamma_param[subject_indices_torch]).detach().numpy()
nn_mse = np.mean((Y_torch.detach().numpy() - nn_pred)**2)
print("NN + Random Effects MSE:", nn_mse)

#############################################
# Results and Comparison
#############################################
print("Comparison of Models:")
print(f" - Baseline Linear Model MSE: {lin_mse:.4f}")
print(f" - NN + Random Effects MSE: {nn_mse:.4f}")

# Ideally, NN + random effects should outperform (lower MSE) than the simple linear baseline
# if the nonlinear patterns are strong enough and training converged well.


Baseline Linear Model MSE: 0.34422797
Epoch 200, NLL: 11899.8330
Epoch 400, NLL: 2307.1782
Epoch 600, NLL: -849.7971
Epoch 800, NLL: -1542.2004
Epoch 1000, NLL: -1681.6061
NN + Random Effects MSE: 0.010102031
Comparison of Models:
 - Baseline Linear Model MSE: 0.3442
 - NN + Random Effects MSE: 0.0101


In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Set random seeds
np.random.seed(123)
torch.manual_seed(123)

#############################################
# Data Generation
#############################################

N = 50    # subjects
T = 15    # time points per subject
p = 1     # dimension of random effects (just random intercept)
d_y = 2   # dimension of Y outcome (bivariate)
sigma = 0.1  # residual std (for simplicity, assume no covariance, just identity * sigma^2)
Sigma_epsilon = sigma**2 * torch.eye(d_y)

tau = 0.5  # std dev of random effects for each dimension of random effect

# Generate X (single predictor)
x_vals = np.linspace(0, 10, T)
X_main = np.tile(x_vals, (N, 1))  # shape (N, T)

# We'll create a neural network input with a few transformations:
# For demonstration, let's do something simple:
# Input dimension = 4: [x, cos(x), sin(x)^2, x^2]
X_nn_full = []
for i in range(N):
    for t in range(T):
        x = X_main[i,t]
        row = [x, np.cos(x), np.sin(x)**2, x**2]
        X_nn_full.append(row)
X_nn_full = np.array(X_nn_full) # (N*T, 4)

# Random effects: gamma_i ~ N(0, tau^2)
gamma = np.random.normal(0, tau, size=(N, p))

# True function for f_NN (simulate something similar to previous)
def f_true(x):
    return np.array([x[0] + np.cos(x[0]) + np.sin(x[0])**2, 
                     x[0]**2 * 0.5])  # A 2D output

# Generate Y
Y = np.zeros((N, T, d_y))
for i in range(N):
    for t in range(T):
        x_vec = X_nn_full[i*T + t, :]
        # Only the first element of x_vec for the true function's first component, second is just a function of x
        mu = f_true([x_vec[0]]) + gamma[i]*1.0  # Z=1
        eps = np.random.multivariate_normal(mean=[0,0], cov=(sigma**2*np.eye(d_y)))
        Y[i,t,:] = mu + eps

# Convert to tensors
X_torch = torch.tensor(X_nn_full, dtype=torch.float32)   # (N*T, 4)
Y_torch = torch.tensor(Y.reshape(-1, d_y), dtype=torch.float32) # (N*T, 2)
subject_indices = torch.tensor(np.repeat(np.arange(N), T), dtype=torch.long)

#############################################
# Model Definition (NN + random effects)
#############################################
class BivariateNN(nn.Module):
    def __init__(self, input_dim=4, hidden_dim=20, output_dim=2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

nn_model = BivariateNN()

gamma_param = nn.Parameter(torch.zeros(N, p))

optimizer = optim.Adam(list(nn_model.parameters())+[gamma_param], lr=0.01)

def nll(Y, X, gamma_param, model, Sigma_epsilon):
    # Model prediction
    f = model(X) # (N*T, 2)
    pred = f + gamma_param[subject_indices] # gamma is (N,1), broadcast to (N*T,1) then added
    # pred shape: (N*T,2), gamma only adds to one dimension?
    # We defined gamma as scalar random intercept that applies equally to both responses for simplicity.
    # If we want it only in first dimension, we can do pred[:,0] += gamma, pred[:,1] unchanged.
    # Let's assume random intercept affects both outcomes equally:
    # Modify pred to add gamma to both dimensions:
    pred = pred + 0.0  # no-op, but we need to be careful with dimensions
    # Actually, gamma_param is (N,1). Let's make it (N) and broadcast:
    gamma_broad = gamma_param[subject_indices].squeeze(-1)
    gamma_2d = gamma_broad.unsqueeze(1).repeat(1,2) # add same intercept to both outcomes
    pred = pred + gamma_2d

    resid = Y - pred
    # NLL under multivariate normal with Sigma_epsilon known:
    # logdet(2*pi*Sigma) = 2*log(sigma)*2*pi dimension
    # Since Sigma_epsilon = sigma^2 I_2, log|Sigma_epsilon| = 2*log(sigma^2) = 4*log(sigma)
    # We'll just compute directly:
    dist_term = torch.sum(resid**2)/(sigma**2) # since identity covariance structure
    n = Y.shape[0]
    d = d_y
    ll_data = -0.5 * n * d * np.log(2*np.pi*sigma**2) - 0.5*dist_term

    # Prior on gamma: gamma ~ N(0, tau^2)
    # log p(gamma)
    prior_gamma = -0.5*N*p*np.log(2*np.pi*(tau**2)) - 0.5*torch.sum(gamma_param**2)/(tau**2)

    return -(ll_data+prior_gamma)

for epoch in range(1000):
    optimizer.zero_grad()
    loss = nll(Y_torch, X_torch, gamma_param, nn_model, Sigma_epsilon)
    loss.backward()
    optimizer.step()

pred = nn_model(X_torch) + gamma_param[subject_indices].repeat(1,2)
pred_np = pred.detach().numpy()
mse = np.mean((Y_torch.detach().numpy() - pred_np)**2)
print("Frequentist Bivariate Model MSE:", mse)


Frequentist Bivariate Model MSE: 0.11098506


In [10]:
import numpy as np
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.distributions import constraints

pyro.clear_param_store()
torch.manual_seed(0)
np.random.seed(0)

#############################################
# Simulate Data (Univariate)
#############################################
N = 100
T = 10
p = 1  # random effects dimension
sigma = 0.1
tau = 0.5

x_vals = np.linspace(0, 10, T)
X_main = np.tile(x_vals, (N, 1))

X = X_main.reshape(N*T, 1)
X_torch = torch.tensor(X, dtype=torch.float32)

gamma_true = np.random.normal(0, tau, size=(N,))
def f_true(x):
    return x[0] + np.cos(x[0]) + (np.sin(x[0])**2)

Y = np.zeros(N*T)
for i in range(N):
    for t in range(T):
        mu = f_true([X_main[i,t]]) + gamma_true[i]
        Y[i*T+t] = mu + np.random.normal(0, sigma)

Y_torch = torch.tensor(Y, dtype=torch.float32)
subject_indices = torch.tensor(np.repeat(np.arange(N), T), dtype=torch.long)

#############################################
# Bayesian Model Definition
#############################################

# We define a simple Bayesian NN with one hidden layer.
# We'll place priors on all weights and biases.

# Dimensions: input=1, hidden=10, output=1
hidden_dim = 10

def model(X, Y, subject_idx):
    # Plate for subjects: we have N subjects
    # gamma ~ Normal(0, tau)
    # We'll consider gamma_i as independent across subjects
    with pyro.plate("subjects", N, dim=-1):
        gamma = pyro.sample("gamma", dist.Normal(torch.tensor(0.0), tau))
    # gamma now has shape [N]

    # Priors for weights:
    fc1_w = pyro.sample("fc1_w", dist.Normal(torch.zeros(hidden_dim, 1), 1.0).to_event(2))
    fc1_b = pyro.sample("fc1_b", dist.Normal(torch.zeros(hidden_dim), 1.0).to_event(1))
    fc2_w = pyro.sample("fc2_w", dist.Normal(torch.zeros(1, hidden_dim), 1.0).to_event(2))
    fc2_b = pyro.sample("fc2_b", dist.Normal(torch.zeros(1), 1.0))

    # Compute predictions
    # X: (N*T,1)
    # fc1_w: (hidden_dim,1), fc1_b: (hidden_dim)
    h = torch.relu(X @ fc1_w.transpose(-1,-2) + fc1_b)  # (N*T, hidden_dim)
    pred = h @ fc2_w.transpose(-1,-2) + fc2_b  # (N*T,1)
    pred = pred.squeeze(-1)  # (N*T,)

    # Add gamma:
    # subject_idx: (N*T,), gamma: (N,)
    # gamma[subject_idx]: (N*T,)
    pred = pred + gamma[subject_idx]

    # Likelihood
    with pyro.plate("data", X.size(0), dim=-1):
        pyro.sample("obs", dist.Normal(pred, sigma), obs=Y)

def guide(X, Y, subject_idx):
    # Define variational parameters
    fc1_w_loc = pyro.param("fc1_w_loc", torch.zeros(hidden_dim, 1))
    fc1_w_scale = pyro.param("fc1_w_scale", torch.ones(hidden_dim, 1), constraint=constraints.positive)
    fc1_b_loc = pyro.param("fc1_b_loc", torch.zeros(hidden_dim))
    fc1_b_scale = pyro.param("fc1_b_scale", torch.ones(hidden_dim), constraint=constraints.positive)

    fc2_w_loc = pyro.param("fc2_w_loc", torch.zeros(1, hidden_dim))
    fc2_w_scale = pyro.param("fc2_w_scale", torch.ones(1, hidden_dim), constraint=constraints.positive)
    fc2_b_loc = pyro.param("fc2_b_loc", torch.zeros(1))
    fc2_b_scale = pyro.param("fc2_b_scale", torch.ones(1), constraint=constraints.positive)

    gamma_loc = pyro.param("gamma_loc", torch.zeros(N))
    gamma_scale = pyro.param("gamma_scale", torch.ones(N), constraint=constraints.positive)

    # Sample statements (ensure shapes match):
    with pyro.plate("subjects", N, dim=-1):
        pyro.sample("gamma", dist.Normal(gamma_loc, gamma_scale))

    pyro.sample("fc1_w", dist.Normal(fc1_w_loc, fc1_w_scale).to_event(2))
    pyro.sample("fc1_b", dist.Normal(fc1_b_loc, fc1_b_scale).to_event(1))
    pyro.sample("fc2_w", dist.Normal(fc2_w_loc, fc2_w_scale).to_event(2))
    pyro.sample("fc2_b", dist.Normal(fc2_b_loc, fc2_b_scale))

optimizer = Adam({"lr":0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

for step in range(500):
    loss = svi.step(X_torch, Y_torch, subject_indices)
    if (step+1)%100==0:
        print(f"Step {step+1}, ELBO: {loss}")

# Posterior mean predictions:
with torch.no_grad():
    fc1_w_loc = pyro.param("fc1_w_loc")
    fc1_b_loc = pyro.param("fc1_b_loc")
    fc2_w_loc = pyro.param("fc2_w_loc")
    fc2_b_loc = pyro.param("fc2_b_loc")
    gamma_loc = pyro.param("gamma_loc")

    h = torch.relu(X_torch @ fc1_w_loc.transpose(-1,-2) + fc1_b_loc)
    pred_mean = (h @ fc2_w_loc.transpose(-1,-2) + fc2_b_loc).squeeze(-1)
    pred_mean = pred_mean + gamma_loc[subject_indices]
    mse = torch.mean((Y_torch - pred_mean)**2).item()
    print("Bayesian Univariate Model MSE:", mse)


Step 100, ELBO: 16978706.740332127
Step 200, ELBO: 9303632.133728206
Step 300, ELBO: 524039.7157717347
Step 400, ELBO: 1063870.1879956722
Step 500, ELBO: 591187.8524053693
Bayesian Univariate Model MSE: 17.977684020996094


In [11]:
import numpy as np
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.distributions import constraints

pyro.clear_param_store()
torch.manual_seed(1)
np.random.seed(1)

#############################################
# Simulate Bivariate Data
#############################################
N = 50
T = 10
d_y = 2
sigma = 0.1
tau = 0.5

x_vals = np.linspace(0, 5, T)
X_main = np.tile(x_vals, (N,1))

X = []
for i in range(N):
    for t in range(T):
        x = X_main[i,t]
        X.append([x, np.cos(x), (np.sin(x))**2, x**2])
X = np.array(X)
X_torch = torch.tensor(X, dtype=torch.float32)

gamma_true = np.random.normal(0, tau, size=N)

def f_true_bivariate(x):
    return np.array([x[0] + np.cos(x[0]) + np.sin(x[0])**2,
                     0.5*(x[0]**2)])

Y = np.zeros((N*T, d_y))
for i in range(N):
    for t in range(T):
        mu = f_true_bivariate(X[i*T+t,:]) + gamma_true[i]
        eps = np.random.multivariate_normal(mean=[0,0], cov=(sigma**2*np.eye(2)))
        Y[i*T+t,:] = mu + eps

Y_torch = torch.tensor(Y, dtype=torch.float32)
subject_indices = torch.tensor(np.repeat(np.arange(N), T), dtype=torch.long)

#############################################
# Bayesian Bivariate Model
#############################################
hidden_dim = 10
output_dim = 2

def model_bivariate(X, Y, subject_idx):
    # Random intercept per subject
    with pyro.plate("subjects", N, dim=-1):
        gamma = pyro.sample("gamma", dist.Normal(torch.tensor(0.0), tau))

    fc1_w = pyro.sample("fc1_w", dist.Normal(torch.zeros(hidden_dim, X.size(1)), 1.0).to_event(2))
    fc1_b = pyro.sample("fc1_b", dist.Normal(torch.zeros(hidden_dim), 1.0).to_event(1))
    fc2_w = pyro.sample("fc2_w", dist.Normal(torch.zeros(output_dim, hidden_dim), 1.0).to_event(2))
    fc2_b = pyro.sample("fc2_b", dist.Normal(torch.zeros(output_dim), 1.0).to_event(1))

    # Compute predictions
    h = torch.relu(X @ fc1_w.transpose(-1,-2) + fc1_b)
    pred = h @ fc2_w.transpose(-1,-2) + fc2_b  # (N*T, 2)

    # Add gamma to both outcomes
    gamma_2d = gamma[subject_idx].unsqueeze(-1).expand(-1, d_y)
    pred = pred + gamma_2d

    # Likelihood
    Sigma_epsilon = sigma**2 * torch.eye(d_y)
    # The "data" plate ensures Pyro knows we're dealing with N*T independent observations
    with pyro.plate("data", X.size(0), dim=-1):
        pyro.sample("obs",
                    dist.MultivariateNormal(pred, covariance_matrix=Sigma_epsilon),
                    obs=Y)

def guide_bivariate(X, Y, subject_idx):
    fc1_w_loc = pyro.param("fc1_w_loc", torch.zeros(hidden_dim, X.size(1)))
    fc1_w_scale = pyro.param("fc1_w_scale", torch.ones(hidden_dim, X.size(1)), constraint=constraints.positive)
    fc1_b_loc = pyro.param("fc1_b_loc", torch.zeros(hidden_dim))
    fc1_b_scale = pyro.param("fc1_b_scale", torch.ones(hidden_dim), constraint=constraints.positive)

    fc2_w_loc = pyro.param("fc2_w_loc", torch.zeros(output_dim, hidden_dim))
    fc2_w_scale = pyro.param("fc2_w_scale", torch.ones(output_dim, hidden_dim), constraint=constraints.positive)
    fc2_b_loc = pyro.param("fc2_b_loc", torch.zeros(output_dim))
    fc2_b_scale = pyro.param("fc2_b_scale", torch.ones(output_dim), constraint=constraints.positive)

    gamma_loc = pyro.param("gamma_loc", torch.zeros(N))
    gamma_scale = pyro.param("gamma_scale", torch.ones(N), constraint=constraints.positive)

    with pyro.plate("subjects", N, dim=-1):
        pyro.sample("gamma", dist.Normal(gamma_loc, gamma_scale))

    pyro.sample("fc1_w", dist.Normal(fc1_w_loc, fc1_w_scale).to_event(2))
    pyro.sample("fc1_b", dist.Normal(fc1_b_loc, fc1_b_scale).to_event(1))
    pyro.sample("fc2_w", dist.Normal(fc2_w_loc, fc2_w_scale).to_event(2))
    pyro.sample("fc2_b", dist.Normal(fc2_b_loc, fc2_b_scale).to_event(1))

optimizer = Adam({"lr":0.01})
svi = SVI(model_bivariate, guide_bivariate, optimizer, loss=Trace_ELBO())

for step in range(500):
    loss = svi.step(X_torch, Y_torch, subject_indices)
    if (step+1) % 100 == 0:
        print(f"Step {step+1}, ELBO: {loss}")

with torch.no_grad():
    fc1_w_loc = pyro.param("fc1_w_loc")
    fc1_b_loc = pyro.param("fc1_b_loc")
    fc2_w_loc = pyro.param("fc2_w_loc")
    fc2_b_loc = pyro.param("fc2_b_loc")
    gamma_loc = pyro.param("gamma_loc")

    h = torch.relu(X_torch @ fc1_w_loc.transpose(-1,-2) + fc1_b_loc)
    pred_mean = h @ fc2_w_loc.transpose(-1,-2) + fc2_b_loc
    pred_mean = pred_mean + gamma_loc[subject_indices].unsqueeze(-1).expand(-1,d_y)
    mse_bivariate = torch.mean((Y_torch - pred_mean)**2).item()
    print("Bayesian Bivariate Model MSE:", mse_bivariate)


Step 100, ELBO: 15119495.51012969
Step 200, ELBO: 357725.1309299469
Step 300, ELBO: 996476.6801431179
Step 400, ELBO: 325760.84372878075
Step 500, ELBO: 11704605.582691431
Bayesian Bivariate Model MSE: 12.710402488708496
