# Dis-AE 2: Advanced Usage

This notebook covers advanced topics:
1. Hyperparameter tuning and loss weight selection
2. Model persistence and reuse
3. Using embeddings for downstream tasks
4. Analyzing learned representations
5. Best practices and troubleshooting

In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA
from sysmexcbctools.disae2.disae2 import DisAE

np.random.seed(42)

# PDF-compatible fonts
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

# Scientific plot style
import scienceplots

plt.style.use(["science", "nature"])

# Colourblind-friendly palette
SEABORN_PALETTE = "colorblind"
seaborn_colors = sns.color_palette(SEABORN_PALETTE)

## 1. Load and Prepare Data

We'll use data_B.csv for these examples.

In [None]:
# Load data
df = pd.read_csv("../data/data_B.csv")

# Extract features and labels
X = df[[str(i) for i in range(32)]].values
y_task = df["ClassCategory_0"].values.reshape(-1, 1)
y_domains = df[["Machine", "vendelay_binned", "studytime_binned"]].values.astype(int)

# Normalize
scaler = StandardScaler()
X = scaler.fit_transform(X)

# Split data, this time let's keep all domains in train and test set.
X_train, X_test, y_task_train, y_task_test, y_domains_train, y_domains_test = (
    train_test_split(
        X, y_task, y_domains, test_size=0.2, stratify=y_domains[:, 0], random_state=42
    )
)

X_train, X_val, y_task_train, y_task_val, y_domains_train, y_domains_val = (
    train_test_split(
        X_train,
        y_task_train,
        y_domains_train,
        test_size=0.2,
        stratify=y_domains_train[:, 0],
        random_state=42,
    )
)

print(
    f"Train: {X_train.shape[0]:,}, Val: {X_val.shape[0]:,}, Test: {X_test.shape[0]:,}"
)

## 2. Hyperparameter Tuning: Loss Weights

The three key loss weights control the tradeoff between objectives:
- `reconstruction_weight`: Higher = better reconstruction, may preserve more domain info
- `adversarial_weight`: Higher = stronger domain invariance, may hurt task performance
- `orthogonality_weight`: Higher = more separation between shared/private features

### Experiment: Varying Adversarial Weight

In [None]:
# Try different adversarial weights
adv_weights = [0.1, 0.5, 1.0, 2.0, 5.0]
results = []

if torch.cuda.is_available():  # NVIDIA GPU preferred
    device = "cuda"
elif torch.backends.mps.is_available():  # Apple Silicon GPU if running on Mac
    device = "mps"
else:
    device = "cpu"  # CPU fallback

for adv_weight in adv_weights:
    print(f"\nTraining with adversarial_weight={adv_weight}...")

    model = DisAE(
        input_dim=32,
        latent_dim=16,
        num_tasks=[2],
        num_domains=[5, 10, 10],
        hidden_dims=[64, 32],
        reconstruction_weight=0.1,
        adversarial_weight=adv_weight,  # Vary this
        orthogonality_weight=0.1,
        learning_rate=0.01,
        batch_size=256,
        device=device,
        random_state=42,
    )

    model.fit(
        X_train,
        y_task_train,
        y_domains_train,
        X_val=X_val,
        y_tasks_val=y_task_val,
        y_domains_val=y_domains_val,
        max_epochs=50,
        early_stopping_patience=10,
        verbose=False,
    )

    # Evaluate
    y_pred = model.predict_tasks(X_test)[0]
    task_acc = accuracy_score(y_task_test, y_pred)

    d_pred = model.predict_domains(X_test)
    domain_acc = np.mean(
        [accuracy_score(y_domains_test[:, i], d_pred[i]) for i in range(3)]
    )

    results.append(
        {
            "adversarial_weight": adv_weight,
            "task_accuracy": task_acc,
            "domain_accuracy": domain_acc,
        }
    )

    print(f"  Task accuracy: {task_acc:.4f}")
    print(f"  Domain accuracy: {domain_acc:.4f} (lower is better)")

results_df = pd.DataFrame(results)
print("\n" + "=" * 60)
print(results_df.to_string(index=False))

In [None]:
# Visualise tradeoff
fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

# Task accuracy vs adversarial weight
axes[0].plot(
    results_df["adversarial_weight"],
    results_df["task_accuracy"],
    marker="o",
    linewidth=2,
    markersize=8,
)
axes[0].set_xlabel("Adversarial Weight")
axes[0].set_ylabel("Task Accuracy")
axes[0].set_title("Task Performance vs Adversarial Weight")
axes[0].grid(True, alpha=0.3)

# Domain accuracy vs adversarial weight
axes[1].plot(
    results_df["adversarial_weight"],
    results_df["domain_accuracy"],
    marker="o",
    linewidth=2,
    markersize=8,
    color="coral",
)
axes[1].axhline(
    1.0 / 5, color="red", linestyle="--", label="Random (Machine)", alpha=0.5
)
axes[1].axhline(
    1.0 / 10, color="orange", linestyle="--", label="Random (Delay/Time)", alpha=0.5
)
axes[1].set_xlabel("Adversarial Weight")
axes[1].set_ylabel("Domain Accuracy (Lower=Better)")
axes[1].set_title("Domain Invariance vs Adversarial Weight")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(
    "\n✓ Higher adversarial weight → Better domain invariance but may hurt task accuracy"
)
print("✓ Choose weight that balances both objectives for your use case")

## 3. Latent Dimension Selection

**Important**: Due to the orthogonality loss implementation (element-wise dot product), `shared_dim` must equal `private_dim`.

We can experiment with different total latent dimensions (always split equally).

In [None]:
# Try different total latent dimensions (always equal split)
latent_dims = [8, 16, 32, 64]
dim_results = []

for latent_dim in latent_dims:
    shared_dim = latent_dim // 2
    private_dim = latent_dim // 2

    print(
        f"\nTraining with latent_dim={latent_dim} (shared={shared_dim}, private={private_dim})..."
    )

    model = DisAE(
        input_dim=32,
        latent_dim=latent_dim,  # Automatically splits equally
        num_tasks=[2],
        num_domains=[5, 10, 10],
        hidden_dims=[64, 32],
        reconstruction_weight=0.1,
        adversarial_weight=1.0,
        orthogonality_weight=0.1,
        learning_rate=0.01,
        batch_size=256,
        device=device,
        random_state=42,
    )

    model.fit(
        X_train,
        y_task_train,
        y_domains_train,
        X_val=X_val,
        y_tasks_val=y_task_val,
        y_domains_val=y_domains_val,
        max_epochs=50,
        early_stopping_patience=10,
        verbose=False,
    )

    y_pred = model.predict_tasks(X_test)[0]
    task_acc = accuracy_score(y_task_test, y_pred)

    X_recon = model.reconstruct(X_test)
    recon_mse = np.mean((X_test - X_recon) ** 2)

    dim_results.append(
        {
            "latent_dim": latent_dim,
            "shared_dim": shared_dim,
            "private_dim": private_dim,
            "task_accuracy": task_acc,
            "reconstruction_mse": recon_mse,
        }
    )

    print(f"  Task accuracy: {task_acc:.4f}")
    print(f"  Reconstruction MSE: {recon_mse:.6f}")

dim_results_df = pd.DataFrame(dim_results)
print("\n" + "=" * 70)
print(dim_results_df.to_string(index=False))

print("\n✓ Larger latent_dim → More capacity for both task and domain information")
print("✓ Too small → May not capture sufficient information")
print("✓ Too large → Risk of overfitting, slower training")

## 4. Model Persistence: Save, Load, and Reuse

Train once, use many times!

In [None]:
# Train a model
print("Training model...")
model = DisAE(
    input_dim=32,
    latent_dim=16,
    num_tasks=[2],
    num_domains=[5, 10, 10],
    hidden_dims=[64, 32],
    reconstruction_weight=0.1,
    adversarial_weight=1.0,
    orthogonality_weight=0.1,
    learning_rate=0.01,
    batch_size=256,
    device=device,
    random_state=42,
)

model.fit(
    X_train,
    y_task_train,
    y_domains_train,
    X_val=X_val,
    y_tasks_val=y_task_val,
    y_domains_val=y_domains_val,
    max_epochs=50,
    early_stopping_patience=10,
    verbose=False,
)

print("✓ Model trained")

# Save model
model_path = "disae2_trained_model.pkl"
model.save(model_path)
print(f"✓ Model saved to {model_path}")

In [None]:
# Load model in a "fresh session"
print("Loading model from disk...")
loaded_model = DisAE.load(model_path)
print("✓ Model loaded successfully")

# Verify it works
y_pred_original = model.predict_tasks(X_test)[0]
y_pred_loaded = loaded_model.predict_tasks(X_test)[0]

assert np.array_equal(y_pred_original, y_pred_loaded), "Predictions don't match!"
print("✓ Loaded model produces identical predictions")

# Check model parameters
print(f"\nLoaded model configuration:")
print(f"  Input dim: {loaded_model.input_dim}")
print(f"  Shared dim: {loaded_model.shared_dim}")
print(f"  Private dim: {loaded_model.private_dim}")
print(f"  Num tasks: {loaded_model.num_tasks}")
print(f"  Num domains: {loaded_model.num_domains}")

## 5. Using Embeddings for Downstream Tasks

Extract shared features and use them as input to other models.

In [None]:
# Extract embeddings
shared_train = loaded_model.embed(X_train, private=False)
shared_test = loaded_model.embed(X_test, private=False)

print(f"Extracted shared features:")
print(f"  Train: {shared_train.shape}")
print(f"  Test: {shared_test.shape}")

# Use shared features for classification
print("\nTraining logistic regression on shared features...")
lr_on_shared = LogisticRegression(max_iter=1000, random_state=42)
lr_on_shared.fit(shared_train, y_task_train.ravel())

y_pred_shared = lr_on_shared.predict(shared_test)
acc_shared = accuracy_score(y_task_test, y_pred_shared)

# Compare with logistic regression on original features
lr_on_original = LogisticRegression(max_iter=1000, random_state=42)
lr_on_original.fit(X_train, y_task_train.ravel())

y_pred_original = lr_on_original.predict(X_test)
acc_original = accuracy_score(y_task_test, y_pred_original)

print(f"\nResults:")
print(f"  LR on shared features: {acc_shared:.4f}")
print(f"  LR on original features: {acc_original:.4f}")
print(f"\n✓ Shared features maintain task information while being domain-invariant")

## 6. Analyzing What the Model Learned

### Feature Importance via Reconstruction

In [None]:
# Check which features are well-reconstructed
X_recon = loaded_model.reconstruct(X_test)
per_feature_mse = np.mean((X_test - X_recon) ** 2, axis=0)

# Plot feature reconstruction error
fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

axes[0].bar(range(32), per_feature_mse)
axes[0].set_xlabel("Feature Index")
axes[0].set_ylabel("Reconstruction MSE")
axes[0].set_title("Per-Feature Reconstruction Error")
axes[0].axhline(per_feature_mse.mean(), color="red", linestyle="--", alpha=0.5)

# Sort and show top 10 worst-reconstructed features
worst_features = np.argsort(per_feature_mse)[-10:]
axes[1].barh(range(10), per_feature_mse[worst_features])
axes[1].set_yticks(range(10))
axes[1].set_yticklabels([f"Feature {i}" for i in worst_features])
axes[1].set_xlabel("Reconstruction MSE")
axes[1].set_title("10 Worst-Reconstructed Features")
axes[1].invert_yaxis()

plt.tight_layout()
plt.show()

print(
    "✓ High reconstruction error may indicate features with high domain-specific variance"
)

### Visualizing What Shared Features Encode

In [None]:
# Get shared features
shared_test = loaded_model.embed(X_test, private=False)

# PCA to 2D
pca = PCA(n_components=2)
shared_2d = pca.fit_transform(shared_test)

print(f"PCA explained variance: {pca.explained_variance_ratio_.sum():.3f}")

# Plot with different colorings
with plt.style.context(["science", "nature", "scatter"]):
    fig, axes = plt.subplots(1, 3, figsize=(6.6, 2.2))

    # By task
    scatter1 = axes[0].scatter(
        shared_2d[:, 0],
        shared_2d[:, 1],
        c=y_task_test.ravel(),
        cmap="viridis",
        s=5,
        alpha=0.2,
    )
    axes[0].set_title("Shared Features by Task")
    axes[0].set_xlabel("PC1")
    axes[0].set_ylabel("PC2")
    plt.colorbar(scatter1, ax=axes[0])

    # By machine
    scatter2 = axes[1].scatter(
        shared_2d[:, 0],
        shared_2d[:, 1],
        c=y_domains_test[:, 0],
        cmap="Set1",
        s=5,
        alpha=0.2,
    )
    axes[1].set_title("Shared Features by Machine")
    axes[1].set_xlabel("PC1")
    axes[1].set_ylabel("PC2")
    plt.colorbar(scatter2, ax=axes[1])

    # By study time
    scatter3 = axes[2].scatter(
        shared_2d[:, 0],
        shared_2d[:, 1],
        c=y_domains_test[:, 2],
        cmap="plasma",
        s=5,
        alpha=0.2,
    )
    axes[2].set_title("Shared Features by Study Time")
    axes[2].set_xlabel("PC1")
    axes[2].set_ylabel("PC2")
    plt.colorbar(scatter3, ax=axes[2])

    plt.tight_layout()
    plt.show()

print("✓ Left plot should show separation (task-relevant)")
print("✓ Middle/right plots should show mixing (domain-invariant)")

## 7. Best Practices and Tips

### Important Constraint

⚠️ **shared_dim must equal private_dim**: The orthogonality loss uses element-wise dot product, requiring equal dimensions. Use `latent_dim` parameter for automatic equal splitting.

### Loss Weight Guidelines

| Weight | Low (0.01-0.1) | Medium (0.1-1.0) | High (1.0-10.0) |
|--------|---------------|------------------|----------------|
| **Reconstruction** | Domain invariance priority | Balanced | Reconstruction priority |
| **Adversarial** | Task accuracy priority | Balanced | Strong domain invariance |
| **Orthogonality** | Allow feature mixing | Balanced | Strict feature separation |

### When to Use Dis-AE 2

✓ Multiple domain factors (machines, batches, sites)  
✓ Domain shift is a concern  
✓ Need interpretable features (shared = task, private = domain)  
✓ Want to understand domain effects  

### Troubleshooting

**Poor task accuracy:**
- Reduce `adversarial_weight`
- Increase `latent_dim` (more capacity)
- Train longer (increase `max_epochs`)

**Domain information leaking:**
- Increase `adversarial_weight`
- Increase `orthogonality_weight`
- Use more discriminator updates (`d_steps_per_g_step`)

**Poor reconstruction:**
- Increase `reconstruction_weight`
- Increase `latent_dim` (more capacity)
- Check if data normalization is appropriate

### Data Preparation Checklist

1. ✓ Normalize/standardize features
2. ✓ Encode labels as integers (0, 1, 2, ...)
3. ✓ Use train/val split for early stopping
4. ✓ Stratify by domain when splitting
5. ✓ Check for class imbalance