# Results

In [None]:
from pathlib import Path
import re

log_paths = [Path(d, 'checkpoint.log') for d in Path('/Users/marwei/AFS/home/urz/m/marwei/Logs/IDEAL/cifar10_fixed_budget').glob('cifar10_*')] \
          + [Path(d, 'checkpoint.log') for d in Path('/Users/marwei/AFS/home/urz/m/marwei/Logs/IDEAL/cifar100_fixed_budget').glob('cifar100_*')]


c = []

for exp in log_paths:
    with open(exp) as infile:
        loglines = infile.read().splitlines()

    mem_size = int(re.findall(r"memory_size=(\d+)", loglines[0])[0])
    final_acc = float(re.findall(r"Acc: \[(.*?)\]", loglines[-1])[0])
    compressor = re.findall(r"compressor=\'(.*?)\'", loglines[0])[0]
    encoder = re.findall(r"encoder=\'(.*?)\'", loglines[0])[0]
    dataset = re.findall(r"dataset=\'(.*?)\'", loglines[0])[0]

    compressor_param = None

    if compressor == 'autoencoder':
        compressor = 'convae'

    if compressor == 'thinning':
        compressor_param = float(re.findall(r"compression_factor=(.*?),", loglines[0])[0])
    elif compressor == 'quantization':
        compressor_param = int(re.findall(r"n_states=(\d+)", loglines[0])[0])
        try:
            strategy = re.findall(r"strategy=\'(.*?)\'", loglines[0])[0]
            assert strategy == 'tiny_imagenet_transfer' or strategy == 'local'
            if strategy == 'tiny_imagenet_transfer':
                compressor_new = 'quantization transfer'
            elif strategy == 'local':
                compressor_new = 'quantization local'
            else:
                raise ValueError('Unknown Quantization Stragegy: ' + strategy)
        except IndexError:
            compressor_new = 'quantization local'
        compressor = compressor_new
    elif compressor == 'convae':
        compressor_param = int(re.findall(r"latent_channels=(\d+)", loglines[0])[0])
    elif compressor == 'fcae':
        compressor_param = int(re.findall(r"bottleneck_neurons=(\d+)", loglines[0])[0])
    elif compressor == 'none':
        compressor_param = 0
    else:
        raise ValueError(f'Unknown Compressor: {compressor}')

    encoder = 'cutr' if encoder == 'cutr34' else encoder

    c.append({
        'mem_size': mem_size,
        'final_acc': final_acc,
        'encoder': encoder,
        'compressor': compressor,
        'compressor_param': compressor_param,
        'dataset': dataset
    })

In [None]:
import pandas as pd

df = pd.DataFrame.from_records(c)
df.drop(df.loc[
    (df['compressor']=='fcae') | (df['compressor']=='quantization local')
    ].index, inplace=True)

compressor_names = {
    'none': 'No Compressor',
    'thinning': 'Thinning',
    'quantization transfer': 'Quantization',
    'convae': 'Autoencoder'
}

encoder_names = {
    'cutr': 'FETCH',
    'none': 'no Fixed Encoder'
}

df['compressor_name'] = df['compressor'].map(compressor_names)
df['encoder_name'] = df['encoder'].map(encoder_names)


In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from plot_utils import science_template

com_dict = {k: v for k, v in zip(['none', 'quantization transfer', 'thinning', 'convae'], science_template['layout']['colorway'])}

fig = make_subplots(
    rows=1,
    cols=4,
    shared_yaxes=True,
    subplot_titles=['CIFAR10,<br>no fixed Encoder', 'CIFAR10,<br>FETCH', 'CIFAR100,<br>no fixed Encoder', 'CIFAR100,<br>FETCH']
)

col=1
names = set()

for this_dataset in df.dataset.unique():
    for this_encoder in df.encoder_name.unique()[-1::-1]:
        subfig_view = df.loc[(df['dataset'] == this_dataset) & (df['encoder_name'] == this_encoder)]
        for this_compressor in com_dict.keys():
            line_view = subfig_view.loc[subfig_view['compressor'] == this_compressor].sort_values('mem_size')
            if len(line_view.compressor_name.unique()) == 0:
                continue
            assert len(line_view.compressor_name.unique()) == 1
            if this_compressor == 'none':
                mode='markers'
                symbol='square'
                is_gdumb = this_encoder == 'no Fixed Encoder'
            else:
                mode='lines+markers'
                symbol='circle'
                is_gdumb = False
            outline = {'width': 2 if is_gdumb else 0, 'color': 'black'}
            this_name = 'GDumb' if is_gdumb else line_view.compressor_name.unique().item()
            if this_name in names:
                show_legend = False
            else:
                names.add(this_name)
                show_legend = True
            fig.add_trace(
                go.Scatter(
                    x=line_view['mem_size'],
                    y=line_view['final_acc'],
                    mode=mode,
                    line={
                        'color': com_dict[this_compressor],
                    },
                    marker={
                        'color': com_dict[this_compressor],
                        'symbol': symbol,
                        'line': outline
                    },
                    name=this_name,
                    showlegend=show_legend,
                ),
                col=col,
                row=1,
            )
        col += 1


for i in range(1, 5):
    fig.update_xaxes(
        type='log',
        title='N',
        title_standoff=5,
        row=1,
        col=i,
    )
    fig.update_yaxes(
        range=[0, 0.8],
        row=1,
        col=i
    )

fig.update_layout(
    template=science_template,
    font_size=16,
    legend=dict(
        orientation="h",  # Set the orientation to horizontal
        yanchor="top",
        y=-0.3,
        xanchor="left",
        x=0,
        # font_size=14
    ),
    xaxis=dict(domain=[0, 0.23]),
    xaxis2=dict(domain=[0.25, 0.48]),
    xaxis3=dict(domain=[0.52, 0.75]),
    xaxis4=dict(domain=[0.77, 1]), 
)

fig.update_yaxes(
        title='Accuracy',
        row=1,
        col=1,
    )
fig.update_xaxes(range=[3, 4.9], row=1, col=3)
fig.update_annotations(font_size=16)
fig.update_yaxes(rangemode="tozero")
fig.layout.xaxis2.matches = 'x'
fig.layout.xaxis4.matches = 'x3'
    
fig.show()

In [None]:
width=1000
height=380

config = {
    'displaylogo': False,
    'toImageButtonOptions': {
        'format': 'svg', # one of png, svg, jpeg, webp
        'filename': 'plot',
        'height': width,
        'width': height,
        'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
    }
}
# fig.show(renderer='browser', config=config)
fig.write_image('../plots/fixb.pdf', width=width, height=height)
