## Presentation

In [None]:
import sys
sys.path.append('../src')
import torch
from torchvision import transforms
from torchvision.utils import save_image

# custom libraries
from classes.autoencoder.autoencoderCNN import AutoEncoderCNN
from classes.gan.gan import Generator
from classes.imageclassification.classification import MNIST_Classification_Class

from utils.dataset import load_datasets
from utils.presentationPlot import presentation_plot

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
datasettype = "MNIST"

batch_size = 1
z_dim = 100

#### Get Data

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5], [0.5]
        )
    ]
)

dataset_train, loader_train = load_datasets(datasettype,transform, batch_size=batch_size)

#### Load models

In [None]:
AutoEncoder = AutoEncoderCNN(datasettype=datasettype).to(device)
AutoEncoder = torch.load('autoencoder.pth')
Encoder = AutoEncoder.encoder

Gen = Generator(100, z_dim, 28 * 28)
Gen.load_state_dict(torch.load("generator.pth", map_location=device))

# TODO: import Classification
Classificator = MNIST_Classification_Class(input_size= 8 * 4 * 4, hidden_dim=100).to(device)
Classificator = torch.load('classificator.pth')

#### Examples

In [None]:
data_iter = iter(loader_train)
for epoch in range(1, 2):
    # real data
    # batch = next(data_iter)
    # input, label = batch
    # real_encoded = Encoder(input)

    # fake picture
    z = torch.randn(batch_size, z_dim).to(device)
    fake_image = Gen(z)
    fake_image = fake_image.view(fake_image.size(0), 1, 28, 28)
    fake_encoded = Encoder(fake_image)
    
    # classify image (based on decoded)
    classification = Classificator(fake_encoded.view(fake_encoded.size(0), -1))

    presentation_plot(fake_image, fake_encoded.view(fake_encoded.size(0), 1, 8, 16), classification)
    # save_image(fake_image.view(fake_image.size(0), 1, 28, 28), f'{epoch}_gen.png')
    # save_image(input.view(input.size(0), 1, 28, 28), f'{epoch}.png')
    # save_image(fake_encoded.view(fake_encoded.size(0), 1, 8, 16), f'{epoch}_encoded.png')