In [1]:
import pickle 
import os
import pandas as pd
import torch
import plotly.graph_objects as go
from io import BytesIO

In [2]:
def fix(map_loc):
    # Closure rather than a lambda to preserve map_loc 
    return lambda b: torch.load(BytesIO(b), map_location=map_loc)

class MappedUnpickler(pickle.Unpickler):
    # https://github.com/pytorch/pytorch/issues/16797#issuecomment-633423219

    def __init__(self, *args, map_location='cpu', **kwargs):
        self._map_location = map_location
        super().__init__(*args, **kwargs)

    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return fix(self._map_location)
        else: 
            return super().find_class(module, name)
        
def mapped_loads(s, map_location='cpu'):
    bs = BytesIO(s)
    unpickler = MappedUnpickler(bs, map_location=map_location)
    return unpickler.load()

def load_all_outputs(output_path, subdir): 
    data_dir = os.path.join(output_path, subdir)
    datas = []
    for file in os.listdir(data_dir): 
        file_path = os.path.join(data_dir, file)
        with open(file_path, 'rb') as f: 
            data_dict = mapped_loads(f.read())
        datas.append(data_dict)
    return datas

In [3]:
def return_accuracy(confusion_matrix, data): 
    return (confusion_matrix.diag().sum() / confusion_matrix.sum()).item()

def wrap_mean_sim(sim_key): 
    def return_mean_sim(sim_matrix, data): 
        confusion_matrix = data[sim_key]
        actual_counts = confusion_matrix.sum()
        actual_sims = sim_matrix.diag().sum()
        return (actual_sims / actual_counts).item()
    return return_mean_sim

def wrap_mean_dissim(sim_key): 
    def return_mean_dissim(sim_matrix, data): 
        N = sim_matrix.shape[0]
        sim_matrix = sim_matrix.cpu()
        confusion_matrix = data[sim_key].cpu()
        actual_counts = confusion_matrix.sum()
        actual_sims = (sim_matrix.sum() - sim_matrix.diag().sum() ) / (N - 1)
        return (actual_sims / actual_counts).item()
    return return_mean_dissim


In [4]:
# filter: dictionary key : value -> filters dictionary entries such that key == value
# takes mean of pandas dataframe
def return_plot_line(data_list, x_var, y_var, filters, process_y=None): 
    line_points = []
    for data in data_list: 
        
        valid = True
        for key, value in filters.items(): 
            if data[key] != value: 
                valid = False 
                break 
        if not valid: 
            continue 
            
        x_value = data[x_var[0]]
        for idx in range(1, len(x_var)): 
            x_value = x_value[x_var[idx]]
            if isinstance(x_value, pd.core.series.Series): 
                x_value = x_value.mean()
                
        y_value = data[y_var[0]]
        for idx in range(1, len(y_var)): 
            y_value = y_value[y_var[idx]]
            if isinstance(y_value, pd.core.series.Series): 
                y_value = y_value.mean()
        if process_y is not None: 
            y_value = process_y( y_value, data)
        line_points.append((x_value, y_value))
    return line_points

In [7]:
def make_graph(all_lines, run_title, compute_type, save=True): 
    fig = go.Figure()

    # Create and style traces
    for key, value in all_lines.items(): 
        value.sort(key = lambda x: x[0]) 
        line = list(zip(*value))
        fig.add_trace(go.Scatter(x=line[0], y=line[1], name=key))

    fig.update_layout(title=f'{run_title} - {compute_type}',
                       xaxis_title='epoch',
                       yaxis_title=compute_type)
    #fig.update_yaxes(range=[10, 140000])
    fig.update_xaxes(range=[0, 50])

    fig.show()
    if save: 
        title = f"{run_title}_{compute_type}"
        print(title)
        fig.write_html(f"graphs/{title}.html")

In [8]:
all_params = [
    ("/mnt/nfs/home/yunxingl/self-supervised-learning/outputs/2022-02-01/14-08-35", "simclr_no_crop_2"), 
    ("/mnt/nfs/home/yunxingl/self-supervised-learning/outputs/2022-02-01/14-44-38", "simclr_default_with_crop_2"), 
    ("/mnt/nfs/home/yunxingl/self-supervised-learning/outputs/2021-11-15/10-24-32", "byol_weak_colorjitter_nocrop"), 
    ("/mnt/nfs/home/yunxingl/self-supervised-learning/outputs/2021-11-15/10-20-24", "byol_crop_and_strong_colorjitter")
]
for path, run_title in all_params: 
    outs = load_all_outputs(path, "confusion")
    data_train = return_plot_line(outs, ['epoch'], ['rep_confusion_sim'], {'split': 'val'}, process_y=return_accuracy)
    data_val = return_plot_line(outs, ['epoch'], ['rep_confusion_sim'], {'split': 'train'}, process_y=return_accuracy)
    make_graph({"train": data_train, "val": data_val}, run_title, "10-shot accuracy")

simclr_no_crop_2_10-shot accuracy


simclr_default_with_crop_2_10-shot accuracy


byol_weak_colorjitter_nocrop_10-shot accuracy


byol_crop_and_strong_colorjitter_10-shot accuracy


[(899, 0.40458598732948303),
 (749, 0.38675159215927124),
 (599, 0.37248408794403076),
 (549, 0.3923566937446594),
 (799, 0.39057326316833496),
 (849, 0.3880254924297333),
 (199, 0.35949045419692993),
 (349, 0.362293004989624),
 (399, 0.37248408794403076),
 (149, 0.37834393978118896),
 (249, 0.37808915972709656),
 (299, 0.36585986614227295),
 (99, 0.36560508608818054),
 (499, 0.3549044728279114),
 (649, 0.39057326316833496),
 (999, 0.40764331817626953),
 (949, 0.41477707028388977),
 (699, 0.3824203908443451),
 (449, 0.3678980767726898),
 (49, 0.3156687915325165)]

In [10]:
for path, run_title in all_params: 
    outs = load_all_outputs(path, "confusion")
    raw_similarity = return_plot_line(outs, ['epoch'], ['sum_sim_rep'], {'split': 'val'}, 
                                      process_y=wrap_mean_sim('rep_confusion_sim'))
    raw_dissim = return_plot_line(outs, ['epoch'], ['sum_sim_rep'], {'split': 'val'}, 
                                  process_y=wrap_mean_dissim('rep_confusion_sim'))
    relative_sim = return_plot_line(outs, ['epoch'], ['sum_sim_rep'], {'split': 'val'}, 
                                    process_y=return_accuracy)
    make_graph({"similarity, same class": raw_similarity, 
                "similarity, different classes": raw_dissim}, run_title, "mean_similarity_raw", save=False)
    #make_graph({"relative_similarity": relative_sim}, run_title, "relative_mean_similarity", save=False)


