In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [None]:
def get_loss_acc(line):
    parts = line.split(",")
    return float(parts[0].replace("Loss:","")), float(parts[1].replace("Acc:",""))

def load_logs(log_file_path):
    train_loss = []
    train_acc = []
    validation_loss = []
    validation_acc = []
    best_epoch = -1
    base_model = ""
    lr = 0
    test_accuracy = "N/A"

    with open(log_file_path) as file:
        for line in file.readlines():
            if "Learing rate:" in line:
                lr = line.split("Learing rate: ")[1].replace("\n", "")
            if "Base Model:" in line:
                base_model = line.split("Base Model: ")[1].replace("\n", "")
            if "Best Epoch:" in line:
                best_epoch = int(line.split("Best Epoch: ")[1])
            if "TRAIN:" in line:
                loss, acc = get_loss_acc(line.split("TRAIN:")[1].strip().replace("%",""))
                train_loss.append(loss)
                train_acc.append(acc)
            if "VALIDATION:" in line:
                loss, acc = get_loss_acc(line.split("VALIDATION:")[1].strip().replace("%",""))
                validation_loss.append(loss)
                validation_acc.append(acc)
            if "Test Set Accuracy:" in line:
                test_accuracy = line.split("Test Set Accuracy: ")[1].replace("\n", "")
    
    return {
        "epochs": list(range(1, len(train_loss)+1)),
        "best epoch": best_epoch,
        "base model": base_model,
        "test accuracy": test_accuracy,
        "learning rate": lr,
        "train": {
            "loss": train_loss,
            "accuracy": train_acc
        },
        "validation": {
            "loss": validation_loss,
            "accuracy": validation_acc
        }
    }

In [None]:
def load_results(file_path):
    df = pd.read_csv(file_path)
    
    data = {}
    
    p, r, f = 0, 0, 0
    
    for i, row in df.iterrows():
        row = row[0].split(";")
        data[row[1]] = {
            "precision": float(row[2]),
            "recall": float(row[3]),
            "f1score": float(row[4])
        }
        
        p += float(row[2])
        r += float(row[3])
        f += float(row[4])
        
    data["Average"] = {
        "precision": p/33,
        "recall": r/33,
        "f1score": f/33
    }

    return data

def compare_results(data):
    for key in list(data[0].keys()):
        print(key)
        p = r = f = ""
        
        for d in data:
            p += f' {d[key]["precision"]:.4f} '
            r += f' {d[key]["recall"]:.4f} '
            f += f' {d[key]["f1score"]:.4f} '
       
            
        print(f"    Precision  {p}")
        print(f"    Recall     {r}")
        print(f"    F1-Score   {f}")

In [None]:
def plot_single(data):
    fig, ax = plt.subplots(2, 1, figsize=(10,12))
    fig.suptitle(f'{data["base model"]} lr={data["learning rate"]}', fontsize=15, y=0.95)
    
    ax[0].plot(
        data["epochs"], data["train"]["loss"],
        data["epochs"], data["validation"]["loss"],
    )
    ax[0].legend(["Train", "Validation"])
    ax[0].grid()
    #ax[0].axvline(x=data["best epoch"], color='gray')
    ax[0].set_title("Loss")

    ax[1].plot(
        data["epochs"], data["train"]["accuracy"],
        data["epochs"], data["validation"]["accuracy"]
    )
    ax[1].legend(["Train", "Validation"])
    ax[1].grid()
    #ax[1].axvline(x=data["best epoch"], color='gray')
    ax[1].set_title("Accuracy")

    plt.show()

In [None]:
COLORS = ["blue", "green", "red", "orange", "purple", "#FFFFFF"]

def compare_plot(data):
    fig, ax = plt.subplots(4, 1, figsize=(10,16))

    labels = []
    best_epochs = []
    title = []
    
    print("Overall Test Accuracy")
    
    for i, d in enumerate(data):
        print(f"{d['base model']} lr={d['learning rate']} ({d['best epoch']}) = {d['test accuracy']}")
        print(f'    Best Validitation Accuracy {max(d["validation"]["accuracy"])}')
        print(f'    Lowest Validation Loss: {min(d["validation"]["loss"])}')
        labels.append(f"{d['base model']} lr={d['learning rate']}")
        title.append(f"{d['base model']} lr={d['learning rate']} ({d['best epoch']}) = {d['test accuracy']}")
        
        ax[0].plot(d["epochs"], d["train"]["loss"], color=COLORS[i], linewidth=1)
        ax[1].plot(d["epochs"], d["train"]["accuracy"], color=COLORS[i], linewidth=1)
        ax[2].plot(d["epochs"], d["validation"]["loss"], color=COLORS[i], linewidth=1)
        ax[3].plot(d["epochs"], d["validation"]["accuracy"], color=COLORS[i], linewidth=1)
        
        
        best_epochs.append(d["best epoch"])
            
        
    ax[0].set_title("Train Loss")
    ax[1].set_title("Train Accuracy")
    ax[2].set_title("Validation Loss")
    ax[3].set_title("Validation Accuracy")
    
    ax[0].axis(xmin=0, xmax=100)
    ax[1].axis(xmin=0, xmax=100)
    ax[2].axis(xmin=0, xmax=100)
    ax[3].axis(xmin=0, xmax=100)
    
    ax[0].grid()
    ax[1].grid()
    ax[2].grid()
    ax[3].grid()
    
    ax[0].legend(labels)
    ax[1].legend(labels)
    ax[2].legend(labels)
    ax[3].legend(labels)
    
    fig.suptitle("\n".join(list(set(title))), fontsize=20)
    #fig.savefig("compare.png")
    plt.show()

In [None]:
logs_nofeatures = {
    "resnet50": {
        "0.01": load_logs(r"Checkpoints\resnet50_nofeatures\2024-02-04 00_58_13\log.txt"),
        "0.001": load_logs(r"Checkpoints\resnet50_nofeatures\2024-02-03 17_43_07\log.txt"),
        "0.0001": load_logs(r"Checkpoints\resnet50_nofeatures\2024-02-03 17_52_30\log.txt")
    },
    "densenet161": {
        "0.001": load_logs(r"Checkpoints\densenet161_nofeatures\2024-02-11 08_08_26\log.txt"),
        "0.001-NLL": load_logs(r"Checkpoints\densenet161_nofeatures\2024-02-12 19_32_21\log.txt")
    },
    "resnet151": {
        "0.001": load_logs(r"Checkpoints\resnet152_nofeatures\2024-02-09 07_43_24\log.txt")
    },
    "vit": {
        "0.001": load_logs(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-10 09_39_32\log.txt"),
        "0.001-NLL": load_logs(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-11 20_30_58\log.txt"),
        "0.001-NLL-1": load_logs(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-13 21_35_14\log.txt"),
        "0.001-NLL-2": load_logs(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-14 20_18_33\log.txt")
    }
}

logs_channel = {
    "resnet50": {
        "0.001": load_logs(r"Checkpoints\resnet50_channels\2024-02-04 08_12_54\log.txt"),
    }
}

logs_eech = {
    "resnet50": {
        "0.001": load_logs(r"Checkpoints\resnet50_eech\2024-02-06 13_04_01\log.txt"),
    }
}

logs_allfeatures = {
    "resnet50": {
        "0.001": load_logs(r"Checkpoints\resnet50_allfeatures\2024-02-04 08_10_15\log.txt")
    }
}

In [None]:
compare_plot([
    #logs_eech["resnet50"]["0.001"],
    #logs_channel["resnet50"]["0.001"],
    #logs_allfeatures["resnet50"]["0.001"],
    #logs_nofeatures["resnet50"]["0.01"],
    #logs_nofeatures["resnet50"]["0.001"],
    #logs_nofeatures["resnet50"]["0.0001"],
    #logs_nofeatures["densenet161"]["0.001"],
    #logs_nofeatures["densenet161"]["0.001-NLL"],
    #logs_nofeatures["resnet151"]["0.001"],
    logs_nofeatures["vit"]["0.001"],
    logs_nofeatures["vit"]["0.001-NLL"],
    logs_nofeatures["vit"]["0.001-NLL-1"],
    logs_nofeatures["vit"]["0.001-NLL-2"],
])

In [None]:
results_nofeatures = {
    "resnet50": {
        "0.01": load_results(r"Checkpoints\resnet50_nofeatures\2024-02-04 00_58_13\fc_train_results.csv"),
        "0.001": load_results(r"Checkpoints\resnet50_nofeatures\2024-02-03 17_43_07\fc_train_results.csv"),
        "0.0001": load_results(r"Checkpoints\resnet50_nofeatures\2024-02-03 17_52_30\fc_train_results.csv"),
    
        "0.01-R": load_results(r"Checkpoints\resnet50_nofeatures\2024-02-04 00_58_13\fc_train_results_inrange.csv"),
        "0.001-R": load_results(r"Checkpoints\resnet50_nofeatures\2024-02-03 17_43_07\fc_train_results_inrange.csv"),
        "0.0001-R": load_results(r"Checkpoints\resnet50_nofeatures\2024-02-03 17_52_30\fc_train_results_inrange.csv")
    },
    "resnet151": {
        "0.001": load_results(r"Checkpoints\resnet152_nofeatures\2024-02-09 07_43_24\fc_train_results.csv"),
        
        "0.001-R": load_results(r"Checkpoints\resnet152_nofeatures\2024-02-09 07_43_24\fc_train_results_inrange.csv"),
    },
    "densenet161": {
        "0.001": load_results(r"Checkpoints\densenet161_nofeatures\2024-02-11 08_08_26\fc_train_results.csv"),
        "0.001-NLL": load_results(r"Checkpoints\densenet161_nofeatures\2024-02-12 19_32_21\fc_train_results.csv"),
        
        "0.001-R": load_results(r"Checkpoints\densenet161_nofeatures\2024-02-11 08_08_26\fc_train_results_inrange.csv"),
        "0.001-NLL-R": load_results(r"Checkpoints\densenet161_nofeatures\2024-02-12 19_32_21\fc_train_results_inrange.csv")
    },
    "vit": {
        "0.001": load_results(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-10 09_39_32\fc_train_results.csv"),
        "0.001-NLL": load_results(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-11 20_30_58\fc_train_results.csv"),
        "0.001-NLL-1": load_results(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-13 21_35_14\fc_train_results.csv"),
        "0.001-NLL-2": load_results(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-14 20_18_33\fc_train_results.csv"),
        
        "0.001-R": load_results(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-10 09_39_32\fc_train_results_inrange.csv"),
        "0.001-NLL-R": load_results(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-11 20_30_58\fc_train_results_inrange.csv"),
        "0.001-NLL-1-R": load_results(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-13 21_35_14\fc_train_results_inrange.csv"),
        "0.001-NLL-2-R": load_results(r"Checkpoints\visiontransformerb16_nofeatures\2024-02-14 20_18_33\fc_train_results_inrange.csv")
    }
}

results_allfeatures = {
    "resnet50": {
        "0.001": load_results(r"Checkpoints\resnet50_allfeatures\2024-02-04 08_10_15\fc_train_results.csv"),
        
        "0.001-R": load_results(r"Checkpoints\resnet50_allfeatures\2024-02-04 08_10_15\fc_train_results_inrange.csv"),
    }
}

results_channel = {
    "resnet50": {
        "0.001": load_results(r"Checkpoints\resnet50_channels\2024-02-04 08_12_54\fc_train_results.csv"),
        
        "0.001-R": load_results(r"Checkpoints\resnet50_channels\2024-02-04 08_12_54\fc_train_results_inrange.csv"),
    }
}

results_eech = {
    "resnet50": {
        "0.001": None, #load_results(r"Checkpoints\resnet50_eech\2024-02-06 13_04_01\fc_train_results.csv"),
    }
}

compare_results([
    results_nofeatures["resnet50"]["0.001"],
    results_nofeatures["resnet151"]["0.001"],
    results_nofeatures["densenet161"]["0.001"],
    results_nofeatures["vit"]["0.001"],
])