In [None]:
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as tck
from pathlib import Path

colors = ["#E69F00" ,"#009E73" ,"#0072B2" ,"#999999", "#56B4E9", "#F0E442", "#CC79A7", "#D55E00"]
markers = ["v", "x", "o", "^", "s", "<", ">", "8"]

plt.rcParams.update({
        'font.size': '18',
        'svg.fonttype': 'none'
})

plt.rc('axes', axisbelow=True)

%config Completer.use_jedi = False

In [None]:
def make_line_chart(x,y, labels, xlabel="", ylabel="", name="result.png", legend_location="best", ncol=4, columnspacing=2):
    linewidth = 2
    fig = plt.subplots(figsize =(6, 2.5))

    plt.minorticks_on()
    
    plt.grid(color='lightgrey', linestyle='-', linewidth=1, which="minor")
    plt.grid(color='grey', linestyle='-', linewidth=1, which="major")
    
    for x_values, y_values, label, color in zip(x, y, labels, colors):
        plt.plot(x_values, y_values, label=label, linewidth=linewidth, color=color)
        
        
    plt.xlabel(xlabel)
    
    plt.ylabel(ylabel)
    
    handles, labels = plt.gca().get_legend_handles_labels()
    # sort both labels and handles by labels
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: int(t[0]) if t[0].isnumeric() else math.inf))
    if legend_location == "upper center":
        ncol = ncol
        height = (len(lables) / ncol + 1) * 0.2 + 1
        plt.legend(handles, labels, frameon=False, ncol=ncol, loc="upper center", bbox_to_anchor=(0.5,height), fontsize=14 ,labelspacing=1, columnspacing=columnspacing)
    else:        
        plt.legend(handles, labels, frameon=False, loc="lower right", bbox_to_anchor=(0.85,0), fontsize=10 ,labelspacing=1, columnspacing=columnspacing)
    
    plt.savefig(name, bbox_inches='tight')
    plt.show()
    

In [None]:
def add_labels(x,y, precision):
    overset = 0.1 * max(y)
    for i in range(len(x)):
        if precision == -1:
            plt.text(i + 0.25, y[i] + overset, round(y[i]), ha = 'center')
        else:
            plt.text(i + 0.25, y[i] + overset, round(y[i], precision), ha = 'center')

In [None]:
def make_bar_chart(x,y, xlabel="", ylabel="", name="result.png", lables=True, lableprecision=-1):
    barWidth = 0.25
    fig, ax = plt.subplots(figsize =(6, 2.5))
    
    plt.minorticks_on()
    
    plt.grid(color='lightgrey', linestyle='-', linewidth=1, which="minor", axis="y")
    plt.grid(color='grey', linestyle='-', linewidth=1, which="major", axis="y")
        
    y = [val for _,val in sorted(zip(x,y))]
    x.sort()
    
    x = [str(int(i)) if not pd.isna(i) else "Never" for i in x]

    
    num_elements_in_x = len(x)

    # Set position of bar on X axis\n",
    br1 = list(map(lambda x: x + barWidth, np.arange(num_elements_in_x)))

    # Make the plot\n",
    plt.bar(br1, y, color=colors, width=barWidth,edgecolor ='black')
    
    if lables:
        add_labels(br1, y, lableprecision)
    
    # Adding Xticks\n",
    plt.xlabel(xlabel)
    plt.xticks([r + barWidth for r in range(num_elements_in_x)], x, minor=False)
    plt.gca().xaxis.set_minor_locator(tck.NullLocator())
    
    plt.ylabel(ylabel)

    ylim_max = max(y) * 1.3
    plt.gca().set_ylim(None, ylim_max)

    plt.savefig(name, bbox_inches='tight')
    plt.show()
    

In [None]:
def combine_column_and_file_name(column, file):
    return f'{file}_{column}'

In [None]:
def get_data(columns: list[str]):
    base_dir = Path('../experiment/final')
    csv_files = base_dir.glob('*.csv')
    
    dfs = {file.name : pd.read_csv(file) for file in csv_files}
    
    res_df = pd.DataFrame()
    for file_name, df in dfs.items():
        for column in columns:
            if column in df.columns:
                res_df[combine_column_and_file_name(file_name, column)] = df[column]
    return res_df

In [None]:
data = get_data(["validate.BSD100.psnr_scale_2", "Epoch"])
filtered_x = data.filter(regex=("Epoch_rfdn.*_600.csv")).head(300).transpose()
filtered_y = data.filter(regex=("validate\.BSD100.psnr_scale_2.*_600.csv")).head(300).transpose()
lables = ["RFDN", "RFDN + RepVgg - Batch Norm", "RFDN + RepVgg"]
make_line_chart(filtered_x.values, filtered_y.values, lables, "Epoch", "BSD100 PSNR", "epoch_psnr_scale.svg", legend_location="upper center", ncol=2, columnspacing=-10)

In [None]:
data = get_data(["validate.mean_forward_pass_time", "config.pruning_interval"])
filtered_y = data.filter(regex=("validate\.mean_forward_pass_time_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).mean()
filtered_x = data.filter(regex=("config\.pruning_interval_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).head(1)
make_bar_chart(filtered_x.values[0], filtered_y.values * 1000, "Pruning Interval", "Time [ms]", "pruning_interval_average_validation_time.svg", True, lableprecision=1)

In [None]:
data = get_data(["num_parameters", "config.pruning_interval"])
filtered_y = data.filter(regex=("num_parameters_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).min()
filtered_x = data.filter(regex=("config\.pruning_interval_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).head(1)
make_bar_chart(filtered_x.values[0], filtered_y.values / 1000, "Pruning Interval", "Params [in 1000]", "pruning_interval_final_number_of_parameters.svg", True, lableprecision=-1)

In [None]:
data = get_data(["num_parameters", "Epoch"])
filtered_x = data.filter(regex=("Epoch_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).head(600).transpose()
filtered_y = data.filter(regex=("num_parameters_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).head(600).transpose()
lables = [512, 256, 128, 64, 32, 16, 8, "Never"]
make_line_chart(filtered_x.values, filtered_y.values / 1000, lables, "Epoch", "Params [\# in 1000]", "pruning_interval_number_of_parameters.svg", legend_location="upper center")

In [None]:
data = get_data(["validate.BSD100.psnr_scale_2", "Epoch"])
filtered_x = data.filter(regex=("Epoch_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).head(600).transpose()
filtered_y = data.filter(regex=("validate\.BSD100.psnr_scale_2_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).head(600).transpose()
lables = [512, 256, 128, 64, 32, 16, 8, "Never"]
make_line_chart(filtered_x.values, filtered_y.values, lables, "Epoch", "BSD100 PSNR", "pruning_interval_BSD100_psnr.svg", legend_location="upper center")

In [None]:
data = get_data(["validate.Urban100.psnr_scale_2", "Epoch"])
filtered_x = data.filter(regex=("Epoch_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).head(600).transpose()
filtered_y = data.filter(regex=("validate\.Urban100.psnr_scale_2_rfdn_advanced_600_epochs_no_batchnorm_pruning_([0-9]*|none).csv")).head(600).transpose()
lables = [512, 256, 128, 64, 32, 16, 8, "Never"]
make_line_chart(filtered_x.values, filtered_y.values, lables, "Epoch", "Urban100 PSNR", "pruning_interval_Urban100_psnr.svg", legend_location="upper center")

In [None]:
data = get_data(["test.BSD100.psnr_scale_2", "config.batch_size_test"])
filtered_y = data.filter(regex=("test\.BSD100\.psnr_scale_2_rfdn_advanced_batch_size_test.*")).head(1)
filtered_x = data.filter(regex=("config\.batch_size_test_rfdn_advanced_batch_size_test.*")).head(1)
make_bar_chart(filtered_x.values[0], filtered_y.values[0], "Batch Size", "BSD100 PSNR", "batch_size_test_BSD100_PSNR.svg", True, lableprecision=1)

In [None]:
data = get_data(["train.time", "config.batch_size"])
filtered_y = data.filter(regex=("train\.time_rfdn_advanced_batchsize*")).mean()
filtered_x = data.filter(regex=("config\.batch_size_rfdn_advanced_batchsize.*")).head(1)
make_bar_chart(filtered_x.values[0], filtered_y.values * 1000, "Batch Size", "Time [ms]", "batch_size_average_validation_time.svg", True, lableprecision=1)

In [None]:
data = get_data(["validate.Urban100.psnr_scale_2", "Epoch"])
filtered_x = data.filter(regex=("Epoch_rfdn_advanced_600_epochs_no_batchnorm_pruning_[0-9]+_(cawr_lr|lr)_.*.csv")).head(600).transpose()
filtered_y = data.filter(regex=("validate\.Urban100.psnr_scale_2_rfdn_advanced_600_epochs_no_batchnorm_pruning_[0-9]+_(cawr_lr|lr)_.*.csv")).head(600).transpose()
lables = ["LR \$10^{-2}\$", "LR \$10^{-3}\$", "LR \$10^{-4}\$", "LR \$10^{-5}\$", "CAWR \$10^{-2}\$", "CAWR \$10^{-3}\$", "CAWR \$10^{-4}\$", "CAWR \$10^{-5}\$"]
make_line_chart(filtered_x.values, filtered_y.values, lables, "Epoch", "Urban100 PSNR", "epoch_Urban100_psnr.svg", legend_location="upper center", ncol=4)

In [None]:
data = get_data(["lr", "Epoch"])
filtered_x = data.filter(items=["Epoch_rfdn_advanced_600_epochs_no_batchnorm_pruning_8_cawr_lr_3_5.csv", "Epoch_rfdn_advanced_600_epochs_no_batchnorm_pruning_8_lr_3_5.csv"]).head(100).transpose()
filtered_y = data.filter(items=["lr_rfdn_advanced_600_epochs_no_batchnorm_pruning_8_cawr_lr_3_5.csv", "lr_rfdn_advanced_600_epochs_no_batchnorm_pruning_8_lr_3_5.csv"]).head(100).transpose()
lables = ["CAWR", "Multistep"]
make_line_chart(filtered_x.values, filtered_y.values, lables, "Epoch", "Learning Rate", "epoch_lr_cawr.svg", legend_location="upper center")

In [None]:
data = get_data(["num_parameters", "test.mean_forward_pass_time", "validate.BSD100.psnr_scale_2", "validate.BSD100.psnr_scale_3", "validate.BSD100.psnr_scale_4", "validate.Urban100.psnr_scale_2", "validate.Urban100.psnr_scale_3", "validate.Urban100.psnr_scale_4", "validate.Set5.psnr_scale_2", "validate.Set5.psnr_scale_3", "validate.Set5.psnr_scale_4", "validate.Set14.psnr_scale_2", "validate.Set14.psnr_scale_3", "validate.Set14.psnr_scale_4", "validate.BSD100.ssim_scale_2", "validate.BSD100.ssim_scale_3", "validate.BSD100.ssim_scale_4", "validate.Urban100.ssim_scale_2", "validate.Urban100.ssim_scale_3", "validate.Urban100.ssim_scale_4", "validate.Set5.ssim_scale_2", "validate.Set5.ssim_scale_3", "validate.Set5.ssim_scale_4", "validate.Set14.ssim_scale_2", "validate.Set14.ssim_scale_3", "validate.Set14.ssim_scale_4"])

num_parameters = data.filter(regex=("num_parameters_.*last.*")).min().transpose()

inference_time = data.filter(regex=("test\.mean_forward_pass_time.*last.*")).min().transpose()

psnrBSD100 = data.filter(regex=("validate\.BSD100\.psnr_scale_.*last.*")).max().transpose()
psnrUrban100 = data.filter(regex=("validate\.Urban100\.psnr_scale_.*last.*")).max().transpose()
psnrSet5 = data.filter(regex=("validate\.Set5\.psnr_scale_.*last.*")).max().transpose()
psnrSet14 = data.filter(regex=("validate\.Set14\.psnr_scale_.*last.*")).max().transpose()

ssimBSD100 = data.filter(regex=("validate\.BSD100\.ssim_scale_.*last.*")).max().transpose()
ssimUrban100 = data.filter(regex=("validate\.Urban100\.ssim_scale_.*last.*")).max().transpose()
ssimSet5 = data.filter(regex=("validate\.Set5\.ssim_scale_.*last.*")).max().transpose()
ssimSet14 = data.filter(regex=("validate\.Set14\.ssim_scale_.*last.*")).max().transpose()

display((num_parameters / 1000).round().astype(int))

display(inference_time)

display(psnrBSD100.round(2))
display(psnrUrban100.round(2))
display(psnrSet5.round(2))
display(psnrSet14.round(2))

display(ssimBSD100.round(4))
display(ssimUrban100.round(4))
display(ssimSet5.round(4))
display(ssimSet14.round(4))