In [7]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
from semisupervised import SemiSupervisedAutoEncoderOptions, SemiSupervisedAdversarialAutoencoder

In [8]:
def configure_mnist(batch_size=100):
    # Transform: Just ToTensor (auto 0-1) + flatten
    transform = transforms.Compose([
        transforms.ToTensor(),  # Automatically scales pixels to [0, 1]
        transforms.Lambda(lambda x: x.view(-1))  # Flatten
    ])

    # Load datasets (applies transform automatically)
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    # Get the actual transformed data (0-1 scaled, flattened)
    X_train = torch.stack([x for x, _ in train_dataset])  # Exactly what DataLoader will see
    X_test = torch.stack([x for x, _ in test_dataset])

    Y_train = train_dataset.targets.clone()
    Y_test = test_dataset.targets.clone()

    # DataLoader (will serve same transformed data)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    return X_train, X_test, Y_train, Y_test, train_loader

In [9]:
X_train, X_test, Y_train, Y_test, train_loader = configure_mnist()

print(Y_train.max())
print(Y_train.min())

tensor(9)
tensor(0)


In [10]:
INPUT_DIM = 784
BATCH_SIZE = 100
AE_HIDDEN = 1000
DC_HIDDEN = 1000
LATENT_DIM_CAT = 10
LATENT_DIM_STYLE = 15
PRIOR_STD = 5.0

recon_loss = nn.MSELoss()
init_recon_lr = 0.01

semi_sup_loss = nn.CrossEntropyLoss()
init_semi_sup_lr = 0.01

init_gen_lr = init_disc_lr = 0.1
use_decoder_sigmoid = True

In [11]:
options = SemiSupervisedAutoEncoderOptions(
    input_dim=INPUT_DIM,
    ae_hidden_dim=AE_HIDDEN,
    disc_hidden_dim=DC_HIDDEN,
    latent_dim_categorical=LATENT_DIM_CAT,
    latent_dim_style=LATENT_DIM_STYLE,
    recon_loss_fn=recon_loss,
    init_recon_lr=init_recon_lr,
    semi_supervised_loss_fn=semi_sup_loss,
    init_semi_sup_lr=init_semi_sup_lr,
    init_gen_lr=init_gen_lr,
    use_decoder_sigmoid=use_decoder_sigmoid,
    init_disc_categorical_lr = init_disc_lr,
    init_disc_style_lr = init_disc_lr
)

model = SemiSupervisedAdversarialAutoencoder(options);

In [13]:
model.train(
    data_loader=train_loader,
    epochs=50,
    prior_std=PRIOR_STD,
)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 10)