### Imports

In [62]:
from experiments.class_conditioned_alignment import class_conditioned_alignment, class_conditioned_alignment_shared
from experiments.compared_shared_encoder import compare_shared_encoder_alignment
from data.loader import get_dataloader
from models.models import SplitEncoder
from models.models import SplitDecoder
from models.models import LinearProbe, DomainProbe
from utils.aligner import finetune_entropy, finetune_entropy_detach_stabilized, finetune_entropy_detach_usps_contrastive, finetune_domain_adversary
import torch

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_M = torch.load("artifacts/mnist/mnist_pretrained_usage_swap_asym.pt", map_location=device)
ckpt_U = torch.load("artifacts/usps/usps_pretrained_usage_swap_asym.pt", map_location=device)

In [30]:
input_dim = 784
output_dim = 784
latent_dim = 64
signal_dim = 32
num_classes = 10

encoder_M = SplitEncoder(input_dim=input_dim, latent_dim=latent_dim, signal_dim=signal_dim).to(device)
decoder_M = SplitDecoder(latent_dim=latent_dim, output_dim=output_dim).to(device)
probe_M = LinearProbe()

encoder_U = SplitEncoder(input_dim=input_dim, latent_dim=latent_dim, signal_dim=signal_dim).to(device)
decoder_U = SplitDecoder(latent_dim=latent_dim, output_dim=output_dim).to(device)
probe_U = LinearProbe()

In [31]:
encoder_M.load_state_dict(ckpt_M["encoder"])
decoder_M.load_state_dict(ckpt_M["decoder"])
probe_M.load_state_dict(ckpt_M["probe"])

encoder_U.load_state_dict(ckpt_U["encoder"])
decoder_U.load_state_dict(ckpt_U["decoder"])
probe_U.load_state_dict(ckpt_U["probe"])

<All keys matched successfully>

In [32]:
loader_M = get_dataloader("mnist", batch_size=256, train=False)
loader_U = get_dataloader("usps", batch_size=256, train=False)

### Results Prior to Finetuning on USPS

In [33]:
print("== Before finetune ==")
print("\n---- MNIST to USPS ----\n")
_ = compare_shared_encoder_alignment(encoder_M, decoder_M, loader_M, loader_U, device, visualize=False)
_ = class_conditioned_alignment_shared(
    encoder_M, decoder_M, loader_M, loader_U, device
)
print("\n---- USPS to MNIST ----\n")
_ = compare_shared_encoder_alignment(encoder_U, decoder_U, loader_M, loader_U, device, visualize=False)
_ = class_conditioned_alignment_shared(
    encoder_U, decoder_U, loader_U, loader_M, device
)

== Before finetune ==

---- MNIST to USPS ----

[Domain probe on signal] acc=65.40% | AUC=0.743
[Domain probe on nuisance] acc=63.35% | AUC=0.703
[MNIST] Var(z_s)=9.6935e-02  Var(z_n)=2.6360e-03  ‖z_s‖=1.7153  ‖z_n‖=1.7984
[USPS] Var(z_s)=4.5405e-01  Var(z_n)=2.9220e-02  ‖z_s‖=2.5632  ‖z_n‖=1.9410
Δrecon after z_n swap (mean |x−x'|) = 0.009666

=== Shared Encoder Cross-Domain Alignment ===
Signal latent             | CORAL=104.5470 | MMD=0.0622
Nuisance latent           | CORAL=0.5912 | MMD=0.0177
Full latent               | CORAL=118.1416 | MMD=0.0638
Reconstruction            | CORAL=6.7486 | MMD=0.0619
Stationarized recon       | CORAL=8.1889 | MMD=0.0959


Class-conditioned (shared encoder): 100%|██████████| 10/10 [00:00<00:00, 327.37it/s]


=== Class-Conditioned Alignment (Shared Encoder) ===
Class  0: CORAL=64.2534  MMD=0.5302
Class  1: CORAL=1.0718  MMD=0.7130
Class  2: CORAL=5.0532  MMD=0.2616
Class  3: CORAL=2.1416  MMD=0.1121
Class  4: CORAL=1.2950  MMD=0.3983
Class  5: CORAL=10.1082  MMD=0.2501
Class  6: CORAL=2.7917  MMD=0.7233
Class  7: CORAL=1.9773  MMD=0.6664
Class  8: CORAL=0.2154  MMD=0.0136
Class  9: CORAL=1.9150  MMD=0.7679
Avg: CORAL=9.0823  MMD=0.4436

---- USPS to MNIST ----






[Domain probe on signal] acc=83.45% | AUC=0.881
[Domain probe on nuisance] acc=72.25% | AUC=0.809
[MNIST] Var(z_s)=1.5588e-01  Var(z_n)=1.4027e-03  ‖z_s‖=2.7506  ‖z_n‖=1.7973
[USPS] Var(z_s)=4.7786e-01  Var(z_n)=1.5678e-03  ‖z_s‖=3.7494  ‖z_n‖=1.9402
Δrecon after z_n swap (mean |x−x'|) = 0.009551

=== Shared Encoder Cross-Domain Alignment ===
Signal latent             | CORAL=60.3255 | MMD=0.1466
Nuisance latent           | CORAL=0.0261 | MMD=0.0236
Full latent               | CORAL=60.4786 | MMD=0.1449
Reconstruction            | CORAL=13.1704 | MMD=0.1798
Stationarized recon       | CORAL=14.6490 | MMD=0.1841


Class-conditioned (shared encoder): 100%|██████████| 10/10 [00:00<00:00, 226.60it/s]


=== Class-Conditioned Alignment (Shared Encoder) ===
Class  0: CORAL=5.5724  MMD=0.5430
Class  1: CORAL=1.6035  MMD=0.4362
Class  2: CORAL=1.7592  MMD=0.1007
Class  3: CORAL=6.3075  MMD=0.6643
Class  4: CORAL=0.2450  MMD=0.0903
Class  5: CORAL=3.4456  MMD=0.2838
Class  6: CORAL=13.5263  MMD=1.1193
Class  7: CORAL=11.6882  MMD=1.0019
Class  8: CORAL=1.5705  MMD=0.4883
Class  9: CORAL=6.5217  MMD=1.1956
Avg: CORAL=5.2240  MMD=0.5923





### Option 1: Finetuning for reduce MMD and CORAL

In [34]:
print("\n -- Finetuning MNIST encoder-decoder -- \n")
finetune_entropy_detach_stabilized(
    encoder_M, decoder_M, probe_M,
    loader_mnist=loader_M,
    loader_usps=loader_U,
    device=device,
    lambda_cls=1.0, lambda_rec=0.5, lambda_ent=0.3,
    epochs=3, lr=1e-4
)
print("\n -- Finetuning USPS encoder-decoder -- \n")
finetune_entropy_detach_stabilized(
    encoder_U, decoder_M, probe_U,
    loader_mnist=loader_U,
    loader_usps=loader_M,
    device=device,
    lambda_cls=1.0, lambda_rec=0.5, lambda_ent=0.3,
    epochs=3, lr=1e-4
)


 -- Finetuning MNIST encoder-decoder -- 

[Finetune Detach+Stabilized 1/3] Loss=0.7726 Cls=0.1980 Rec=0.7058 Ent=0.6809 Zvar=0.0215 Swap=1.6549
[Finetune Detach+Stabilized 2/3] Loss=0.7602 Cls=0.1914 Rec=0.7014 Ent=0.6698 Zvar=0.0216 Swap=1.6509
[Finetune Detach+Stabilized 3/3] Loss=0.7136 Cls=0.1479 Rec=0.7040 Ent=0.6556 Zvar=0.0255 Swap=1.6440

 -- Finetuning USPS encoder-decoder -- 

[Finetune Detach+Stabilized 1/3] Loss=1.2898 Cls=0.5015 Rec=0.8974 Ent=1.0619 Zvar=0.0004 Swap=1.9725
[Finetune Detach+Stabilized 2/3] Loss=1.2517 Cls=0.4778 Rec=0.8772 Ent=1.0482 Zvar=0.0005 Swap=1.9613
[Finetune Detach+Stabilized 3/3] Loss=1.2165 Cls=0.4672 Rec=0.8727 Ent=0.9741 Zvar=0.0006 Swap=1.9487


In [35]:
print("== After finetune ==")
print("\n---- MNIST to USPS ----\n")
_ = compare_shared_encoder_alignment(encoder_M, decoder_M, loader_M, loader_U, device, visualize=False)
_ = class_conditioned_alignment_shared(
    encoder_M, decoder_M, loader_M, loader_U, device
)
print("\n---- USPS to MNIST ----\n")
_ = compare_shared_encoder_alignment(encoder_U, decoder_U, loader_M, loader_U, device, visualize=False)
_ = class_conditioned_alignment_shared(
    encoder_U, decoder_U, loader_U, loader_M, device
)

== After finetune ==

---- MNIST to USPS ----

[Domain probe on signal] acc=72.70% | AUC=0.837
[Domain probe on nuisance] acc=71.70% | AUC=0.787
[MNIST] Var(z_s)=1.1834e-01  Var(z_n)=3.1371e-03  ‖z_s‖=1.9345  ‖z_n‖=1.8681
[USPS] Var(z_s)=3.4796e-01  Var(z_n)=3.9864e-02  ‖z_s‖=2.4714  ‖z_n‖=2.1703
Δrecon after z_n swap (mean |x−x'|) = 0.009483

=== Shared Encoder Cross-Domain Alignment ===
Signal latent             | CORAL=42.6241 | MMD=0.0751
Nuisance latent           | CORAL=1.0995 | MMD=0.0286
Full latent               | CORAL=53.2866 | MMD=0.0753
Reconstruction            | CORAL=1.8759 | MMD=0.0782
Stationarized recon       | CORAL=3.9723 | MMD=0.1093


Class-conditioned (shared encoder): 100%|██████████| 10/10 [00:00<00:00, 319.80it/s]


=== Class-Conditioned Alignment (Shared Encoder) ===
Class  0: CORAL=29.0384  MMD=0.5740
Class  1: CORAL=0.9650  MMD=0.6700
Class  2: CORAL=2.9797  MMD=0.2327
Class  3: CORAL=1.3839  MMD=0.0890
Class  4: CORAL=1.5518  MMD=0.4134
Class  5: CORAL=4.9386  MMD=0.2273
Class  6: CORAL=1.7885  MMD=0.5697
Class  7: CORAL=2.8351  MMD=0.8257
Class  8: CORAL=0.1838  MMD=0.0208
Class  9: CORAL=3.0581  MMD=0.9442
Avg: CORAL=4.8723  MMD=0.4567

---- USPS to MNIST ----






[Domain probe on signal] acc=88.10% | AUC=0.916
[Domain probe on nuisance] acc=86.55% | AUC=0.926
[MNIST] Var(z_s)=1.9463e-01  Var(z_n)=1.5756e-03  ‖z_s‖=3.1379  ‖z_n‖=1.9779
[USPS] Var(z_s)=5.2235e-01  Var(z_n)=1.8634e-03  ‖z_s‖=3.9428  ‖z_n‖=2.1156
Δrecon after z_n swap (mean |x−x'|) = 0.008084

=== Shared Encoder Cross-Domain Alignment ===
Signal latent             | CORAL=62.5882 | MMD=0.1124
Nuisance latent           | CORAL=0.0303 | MMD=0.0276
Full latent               | CORAL=62.7885 | MMD=0.1112
Reconstruction            | CORAL=11.5004 | MMD=0.1626
Stationarized recon       | CORAL=13.0686 | MMD=0.1671


Class-conditioned (shared encoder): 100%|██████████| 10/10 [00:00<00:00, 334.23it/s]


=== Class-Conditioned Alignment (Shared Encoder) ===
Class  0: CORAL=6.2510  MMD=0.5841
Class  1: CORAL=1.7744  MMD=0.3310
Class  2: CORAL=1.6647  MMD=0.1154
Class  3: CORAL=7.7055  MMD=0.6335
Class  4: CORAL=0.3403  MMD=0.1292
Class  5: CORAL=3.8650  MMD=0.3237
Class  6: CORAL=12.4440  MMD=1.0971
Class  7: CORAL=13.0117  MMD=0.9348
Class  8: CORAL=2.1939  MMD=0.5446
Class  9: CORAL=7.2703  MMD=1.1227
Avg: CORAL=5.6521  MMD=0.5816





### Option 2: Finetuning for domain probe invariance

In [66]:
import torch
from data.loader import get_dataloader
from models.models import SplitEncoder, SplitDecoder, LinearProbe, DomainProbe
from utils.aligner import finetune_entropy, finetune_entropy_detach_usps_contrastive
from experiments.compared_shared_encoder import compare_shared_encoder_alignment

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_dim = 784
output_dim = 784
latent_dim = 64
signal_dim = 32
num_classes = 10

encoder = SplitEncoder(input_dim=input_dim, latent_dim=latent_dim, signal_dim=signal_dim).to(device)
decoder = SplitDecoder(latent_dim=latent_dim, output_dim=output_dim).to(device)
probe = LinearProbe()

# --- load your pretrained MNIST model ---
ckpt = torch.load("artifacts/mnist/mnist_pretrained_usage_swap_asym.pt", map_location=device)
encoder.load_state_dict(ckpt["encoder"])
decoder.load_state_dict(ckpt["decoder"])
probe.load_state_dict(ckpt["probe"])

loader_M = get_dataloader("mnist", batch_size=256, train=False)
loader_U = get_dataloader("usps", batch_size=256, train=False)

print("== Before finetune ==")
_ = compare_shared_encoder_alignment(encoder, decoder, loader_M, loader_U, device, visualize=False)

# --- short finetune ---
finetune_entropy(
    encoder, decoder, probe,
    loader_mnist=loader_M,
    loader_usps=loader_U,
    device=device,
    lambda_cls=1.0, lambda_rec=0.5, lambda_ent=0.3,
    epochs=3, lr=1e-4
)
domain_probe = DomainProbe(in_dim=32).to(device)

finetune_entropy_detach_usps_contrastive(
    encoder, decoder, probe,
    loader_M, loader_M,
    device,
    domain_probe=domain_probe,
    lambda_cls=1.0, lambda_rec=0.5, lambda_ent=0.3, lambda_dom=0.2,
    epochs=3, lr=1e-4
)

print("== After finetune ==")
_ = compare_shared_encoder_alignment(encoder, decoder, loader_M, loader_U, device, visualize=False)

== Before finetune ==
[Domain probe on signal] acc=65.40% | AUC=0.743
[Domain probe on nuisance] acc=63.35% | AUC=0.703
[MNIST] Var(z_s)=9.6935e-02  Var(z_n)=2.6360e-03  ‖z_s‖=1.7153  ‖z_n‖=1.7984
[USPS] Var(z_s)=4.5405e-01  Var(z_n)=2.9220e-02  ‖z_s‖=2.5632  ‖z_n‖=1.9410
Δrecon after z_n swap (mean |x−x'|) = 0.009677

=== Shared Encoder Cross-Domain Alignment ===
Signal latent             | CORAL=104.5470 | MMD=0.0622
Nuisance latent           | CORAL=0.5912 | MMD=0.0177
Full latent               | CORAL=118.1416 | MMD=0.0638
Reconstruction            | CORAL=6.7486 | MMD=0.0619
Stationarized recon       | CORAL=8.1889 | MMD=0.0959
[Finetune 1/3] Loss=0.7552 Cls=0.1980 Rec=0.7059 Ent=0.6808
[Finetune 2/3] Loss=0.7428 Cls=0.1914 Rec=0.7016 Ent=0.6686
[Finetune 3/3] Loss=0.6952 Cls=0.1480 Rec=0.7043 Ent=0.6502
[Finetune Detach+Contrastive 1/3] Loss=9.8578 Cls=2.2220 Rec=9.7483 Ent=1.9595 Dom=10.8686
[Finetune Detach+Contrastive 2/3] Loss=9.5736 Cls=2.1012 Rec=9.5179 Ent=1.8211 Dom=10.83

In [67]:
import torch
from data.loader import get_dataloader
from models.models import SplitEncoder, SplitDecoder, LinearProbe, DomainProbe
from utils.aligner import finetune_entropy, finetune_entropy_detach_usps_contrastive
from experiments.compared_shared_encoder import compare_shared_encoder_alignment

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_dim = 784
output_dim = 784
latent_dim = 64
signal_dim = 32
num_classes = 10

encoder = SplitEncoder(input_dim=input_dim, latent_dim=latent_dim, signal_dim=signal_dim).to(device)
decoder = SplitDecoder(latent_dim=latent_dim, output_dim=output_dim).to(device)
probe = LinearProbe()

# --- load your pretrained MNIST model ---
ckpt = torch.load("artifacts/mnist/mnist_pretrained_usage_swap_asym.pt", map_location=device)
encoder.load_state_dict(ckpt["encoder"])
decoder.load_state_dict(ckpt["decoder"])
probe.load_state_dict(ckpt["probe"])

loader_M = get_dataloader("mnist", batch_size=256, train=False)
loader_U = get_dataloader("usps", batch_size=256, train=False)

print("== Before finetune ==")
_ = compare_shared_encoder_alignment(encoder, decoder, loader_M, loader_U, device, visualize=False)

# --- short finetune ---
finetune_entropy_detach_stabilized(
    encoder, decoder, probe,
    loader_mnist=loader_M,
    loader_usps=loader_U,
    device=device,
    lambda_cls=1.0, lambda_rec=0.5, lambda_ent=0.3,
    epochs=3, lr=1e-4
)
domain_probe = DomainProbe(in_dim=32).to(device)

finetune_entropy_detach_usps_contrastive(
    encoder, decoder, probe,
    loader_M, loader_M,
    device,
    domain_probe=domain_probe,
    lambda_cls=1.0, lambda_rec=0.5, lambda_ent=0.3, lambda_dom=0.2,
    epochs=3, lr=1e-4
)

print("== After finetune ==")
_ = compare_shared_encoder_alignment(encoder, decoder, loader_M, loader_U, device, visualize=False)

== Before finetune ==
[Domain probe on signal] acc=65.40% | AUC=0.743
[Domain probe on nuisance] acc=63.35% | AUC=0.703
[MNIST] Var(z_s)=9.6935e-02  Var(z_n)=2.6360e-03  ‖z_s‖=1.7153  ‖z_n‖=1.7984
[USPS] Var(z_s)=4.5405e-01  Var(z_n)=2.9220e-02  ‖z_s‖=2.5632  ‖z_n‖=1.9410
Δrecon after z_n swap (mean |x−x'|) = 0.009454

=== Shared Encoder Cross-Domain Alignment ===
Signal latent             | CORAL=104.5470 | MMD=0.0622
Nuisance latent           | CORAL=0.5912 | MMD=0.0177
Full latent               | CORAL=118.1416 | MMD=0.0638
Reconstruction            | CORAL=6.7486 | MMD=0.0619
Stationarized recon       | CORAL=8.1889 | MMD=0.0959
[Finetune Detach+Stabilized 1/3] Loss=0.7726 Cls=0.1980 Rec=0.7058 Ent=0.6809 Zvar=0.0215 Swap=1.6549
[Finetune Detach+Stabilized 2/3] Loss=0.7602 Cls=0.1914 Rec=0.7014 Ent=0.6698 Zvar=0.0216 Swap=1.6509
[Finetune Detach+Stabilized 3/3] Loss=0.7136 Cls=0.1479 Rec=0.7040 Ent=0.6556 Zvar=0.0255 Swap=1.6440
[Finetune Detach+Contrastive 1/3] Loss=9.8507 Cls=2.2