# Dis-AE 2: Real Data Example with data_B.csv

This notebook demonstrates using Dis-AE 2 on the **data_B.csv** synthetic dataset, which simulates real-world challenges:
- **Multiple domain factors (modelled after domain shifts found in CBC data)**: Machine, venepuncture delay, study time
- **Task**: Binary classification with CBC-like features

## Goals

1. Train a domain-invariant model for classification
2. Evaluate task performance across different domains
3. Verify domain invariance in learned features
4. Compare with baseline methods (no domain adaptation)

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from sysmexcbctools.disae2.disae2 import DisAE, normalize_data

# Set random seed
np.random.seed(42)

# PDF-compatible fonts
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

# Scientific plot style
import scienceplots

plt.style.use(["science", "nature"])

# Colourblind-friendly palette
SEABORN_PALETTE = "colorblind"
seaborn_colors = sns.color_palette(SEABORN_PALETTE)

## 1. Load and Explore Data

In [None]:
# Load data_B.csv
data_path = "../data/data_B.csv"
df = pd.read_csv(data_path)

print("Dataset Shape:", df.shape)
print("\nFirst few rows:")
df.head()

In [None]:
# Dataset summary
print("Dataset Summary:")
print("=" * 60)
print(f"Total samples: {len(df):,}")
print(f"Features: 32 (columns 0-31)")
print(f"\nTask label: ClassCategory_0")
print(f"  Class distribution:")
print(df["ClassCategory_0"].value_counts().sort_index())
print(f"\nDomain factors:")
print(f"  Machine: {df['Machine'].nunique()} levels")
print(f"  vendelay_binned: {df['vendelay_binned'].nunique()} bins")
print(f"  studytime_binned: {df['studytime_binned'].nunique()} bins")

### Visualise Domain Distributions

In [None]:
# Plot domain factor distributions
fig, axes = plt.subplots(1, 3, figsize=(6.6, 2.2))

df["Machine"].value_counts().sort_index().plot(
    kind="bar", ax=axes[0], color="steelblue"
)
axes[0].set_title("Machine Distribution")
axes[0].set_xlabel("Machine ID")
axes[0].set_ylabel("Count")
axes[0].tick_params(axis="x", rotation=0)

df["vendelay_binned"].value_counts().sort_index().plot(
    kind="bar", ax=axes[1], color="coral"
)
axes[1].set_title("Venepuncture Delay Distribution")
axes[1].set_xlabel("Delay Bin")
axes[1].set_ylabel("Count")
axes[1].tick_params(axis="x", rotation=0)

df["studytime_binned"].value_counts().sort_index().plot(
    kind="bar", ax=axes[2], color="mediumseagreen"
)
axes[2].set_title("Study Time Distribution")
axes[2].set_xlabel("Time Bin")
axes[2].set_ylabel("Count")
axes[2].tick_params(axis="x", rotation=0)

plt.tight_layout()
plt.show()

## 2. Prepare Data

In [None]:
# Extract features and labels
feature_columns = [str(i) for i in range(32)]  # Columns are just named "0" to "31"
X = df[feature_columns].values

# Task label (binary classification)
y_task = df["ClassCategory_0"].values.reshape(-1, 1)

# Domain factors
domain_columns = ["Machine", "vendelay_binned", "studytime_binned"]
y_domains_raw = df[domain_columns]

# Encode domain labels as integers
y_domains = []
domain_mappings = {}
for col in domain_columns:
    le = LabelEncoder()
    encoded = le.fit_transform(y_domains_raw[col])
    y_domains.append(encoded)
    domain_mappings[col] = dict(zip(le.classes_, le.transform(le.classes_)))
y_domains = np.column_stack(y_domains)

print("Data shapes:")
print(f"  X: {X.shape}")
print(f"  y_task: {y_task.shape}")
print(f"  y_domains: {y_domains.shape}")
print(
    f"\nDomain cardinalities: {[len(np.unique(y_domains[:, i])) for i in range(y_domains.shape[1])]}"
)

In [None]:
# Normalise features
scaler = StandardScaler()
X_normalized = scaler.fit_transform(X)

print("Feature statistics after normalisation:")
print(f"  Mean: {X_normalized.mean():.6f}")
print(f"  Std: {X_normalized.std():.6f}")
print(f"  Min: {X_normalized.min():.4f}")
print(f"  Max: {X_normalized.max():.4f}")

In [None]:
# Domain Generalization Split Strategy:
# - Train/Val: Machines 0 and 1 (source domains)
# - Test: Machines 2, 3, 4 (target domains) + small portion of 0, 1 (within-source)

# Split by machine ID
machine_ids = y_domains[:, 0]

# Source domains (train + val): machines 0 and 1
source_mask = np.isin(machine_ids, [0, 1])
X_source = X_normalized[source_mask]
y_task_source = y_task[source_mask]
y_domains_source = y_domains[source_mask]

# Target domains (test): machines 2, 3, 4
target_mask = np.isin(machine_ids, [2, 3, 4])
X_target = X_normalized[target_mask]
y_task_target = y_task[target_mask]
y_domains_target = y_domains[target_mask]

# Also create a small within-source test set (10% of source data)
(
    X_source_train_val,
    X_source_test,
    y_task_source_train_val,
    y_task_source_test,
    y_domains_source_train_val,
    y_domains_source_test,
) = train_test_split(
    X_source,
    y_task_source,
    y_domains_source,
    test_size=0.1,
    stratify=y_domains_source[:, 0],  # Stratify by machine within source
    random_state=42,
)

# Split source into train and val (80/20 of remaining source data)
X_train, X_val, y_task_train, y_task_val, y_domains_train, y_domains_val = (
    train_test_split(
        X_source_train_val,
        y_task_source_train_val,
        y_domains_source_train_val,
        test_size=0.2,
        stratify=y_domains_source_train_val[:, 0],
        random_state=42,
    )
)

# Combine within-source and target test sets
X_test_within_source = X_source_test
y_task_test_within_source = y_task_source_test
y_domains_test_within_source = y_domains_source_test

X_test_target = X_target
y_task_test_target = y_task_target
y_domains_test_target = y_domains_target

# Combined test set
X_test = np.vstack([X_test_within_source, X_test_target])
y_task_test = np.vstack([y_task_test_within_source, y_task_test_target])
y_domains_test = np.vstack([y_domains_test_within_source, y_domains_test_target])

print("Domain Generalization Split Strategy:")
print("=" * 80)
print(f"Source domains (machines 0, 1): {source_mask.sum():,} samples")
print(f"Target domains (machines 2, 3, 4): {target_mask.sum():,} samples")
print(f"\nDataset splits:")
print(
    f"  Train:                 {X_train.shape[0]:,} samples ({X_train.shape[0] / len(X) * 100:.1f}%) - Machines 0, 1"
)
print(
    f"  Val:                   {X_val.shape[0]:,} samples ({X_val.shape[0] / len(X) * 100:.1f}%) - Machines 0, 1"
)
print(
    f"  Test (within-source):  {X_test_within_source.shape[0]:,} samples ({X_test_within_source.shape[0] / len(X) * 100:.1f}%) - Machines 0, 1"
)
print(
    f"  Test (target):         {X_test_target.shape[0]:,} samples ({X_test_target.shape[0] / len(X) * 100:.1f}%) - Machines 2, 3, 4"
)
print(
    f"  Test (combined):       {X_test.shape[0]:,} samples ({X_test.shape[0] / len(X) * 100:.1f}%)"
)

# Verify machine distribution
print(f"\nMachine distribution in splits:")
for split_name, split_domains in [
    ("Train", y_domains_train),
    ("Val", y_domains_val),
    ("Test (within-source)", y_domains_test_within_source),
    ("Test (target)", y_domains_test_target),
]:
    machine_counts = np.bincount(split_domains[:, 0])
    print(f"  {split_name:20s}: {dict(enumerate(machine_counts[machine_counts > 0]))}")

## 3. Baseline Model (No Domain Adaptation)

First, let's train a simple logistic regression as a baseline to compare against Dis-AE 2.

In [None]:
# Train baseline logistic regression
baseline_model = LogisticRegression(max_iter=1000, random_state=42)
baseline_model.fit(X_train, y_task_train.ravel())

# Evaluate on test set
y_pred_baseline = baseline_model.predict(X_test)
baseline_acc = accuracy_score(y_task_test, y_pred_baseline)

print("Baseline Model (Logistic Regression):")
print("=" * 60)
print(f"Overall test accuracy: {baseline_acc:.4f}")
print(f"\nClassification Report:")
print(classification_report(y_task_test, y_pred_baseline))

In [None]:
# Per-domain baseline performance
print("\nPer-domain baseline accuracy:")
print("=" * 60)
for domain_idx, domain_name in enumerate(domain_columns):
    print(f"\n{domain_name}:")
    domain_values = np.unique(y_domains_test[:, domain_idx])
    for val in domain_values:
        mask = y_domains_test[:, domain_idx] == val
        if mask.sum() > 0:
            domain_acc = accuracy_score(y_task_test[mask], y_pred_baseline[mask])
            print(f"  Level {val}: {domain_acc:.4f} ({mask.sum()} samples)")

## 4. Train Dis-AE 2 Model

In [None]:
# Initialize Dis-AE 2
num_domains = [
    len(np.unique(y_domains_train[:, i])) for i in range(y_domains_train.shape[1])
]

model = DisAE(
    input_dim=32,
    latent_dim=16,
    shared_dim=8,
    private_dim=8,
    num_tasks=[2],  # Binary classification
    num_domains=num_domains,  # [5, 10, 10]
    hidden_dims=[64, 32],
    reconstruction_weight=0.1,
    adversarial_weight=1.0,
    orthogonality_weight=0.1,
    learning_rate=0.01,
    batch_size=256,
    device="cpu",  # Change to 'cuda' if GPU available
    random_state=42,
)

print("Dis-AE 2 Model Configuration:")
print("=" * 60)
print(f"Input dimension: {model.input_dim}")
print(f"Latent dimension: {model.latent_dim}")
print(f"  - Shared: {model.shared_dim}")
print(f"  - Private: {model.private_dim}")
print(f"Tasks: {model.num_tasks}")
print(f"Domain factors: {model.num_domains}")
print(f"Batch size: {model.batch_size}")

In [None]:
# Train model
print("\nTraining Dis-AE 2...\n")
model.fit(
    X_train,
    y_task_train,
    y_domains_train,
    X_val=X_val,
    y_tasks_val=y_task_val,
    y_domains_val=y_domains_val,
    max_epochs=100,
    early_stopping_patience=16,
    verbose=True,
)

## 5. Evaluate Dis-AE 2 Performance

In [None]:
# Overall test accuracy
y_pred_disae = model.predict_tasks(X_test)[0]
disae_acc = accuracy_score(y_task_test, y_pred_disae)

print("Dis-AE 2 Test Performance:")
print("=" * 60)
print(f"Overall test accuracy: {disae_acc:.4f}")
print(f"Baseline accuracy: {baseline_acc:.4f}")
print(
    f"Improvement: {(disae_acc - baseline_acc):.4f} ({(disae_acc - baseline_acc) / baseline_acc * 100:+.2f}%)"
)
print(f"\nClassification Report:")
print(classification_report(y_task_test, y_pred_disae))

In [None]:
# Confusion matrix comparison
fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

cm_baseline = confusion_matrix(y_task_test, y_pred_baseline)
cm_disae = confusion_matrix(y_task_test, y_pred_disae)

sns.heatmap(cm_baseline, annot=True, fmt="d", cmap="Blues", ax=axes[0])
axes[0].set_title(f"Baseline Confusion Matrix\n(Acc: {baseline_acc:.4f}")
axes[0].set_xlabel("Predicted")
axes[0].set_ylabel("True")

sns.heatmap(cm_disae, annot=True, fmt="d", cmap="Greens", ax=axes[1])
axes[1].set_title(f"Dis-AE 2 Confusion Matrix\n(Acc: {disae_acc:.4f})")
axes[1].set_xlabel("Predicted")
axes[1].set_ylabel("True")

plt.tight_layout()
plt.show()

In [None]:
# Within-Source vs Target Domain Performance
print("\nDomain Generalization Analysis:")
print("=" * 80)

# Baseline performance
y_pred_baseline_within = baseline_model.predict(X_test_within_source)
y_pred_baseline_target = baseline_model.predict(X_test_target)

baseline_acc_within = accuracy_score(y_task_test_within_source, y_pred_baseline_within)
baseline_acc_target = accuracy_score(y_task_test_target, y_pred_baseline_target)

# Dis-AE 2 performance
y_pred_disae_within = model.predict_tasks(X_test_within_source)[0]
y_pred_disae_target = model.predict_tasks(X_test_target)[0]

disae_acc_within = accuracy_score(y_task_test_within_source, y_pred_disae_within)
disae_acc_target = accuracy_score(y_task_test_target, y_pred_disae_target)

print("\nBaseline (Logistic Regression):")
print(f"  Within-source (machines 0, 1): {baseline_acc_within:.4f}")
print(f"  Target domains (machines 2-4):  {baseline_acc_target:.4f}")
print(
    f"  Domain gap:                     {baseline_acc_within - baseline_acc_target:.4f}"
)

print("\nDis-AE 2:")
print(f"  Within-source (machines 0, 1): {disae_acc_within:.4f}")
print(f"  Target domains (machines 2-4):  {disae_acc_target:.4f}")
print(f"  Domain gap:                     {disae_acc_within - disae_acc_target:.4f}")

print("\nDomain Gap Reduction:")
baseline_gap = baseline_acc_within - baseline_acc_target
disae_gap = disae_acc_within - disae_acc_target
gap_reduction = baseline_gap - disae_gap
print(f"  Baseline domain gap:  {baseline_gap:.4f}")
print(f"  Dis-AE 2 domain gap:  {disae_gap:.4f}")
print(
    f"  Gap reduction:        {gap_reduction:.4f} ({gap_reduction / baseline_gap * 100:.1f}%)"
)
print(f"\n✓ Smaller domain gap indicates better domain generalization")

### Per-Domain Performance Comparison

In [None]:
# Compare baseline vs Dis-AE 2 per domain
results = []

for domain_idx, domain_name in enumerate(domain_columns):
    domain_values = np.unique(y_domains_test[:, domain_idx])
    for val in domain_values:
        mask = y_domains_test[:, domain_idx] == val
        if mask.sum() > 0:
            baseline_domain_acc = accuracy_score(
                y_task_test[mask], y_pred_baseline[mask]
            )
            disae_domain_acc = accuracy_score(y_task_test[mask], y_pred_disae[mask])
            results.append(
                {
                    "Domain": domain_name,
                    "Level": val,
                    "Baseline": baseline_domain_acc,
                    "DisAE2": disae_domain_acc,
                    "Improvement": disae_domain_acc - baseline_domain_acc,
                    "Samples": mask.sum(),
                }
            )

results_df = pd.DataFrame(results)
print("\nPer-Domain Accuracy Comparison:")
print("=" * 80)
print(results_df.to_string(index=False))

In [None]:
# Visualise per-domain performance
fig, axes = plt.subplots(1, 3, figsize=(6.6, 2.2))

for idx, domain_name in enumerate(domain_columns):
    domain_data = results_df[results_df["Domain"] == domain_name]

    x = np.arange(len(domain_data))
    width = 0.35

    axes[idx].bar(
        x - width / 2, domain_data["Baseline"], width, label="Baseline", alpha=0.8
    )
    axes[idx].bar(
        x + width / 2, domain_data["DisAE2"], width, label="Dis-AE 2", alpha=0.8
    )

    axes[idx].set_xlabel("Level")
    axes[idx].set_ylabel("Accuracy")
    axes[idx].set_title(f"{domain_name}")
    axes[idx].set_xticks(x)
    axes[idx].set_xticklabels(domain_data["Level"])
    axes[idx].legend()
    # axes[idx].set_ylim([0.48, 1.0])  # Adjust as needed
    axes[idx].grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nMean improvement across all domains: {results_df['Improvement'].mean():.4f}")
print(f"Std of accuracy (Baseline): {results_df['Baseline'].std():.4f}")
print(f"Std of accuracy (Dis-AE 2): {results_df['DisAE2'].std():.4f}")
print(
    "\nWe can see that Dis-AE 2 learns the type of affine shift present on domain instances 0 and 1 (which is only slightly varied in 2 and 3)."
)
print(
    "However, instance 4 is shifted on different features and hence Dis-AE 2 struggles to adapt - even to the point of decreased performance."
)
print(
    "This is typical behaviour in domain generalisation methods, where completely new types of domain shifts can be very challenging."
)

From here onwards, we will take out Machine domain instance 4 to investigate the test set performance more faithfully.

In [None]:
mask_no_machine4 = y_domains_test[:, 0] != 4
X_test = X_test[mask_no_machine4]
y_domains_test = y_domains_test[mask_no_machine4]
y_task_test = y_task_test[mask_no_machine4]

## 6. Domain Invariance Verification

Check if discriminators can predict domain from shared features (should be poor).

In [None]:
# Domain prediction accuracy
d_pred = model.predict_domains(X_test)

print("Domain Prediction Accuracy (Lower = Better Domain Invariance):")
print("=" * 60)
for i, domain_name in enumerate(domain_columns):
    domain_acc = accuracy_score(y_domains_test[:, i], d_pred[i])
    random_chance = 1.0 / model.num_domains[i]
    print(f"\n{domain_name}:")
    print(f"  Accuracy: {domain_acc:.4f}")
    print(f"  Random chance: {random_chance:.4f}")
    print(f"  Ratio: {domain_acc / random_chance:.2f}x random")
    if domain_acc <= random_chance * 1.3:
        print(f"  ✓ Good domain invariance!")
    else:
        print(f"  ⚠ Some domain information may be present")

## 7. Feature Analysis

In [None]:
# Extract embeddings
shared_train, private_train = model.embed(X_train, private=True)
shared_test, private_test = model.embed(X_test, private=True)

print("Embedding Statistics:")
print("=" * 60)
print(f"Shared features: {shared_test.shape}")
print(f"Private features: {private_test.shape}")
print(f"\nShared features (test):")
print(f"  Mean: {shared_test.mean():.4f}")
print(f"  Std: {shared_test.std():.4f}")
print(f"\nPrivate features (test):")
print(f"  Mean: {private_test.mean():.4f}")
print(f"  Std: {private_test.std():.4f}")

### Visualize Shared Features (Task-Relevant, Domain-Invariant)

In [None]:
# Filter out machine 4 and subsample to max 1000 samples per machine for visualization
mask_no_machine4 = y_domains_test[:, 0] != 4

# Subsample each machine to max 1000 samples
np.random.seed(42)
indices_to_keep = []
for machine_id in np.unique(y_domains_test[mask_no_machine4, 0]):
    machine_mask = (y_domains_test[:, 0] == machine_id) & mask_no_machine4
    machine_indices = np.where(machine_mask)[0]

    if len(machine_indices) > 1000:
        # Randomly sample 1000 indices
        sampled_indices = np.random.choice(machine_indices, size=1000, replace=False)
        indices_to_keep.append(sampled_indices)
    else:
        indices_to_keep.append(machine_indices)

indices_to_keep = np.concatenate(indices_to_keep)
indices_to_keep.sort()  # Keep sorted for consistency

shared_test_filtered = shared_test[indices_to_keep]
y_task_test_filtered = y_task_test[indices_to_keep]
y_domains_test_filtered = y_domains_test[indices_to_keep]

print(f"Filtered visualization data: {len(indices_to_keep)} samples")
print(f"Samples per machine: {np.bincount(y_domains_test_filtered[:, 0].astype(int))}")

# we can use PCA or PHATE
USE_PHATE = False

if USE_PHATE:
    import phate

    phate_op = phate.PHATE(n_components=2, random_state=42)
    pca_shared = phate_op.fit_transform(shared_test_filtered)
else:
    # PCA on shared features
    pca_shared = PCA(n_components=2).fit_transform(shared_test_filtered)

# Set axis labels based on method
xlabel = "PHATE1" if USE_PHATE else "PC1"
ylabel = "PHATE2" if USE_PHATE else "PC2"

# Create custom colormaps from seaborn_colors
from matplotlib.colors import ListedColormap

task_colors = seaborn_colors[:2]  # First two colors for binary task
machine_colors = seaborn_colors[2:]  # Remaining colors for machines
task_cmap = ListedColormap(task_colors)
machine_cmap = ListedColormap(machine_colors)

with plt.style.context(["science", "nature", "scatter"]):
    fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

    # Coloured by task
    scatter1 = axes[0].scatter(
        pca_shared[:, 0],
        pca_shared[:, 1],
        c=y_task_test_filtered.ravel(),
        cmap=task_cmap,
        s=5,
        alpha=0.2,
    )
    axes[0].set_title("Shared Features - Coloured by Task")
    axes[0].set_xlabel(xlabel)
    axes[0].set_ylabel(ylabel)
    plt.colorbar(scatter1, ax=axes[0], label="Task Label")

    # Coloured by domain (Machine)
    scatter2 = axes[1].scatter(
        pca_shared[:, 0],
        pca_shared[:, 1],
        c=y_domains_test_filtered[:, 0],
        cmap=machine_cmap,
        s=5,
        alpha=0.2,
    )
    axes[1].set_title("Shared Features - Coloured by Machine")
    axes[1].set_xlabel(xlabel)
    axes[1].set_ylabel(ylabel)
    plt.colorbar(scatter2, ax=axes[1], label="Machine ID")

    plt.tight_layout()  # print("✓ Shared features should show task separation (left) but NOT domain separation (right)")

    plt.show()

### Visualize Private Features (Domain-Specific)

In [None]:
# Filter private features using the same indices
private_test_filtered = private_test[indices_to_keep]

# we can use PCA or PHATE
if USE_PHATE:
    import phate

    phate_op = phate.PHATE(n_components=2, random_state=42)
    pca_private = phate_op.fit_transform(private_test_filtered)

else:
    # PCA on private features
    pca_private = PCA(n_components=2).fit_transform(private_test_filtered)

with plt.style.context(["science", "nature", "scatter"]):
    fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

    # Coloured by task
    scatter1 = axes[0].scatter(
        pca_private[:, 0],
        pca_private[:, 1],
        c=y_task_test_filtered.ravel(),
        cmap=task_cmap,
        s=5,
        alpha=0.2,
    )
    axes[0].set_title("Private Features - Coloured by Task")
    axes[0].set_xlabel(xlabel)
    axes[0].set_ylabel(ylabel)
    plt.colorbar(scatter1, ax=axes[0], label="Task Label")

    # Coloured by domain (Machine)
    scatter2 = axes[1].scatter(
        pca_private[:, 0],
        pca_private[:, 1],
        c=y_domains_test_filtered[:, 0],
        cmap=machine_cmap,
        s=5,
        alpha=0.2,
    )
    axes[1].set_title("Private Features - Coloured by Machine")
    axes[1].set_xlabel(xlabel)
    axes[1].set_ylabel(ylabel)
    plt.colorbar(scatter2, ax=axes[1], label="Machine ID")

    plt.tight_layout()
    plt.show()

# print("✓ Private features can contain domain information for reconstruction")

In [None]:
# Combined visualization: Shared and Private features (without machine 4)
with plt.style.context(["science", "nature", "scatter"]):
    # Calculate figure size: 2 columns of 2.2 each, with some spacing
    fig_width = 2 * 2.2 + 0.3  # Extra space for spacing and labels
    # Calculate height: 2 rows of 2.2 + legend row (taller for stacked legend)
    fig_height = 2 * 2.2 + 1.2  # Extra space for legend row

    fig = plt.figure(figsize=(fig_width, fig_height))

    # Create a 3x2 grid with GridSpec for custom layout
    from matplotlib import gridspec

    gs = gridspec.GridSpec(
        3, 2, figure=fig, height_ratios=[2.2, 2.2, 1.2], hspace=0.15, wspace=0.2
    )

    # Row 0: Shared features
    ax00 = fig.add_subplot(gs[0, 0])
    ax01 = fig.add_subplot(gs[0, 1])

    # Row 1: Private features
    ax10 = fig.add_subplot(gs[1, 0])
    ax11 = fig.add_subplot(gs[1, 1])

    # Row 2: Legends (invisible axes)
    ax20 = fig.add_subplot(gs[2, 0])
    ax21 = fig.add_subplot(gs[2, 1])

    # Plot shared features - colored by task
    scatter00 = ax00.scatter(
        pca_shared[:, 0],
        pca_shared[:, 1],
        c=y_task_test_filtered.ravel(),
        cmap=task_cmap,
        s=5,
        alpha=0.2,
        rasterized=True,
    )
    ax00.set_ylabel(ylabel)
    ax00.set_xticklabels([])  # Remove x-tick labels

    # Plot shared features - colored by machine
    scatter01 = ax01.scatter(
        pca_shared[:, 0],
        pca_shared[:, 1],
        c=y_domains_test_filtered[:, 0],
        cmap=machine_cmap,
        s=5,
        alpha=0.2,
        rasterized=True,
    )
    ax01.set_xticklabels([])  # Remove x-tick labels
    ax01.set_yticklabels([])  # Remove y-tick labels

    # Plot private features - colored by task
    scatter10 = ax10.scatter(
        pca_private[:, 0],
        pca_private[:, 1],
        c=y_task_test_filtered.ravel(),
        cmap=task_cmap,
        s=5,
        alpha=0.2,
        rasterized=True,
    )
    ax10.set_xlabel(xlabel)
    ax10.set_ylabel(ylabel)

    # Plot private features - colored by machine
    scatter11 = ax11.scatter(
        pca_private[:, 0],
        pca_private[:, 1],
        c=y_domains_test_filtered[:, 0],
        cmap=machine_cmap,
        s=5,
        alpha=0.2,
        rasterized=True,
    )
    ax11.set_xlabel(xlabel)
    ax11.set_yticklabels([])  # Remove y-tick labels

    # Add row labels on the right side - vertically centered and closer to plots
    # For shared features (row 0): center between top and middle of figure
    fig.text(
        0.92, 0.73, "Shared features", va="center", ha="left", rotation=90, fontsize=10
    )
    # For private features (row 1): center between middle and bottom of plots
    fig.text(
        0.92, 0.43, "Private features", va="center", ha="left", rotation=90, fontsize=10
    )

    # Create custom legends in row 2
    # Left legend: Task (vertical layout to fit width constraint)
    ax20.axis("off")
    task_labels = np.unique(y_task_test_filtered.ravel())
    task_patches = [
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor=task_colors[i],
            markersize=8,
            alpha=0.7,
        )
        for i in range(len(task_labels))
    ]
    legend_task = ax20.legend(
        task_patches,
        [f"Class {int(label)}" for label in task_labels],
        title="Task",
        loc="upper center",
        bbox_to_anchor=(0.5, 0.8),
        frameon=True,
        ncol=1,  # Vertical layout
    )

    # Right legend: Machine (vertical layout to fit width constraint)
    ax21.axis("off")
    machine_labels = np.unique(y_domains_test_filtered[:, 0])
    machine_patches = [
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor=machine_colors[int(label)],
            markersize=8,
            alpha=0.7,
        )
        for label in machine_labels
    ]
    legend_machine = ax21.legend(
        machine_patches,
        [f"Instance {int(label) + 1}" for label in machine_labels],
        title="Affine domain",
        loc="upper center",
        bbox_to_anchor=(0.5, 0.8),
        frameon=True,
        ncol=1,  # Vertical layout
    )

    plt.savefig("../outputs/disae2_dataB_shared_private_pca.pdf", dpi=350)
    plt.show()

### Feature Orthogonality

In [None]:
# Check orthogonality between shared and private
shared_private_corr = np.corrcoef(
    np.concatenate([shared_test, private_test], axis=1).T
)[: shared_test.shape[1], shared_test.shape[1] :]

plt.figure(figsize=(2.2, 2.2))
sns.heatmap(
    shared_private_corr,
    cmap="coolwarm",
    center=0,
    vmin=-1,
    vmax=1,
    cbar_kws={"label": "Correlation"},
)
plt.title("Shared-Private Feature Correlation\n(Should be close to 0)")
plt.xlabel("Private Feature Index")
plt.ylabel("Shared Feature Index")
plt.tight_layout()
plt.show()

print(f"Mean absolute correlation: {np.abs(shared_private_corr).mean():.4f}")
print(f"Max absolute correlation: {np.abs(shared_private_corr).max():.4f}")

## 8. Reconstruction Quality

In [None]:
# Reconstruction
X_reconstructed = model.reconstruct(X_test)
reconstruction_mse = np.mean((X_test - X_reconstructed) ** 2)

print("Reconstruction Quality:")
print("=" * 60)
print(f"Overall MSE: {reconstruction_mse:.6f}")

# Per-feature reconstruction error
per_feature_mse = np.mean((X_test - X_reconstructed) ** 2, axis=0)

plt.figure(figsize=(4.5, 2.2))
plt.bar(range(len(per_feature_mse)), per_feature_mse, alpha=0.7)
plt.axhline(
    per_feature_mse.mean(),
    color="red",
    linestyle="--",
    label=f"Mean: {per_feature_mse.mean():.6f}",
)
plt.xlabel("Feature Index")
plt.ylabel("MSE")
plt.title("Per-Feature Reconstruction Error")
plt.legend()
plt.tight_layout()
plt.show()

## 9. Save Model

In [None]:
# Save trained model
model_path = "disae2_data_B.pkl"
model.save(model_path)
print(f"✓ Model saved to {model_path}")

# Verify loading
loaded_model = DisAE.load(model_path)
y_pred_loaded = loaded_model.predict_tasks(X_test[:10])[0]
print(f"✓ Model loaded and verified successfully")