# See the results

## Imports

In [4]:
import os
import torch
import matplotlib.pyplot as plt

In [5]:
from training.training_pipeline import TrainingPipeline
from training.components.hyperparameters import Hyperparameters
from training.components.paths import Paths
from training.components.state import State

## Constants

In [6]:
path_to_dataset = "../../execution_data/data/MNIST"
nr_classes = 10
batch_size = 32
target_model_network_type = "Net"
device = "cuda:0"
verbose = True

In [7]:
path_to_load_target_model = "../../execution_data/target_model_0.pth"
path_to_load_gan = "../../execution_data/gan_0.pth"

## See the results

In [8]:
path_to_dataset = os.path.expandvars(path_to_dataset)
path_to_load_target_model = os.path.expandvars(path_to_load_target_model)
path_to_load_gan = os.path.expandvars(path_to_load_gan)

In [None]:
hyperparameters = Hyperparameters(
        batch_size=batch_size,
        total_epochs=None,
        nr_steps_target_model_on_gan=None,
        nr_steps_gan=None,
        nr_step_target_model_alone=None,
        proportion_target_model_alone=None,
        target_model_network_type=target_model_network_type,
        k_fold=1,
        validation_interval=1,
        gan_residual_units_number=1,
        target_model_residual_units_number=1
)

paths = Paths(dataset=path_to_dataset, root_folder="../../execution_data", 
              load_target_model=path_to_load_target_model, load_gan=path_to_load_gan
)

state = State(nr_classes=nr_classes, verbose=verbose, device_name=device)

training_pipeline = TrainingPipeline(
        hyperparameters=hyperparameters,
        paths=paths,
        state=state
)

In [None]:
training_pipeline.epoch = None
training_pipeline.initialize_data_for_new_model()

In [11]:
# training_pipeline.performances_logger.images_plotter.plot_best_and_worst_examples(training_pipeline.data_loaders.validation[0], 0, device)

In [13]:
training_pipeline.networks_data.target_model.eval()

inputs, labels = next(iter(training_pipeline.data_loaders.train[0]))
inputs, labels = inputs.to(training_pipeline.state.device), labels.to(
    training_pipeline.state.device
)
inputs, labels = training_pipeline.modifier((inputs, labels))
_, labels = torch.max(labels, 1)

target_model_outputs = training_pipeline.networks_data.target_model(inputs)

target_model_outputs = torch.nn.functional.softmax(target_model_outputs, dim=1)
score_pred, predicted = torch.max(target_model_outputs, 1)

# if score_pred.ndim > 1:
#    score_pred = score_pred.mean(axis=tuple(range(1, score_pred.ndim)))
#    predicted = predicted.to(torch.float).mean(
#        axis=tuple(range(1, predicted.ndim))
#    )
#    labels = labels.to(torch.float).mean(axis=tuple(range(1, labels.ndim)))
#
# idx_min = torch.argmin(score_pred)
# idx_max = torch.argmax(score_pred)
#
# for idx, name in [(idx_min, "min"), (idx_max, "max")]:#
#    images = inputs[idx].cpu().detach().numpy()
#
# training_pipeline.networks_data.target_model.train()

In [None]:
for j in [1]:
    X = inputs.detach().cpu()
    y = labels.detach().cpu()
    y_hat = predicted.detach().cpu()
    y_conf = score_pred.detach().cpu()

    el = torch.cat([X[j, i, ...][..., None] for i in range(X.shape[1])], axis=-1)
    plt.imshow(el)
    if y.ndim == 1:
        plt.title(int(y[[j]]))
    plt.show()

    if y.ndim == 3:
        plt.imshow(y[j])
        plt.show()

    if y_hat.ndim == 3:
        plt.imshow(y_hat[j])
        plt.show()

    if y_conf.ndim == 3:
        plt.imshow(y_conf[j])
        plt.colorbar()
        plt.show()