# To compare between epochs

In [None]:
TO_COMPARE = [5, 10, 20, 30, 40, 50] # bs

RUN_ID = "2022_03_28-02_02_28"

NB_IMAGES = 20

BATCH_SIZE = 5

WIDTH = 8
HEIGHT = 8

In [None]:
def get_name(epoch: int) -> str:
    return f"{RUN_ID} on {epoch}"

In [None]:
%cd ../..

In [None]:
import os
import pandas as pd
from ast import literal_eval
from math import ceil

from scripts import config
from scripts.tools.get_name import get_gen_name, get_model_img_name
from scripts.tools.type_dict.csv_to_object import architecture_dict, crop_mode_dict, ratio_filter_dict
from src.pipelines.cartoonizer import Cartoonizer
from src import models, dataset
import matplotlib.pyplot as plt

In [None]:
def get_cartoonizer(df_runs, epoch):
    run = df_runs.loc[df_runs['run_id'] == RUN_ID].iloc[0]
    gen_path = os.path.join(
        config.WEIGHTS_FOLDER,
        run["run_id"],get_gen_name(epoch)
        
    )

    pictures_dataset_parameters = dataset.PicturesDatasetParameters(
        new_size=literal_eval(run["picture_dataset_new_size"]),
        crop_mode=crop_mode_dict[run["picture_dataset_crop_mode"]],
        ratio_filter_mode=ratio_filter_dict[run["picture_dataset_ratio_filter_mode"]],
        nb_images=NB_IMAGES,
    )

    cartoonizer = Cartoonizer(
        infering_parameters=models.InferingParams(batch_size=BATCH_SIZE),
        architecture=architecture_dict[run["cartoon_gan_architecture"]],
        architecture_params=models.ArchitectureParamsNULL(),
        pictures_dataset_parameters=pictures_dataset_parameters,
        gen_path=gen_path,
    )

    return get_name(epoch), cartoonizer

In [None]:
def get_path(df_runs, epoch):
    run = df_runs.loc[df_runs['run_id'] == RUN_ID].iloc[0]
    pictures_path = os.path.join(config.LOGS_FOLDER, RUN_ID, "pictures")
    epoch_pictures_path = os.path.join(
        pictures_path, f"epoch_{epoch}"
    )
    return {"name": get_name(epoch), "path": epoch_pictures_path}

In [None]:
df_runs = pd.read_csv(config.REMOTE_PARAMS_PATH, index_col=0)
all_paths = {epoch: get_path(df_runs, epoch) for epoch in TO_COMPARE}

max_epoch = df_runs.loc[df_runs['run_id'] == RUN_ID].iloc[0]["epochs_trained_nb"]

for i in range(NB_IMAGES):
    nb_epochs = len(TO_COMPARE)
    cols = 2
    rows = ceil((nb_epochs+1)/cols)
    fig = plt.figure(figsize=(WIDTH*cols, HEIGHT*rows))
    picture = plt.imread(os.path.join(all_paths[TO_COMPARE[0]]['path'], get_model_img_name(i, "picture")))
    ax = fig.add_subplot(rows, cols, 1)
    ax.axis("off")
    ax.imshow(picture)
    plt.title("Original picture")
    for j in range(len(TO_COMPARE)):
        epoch = min(TO_COMPARE[j], max_epoch)
        cartoon = plt.imread(os.path.join(all_paths[epoch]['path'], get_model_img_name(i, "cartoon")))
        ax = fig.add_subplot(rows, cols, j+2)
        ax.imshow(cartoon)
        plt.title(all_paths[epoch]['name'])
        ax.axis("off")
        if epoch == max_epoch:
            break