In [1]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

colors = px.colors.qualitative.Plotly

In [2]:
import os
import numpy as np
from pathlib import Path

In [3]:
def extract_values(p):
    avg_precision = []
    avg_recall = []
    with open(p) as f:
        cont = f.read()
        lines = cont.split("\n")
        for l in lines:
            if "Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]" in l:
                v = float(l.split("] = ")[-1])
                avg_precision.append(v)
            if "Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]" in l:
                v = float(l.split("] = ")[-1])
                avg_recall.append(v)
    return avg_precision, avg_recall

In [14]:
dir = Path("/home/andliao/results")

ngc_id = {
    2838350: "T", 
    2838351: "S + T",
    2838352: "S", 
    2838290: "shift-30-left",
    2838292: "shift-60-left",
    2838293: "shift-30-up",
    2838294: "shift-60-up",
    2838295: "scale-20-up",
    2838296: "scale-20-down",
}

data = {}
for id, name in ngc_id.items():

    ps = list((dir / str(id)).glob("2022*log"))
    if len(ps) > 0:
        p = sorted(ps)[-1]

        avg_precision, avg_recall = extract_values(p)
        data[name] = [avg_precision, avg_recall]

    ps = list((dir / str(id)).glob("*.pth"))
    for p in ps:
        cmd = f"rm {p}"
        os.system(cmd)

In [17]:
fig = make_subplots(rows=1, cols=2, shared_xaxes=True, subplot_titles=("Target AP", "Target AR"))

for i, (name, values) in enumerate(data.items()):

    avg_precision, avg_recall = values
    trace = go.Scatter(x=8*np.arange(len(avg_precision)), y=avg_precision, name=name, showlegend=True, mode="lines", line=dict(color=colors[i]))
    fig.add_trace(trace, row=1, col=1)
    
    trace = go.Scatter(x=8*np.arange(len(avg_recall)), y=avg_recall, name=name, showlegend=False, mode="lines", line=dict(color=colors[i]))
    fig.add_trace(trace, row=1, col=2)

    print(name, avg_precision[-1], avg_recall[-1])

T 0.372 0.463
S + T 0.366 0.476
S 0.054 0.182
shift-30-left 0.391 0.515
shift-60-left 0.315 0.446
shift-30-up 0.371 0.467
shift-60-up 0.371 0.476
scale-20-up 0.429 0.528
scale-20-down 0.361 0.475


In [16]:
fig.update_layout(width=700, height=400, template="plotly_white", legend=dict(orientation="h", x=0.1, y=1.45), margin=dict(t=50, l=30, b=30, r=10))
fig.update_xaxes(linewidth=2, linecolor="black", title="epoch")
fig.update_yaxes(linewidth=2, linecolor="black")