In [None]:
from variational_autoencoder import VariationalAutoEncoder
from verification_net import VerificationNet
from model_trainer import ModelTrainer
from stacked_mnist import StackedMNIST, DataMode

import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from pathlib import Path

import matplotlib.pyplot as plt

In [None]:
device = torch.device("mps")
torch.mps.set_per_process_memory_fraction(0.)

trainer_file = Path("trainers/vae-basic.pkl")
model_file = Path("models/vae-basic")

In [None]:
class VariationalAutoEncoderTrainer(ModelTrainer):
    def __init__(
            self, 
            model, 
            loss, 
            optimizer,
            device = torch.device("mps"),
            file_name: str | Path = model_file, 
            force_learn: bool = False
        ) -> None:
        super().__init__(model, loss, optimizer, device, file_name, force_learn)

    def get_output_from_batch(self, batch):
        x, _, _ = batch
        x = x.to(self.device)
        (mu, log_var), x_hat = self.model(x)
        return (x_hat, x), (mu, log_var)

In [None]:
latent_space_size = 64

In [None]:
VAE = VariationalAutoEncoder(latent_space_size=latent_space_size)

In [None]:
def loss(X, params):
    x_hat, x = X
    mu, log_var = params
    BCE = F.binary_cross_entropy(x_hat, x, reduction='mean')
    KLD = torch.mean(- 0.5 * torch.mean(1 + log_var - mu.pow(2) - torch.exp(log_var), axis=1))

    return BCE + .02 * KLD

In [None]:
opt = optim.Adam(VAE.parameters(), lr=1e-6)

In [None]:
Trainer = VariationalAutoEncoderTrainer(
        model=VAE, 
        loss=loss, 
        optimizer=opt, 
        file_name=model_file,
        force_learn=False
    )

In [None]:
Trainer = Trainer.load_trainer(trainer_file=trainer_file)

In [None]:
mode = DataMode.MONO | DataMode.BINARY

In [None]:
trainset = StackedMNIST(train=True, mode=mode)
testset = StackedMNIST(train=False, mode=mode)

In [None]:
batch_size = 1
data = DataLoader(trainset, shuffle=True, batch_size=batch_size)
x, _, labels = next(iter(data))

In [None]:
Trainer.force_relearn and Trainer.done_training

In [None]:
plt.plot(Trainer.losses, label="train loss")
plt.plot(Trainer.val_losses, label="val loss")
plt.legend()

In [None]:
train_loader = DataLoader(dataset=trainset, shuffle=True, batch_size=2048)
test_loader = DataLoader(dataset=testset, shuffle=True, batch_size=2048)

In [None]:
Trainer.print_reconstructed_img(trainset, batch_size=16)

In [None]:
VerifNet = VerificationNet(force_learn=False, file_name='models/verification_model_torch_ok_copy')

In [None]:
Trainer.print_class_coverage_and_predictability(VerifNet, dataset=trainset, batch_size=10_000)

In [None]:
Trainer.print_class_coverage_and_predictability(VerifNet, testset, batch_size=10_000)