# Setting

In [11]:
import numpy as np
import pandas as pd
import plotly.express as px         # pip install plotly 터미널에서 돌릴 것
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots


pio.templates.default = 'plotly_dark' # 어두운모드 설정. 지우면 밝은 모드로 돌아옵니다.

In [12]:
DATASETS = [
    # 'None',
    'aquatic_mammals',
    'fish',
    'flowers',
    'food_containers',
    'fruit_and_vegetables',
    'household_electrical_devices',
    'household_furniture',
    'insects',
    'large_carnivores',
    'large_man-made_outdoor_things',
    'large_natural_outdoor_scenes',
    'large_omnivores_and_herbivores',
    'medium_mammals',
    'non-insect_invertebrates',
    'people',
    'reptiles',
    'small_mammals',
    'trees',
    'vehicles_1',
    'vehicles_2' ]

In [None]:
result_path = "../wo_backbone_train_log/"
result = {}
combined_results = {}
rnd_dataframes = {}

for dataset in DATASETS:
    df = pd.read_csv(result_path + f'wo_backbone_results_{dataset}.csv')
    
    train_df = df[df["Mode"] == "train"].reset_index(drop=True)
    val_df = df[df["Mode"] == "val"].reset_index(drop=True)
    
    result[dataset] = {
        "train": train_df,
        "val": val_df
    }
    combined_data = []
    train = result[dataset]["train"]
    val = result[dataset]["val"]

    train_index, val_index = 0, 0

    while train_index < len(train) and val_index<len(val):
        train_epoch = train.loc[train_index, "Epoch"]
        val_epoch = val.loc[val_index, "Epoch"]

        if train_epoch == val_epoch and train_epoch != 0:
            combined_data.append({
                "Epoch": train_epoch,
                "Train_loss": train.loc[train_index, "Loss"],
                "Train_accuracy": train.loc[train_index, "Accuracy"],
                "Val_loss": val.loc[val_index, "Loss"],
                "Val_accuracy": val.loc[val_index, "Accuracy"]
            })
            train_index += 1
            val_index += 1
        elif train_epoch < val_epoch:
            train_index += 1
        else:
            val_index += 1
    combined_df = pd.DataFrame(combined_data)
    combined_results[dataset] = combined_df
    
    # pruning 단계: 10개 20 epoch씩 rnd_i에 저장
    start_index = 100
    rnd_count = 1
    while start_index < len(combined_df):
        rnd_data = combined_df.iloc[start_index:start_index+20]
        if len(rnd_data) == 20 and (rnd_data["Epoch"].iloc[0]==1 and rnd_data["Epoch"].iloc[-1]==20):
            rnd_dataframes[f"{dataset}_prune_{rnd_count}"] = rnd_data.reset_index(drop=True)
            rnd_count += 1
        start_index += 20

In [69]:
len(result)

20

In [68]:
combined_results["fish"]

Unnamed: 0,Epoch,Train_loss,Train_accuracy,Val_loss,Val_accuracy
0,1,1.610,18.96,1.608,20.6
1,2,1.606,21.04,1.586,32.4
2,3,1.492,32.28,1.422,33.4
3,4,1.379,32.84,1.328,33.8
4,5,1.324,39.60,1.277,40.4
...,...,...,...,...,...
295,16,0.803,69.32,0.831,66.8
296,17,0.784,69.80,0.824,66.8
297,18,0.803,68.80,0.818,66.8
298,19,0.772,70.12,0.826,66.0


In [78]:
print(rnd_dataframes["fish_prune_1"])

    Epoch  Train_loss  Train_accuracy  Val_loss  Val_accuracy
0       1       0.638           75.60     0.753          72.4
1       2       0.637           75.92     0.732          73.2
2       3       0.634           76.44     0.748          71.8
3       4       0.634           75.68     0.738          72.2
4       5       0.643           74.88     0.739          71.8
5       6       0.629           76.36     0.744          73.2
6       7       0.629           76.20     0.737          72.0
7       8       0.637           75.20     0.737          71.8
8       9       0.624           75.48     0.741          72.4
9      10       0.651           75.60     0.738          71.8
10     11       0.635           75.80     0.734          73.2
11     12       0.627           76.20     0.757          72.4
12     13       0.639           75.64     0.751          71.6
13     14       0.627           76.08     0.743          72.8
14     15       0.637           75.32     0.747          70.4
15     1

# Function

In [59]:
def compare_plot(dataset, metric_train, metric_val, x_axis):
    fig = go.Figure()
    df = combined_results[dataset]

    # Train 데이터
    fig.add_trace(go.Scatter(
        x=df[x_axis],
        y=df[train],
        mode='lines',
        name=f'{dataset} - Train {metric_train}'
    ))
    # Validation 데이터
    fig.add_trace(go.Scatter(
        x=df[x_axis],
        y=df[val],
        mode='lines',
        name=f'{dataset} - Val {metric_val}'
    ))

    fig.update_layout(
        title=f'{metric_train} & {metric_val} for {dataset}',
        xaxis_title=x_axis,
        yaxis_title="Value"
    )
    fig.show()


In [None]:
def train_result(task, col_num, Train_loss, Val_loss, Train_accuracy, Val_accuracy, Epoch):
    fig = make_subplots(rows=1, cols=col_num, subplot_titles=[task], specs=[[{"secondary_y": True}] * col_num])

    df = combined_results[task].iloc[:100]  # 인덱스 0-99까지만 선택
    row, col = 1, 1

    # Training Loss
    fig.add_trace(
        go.Scatter(x=df[Epoch], y=df[Train_loss], mode='lines', name=f'Train Loss', line=dict(color='blue', dash='dash')),
        row=row, col=col, secondary_y=False
    )

    # Validation Loss
    fig.add_trace(
        go.Scatter(x=df[Epoch], y=df[Val_loss], mode='lines', name=f'Val Loss', line=dict(color='red', dash='dash')),
        row=row, col=col, secondary_y=False
    )

    # Training Accuracy (secondary y-axis)
    fig.add_trace(
        go.Scatter(x=df[Epoch], y=df[Train_accuracy], mode='lines', name=f'Train Accuracy', line=dict(color='purple')),
        row=row, col=col, secondary_y=True
    )

    # Validation Accuracy (secondary y-axis)
    fig.add_trace(
        go.Scatter(x=df[Epoch], y=df[Val_accuracy], mode='lines', name=f'Val Accuracy', line=dict(color='green')),
        row=row, col=col, secondary_y=True
    )

    # Layout 설정
    fig.update_layout(
        title=f"{task} (First 100 Rows)",
        height=400,
        width=1000,
        legend=dict(orientation="h", x=0.5, y=-0.1, xanchor="center")
    )

    fig.show()

def prune_result(task, col_num, Train_loss, Val_loss, Train_accuracy, Val_accuracy, Epoch):
    prune_plots = [key for key in rnd_dataframes.keys() if key.startswith(f"{task}_prune_")]

    # 행과 열 계산
    rows = int(len(prune_plots) / col_num) + (1 if len(prune_plots) % col_num else 0)

    # 서브플롯 생성
    fig = make_subplots(rows=rows, cols=col_num, subplot_titles=prune_plots, specs=[[{"secondary_y": True}] * col_num] * rows)

    for i, prune_key in enumerate(prune_plots):
        df = rnd_dataframes[prune_key]  # {dataset}_prune_i 데이터
        row = (i // col_num) + 1
        col = (i % col_num) + 1

        # Training Loss
        fig.add_trace(
            go.Scatter(x=df[Epoch], y=df[Train_loss], mode='lines', name=f'Train Loss', line=dict(color='blue', dash='dash')),
            row=row, col=col, secondary_y=False
        )

        # Validation Loss
        fig.add_trace(
            go.Scatter(x=df[Epoch], y=df[Val_loss], mode='lines', name=f'Val Loss', line=dict(color='red', dash='dash')),
            row=row, col=col, secondary_y=False
        )

        # Training Accuracy (secondary y-axis)
        fig.add_trace(
            go.Scatter(x=df[Epoch], y=df[Train_accuracy], mode='lines', name=f'Train Accuracy', line=dict(color='purple')),
            row=row, col=col, secondary_y=True
        )

        # Validation Accuracy (secondary y-axis)
        fig.add_trace(
            go.Scatter(x=df[Epoch], y=df[Val_accuracy], mode='lines', name=f'Val Accuracy', line=dict(color='green')),
            row=row, col=col, secondary_y=True
        )

    # 레이아웃 설정
    fig.update_layout(
        title=f"{task} Prune Stages (Epoch 1-20)",
        height=300 * rows,
        width=1000,
        legend=dict(orientation="h", x=0.5, y=-0.1, xanchor="center")
    )

    fig.show()


In [97]:

train_result("fish", col_num=1, Train_loss="Train_loss", Val_loss="Val_loss", Train_accuracy="Train_accuracy", Val_accuracy="Val_accuracy", Epoch="Epoch")

prune_result("fish", col_num=2, Train_loss="Train_loss", Val_loss="Val_loss", Train_accuracy="Train_accuracy", Val_accuracy="Val_accuracy", Epoch="Epoch")
