# TabularPriors Visualization and Demo Notebook

This notebook shows how to generate, load, and visualize synthetic tabular data using the `tabularpriors` package. You’ll see examples for both regression and classification, using different types of priors, with clear visualizations to help you explore the data.

## Contents:

1. **TICL MLP Prior Regression** - Visualizing regression data generated from MLP priors
2. **TICL GP Prior Regression** - Visualizing regression data from Gaussian Process priors
3. **TabICL Classification** - Visualizing classification data from TabICL priors
4. **Live Data Generation** - Generating and visualizing synthetic data in real-time

## Prerequisites:

- Ensure you have the tabularpriors package installed
- Run the necessary commands to generate the HDF5 data files as shown in each section

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tabularpriors.dataloader import PriorDumpDataLoader

def plot_regression_samples(X_data, y_data, title, color, batch_size=20):
    """Plot regression feature-target and sample evolution."""
    num_features = X_data.shape[2]
    if num_features == 1:
        _plot_regression_1d(X_data, y_data, title, batch_size)
    elif num_features >= 2:
        _plot_regression_simplified(X_data, y_data, title, color, batch_size, num_features)

def _plot_regression_1d(X_data, y_data, title, batch_size):
    """Plot 1D regression samples."""
    rows, cols = 4, 5
    fig, axes = plt.subplots(rows, cols, figsize=(20, 12))
    axes = axes.flatten()
    for i in range(batch_size):
        X_i = X_data[i, :, 0]
        y_i = y_data[i, :]
        axes[i].scatter(X_i, y_i, color='blue', alpha=0.7, s=20)
        axes[i].set_title(f"Sample {i + 1}")
        axes[i].set_xlabel("x")
        axes[i].set_ylabel("y")
    for i in range(batch_size, len(axes)):
        fig.delaxes(axes[i])
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

def _plot_regression_simplified(X_data, y_data, title, color, batch_size, num_features):
    """Plot multi-feature regression stats and plots."""
    # Print stats first
    _print_regression_statistics(X_data, y_data, num_features, batch_size)
    # Feature vs Target plots
    fig1 = plt.figure(figsize=(15, 5))
    for feat_idx in range(min(3, num_features)):
        ax = plt.subplot(1, 3, feat_idx + 1)
        _plot_feature_vs_target(X_data, y_data, feat_idx, ax, color, f"Feature {feat_idx + 1}")
    plt.suptitle(f"{title} - Feature-Target Relationships", fontsize=14)
    plt.tight_layout()
    plt.show()
    # Sample Evolution Analysis
    fig2 = plt.figure(figsize=(16, 10))
    for sample_idx in range(min(4, batch_size)):
        ax = plt.subplot(2, 2, sample_idx + 1)
        _plot_single_sample_analysis(X_data, y_data, sample_idx, ax, color, num_features)
    plt.suptitle(f"{title} - Sample Evolution Analysis", fontsize=14)
    plt.tight_layout()
    plt.show()

def _plot_feature_vs_target(X_data, y_data, feature_idx, ax, color, label):
    """Plot feature vs target."""
    all_x = X_data[:, :, feature_idx].flatten()
    all_y = y_data.flatten()
    ax.scatter(all_x, all_y, color=color, alpha=0.4, s=12)
    ax.set_xlabel(label, fontsize=12)
    ax.set_ylabel('Target', fontsize=12)
    ax.set_title(f'{label} vs Target', fontsize=12, pad=10)
    ax.grid(True, alpha=0.3)
    correlation = np.corrcoef(all_x, all_y)[0, 1]
    ax.text(0.05, 0.95, f'r = {correlation:.3f}', transform=ax.transAxes, 
           bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
           fontsize=10, verticalalignment='top')

def _plot_single_sample_analysis(X_data, y_data, sample_idx, ax, color, num_features):
    """Plot single sample evolution."""
    if num_features == 1:
        x = X_data[sample_idx, :, 0]
        y = y_data[sample_idx, :]
        ax.plot(x, y, 'o-', color='blue', alpha=0.7, markersize=4)
        ax.set_xlabel('Feature 1', fontsize=11)
        ax.set_ylabel('Target', fontsize=11)
    else:
        seq_len = X_data.shape[1]
        x_axis = np.arange(seq_len)
        feature_colors = ['#1f77b4', '#9467bd', '#2ca02c']
        for feat_idx in range(min(3, num_features)):
            ax.plot(x_axis, X_data[sample_idx, :, feat_idx], 
                   label=f'F{feat_idx + 1}', alpha=0.8, linewidth=2,
                   color=feature_colors[feat_idx])
        ax2 = ax.twinx()
        ax2.plot(x_axis, y_data[sample_idx, :], color='#555555', linestyle='--', label='Target', 
                alpha=0.9, linewidth=2.5)
        ax2.set_ylabel('Target', color='#555555', fontsize=11)
        ax.legend(loc='upper left', fontsize=9)
        ax2.legend(loc='upper right', fontsize=9)
        ax.set_xlabel('Sequence Position', fontsize=11)
        ax.set_ylabel('Feature Values', fontsize=11)
    ax.set_title(f'Sample {sample_idx + 1} Evolution', fontsize=11, pad=8)
    ax.grid(True, alpha=0.3)

def _print_regression_statistics(X_data, y_data, num_features, batch_size):
    """Print regression stats."""
    print(f"Features: {num_features}, Samples: {batch_size}")
    print(f"Sequence length: {X_data.shape[1]}")
    print(f"Feature stats:")
    for feat_idx in range(num_features):
        feat_data = X_data[:, :, feat_idx].flatten()
        print(f"  Feature {feat_idx + 1}: mean={feat_data.mean():.3f}, std={feat_data.std():.3f}, range=[{feat_data.min():.3f}, {feat_data.max():.3f}]")
    target_data = y_data.flatten()
    print(f"Target: mean={target_data.mean():.3f}, std={target_data.std():.3f}, range=[{target_data.min():.3f}, {target_data.max():.3f}]")
    print(f"Feature-Target correlations:")
    for feat_idx in range(num_features):
        feat_data = X_data[:, :, feat_idx].flatten()
        correlation = np.corrcoef(feat_data, target_data)[0, 1]
        print(f"  Feature {feat_idx + 1} <-> Target: {correlation:.3f}")

def plot_classification_samples(X_data, y_data, title, batch_size=4, max_classes=5):
    """Plot 3D classification samples."""
    colors = ['#e41a1c', '#377eb8', '#4daf4a', '#ff7f00', '#984ea3']
    marker = 'o'
    rows, cols = 2, 2
    all_classes = np.unique(y_data.flatten())
    global_class_counts = {cls: np.sum(y_data == cls) for cls in all_classes}
    print("Class distribution:")
    for cls, count in global_class_counts.items():
        percentage = count / y_data.size * 100
        print(f"  Class {cls}: {count} points ({percentage:.1f}%)")
    fig = plt.figure(figsize=(16, 14))
    axes = []
    for i in range(rows * cols):
        ax = fig.add_subplot(rows, cols, i + 1, projection='3d')
        axes.append(ax)
    for i in range(batch_size):
        X_i = X_data[i, :, :3]
        y_i = y_data[i, :].astype(int)
        unique_classes, class_counts = np.unique(y_i, return_counts=True)
        ax = axes[i]
        for j, class_idx in enumerate(unique_classes):
            mask = (y_i == class_idx)
            if np.any(mask):
                color = colors[class_idx % len(colors)]
                ax.scatter(X_i[mask, 0], X_i[mask, 1], X_i[mask, 2],
                           color=color, alpha=0.85, s=80, marker=marker,
                           label=f'C{class_idx} ({class_counts[j]}pts)')
                centroid = np.mean(X_i[mask, :3], axis=0)
                ax.scatter(centroid[0], centroid[1], centroid[2],
                           color=color, s=120, marker='X', alpha=1.0, linewidths=1)
        dominant_class = unique_classes[np.argmax(class_counts)]
        class_diversity = len(unique_classes)
        ax.set_title(f"Sample {i + 1}\n{class_diversity} classes, Dom: C{dominant_class}", fontsize=13)
        ax.set_xlabel("Feature 1", fontsize=11)
        ax.set_ylabel("Feature 2", fontsize=11)
        ax.set_zlabel("Feature 3", fontsize=11)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=10, loc='upper right')
    for i in range(batch_size, len(axes)):
        fig.delaxes(axes[i])
    total_samples = batch_size
    avg_classes_per_sample = np.mean([len(np.unique(y_data[i, :])) for i in range(batch_size)])
    enhanced_title = f"{title}\n{total_samples} samples, Avg {avg_classes_per_sample:.1f} classes/sample, {len(all_classes)} unique classes overall"
    plt.suptitle(enhanced_title, fontsize=16)
    plt.tight_layout()
    plt.show()

def load_and_print_info(loader_or_path, loader_type="dump"):
    """Load data and print info."""
    if loader_type == "dump":
        if isinstance(loader_or_path, str):
            loader = PriorDumpDataLoader(filename=loader_or_path, num_steps=1, batch_size=20, device='cpu')
        else:
            loader = loader_or_path
    else:
        loader = loader_or_path
    batch = next(iter(loader))
    X_data = batch["x"].cpu().numpy()
    y_data = batch["y"].cpu().numpy()
    print(f"Data shape: X={X_data.shape}, y={y_data.shape}")
    print(f"Single eval pos: {batch['single_eval_pos']}")
    return X_data, y_data, batch

## 1. TICL MLP Prior for Regression

This section loads regression data generated from MLP (Multi-Layer Perceptron) priors. The plots show:
- How each feature relates to the target variable
- How data patterns evolve across a sequence for multiple samples
- Statistical relationships between features and the target

You can create the data file for this section with:
```bash
python -m tabularpriors --lib ticl --prior_type mlp --num_batches 1 --batch_size 4 --max_features 3 --max_seq_len 25
```

In [None]:
# Load and visualize TICL MLP prior data from dump
path_to_dump = "prior_ticl_mlp_1x4_25x3.h5"

# Load data and print information
X_all, y_all, batch = load_and_print_info(path_to_dump, "dump")

# Visualize the regression data
plot_regression_samples(
    X_data=X_all, 
    y_data=y_all, 
    title="TICL MLP Prior Samples from Dump (Regression)", 
    color="blue", 
    batch_size=4
)

## 2. TICL GP Prior for Regression

This section loads regression data generated from Gaussian Process priors. The plots show:
- How each feature relates to the target variable
- How data patterns evolve across a sequence for multiple samples
- Statistical relationships between features and the target

You can create the data file for this section with:
```bash
python -m tabularpriors --lib ticl --prior_type gp --num_batches 1 --batch_size 4 --max_features 3 --max_seq_len 25
```

In [None]:
# Load and visualize TICL GP prior data from dump
path_to_gp_dump = "prior_ticl_gp_1x4_25x3.h5"

# Load data and print information
X_gp, y_gp, gp_batch = load_and_print_info(path_to_gp_dump, "dump")

# Visualize the GP regression data
plot_regression_samples(
    X_data=X_gp, 
    y_data=y_gp, 
    title="TICL GP Prior Samples from Dump (Regression)", 
    color="blue", 
    batch_size=4
)

## 3. TabICL Prior for Classification

This section loads classification data from TabICL priors. The visualizations show:
- Class distributions in 3D feature space
- Class centroids and boundaries 
- Distribution of points across different classes

You can create the data file for this section with:
```bash
python -m tabularpriors --lib tabicl --num_batches 1 --batch_size 4 --max_features 3 --max_seq_len 25 --max_classes 5
```

In [None]:
# Load and visualize TabICL prior data from dump
path_to_tabicl_dump = "prior_tabicl_1x4_25x3.h5"

# Load data and print information
X_tabicl_dump, y_tabicl_dump, tabicl_dump_batch = load_and_print_info(path_to_tabicl_dump, "dump")
print(f"TabICL unique classes: {np.unique(y_tabicl_dump)}")

# Visualize the classification data
plot_classification_samples(
    X_data=X_tabicl_dump, 
    y_data=y_tabicl_dump, 
    title="TabICL Prior Samples from Dump (Classification)", 
    batch_size=4
)


## 4. Live Data Generation (Real-time Synthesis)

This section demonstrates generating synthetic data on-the-fly instead of loading from pre-generated files. 
This approach is useful for:
- Experimenting with different prior configurations
- Testing parameter sensitivity
- Generating data with specific random seeds
- Immediate visualization without creating intermediate files

### 4.1 Live TICL MLP Regression Data

In [None]:
# Test Live TICL MLP Loader (Real-time Generation)
from tabularpriors.dataloader import TICLPriorDataLoader
from tabularpriors.utils import build_ticl_prior
import torch

device = torch.device('cpu')

print("=== Testing Live TICL MLP Loader ===")
# Create live TICL MLP loader
ticl_mlp_loader = TICLPriorDataLoader(
    prior=build_ticl_prior('mlp'),
    num_steps=1,
    batch_size=20,
    num_datapoints_max=50,
    num_features=1,
    device=device,
    min_eval_pos=10,
)

# Load data and print information
X_ticl_mlp, y_ticl_mlp, ticl_batch = load_and_print_info(ticl_mlp_loader, "live")

# Visualize the live MLP data
plot_regression_samples(
    X_data=X_ticl_mlp, 
    y_data=y_ticl_mlp, 
    title="Live TICL MLP Prior Samples (Real-time Regression)", 
    color="green"
)

### 4.2 Live TICL GP Regression Data

In [None]:
# Test Live TICL GP Loader (Real-time Generation)
print("=== Testing Live TICL GP Loader ===")

# Create live TICL GP loader
ticl_gp_loader = TICLPriorDataLoader(
    prior=build_ticl_prior('gp'),
    num_steps=1,
    batch_size=20,
    num_datapoints_max=50,
    num_features=1,
    device=device,
    min_eval_pos=10,
)

# Load data and print information
X_ticl_gp, y_ticl_gp, ticl_gp_batch = load_and_print_info(ticl_gp_loader, "live")

# Visualize the live GP data
plot_regression_samples(
    X_data=X_ticl_gp, 
    y_data=y_ticl_gp, 
    title="Live TICL GP Prior Samples (Real-time Regression)", 
    color="darkgreen"
)

### 4.3 Live TabICL Classification Data

In [None]:
# Test Live TabICL Loader (Real-time Classification)
from tabularpriors.dataloader import TabICLPriorDataLoader

print("=== Testing Live TabICL Loader ===")

# Create live TabICL loader
tabicl_live_loader = TabICLPriorDataLoader(
    num_steps=1,
    batch_size=20,
    num_datapoints_max=50,
    num_features=3,
    max_num_classes=5,
    device=device,
)

# Load data and print information
X_tabicl_live, y_tabicl_live, tabicl_live_batch = load_and_print_info(tabicl_live_loader, "live")

# Visualize the live classification data
plot_classification_samples(
    X_data=X_tabicl_live, 
    y_data=y_tabicl_live, 
    title="Live TabICL Prior Samples (Real-time Classification)", 
    batch_size=4
)
