# Experiment 3: Unknown Group Prediction Test

This experiment evaluates how hierarchical models handle predictions for unknown groups.
We train models on some groups and test on held-out groups to assess generalization.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error

from models.linear_regression import LinearRegression
from models.gaussian_process import GaussianProcessRegression

# Set random seed
np.random.seed(42)
sns.set_style('whitegrid')

## 1. Generate Complete Dataset (No Missing Values)

In [None]:
# Generate dataset with 5 groups
n_samples_per_group = 50
groups_all = ['A', 'B', 'C', 'D', 'E']
n_total = n_samples_per_group * len(groups_all)

# Create group labels
groups = np.repeat(groups_all, n_samples_per_group)

# Group-specific random effects
group_effects = {
    'A': 3.0,
    'B': -2.0,
    'C': 1.5,
    'D': -1.0,
    'E': 2.5
}

# Generate features
X1 = np.random.randn(n_total)
X2 = np.random.randn(n_total)
X3 = np.random.randn(n_total)

# Generate target with hierarchical structure
y = 2.5 * X1 + 1.8 * X2 + 1.2 * X3
for i, group in enumerate(groups):
    y[i] += group_effects[group]
y += np.random.randn(n_total) * 0.5

# Create DataFrame
data = pd.DataFrame({
    'group': groups,
    'x1': X1,
    'x2': X2,
    'x3': X3,
    'y': y
})

print(f"Dataset shape: {data.shape}")
print(f"Groups: {data['group'].unique()}")
print(f"\nSamples per group:")
print(data['group'].value_counts().sort_index())

## 2. Split Data: Known vs Unknown Groups

In [None]:
# Train on groups A, B, C
# Test on groups D, E (unknown groups)
train_groups = ['A', 'B', 'C']
test_groups = ['D', 'E']

train_data = data[data['group'].isin(train_groups)].copy()
test_data = data[data['group'].isin(test_groups)].copy()

print(f"Training data: {len(train_data)} samples from groups {train_groups}")
print(f"Test data: {len(test_data)} samples from groups {test_groups} (unknown)")

# Prepare features and target
feature_cols = ['x1', 'x2', 'x3']
X_train = train_data[feature_cols + ['group']]
y_train = train_data['y']
X_test = test_data[feature_cols + ['group']]
y_test = test_data['y']

## 3. Train and Evaluate Four Methods

In [None]:
results = []

# Method 1: Ordinary Linear Regression
print("\n" + "="*50)
print("Method 1: Ordinary Linear Regression")
print("="*50)

model_lr = LinearRegression()
model_lr.fit(X_train[feature_cols], y_train, random_effects=None)

y_pred_lr, std_lr = model_lr.predict(X_test[feature_cols], return_std=True)
mse_lr = mean_squared_error(y_test, y_pred_lr)
mean_std_lr = np.mean(std_lr)

print(f"MSE: {mse_lr:.4f}")
print(f"Mean prediction uncertainty: {mean_std_lr:.4f}")

results.append({
    'Method': 'Ordinary LR',
    'MSE': mse_lr,
    'Mean_Std': mean_std_lr
})

In [None]:
# Method 2: Hierarchical Linear Regression
print("\n" + "="*50)
print("Method 2: Hierarchical Linear Regression")
print("="*50)

model_hlr = LinearRegression()
model_hlr.fit(X_train, y_train, random_effects=['group'])

y_pred_hlr, std_hlr = model_hlr.predict(X_test, return_std=True)
mse_hlr = mean_squared_error(y_test, y_pred_hlr)
mean_std_hlr = np.mean(std_hlr)

print(f"MSE: {mse_hlr:.4f}")
print(f"Mean prediction uncertainty: {mean_std_hlr:.4f}")
print(f"Note: Unknown groups D and E use population-level effects")

results.append({
    'Method': 'Hierarchical LR',
    'MSE': mse_hlr,
    'Mean_Std': mean_std_hlr
})

In [None]:
# Method 3: Ordinary Gaussian Process
print("\n" + "="*50)
print("Method 3: Ordinary Gaussian Process")
print("="*50)

model_gp = GaussianProcessRegression()
model_gp.fit(X_train[feature_cols], y_train, random_effects=None)

y_pred_gp, std_gp = model_gp.predict(X_test[feature_cols], return_std=True)
mse_gp = mean_squared_error(y_test, y_pred_gp)
mean_std_gp = np.mean(std_gp)

print(f"MSE: {mse_gp:.4f}")
print(f"Mean prediction uncertainty: {mean_std_gp:.4f}")

results.append({
    'Method': 'Ordinary GP',
    'MSE': mse_gp,
    'Mean_Std': mean_std_gp
})

In [None]:
# Method 4: Hierarchical Gaussian Process
print("\n" + "="*50)
print("Method 4: Hierarchical Gaussian Process")
print("="*50)

model_hgp = GaussianProcessRegression()
model_hgp.fit(X_train, y_train, random_effects=['group'])

y_pred_hgp, std_hgp = model_hgp.predict(X_test, return_std=True)
mse_hgp = mean_squared_error(y_test, y_pred_hgp)
mean_std_hgp = np.mean(std_hgp)

print(f"MSE: {mse_hgp:.4f}")
print(f"Mean prediction uncertainty: {mean_std_hgp:.4f}")
print(f"Note: Unknown groups D and E use population-level effects")

results.append({
    'Method': 'Hierarchical GP',
    'MSE': mse_hgp,
    'Mean_Std': mean_std_hgp
})

## 4. Compare Results

In [None]:
# Create results DataFrame
results_df = pd.DataFrame(results)

print("\n" + "="*70)
print("UNKNOWN GROUP PREDICTION PERFORMANCE")
print("="*70)
print(results_df.to_string(index=False))
print("\n" + "="*70)
print(f"Best method (lowest MSE): {results_df.loc[results_df['MSE'].idxmin(), 'Method']}")
print(f"Most uncertain predictions: {results_df.loc[results_df['Mean_Std'].idxmax(), 'Method']}")
print("="*70)

## 5. Analyze Performance by Group

In [None]:
# Calculate MSE for each unknown group separately
group_results = []

for group in test_groups:
    group_mask = test_data['group'] == group
    y_true_group = y_test[group_mask]
    
    mse_lr_group = mean_squared_error(y_true_group, y_pred_lr[group_mask])
    mse_hlr_group = mean_squared_error(y_true_group, y_pred_hlr[group_mask])
    mse_gp_group = mean_squared_error(y_true_group, y_pred_gp[group_mask])
    mse_hgp_group = mean_squared_error(y_true_group, y_pred_hgp[group_mask])
    
    group_results.append({
        'Group': group,
        'Ordinary LR': mse_lr_group,
        'Hierarchical LR': mse_hlr_group,
        'Ordinary GP': mse_gp_group,
        'Hierarchical GP': mse_hgp_group
    })

group_results_df = pd.DataFrame(group_results)
print("\nMSE by Unknown Group:")
print(group_results_df.to_string(index=False))

## 6. Visualization

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

# Plot 1: MSE comparison
ax1 = axes[0, 0]
ax1.bar(results_df['Method'], results_df['MSE'], color=colors)
ax1.set_ylabel('Mean Squared Error', fontsize=11)
ax1.set_title('Prediction Error on Unknown Groups', fontsize=12, fontweight='bold')
ax1.tick_params(axis='x', rotation=45)
ax1.grid(axis='y', alpha=0.3)

# Plot 2: Uncertainty comparison
ax2 = axes[0, 1]
ax2.bar(results_df['Method'], results_df['Mean_Std'], color=colors)
ax2.set_ylabel('Mean Prediction Std', fontsize=11)
ax2.set_title('Prediction Uncertainty', fontsize=12, fontweight='bold')
ax2.tick_params(axis='x', rotation=45)
ax2.grid(axis='y', alpha=0.3)

# Plot 3: True vs Predicted (Hierarchical GP)
ax3 = axes[1, 0]
for group in test_groups:
    group_mask = test_data['group'] == group
    ax3.scatter(y_test[group_mask], y_pred_hgp[group_mask], 
               label=f'Group {group}', alpha=0.6, s=50)

min_val = min(y_test.min(), y_pred_hgp.min())
max_val = max(y_test.max(), y_pred_hgp.max())
ax3.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5)
ax3.set_xlabel('True Values', fontsize=11)
ax3.set_ylabel('Predicted Values', fontsize=11)
ax3.set_title('Hierarchical GP: True vs Predicted', fontsize=12, fontweight='bold')
ax3.legend()
ax3.grid(alpha=0.3)

# Plot 4: MSE by group
ax4 = axes[1, 1]
x_pos = np.arange(len(test_groups))
width = 0.2
ax4.bar(x_pos - 1.5*width, group_results_df['Ordinary LR'], width, 
        label='Ordinary LR', color=colors[0])
ax4.bar(x_pos - 0.5*width, group_results_df['Hierarchical LR'], width, 
        label='Hierarchical LR', color=colors[1])
ax4.bar(x_pos + 0.5*width, group_results_df['Ordinary GP'], width, 
        label='Ordinary GP', color=colors[2])
ax4.bar(x_pos + 1.5*width, group_results_df['Hierarchical GP'], width, 
        label='Hierarchical GP', color=colors[3])
ax4.set_xticks(x_pos)
ax4.set_xticklabels(test_groups)
ax4.set_ylabel('MSE', fontsize=11)
ax4.set_xlabel('Unknown Group', fontsize=11)
ax4.set_title('MSE by Unknown Group', fontsize=12, fontweight='bold')
ax4.legend(fontsize=9)
ax4.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('experiment_3_results.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nVisualization saved as 'experiment_3_results.png'")

## Conclusion

This experiment evaluates how hierarchical models generalize to unknown groups. Key observations:

1. **Hierarchical models** fall back to population-level effects for unknown groups
2. **Ordinary models** may perform better on unknown groups if group effects are strong
3. **Prediction uncertainty** helps quantify confidence in predictions for unknown groups
4. The performance difference reveals the trade-off between leveraging group structure and generalization