# Dis-AE 2: Basic Usage

This notebook demonstrates the basic API usage of Dis-AE 2 (Domain Separation Network Multi-Domain Adversarial Autoencoder) with synthetic data.

## What is Dis-AE 2?

Dis-AE 2 is a **domain-invariant learning** model that:
- Separates features into **shared** (domain-invariant, task-relevant) and **private** (domain-specific, for reconstruction)
- Uses **adversarial training** to prevent domain information from leaking into shared features
- Supports **multiple tasks** and **multiple domain factors** simultaneously
- Provides a **scikit-learn-style API** for easy integration

In [None]:
# Import required libraries
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import seaborn as sns

from sysmexcbctools.disae2.disae2 import DisAE

# Set random seed for reproducibility
np.random.seed(42)

# PDF-compatible fonts
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

# Scientific plot style
import scienceplots

plt.style.use(["science", "nature"])

# Colourblind-friendly palette
SEABORN_PALETTE = "colorblind"
seaborn_colors = sns.color_palette(SEABORN_PALETTE)

## 1. Generate Synthetic Data

We'll create a nonsensical synthetic dataset with:
- **1000 samples** with **32 features**
- **2 tasks**: binary classification and 3-class classification
- **3 domain factors**: machine (5 levels), time (10 levels), delay (10 levels)

In [None]:
from sklearn.datasets import make_classification

# Data dimensions
n_samples = 1000
n_features = 32

# Generate Task 1: Binary classification
X_task1, y_task_1 = make_classification(
    n_samples=n_samples,
    n_features=n_features,
    n_informative=10,
    n_redundant=5,
    n_classes=2,
    n_clusters_per_class=2,
    flip_y=0.1,
    random_state=42,
)

# Generate Task 2: 3-class classification
X_task2, y_task_2 = make_classification(
    n_samples=n_samples,
    n_features=n_features,
    n_informative=8,
    n_redundant=4,
    n_classes=3,
    n_clusters_per_class=1,
    flip_y=0.1,
    random_state=43,
)

# Combine tasks: weighted average of features (shared information)
X = 0.6 * X_task1 + 0.4 * X_task2

# Stack the two classification tasks
y_tasks = np.column_stack([y_task_1, y_task_2])

# Three domain factors
y_domain_1 = np.random.randint(0, 5, size=n_samples)  # 5 machines
y_domain_2 = np.random.randint(0, 10, size=n_samples)  # 10 time bins
y_domain_3 = np.random.randint(0, 10, size=n_samples)  # 10 delay bins
y_domains = np.column_stack([y_domain_1, y_domain_2, y_domain_3])

# Add domain influences to features
# Domain 1 (machine): systematic shift on first 10 features
for machine_id in range(5):
    mask = y_domain_1 == machine_id
    X[mask, :10] += np.random.randn(10) * 0.5 * machine_id

# Domain 2 (time): linear drift on features 10-20
for time_bin in range(10):
    mask = y_domain_2 == time_bin
    X[mask, 10:20] += np.linspace(0, 1, 10) * time_bin * 0.1

# Domain 3 (delay): random noise on features 20-30
for delay_bin in range(10):
    mask = y_domain_3 == delay_bin
    X[mask, 20:30] += np.random.randn(np.sum(mask), 10) * 0.3 * delay_bin / 10

# Split into train and validation
train_size = int(0.8 * n_samples)
X_train = X[:train_size]
y_tasks_train = y_tasks[:train_size]
y_domains_train = y_domains[:train_size]

X_val = X[train_size:]
y_tasks_val = y_tasks[train_size:]
y_domains_val = y_domains[train_size:]

print("Dataset Summary:")
print("=" * 50)
print(f"Training samples: {X_train.shape[0]}")
print(f"Validation samples: {X_val.shape[0]}")
print(f"Features: {X_train.shape[1]}")
print(f"\nTasks: {y_tasks_train.shape[1]}")
print(f"  Task 1: {len(np.unique(y_tasks_train[:, 0]))} classes (binary)")
print(f"  Task 2: {len(np.unique(y_tasks_train[:, 1]))} classes (3-class)")
print(f"\nDomain factors: {y_domains_train.shape[1]}")
print(f"  Domain 1 (machine): {len(np.unique(y_domains_train[:, 0]))} levels")
print(f"  Domain 2 (time): {len(np.unique(y_domains_train[:, 1]))} levels")
print(f"  Domain 3 (delay): {len(np.unique(y_domains_train[:, 2]))} levels")
print(f"\nDomain influences:")
print(f"  Machine: systematic shift on features 0-9")
print(f"  Time: linear drift on features 10-19")
print(f"  Delay: random noise on features 20-29")

## 2. Initialize Dis-AE 2 Model

Key hyperparameters:
- `input_dim`: Number of input features
- `latent_dim`: Total latent dimension (split into shared + private)
- `num_tasks`: List of number of classes per task
- `num_domains`: List of cardinalities per domain factor
- `reconstruction_weight`: Weight for reconstruction loss (helps preserve information)
- `adversarial_weight`: Weight for adversarial domain loss (controls domain invariance)
- `orthogonality_weight`: Weight for orthogonality loss (separates shared/private)

In [None]:
# Initialize model
model = DisAE(
    input_dim=n_features,
    latent_dim=16,  # Total latent dimension
    num_tasks=[2, 3],  # Binary and 3-class tasks
    num_domains=[5, 10, 10],  # Three domain factors
    hidden_dims=[64, 32],  # Hidden layer dimensions
    reconstruction_weight=0.1,
    adversarial_weight=1.0,
    orthogonality_weight=0.1,
    learning_rate=0.01,
    batch_size=128,
    device="cpu",  # Use 'cuda' if GPU available
    random_state=42,
)

print("Model Configuration:")
print("=" * 50)
print(f"Input dimension: {model.input_dim}")
print(f"Latent dimension: {model.latent_dim}")
print(f"  - Shared (domain-invariant): {model.shared_dim}")
print(f"  - Private (domain-specific): {model.private_dim}")
print(f"Tasks: {model.num_tasks}")
print(f"Domain factors: {model.num_domains}")
print(f"Hidden layers: {model.hidden_dims}")
print(f"\nLoss weights:")
print(f"  - Reconstruction: {model.reconstruction_weight}")
print(f"  - Adversarial: {model.adversarial_weight}")
print(f"  - Orthogonality: {model.orthogonality_weight}")

## 3. Train the Model

The model uses:
- **Adversarial training**: Alternates between updating discriminators and generator
- **Early stopping**: Monitors validation loss to prevent overfitting
- **Multiple objectives**: Task classification + reconstruction + adversarial + orthogonality

In [None]:
# Train the model
print("Training Dis-AE 2...\n")
model.fit(
    X_train,
    y_tasks_train,
    y_domains_train,
    X_val=X_val,
    y_tasks_val=y_tasks_val,
    y_domains_val=y_domains_val,
    max_epochs=100,
    early_stopping_patience=16,
    verbose=True,
)

## 4. Task Predictions

The model uses only **shared features** for task prediction.

In [None]:
# Task predictions
y_pred = model.predict_tasks(X_val)
y_proba = model.predict_tasks_proba(X_val)

print("Task Prediction Results:")
print("=" * 50)
for i, (preds, proba) in enumerate(zip(y_pred, y_proba)):
    accuracy = (preds == y_tasks_val[:, i]).mean()
    confidence = proba.max(axis=1).mean()
    print(f"\nTask {i + 1}:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Mean confidence: {confidence:.4f}")
    print(f"  Predictions shape: {preds.shape}")
    print(f"  Probabilities shape: {proba.shape}")

## 5. Domain Predictions (Domain Invariance Check)

**Important**: For good domain generalisation, domain prediction accuracy should be **low** (close to random chance).

This indicates that shared features don't contain domain-specific information.

In [None]:
# Domain predictions
d_pred = model.predict_domains(X_val)
d_proba = model.predict_domains_proba(X_val)

print("Domain Prediction Results (Lower is Better):")
print("=" * 50)
for i, (preds, proba) in enumerate(zip(d_pred, d_proba)):
    accuracy = (preds == y_domains_val[:, i]).mean()
    random_chance = 1.0 / model.num_domains[i]
    print(f"\nDomain Factor {i + 1}:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Random chance: {random_chance:.4f}")
    if accuracy <= random_chance * 1.2:  # Within 20% of random
        print(f"  ✓ Good domain invariance!")
    else:
        print(f"  ⚠ Domain information may be leaking into shared features")

## 6. Feature Embeddings

Extract **shared** and **private** features to analyze feature separation.

In [None]:
# Get embeddings
shared, private = model.embed(X_val, private=True)
shared_only = model.embed(X_val, private=False)

print("Embedding Dimensions:")
print("=" * 50)
print(f"Shared features: {shared.shape}")
print(f"Private features: {private.shape}")
print(f"\nFeature Statistics:")
print(f"Shared - mean: {shared.mean():.4f}, std: {shared.std():.4f}")
print(f"Private - mean: {private.mean():.4f}, std: {private.std():.4f}")

### Visualise Feature Correlations

In [None]:
# Compute correlations
shared_corr = np.corrcoef(shared.T)
private_corr = np.corrcoef(private.T)
shared_private_corr = np.corrcoef(np.concatenate([shared, private], axis=1).T)[
    : shared.shape[1], shared.shape[1] :
]

# Plot correlations
fig, axes = plt.subplots(1, 3, figsize=(6.6, 2.2))

im1 = axes[0].imshow(shared_corr, cmap="coolwarm", vmin=-1, vmax=1)
axes[0].set_title("Shared-Shared Correlation")
axes[0].set_xlabel("Shared Feature")
axes[0].set_ylabel("Shared Feature")
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(private_corr, cmap="coolwarm", vmin=-1, vmax=1)
axes[1].set_title("Private-Private Correlation")
axes[1].set_xlabel("Private Feature")
axes[1].set_ylabel("Private Feature")
plt.colorbar(im2, ax=axes[1])

im3 = axes[2].imshow(shared_private_corr, cmap="coolwarm", vmin=-1, vmax=1)
axes[2].set_title("Shared-Private Correlation\n(Should be low)")
axes[2].set_xlabel("Private Feature")
axes[2].set_ylabel("Shared Feature")
plt.colorbar(im3, ax=axes[2])

plt.tight_layout()
plt.show()

print(f"Mean absolute correlation:")
print(f"  Shared-Shared: {np.abs(shared_corr).mean():.4f}")
print(f"  Private-Private: {np.abs(private_corr).mean():.4f}")
print(f"  Shared-Private: {np.abs(shared_private_corr).mean():.4f}")
print(f"\n✓ Lower Shared-Private correlation indicates better feature separation")

### Visualise Embeddings with PCA

In [None]:
# PCA projection
pca_shared = PCA(n_components=2).fit_transform(shared)
pca_private = PCA(n_components=2).fit_transform(private)

with plt.style.context(["science", "nature", "scatter"]):
    fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

    # Shared features coloured by task
    scatter1 = axes[0].scatter(
        pca_shared[:, 0],
        pca_shared[:, 1],
        c=y_tasks_val[:, 0],
        cmap="viridis",
        s=5,
        alpha=0.2,
    )
    axes[0].set_title("Shared Features (coloured by Task 1)")
    axes[0].set_xlabel("PC1")
    axes[0].set_ylabel("PC2")
    plt.colorbar(scatter1, ax=axes[0], label="Task Label")

    # Private features coloured by domain
    scatter2 = axes[1].scatter(
        pca_private[:, 0],
        pca_private[:, 1],
        c=y_domains_val[:, 0],
        cmap="plasma",
        s=5,
        alpha=0.2,
    )
    axes[1].set_title("Private Features (coloured by Domain 1)")
    axes[1].set_xlabel("PC1")
    axes[1].set_ylabel("PC2")
    plt.colorbar(scatter2, ax=axes[1], label="Domain Label")

    plt.tight_layout()
    plt.show()

## 7. Reconstruction

The model reconstructs inputs using **both** shared and private features.

In [None]:
# Reconstruct data
X_reconstructed = model.reconstruct(X_val)
reconstruction_mse = np.mean((X_val - X_reconstructed) ** 2)

print("Reconstruction Quality:")
print("=" * 50)
print(f"MSE: {reconstruction_mse:.6f}")
print(f"\nPer-feature MSE statistics:")
per_feature_mse = np.mean((X_val - X_reconstructed) ** 2, axis=0)
print(f"  Mean: {per_feature_mse.mean():.6f}")
print(f"  Std: {per_feature_mse.std():.6f}")
print(f"  Min: {per_feature_mse.min():.6f}")
print(f"  Max: {per_feature_mse.max():.6f}")

In [None]:
# Visualize reconstruction
fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

# Original vs reconstructed for first 5 samples
n_samples_to_plot = 5
x_pos = np.arange(n_features)
for i in range(n_samples_to_plot):
    axes[0].plot(
        x_pos, X_val[i], alpha=0.3, label=f"Original {i + 1}" if i == 0 else None
    )
    axes[0].plot(
        x_pos,
        X_reconstructed[i],
        alpha=0.3,
        linestyle="--",
        label=f"Reconstructed {i + 1}" if i == 0 else None,
    )
axes[0].set_title("Original vs Reconstructed (First 5 Samples)")
axes[0].set_xlabel("Feature Index")
axes[0].set_ylabel("Feature Value")
axes[0].legend(["Original", "Reconstructed"])

# Reconstruction error distribution
reconstruction_errors = np.sqrt(np.sum((X_val - X_reconstructed) ** 2, axis=1))
axes[1].hist(reconstruction_errors, bins=30, edgecolor="black", alpha=0.7)
axes[1].set_title("Reconstruction Error Distribution")
axes[1].set_xlabel("L2 Error")
axes[1].set_ylabel("Count")
axes[1].axvline(
    reconstruction_errors.mean(),
    color="red",
    linestyle="--",
    label=f"Mean: {reconstruction_errors.mean():.3f}",
)
axes[1].legend()

plt.tight_layout()
plt.show()

## 8. Model Persistence

Save and load the trained model for later use.

In [None]:
# Save model
model_path = "disae2_basic_example.pkl"
model.save(model_path)
print(f"✓ Model saved to {model_path}")

# Load model
loaded_model = DisAE.load(model_path)
print(f"✓ Model loaded successfully")

# Verify loaded model
y_pred_loaded = loaded_model.predict_tasks(X_val)
for i in range(len(y_pred)):
    assert np.array_equal(
        y_pred[i], y_pred_loaded[i]
    ), f"Task {i} predictions don't match!"
print(f"✓ Loaded model produces identical predictions")

## Summary

This notebook demonstrated:

1. ✓ **Data preparation**: Multi-task, multi-domain dataset
2. ✓ **Model initialization**: Configuring Dis-AE 2 hyperparameters
3. ✓ **Training**: Adversarial training with early stopping
4. ✓ **Task prediction**: Using domain-invariant shared features
5. ✓ **Domain invariance**: Verifying low domain prediction accuracy
6. ✓ **Feature embeddings**: Analyzing shared vs private features
7. ✓ **Reconstruction**: Quality assessment
8. ✓ **Model persistence**: Save/load functionality