In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import glob
from pcam.utils import parse_supervised_metrics, parse_eval_metrics
from collections import defaultdict

In [3]:

def parse_metrics_file(glob_str: str, parse_function):
  metrics_files = sorted(glob.glob(glob_str))
  stats = defaultdict(dict)
  for i, f in enumerate(metrics_files):
      filename_info = f.split("/")
      basename = filename_info[3] + "-" + filename_info[4]+ "-" + filename_info[5]
      version = filename_info[6]+ "-" + filename_info[7]
      stats[basename][version] = parse_function(f)

  sorted_list = []
  for basename, values in stats.items():
      maxv = 0
      for version, metrics in values.items():
          val_stats = metrics["val_accuracy"]
          if val_stats[-1][2] > maxv: maxv = val_stats[-1][2]
      sorted_list.append((maxv, basename))

  sorted_list = [basename for v, basename in sorted(sorted_list, reverse=True)]
  return stats, sorted_list

In [4]:
def calculate_average(model, stats):
  test_accuracies = [x["test_accuracy"][0][2] for x in stats[model].values()]
  return np.mean(test_accuracies), np.std(test_accuracies)

In [5]:
supervised_stats, supervised_list = parse_metrics_file("logs/PCam/Downstream/*/*/*_steps/v_*/*/metrics.csv", parse_supervised_metrics)

In [6]:

from pcam.utils import plot_results_plotly

plot_results_plotly(sorted_list=['Pretrained_ImageNet-aug-128000_spc-15000_steps','Pretrained_ImageNet-aug-51200_spc-2000_steps'], stats=supervised_stats)

plot_results_plotly(sorted_list=['Pretrained_ImageNet-aug-51200_spc-2000_steps',
                          'Pretrained_ImageNet-aug-25600_spc-2000_steps',
                          'Pretrained_ImageNet-aug-12800_spc-2000_steps'], stats=supervised_stats)

plot_results_plotly(sorted_list=['Pretrained_ImageNet-notr-51200_spc-2000_steps',
                          'Pretrained_ImageNet-notr-25600_spc-2000_steps',
                          'Pretrained_ImageNet-notr-12800_spc-2000_steps'], stats=supervised_stats)
                          

In [7]:
plot_results_plotly(sorted_list=['From_Scratch-aug-128000_spc-15000_steps','From_Scratch-aug-51200_spc-2000_steps'], stats=supervised_stats)

plot_results_plotly(sorted_list=['From_Scratch-aug-51200_spc-2000_steps',
                          'From_Scratch-aug-25600_spc-2000_steps',
                          'From_Scratch-aug-12800_spc-2000_steps'], stats=supervised_stats)

plot_results_plotly(sorted_list=['From_Scratch-notr-51200_spc-2000_steps',
                          'From_Scratch-notr-25600_spc-2000_steps',
                          'From_Scratch-notr-12800_spc-2000_steps'], stats=supervised_stats)
                          

In [8]:
eval_stats, eval_list = parse_metrics_file(f"logs/PCam/Eval/*/*_spc/2000_steps/*/*/metrics.csv", parse_eval_metrics)

In [9]:
plot_results_plotly(sorted_list=[
  'PCam-FullFineTune-1250_spc-2000_steps',
  'PCam-FullFineTune-6400_spc-2000_steps',
  'PCam-FullFineTune-12800_spc-2000_steps',
  'PCam-FullFineTune-25600_spc-2000_steps',
  'PCam-FullFineTune-51200_spc-2000_steps'
], stats=eval_stats)

In [10]:
plot_results_plotly(sorted_list=[
  'PCam-FreezeBackbone-1250_spc-2000_steps',
  'PCam-FreezeBackbone-6400_spc-2000_steps',
  'PCam-FreezeBackbone-12800_spc-2000_steps',
  'PCam-FreezeBackbone-25600_spc-2000_steps',
  'PCam-FreezeBackbone-51200_spc-2000_steps',
  'PCam-FreezeBackbone-128000_spc-2000_steps',
], stats=eval_stats)

In [38]:
import plotly.graph_objects as go

categories = ["From_Scratch-notr", "From_Scratch-aug", "Pretrained_ImageNet-notr", "Pretrained_ImageNet-aug"]
samples = ["12800", "25600", "51200"]
full_dataset_accuracy_scratch = supervised_stats["From_Scratch-aug-128000_spc-15000_steps"]["v_0-version_1"]["test_accuracy"][0][2]
full_dataset_accuracy_pretrained = supervised_stats["Pretrained_ImageNet-aug-128000_spc-15000_steps"]["v_0-version_1"]["test_accuracy"][0][2]

fig = go.Figure()
for category in categories:
  y, err = zip(*[calculate_average(f"{category}-{sample}_spc-2000_steps", supervised_stats) for sample in samples])
  fig.add_trace(go.Bar(
    name=f"{category}",
    x=samples,
    y=y,
    error_y=dict(type='data', array=err)
  ))

fig.add_hline(
  full_dataset_accuracy_scratch,
  line=dict(color="gray", dash='dash'),
  annotation=dict(
    text="Full Dataset From Scratch 15k Steps",
    yref="y",
    x=0.342,
    y=0.805,
    showarrow=True,
    arrowcolor="gray",
    arrowhead=2,
    font=dict(size=12, color="gray")
  )
)

fig.add_hline(
  full_dataset_accuracy_pretrained,
  line=dict(color="magenta", dash='dash'),
  annotation=dict(
    text="Full Dataset Pretrained 15k Steps",
    yref="y",
    x=0.755,
    y=0.815,
    showarrow=True,
    arrowcolor="magenta",
    arrowhead=2,
    font=dict(size=12, color="magenta")
  )
)

fig.update_layout(barmode='group', font=dict(size=16), width=1200, height=500, yaxis=dict(range=(0.5, 0.85)), legend=dict(
    yanchor="top",
    y=0.5,
    xanchor="left",
    x=0.71,
    bgcolor="rgba(255, 255, 255, 0.75)"
), margin=dict(t=0, b=0))
fig.show()

In [66]:
import plotly.graph_objects as go

categories = ["FreezeBackbone", "FullFineTune"]
samples = ["1250", "6400", "12800", "25600", "51200"]

fig = go.Figure()
for samples in samples:
  y, _ = zip(*[calculate_average(f"PCam-{category}-{samples}_spc-2000_steps", eval_stats) for category in categories])
  fig.add_trace(go.Bar(
      name=f"{samples} SPC",
      x=categories,
      y=y
  ))

accuracy_target, _ = calculate_average("From_Scratch-aug-51200_spc-2000_steps", supervised_stats)
fig.add_hline(
  accuracy_target,
  line=dict(color="gray", dash='dash'),
  annotation=dict(
    text="51200 SPC Supervised Performance",
    yref="y",
    x=0.31,
    y=0.785,
    showarrow=True,
    arrowcolor="gray",
    arrowhead=2,
    font=dict(size=12, color="gray")
  )
)

fig.update_layout(barmode='group', font=dict(size=16), width=1200, height=500, yaxis=dict(range=(0.5, 0.82)), legend=dict(
    yanchor="top",
    y=0.6,
    xanchor="left",
    x=0.72,
    bgcolor="rgba(255, 255, 255, 0.75)"
), margin=dict(t=0, b=0))
fig.show()