# Show results of a specific model

In [None]:
MODEL_ID = "2022_03_27-20_51_13"
WIDTH = 8
HEIGHT = 8

LOSS_UNTIL_EPOCH = 20
WINDOW_SIZE = 101

In [None]:
%cd ../..

In [None]:
import os
import pandas as pd
import numpy as np
import csv
import json
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter

from scripts import config
from scripts.tools.get_name import get_model_img_name

sns.set()

In [None]:
def show_results(model_id):
    df_runs = pd.read_csv(config.REMOTE_PARAMS_PATH, index_col=0)
    df_result = df_runs.loc[df_runs['run_id'] == model_id]
    run = df_result.iloc[0]

    pictures_path = os.path.join(config.LOGS_FOLDER, run["run_id"], "pictures")
    epoch_pictures_path = os.path.join(
        pictures_path, f"epoch_{run['epochs_trained_nb']}"
    )
    print(
        f"""Model {run['run_id']} with epoch {run['epochs_trained_nb']}:
    - Architecture: {run['cartoon_gan_architecture']}
    - Pictures size: {run['picture_dataset_new_size']}
    - Learning rate: {run['training_gen_lr']}
    - Crop mode: {run['picture_dataset_crop_mode']}
    - Init generator path: {run['init_gen_path']}

    With losses:
    - Generator loss: {run['train_gen_loss']}
    - Discriminator loss: {run['train_disc_loss']}"""
    )

    # Losses on training set
    i = 1
    X = []
    Y_gen = []
    Y_disc = []
    folder_path = os.path.join(config.LOGS_FOLDER, run['run_id'], "losses")
    file_names = list(filter(lambda name : name[-3:] == "txt" and name != "train_validation.txt", os.listdir(folder_path)))
    file_names.sort(key=lambda name : int(name.split("_")[1].split(".")[0]))
    for file_name in file_names:
        if int(file_name.split("_")[1].split(".")[0]) > LOSS_UNTIL_EPOCH:
            break
        with open(os.path.join(folder_path, file_name)) as csv_file:
            csv_reader = csv.reader(csv_file)
            new_X_values = []
            for row in csv_reader:
                losses = json.loads(row[1].replace("\'", "\""))
                Y_disc.append(losses["disc_loss"])
                Y_gen.append(losses["gen_loss"])
                new_X_values.append(i)
                i += 1
            new_X_values = np.array(new_X_values)/len(new_X_values) + int(file_name.split("_")[1].split(".")[0]) - 1
            X.extend(list(new_X_values))
    plt.figure(figsize=(12, 6))
    plt.plot(X, savgol_filter(Y_gen, WINDOW_SIZE, 3))
    plt.title("Generator loss on training set", fontsize=18, fontweight='bold')
    plt.xlabel("Epoch", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.show()
    plt.figure(figsize=(12, 6))
    plt.plot(X, savgol_filter(Y_disc, WINDOW_SIZE, 3))
    plt.title("Disciminator loss on training set", fontsize=18, fontweight='bold')
    plt.xlabel("Epoch", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.show()

    # Then show the losses on validation set
    with open(os.path.join(config.LOGS_FOLDER, run['run_id'], "losses", "train_validation.txt"), encoding="utf-8") as file:
        csv_reader = csv.reader(file)
        X = list(range(1, LOSS_UNTIL_EPOCH+1))
        Y_disc = []
        Y_gen = []
        for i, row in enumerate(csv_reader):
            if i >= LOSS_UNTIL_EPOCH:
                break
            losses = json.loads(row[1].replace("\'", "\""))
            Y_disc.append(losses["disc_loss"])
            Y_gen.append(losses["gen_loss"])
    plt.figure(figsize=(12, 6))
    plt.plot(X, Y_gen)
    plt.title("Generator loss on validation set", fontsize=18, fontweight='bold')
    plt.xlabel("Epoch", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.show()
    plt.figure(figsize=(12, 6))
    plt.plot(X, Y_disc)
    plt.title("Disciminator loss on validation set", fontsize=18, fontweight='bold')
    plt.xlabel("Epoch", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.show()

    # Finally show cartoonized pictures
    columns = 2
    nb_pictures = len(os.listdir(epoch_pictures_path))//2
    fig = plt.figure(figsize=(WIDTH*columns, HEIGHT*nb_pictures))
    for i in range(nb_pictures):
        picture = plt.imread(os.path.join(epoch_pictures_path, get_model_img_name(i, "picture")))
        cartoon = plt.imread(os.path.join(epoch_pictures_path, get_model_img_name(i, "cartoon")))
        ax = fig.add_subplot(nb_pictures, columns, 2*i+1)
        ax.imshow(picture)
        ax.axis("off")
        ax = fig.add_subplot(nb_pictures, columns, 2*i+2)
        ax.imshow(cartoon)
        ax.axis("off")

In [None]:
show_results(MODEL_ID)