# To compare between models

In [None]:
# TO_COMPARE = ["2022_03_27-20_51_13", "2022_03_27-20_51_29", "2022_03_27-20_51_43"] # lr from scratch
# TO_COMPARE = ["2022_03_28-02_01_58", "2022_03_28-02_02_43", "2022_03_28-02_02_38", "2022_03_28-02_02_28", "2022_03_28-18_15_16"] # lr
# TO_COMPARE = ["2022_03_28-18_15_38","2022_03_28-18_23_16","2022_03_29-04_12_27","2022_03_29-04_14_51"] # bs
# TO_COMPARE = ["2022_03_28-02_02_38", "2022_03_29-21_05_24", "2022_03_29-21_05_31"] # crop
TO_COMPARE = ["2022_03_31-18_37_59_gpK","2022_04_04-18_27_57_veL","2022_04_10-00_03_33_Iea","2022_04_09-16_02_23_CES","2022_04_10-00_00_30_jax","2022_04_04-03_19_26_QAj","2022_03_31-18_38_39_AqQ","2022_03_31-18_56_57_xfY","2022_03_28-02_02_38","2022_04_01-04_25_45_DEW","2022_04_01-18_56_08_GUY"] # content loss
# TO_COMPARE = ["2022_03_28-02_02_38", "2022_04_04-03_03_34_FTX", "2022_04_04-03_03_07_KVM", "2022_04_04-02_53_39_YpC"] # blur filter

COMPARE_ON_EPOCH = 50

WITH_ORIGINAL_IMAGE = True

NB_IMAGES = 50

BATCH_SIZE = 5

WIDTH = 8
HEIGHT = 8

WINDOW_SIZE = 101

In [None]:
def get_name(row):
    epoch_nb = min(COMPARE_ON_EPOCH, row["epochs_trained_nb"])
    # return f"{row['run_id']} on {epoch_nb} with lr {row['training_gen_lr']}" # lr
    # return f"{row['run_id']} on {epoch_nb} with batch size {row['training_batch_size']}" # bs
    # return f"{row['run_id']} on {epoch_nb} with crop mode {row['cartoon_dataset_crop_mode']}" # crop
    return f"{row['run_id']} on {epoch_nb} with content loss {row['training_weight_generator_content_loss']}" # content loss
    # return f"{row['run_id']} on {epoch_nb} with smoothing kernel {row['picture_dataset_smoothing_kernel_size']}" # smoothing kernel

In [None]:
# def get_name(row):
#     return f"learning rate = {row['training_gen_lr']}" # lr
#     return f"batch size = {row['training_batch_size']}" # bs
#     return f"crop mode {row['cartoon_dataset_crop_mode']}" # crop
#     return f"content loss weight = {row['training_weight_generator_content_loss']}" # content loss
#     return f"smoothing kernel size = {row['picture_dataset_smoothing_kernel_size']}" # smoothing kernel

In [None]:
%cd ../..

In [None]:
import os
import csv
import json
import pandas as pd
import numpy as np
from ast import literal_eval
from math import ceil
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.signal import savgol_filter

from scripts import config
from scripts.tools.get_name import get_gen_name, get_model_img_name
from scripts.tools.type_dict.csv_to_object import architecture_dict, crop_mode_dict, ratio_filter_dict
from src.pipelines.cartoonizer import Cartoonizer
from src import models, dataset

sns.set()

In [None]:
def get_cartoonizer(df_runs, run_id):
    run = df_runs.loc[df_runs['run_id'] == run_id].iloc[0]
    epoch_nb = min(COMPARE_ON_EPOCH, run["epochs_trained_nb"])
    gen_path = os.path.join(
        config.WEIGHTS_FOLDER,
        run["run_id"],
        get_gen_name(epoch_nb),
    )

    pictures_dataset_parameters = dataset.PicturesDatasetParameters(
        new_size=literal_eval(run["picture_dataset_new_size"]),
        crop_mode=crop_mode_dict[run["picture_dataset_crop_mode"]],
        ratio_filter_mode=ratio_filter_dict[run["picture_dataset_ratio_filter_mode"]],
        nb_images=NB_IMAGES,
    )

    cartoonizer = Cartoonizer(
        infering_parameters=models.InferingParams(batch_size=BATCH_SIZE),
        architecture=architecture_dict[run["cartoon_gan_architecture"]],
        architecture_params=models.ArchitectureParamsNULL(),
        pictures_dataset_parameters=pictures_dataset_parameters,
        gen_path=gen_path,
    )

    return get_name(run), cartoonizer

In [None]:
def get_pictures_path(df_runs, run_id):
    run = df_runs.loc[df_runs['run_id'] == run_id].iloc[0]
    epoch_nb = min(COMPARE_ON_EPOCH, run["epochs_trained_nb"])
    pictures_path = os.path.join(config.LOGS_FOLDER, run_id, "pictures")
    epoch_pictures_path = os.path.join(
        pictures_path, f"epoch_{epoch_nb}"
    )
    return {"name": get_name(run), "path": epoch_pictures_path}

In [None]:
df_runs = pd.read_csv(config.REMOTE_PARAMS_PATH, index_col=0)

# Losses on training set
Y_gen = {}
Y_disc = {}
for run_id in TO_COMPARE:
    folder_path = os.path.join(config.LOGS_FOLDER, run_id, "losses")
    file_names = list(filter(lambda name : name[-3:] == "txt" and name != "train_validation.txt", os.listdir(folder_path)))
    file_names.sort(key=lambda name : int(name.split("_")[1].split(".")[0]))
    run_name = get_name(df_runs.loc[df_runs['run_id'] == run_id].iloc[0])
    Y_gen[run_name] = {"X": [], "Y": []}
    Y_disc[run_name] = {"X": [], "Y": []}
    for file_name in file_names:
        if int(file_name.split("_")[1].split(".")[0]) > COMPARE_ON_EPOCH:
            break
        with open(os.path.join(folder_path, file_name)) as csv_file:
            csv_reader = csv.reader(csv_file)
            new_X_values = []
            for i, row in enumerate(csv_reader):
                try:
                    losses = json.loads(row[1].replace("\'", "\""))
                except:
                    continue
                Y_disc[run_name]["Y"].append(losses["disc_loss"])
                Y_gen[run_name]["Y"].append(losses["gen_loss"])
                new_X_values.append(i)
            new_X_values = np.array(new_X_values)/len(new_X_values) + int(file_name.split("_")[1].split(".")[0]) - 1
            Y_gen[run_name]["X"].extend(list(new_X_values))
            Y_disc[run_name]["X"].extend(list(new_X_values))

plt.figure(figsize=(12, 6))
for name, XY in Y_gen.items():
    plt.plot(XY["X"], savgol_filter(XY["Y"], WINDOW_SIZE, 3), label=name)
plt.title("Generator loss on training set", fontsize=18, fontweight='bold')
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Loss", fontsize=14)
plt.legend(loc='upper right')
plt.show()
plt.figure(figsize=(12, 6))
for name, XY in Y_disc.items():
    plt.plot(XY["X"], savgol_filter(XY["Y"], WINDOW_SIZE, 3), label=name)
plt.title("Disciminator loss on training set", fontsize=18, fontweight='bold')
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Loss", fontsize=14)
plt.legend(loc='upper right')
plt.show()

# Then show the losses on validation set
Y_gen = {}
Y_disc = {}
for run_id in TO_COMPARE:
    file_path = os.path.join(config.LOGS_FOLDER, run_id, "losses", "train_validation.txt")
    run_name = get_name(df_runs.loc[df_runs['run_id'] == run_id].iloc[0])
    Y_gen[run_name] = []
    Y_disc[run_name] = []
    with open(file_path, encoding="utf-8") as file:
        csv_reader = csv.reader(file)
        for i, row in enumerate(csv_reader):
            if i >= COMPARE_ON_EPOCH:
                break
            losses = json.loads(row[1].replace("\'", "\""))
            Y_disc[run_name].append(losses["disc_loss"])
            Y_gen[run_name].append(losses["gen_loss"])

plt.figure(figsize=(12, 6))
for name, Y in Y_gen.items():
    plt.plot(range(1, len(Y)+1), Y, label=name)
plt.title("Generator loss on validation set", fontsize=18, fontweight='bold')
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Loss", fontsize=14)
plt.legend(loc='upper right')
plt.show()
plt.figure(figsize=(12, 6))
for name, Y in Y_disc.items():
    plt.plot(range(1, len(Y)+1), Y, label=name)
plt.title("Disciminator loss on validation set", fontsize=18, fontweight='bold')
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Loss", fontsize=14)
plt.legend(loc='upper right')
plt.show()

In [None]:
df_runs = pd.read_csv(config.REMOTE_PARAMS_PATH, index_col=0)
all_paths = {run_id: get_pictures_path(df_runs, run_id) for run_id in TO_COMPARE}

for i in range(NB_IMAGES):
    nb_models = len(TO_COMPARE)
    cols = 2
    rows = ceil((nb_models+1)/cols) if WITH_ORIGINAL_IMAGE else ceil((nb_models)/cols)
    fig = plt.figure(figsize=(WIDTH*cols, HEIGHT*rows))
    to_add = 1
    if WITH_ORIGINAL_IMAGE:
        picture = plt.imread(os.path.join(all_paths[TO_COMPARE[0]]['path'], get_model_img_name(i, "picture")))
        ax = fig.add_subplot(rows, cols, 1)
        ax.axis("off")
        ax.imshow(picture)
        plt.title("Original picture")
        to_add = 2
    for j in range(len(TO_COMPARE)):
        run_id = TO_COMPARE[j]
        cartoon = plt.imread(os.path.join(all_paths[run_id]['path'], get_model_img_name(i, "cartoon")))
        ax = fig.add_subplot(rows, cols, j+to_add)
        ax.imshow(cartoon)
        plt.title(all_paths[run_id]['name'])
        ax.axis("off")