In [1]:
!export CUDA_VISIBLE_DEVICES=0

In [2]:
import os, torch
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("Built with CUDA:", torch.version.cuda)     
print("CUDA available?:", torch.cuda.is_available())  
print("Device count:", torch.cuda.device_count())

CUDA_VISIBLE_DEVICES: 0
Built with CUDA: 12.4
CUDA available?: True
Device count: 1


In [3]:
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
from semisupervised import SemiSupervisedAutoEncoderOptions, SemiSupervisedAdversarialAutoencoder

In [4]:
from torch.utils.data import random_split

def configure_mnist(batch_size=100, val_size=10000):
    # transform: ToTensor + flatten
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1))
    ])

    # full train + test datasets
    full_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_ds    = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # split full_train → train_ds (60k - val_size) and val_ds (val_size)
    train_size = len(full_train) - val_size
    train_ds, val_ds = random_split(full_train, [train_size, val_size])

    # DataLoaders
    train_loader = DataLoader(train_ds,  batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

    # (Optionally) extract raw tensors:
    X_train = torch.stack([x for x, _ in train_ds])
    Y_train = torch.tensor([y for _, y in train_ds])
    X_val   = torch.stack([x for x, _ in val_ds])
    Y_val   = torch.tensor([y for _, y in val_ds])
    X_test  = torch.stack([x for x, _ in test_ds])
    Y_test  = test_ds.targets.clone()

    return (X_train, X_val, X_test,
            Y_train, Y_val, Y_test,
            train_loader, val_loader, test_loader)

In [6]:
(X_train, X_val, X_test, Y_train, Y_val, Y_test, train_loader, val_loader, test_loader) = configure_mnist()

print(Y_train.max())
print(Y_train.min())

tensor(9)
tensor(0)


In [7]:
INPUT_DIM = 784
BATCH_SIZE = 100
AE_HIDDEN = 1000
DC_HIDDEN = 1000
LATENT_DIM_CAT = 10
LATENT_DIM_STYLE = 15
PRIOR_STD = 5.0

recon_loss = nn.MSELoss()
init_recon_lr = 0.01

semi_sup_loss = nn.CrossEntropyLoss()
init_semi_sup_lr = 0.01

init_gen_lr = init_disc_lr = 0.1
use_decoder_sigmoid = True

In [8]:
options = SemiSupervisedAutoEncoderOptions(
    input_dim=INPUT_DIM,
    ae_hidden_dim=AE_HIDDEN,
    disc_hidden_dim=DC_HIDDEN,
    latent_dim_categorical=LATENT_DIM_CAT,
    latent_dim_style=LATENT_DIM_STYLE,
    recon_loss_fn=recon_loss,
    init_recon_lr=init_recon_lr,
    semi_supervised_loss_fn=semi_sup_loss,
    init_semi_sup_lr=init_semi_sup_lr,
    init_gen_lr=init_gen_lr,
    use_decoder_sigmoid=use_decoder_sigmoid,
    init_disc_categorical_lr = init_disc_lr,
    init_disc_style_lr = init_disc_lr
)

model = SemiSupervisedAdversarialAutoencoder(options);

In [9]:
model.train_mbgd(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=50,
    prior_std=PRIOR_STD,
)

Epoch [1/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1/50 — Recon: 0.2298, Disc_Cat: 1.3773, Gen_Cat: 0.6969, Disc_Style: 0.3763, Gen_Style: 2.9231, SemiSup: 1.0498
Validation Accuracy: 84.99%



Epoch [2/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2/50 — Recon: 0.2241, Disc_Cat: 1.3398, Gen_Cat: 0.7190, Disc_Style: 0.7260, Gen_Style: 2.6286, SemiSup: 0.3729
Validation Accuracy: 89.80%



Epoch [3/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3/50 — Recon: 0.1938, Disc_Cat: 1.3066, Gen_Cat: 0.7924, Disc_Style: 0.9905, Gen_Style: 1.8502, SemiSup: 0.3181
Validation Accuracy: 92.51%



Epoch [4/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 4/50 — Recon: 0.0972, Disc_Cat: 1.3283, Gen_Cat: 0.8337, Disc_Style: 1.1195, Gen_Style: 1.4546, SemiSup: 0.2868
Validation Accuracy: 93.65%



Epoch [5/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 5/50 — Recon: 0.0731, Disc_Cat: 1.3419, Gen_Cat: 0.8188, Disc_Style: 1.2035, Gen_Style: 1.1979, SemiSup: 0.2480
Validation Accuracy: 94.95%



Epoch [6/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 6/50 — Recon: 0.0709, Disc_Cat: 1.3464, Gen_Cat: 0.8106, Disc_Style: 1.2353, Gen_Style: 1.1062, SemiSup: 0.2134
Validation Accuracy: 95.20%



Epoch [7/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 7/50 — Recon: 0.0699, Disc_Cat: 1.3542, Gen_Cat: 0.7944, Disc_Style: 1.2620, Gen_Style: 1.0226, SemiSup: 0.1902
Validation Accuracy: 95.58%



Epoch [8/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 8/50 — Recon: 0.0696, Disc_Cat: 1.3578, Gen_Cat: 0.7852, Disc_Style: 1.2902, Gen_Style: 0.9562, SemiSup: 0.1572
Validation Accuracy: 95.93%



Epoch [9/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 9/50 — Recon: 0.0694, Disc_Cat: 1.3609, Gen_Cat: 0.7771, Disc_Style: 1.2983, Gen_Style: 0.9291, SemiSup: 0.1408
Validation Accuracy: 96.63%



Epoch [10/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 10/50 — Recon: 0.0692, Disc_Cat: 1.3648, Gen_Cat: 0.7675, Disc_Style: 1.3133, Gen_Style: 0.9082, SemiSup: 0.1261
Validation Accuracy: 96.26%



Epoch [11/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 11/50 — Recon: 0.0689, Disc_Cat: 1.3652, Gen_Cat: 0.7630, Disc_Style: 1.3309, Gen_Style: 0.8582, SemiSup: 0.1102
Validation Accuracy: 96.51%



Epoch [12/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 12/50 — Recon: 0.0688, Disc_Cat: 1.3681, Gen_Cat: 0.7550, Disc_Style: 1.3353, Gen_Style: 0.8350, SemiSup: 0.1052
Validation Accuracy: 97.14%



Epoch [13/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 13/50 — Recon: 0.0685, Disc_Cat: 1.3690, Gen_Cat: 0.7529, Disc_Style: 1.3368, Gen_Style: 0.8429, SemiSup: 0.0974
Validation Accuracy: 97.08%



Epoch [14/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 14/50 — Recon: 0.0684, Disc_Cat: 1.3706, Gen_Cat: 0.7482, Disc_Style: 1.3431, Gen_Style: 0.8187, SemiSup: 0.0872
Validation Accuracy: 97.21%



Epoch [15/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 15/50 — Recon: 0.0684, Disc_Cat: 1.3737, Gen_Cat: 0.7399, Disc_Style: 1.3505, Gen_Style: 0.8013, SemiSup: 0.0750
Validation Accuracy: 97.31%



Epoch [16/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 16/50 — Recon: 0.0681, Disc_Cat: 1.3732, Gen_Cat: 0.7378, Disc_Style: 1.3480, Gen_Style: 0.8010, SemiSup: 0.0703
Validation Accuracy: 97.49%



Epoch [17/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 17/50 — Recon: 0.0679, Disc_Cat: 1.3730, Gen_Cat: 0.7373, Disc_Style: 1.3524, Gen_Style: 0.7966, SemiSup: 0.0659
Validation Accuracy: 97.62%



Epoch [18/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 18/50 — Recon: 0.0680, Disc_Cat: 1.3757, Gen_Cat: 0.7317, Disc_Style: 1.3591, Gen_Style: 0.7807, SemiSup: 0.0555
Validation Accuracy: 97.68%



Epoch [19/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 19/50 — Recon: 0.0678, Disc_Cat: 1.3766, Gen_Cat: 0.7285, Disc_Style: 1.3634, Gen_Style: 0.7665, SemiSup: 0.0543
Validation Accuracy: 97.45%



Epoch [20/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 20/50 — Recon: 0.0677, Disc_Cat: 1.3763, Gen_Cat: 0.7270, Disc_Style: 1.3638, Gen_Style: 0.7614, SemiSup: 0.0479
Validation Accuracy: 97.71%



Epoch [21/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 21/50 — Recon: 0.0673, Disc_Cat: 1.3773, Gen_Cat: 0.7250, Disc_Style: 1.3661, Gen_Style: 0.7559, SemiSup: 0.0418
Validation Accuracy: 97.58%



Epoch [22/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 22/50 — Recon: 0.0671, Disc_Cat: 1.3779, Gen_Cat: 0.7229, Disc_Style: 1.3640, Gen_Style: 0.7602, SemiSup: 0.0425
Validation Accuracy: 97.71%



Epoch [23/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 23/50 — Recon: 0.0669, Disc_Cat: 1.3790, Gen_Cat: 0.7190, Disc_Style: 1.3700, Gen_Style: 0.7533, SemiSup: 0.0322
Validation Accuracy: 97.87%



Epoch [24/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 24/50 — Recon: 0.0672, Disc_Cat: 1.3794, Gen_Cat: 0.7177, Disc_Style: 1.3642, Gen_Style: 0.7611, SemiSup: 0.0291
Validation Accuracy: 97.95%



Epoch [25/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 25/50 — Recon: 0.0668, Disc_Cat: 1.3808, Gen_Cat: 0.7150, Disc_Style: 1.3685, Gen_Style: 0.7523, SemiSup: 0.0247
Validation Accuracy: 97.97%



Epoch [26/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 26/50 — Recon: 0.0666, Disc_Cat: 1.3810, Gen_Cat: 0.7133, Disc_Style: 1.3720, Gen_Style: 0.7428, SemiSup: 0.0262
Validation Accuracy: 97.86%



Epoch [27/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 27/50 — Recon: 0.0666, Disc_Cat: 1.3816, Gen_Cat: 0.7121, Disc_Style: 1.3730, Gen_Style: 0.7374, SemiSup: 0.0218
Validation Accuracy: 97.94%



Epoch [28/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 28/50 — Recon: 0.0662, Disc_Cat: 1.3814, Gen_Cat: 0.7113, Disc_Style: 1.3730, Gen_Style: 0.7388, SemiSup: 0.0212
Validation Accuracy: 97.98%



Epoch [29/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 29/50 — Recon: 0.0659, Disc_Cat: 1.3823, Gen_Cat: 0.7084, Disc_Style: 1.3734, Gen_Style: 0.7341, SemiSup: 0.0151
Validation Accuracy: 97.99%



Epoch [30/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 30/50 — Recon: 0.0656, Disc_Cat: 1.3832, Gen_Cat: 0.7073, Disc_Style: 1.3761, Gen_Style: 0.7300, SemiSup: 0.0155
Validation Accuracy: 98.11%



Epoch [31/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 31/50 — Recon: 0.0657, Disc_Cat: 1.3835, Gen_Cat: 0.7062, Disc_Style: 1.3766, Gen_Style: 0.7278, SemiSup: 0.0119
Validation Accuracy: 98.00%



Epoch [32/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 32/50 — Recon: 0.0654, Disc_Cat: 1.3833, Gen_Cat: 0.7059, Disc_Style: 1.3770, Gen_Style: 0.7271, SemiSup: 0.0090
Validation Accuracy: 98.07%



Epoch [33/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 33/50 — Recon: 0.0654, Disc_Cat: 1.3838, Gen_Cat: 0.7040, Disc_Style: 1.3759, Gen_Style: 0.7273, SemiSup: 0.0085
Validation Accuracy: 98.02%



Epoch [34/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 34/50 — Recon: 0.0651, Disc_Cat: 1.3840, Gen_Cat: 0.7036, Disc_Style: 1.3762, Gen_Style: 0.7295, SemiSup: 0.0070
Validation Accuracy: 97.99%



Epoch [35/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 35/50 — Recon: 0.0646, Disc_Cat: 1.3842, Gen_Cat: 0.7031, Disc_Style: 1.3754, Gen_Style: 0.7310, SemiSup: 0.0059
Validation Accuracy: 97.96%



Epoch [36/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 36/50 — Recon: 0.0648, Disc_Cat: 1.3847, Gen_Cat: 0.7020, Disc_Style: 1.3786, Gen_Style: 0.7238, SemiSup: 0.0053
Validation Accuracy: 98.17%



Epoch [37/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 37/50 — Recon: 0.0645, Disc_Cat: 1.3843, Gen_Cat: 0.7018, Disc_Style: 1.3787, Gen_Style: 0.7204, SemiSup: 0.0041
Validation Accuracy: 98.13%



Epoch [38/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 38/50 — Recon: 0.0640, Disc_Cat: 1.3849, Gen_Cat: 0.7007, Disc_Style: 1.3783, Gen_Style: 0.7216, SemiSup: 0.0032
Validation Accuracy: 98.12%



Epoch [39/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 39/50 — Recon: 0.0644, Disc_Cat: 1.3845, Gen_Cat: 0.7006, Disc_Style: 1.3790, Gen_Style: 0.7215, SemiSup: 0.0027
Validation Accuracy: 98.14%



Epoch [40/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 40/50 — Recon: 0.0642, Disc_Cat: 1.3852, Gen_Cat: 0.6997, Disc_Style: 1.3782, Gen_Style: 0.7199, SemiSup: 0.0026
Validation Accuracy: 98.05%



Epoch [41/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 41/50 — Recon: 0.0634, Disc_Cat: 1.3851, Gen_Cat: 0.6994, Disc_Style: 1.3793, Gen_Style: 0.7199, SemiSup: 0.0021
Validation Accuracy: 98.07%



Epoch [42/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 42/50 — Recon: 0.0626, Disc_Cat: 1.3850, Gen_Cat: 0.6993, Disc_Style: 1.3803, Gen_Style: 0.7152, SemiSup: 0.0017
Validation Accuracy: 97.97%



Epoch [43/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 43/50 — Recon: 0.0628, Disc_Cat: 1.3854, Gen_Cat: 0.6985, Disc_Style: 1.3803, Gen_Style: 0.7171, SemiSup: 0.0014
Validation Accuracy: 98.01%



Epoch [44/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 44/50 — Recon: 0.0627, Disc_Cat: 1.3852, Gen_Cat: 0.6986, Disc_Style: 1.3812, Gen_Style: 0.7102, SemiSup: 0.0015
Validation Accuracy: 98.03%



Epoch [45/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 45/50 — Recon: 0.0628, Disc_Cat: 1.3854, Gen_Cat: 0.6982, Disc_Style: 1.3801, Gen_Style: 0.7166, SemiSup: 0.0010
Validation Accuracy: 97.86%



Epoch [46/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 46/50 — Recon: 0.0633, Disc_Cat: 1.3852, Gen_Cat: 0.6978, Disc_Style: 1.3817, Gen_Style: 0.7146, SemiSup: 0.0010
Validation Accuracy: 98.04%



Epoch [47/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 47/50 — Recon: 0.0625, Disc_Cat: 1.3854, Gen_Cat: 0.6976, Disc_Style: 1.3812, Gen_Style: 0.7102, SemiSup: 0.0007
Validation Accuracy: 98.03%



Epoch [48/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 48/50 — Recon: 0.0623, Disc_Cat: 1.3855, Gen_Cat: 0.6974, Disc_Style: 1.3799, Gen_Style: 0.7140, SemiSup: 0.0010
Validation Accuracy: 97.94%



Epoch [49/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 49/50 — Recon: 0.0615, Disc_Cat: 1.3855, Gen_Cat: 0.6973, Disc_Style: 1.3811, Gen_Style: 0.7137, SemiSup: 0.0008
Validation Accuracy: 98.01%



Epoch [50/50]:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 50/50 — Recon: 0.0609, Disc_Cat: 1.3855, Gen_Cat: 0.6969, Disc_Style: 1.3801, Gen_Style: 0.7163, SemiSup: 0.0008
Validation Accuracy: 98.08%



In [11]:
all_probs, all_preds = [], []
for imgs, _ in test_loader:
    probs, preds = model.predict(imgs)
    all_probs.append(probs.cpu())
    all_preds.append(preds.cpu())

all_probs = torch.cat(all_probs, dim=0)
all_preds = torch.cat(all_preds, dim=0)

In [12]:
num_correct = torch.eq(all_preds, Y_test).sum().item()
accuracy = num_correct / Y_test.size(0)
print(f"Test accuracy: {accuracy*100:.2f}%")

Test accuracy: 98.19%
