# Imports

In [185]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline


import os
import sys
import torch
from tqdm import tqdm
import pandas as pd
# import copy
# import h5py
# import numpy as np
# import pandas as pd
from IPython.display import display
# import random
# import matplotlib.ticker as ticker
# import matplotlib.pyplot as plt
import numpy as np
# import itertools
# import sklearn


# local imports
sys.path.insert(
    0,
    os.path.join(
        os.path.dirname(os.path.abspath(''))
    )
)
from dataloaders.cifar10 import (
    load_cifar10
)
from dataloaders.cifar100 import (
    load_cifar100
)
from models import models_dict as MODELS_DICT
# models_dict = {
#     "resnet18_cifar": resnet18_cifar,
#     "resnet18_torch": torchvision.models.resnet18,
#     "resnet50_torch": torchvision.models.resnet50,
#     "densenet_cifar": densenet_cifar
# }
sys.path.pop(0)


'/scratch_local/arubinstein17-163577/starlight'

# Functions

In [206]:
def get_link(raw_link):
    id = raw_link.split("/d/")[-1].split("/")[0]
    return f"https://drive.google.com/uc?id={id}"


def get_model(chkpt_path, model_name, model_settings={}):
    model = MODELS_DICT[model_name](**model_settings)
    model.load_state_dict(torch.load(chkpt_path)["state_dict"])
    return model


def get_resnet18_cifar(chkpt_path, num_classes=10):
    return get_model(
        chkpt_path,
        "resnet18_cifar",
        model_settings={"num_classes": num_classes}
)


def get_resnet18_cifar100(chkpt_path):
    return get_resnet18_cifar(chkpt_path, num_classes=100)



def get_paths(root_path):
    stars_path = os.path.join(root_path, "stars")
    anchors_path = os.path.join(root_path, "anchors")
    # anchors = os.listdir("./cifar10_resnet18/anchors")
    anchors = [
        os.path.join(anchors_path, name) for name in os.listdir(anchors_path)
    ]
    stars = [
        os.path.join(stars_path, name) for name in os.listdir(stars_path)
    ]
    return anchors, stars


@torch.no_grad
def compute_margin_for_batch(model, x, device="cuda:0"):

    x.requires_grad = True

    with torch.enable_grad():
        output = model(x)

    output = output.to(device)

    num_classes = output.shape[-1]
    margins_for_all_class_pairs = []
    used_classes = range(num_classes)

    model_grads = {}

    with torch.enable_grad():
        for class_i in used_classes:
            model_grads[class_i] = get_grad(
                output[..., class_i],
                x,
                retain_graph=True
            )

    # cartesian_product = list(itertools.product(used_classes, repeat=2))
    # for i, (class_i, class_j) in (enumerate(cartesian_product)):
    for i in range(len(used_classes)):
        for j in range(i + 1, len(used_classes)):

            class_i = used_classes[i]
            class_j = used_classes[j]

            output_i = output[..., class_i]
            output_j = output[..., class_j]

            grad_i = model_grads[class_i]
            grad_j = model_grads[class_j]
            batch_margins_for_class_pair = compute_margin(
                output_i,
                output_j,
                grad_i,
                grad_j,
                norm="l1"
            )

            margins_for_all_class_pairs.append(batch_margins_for_class_pair.mean())
    return torch.Tensor(margins_for_all_class_pairs).mean()


def get_grad(output_i, x, retain_graph):
    assert x.requires_grad == True
    differentiable_value = output_i.sum()
    differentiable_value.backward(inputs=x, retain_graph=retain_graph)
    grad = x.grad
    x.grad = None
    return grad


def compute_margin(output_i, output_j, grad_i, grad_j, norm="l1"):
    # eq. 7 here: https://proceedings.neurips.cc/paper_files/paper/2018/file/42998cf32d552343bc8e460416382dca-Paper.pdf
    # d = |f_i(x) - f_j(x)| / |grad_x f_i(x) - grad_x f_j (x)|_q
    assert norm == "l1"
    if norm == "l1":
        return (
                (output_i - output_j).abs()
            /
                (grad_i - grad_j).abs().sum()
        )
    else:
        raise NotImplementedError("Only L1 norm is supported")


def margin_on_dataloader(
    model,
    dataloader,
    percent=0.01,
    device="cuda:0"
):

    model.eval()
    model.to(device)

    dataloader_len = len(dataloader)
    if percent == 1:
        active_batches = list(range(dataloader_len))
    else:
        total_batches = int(percent * dataloader_len)
        active_batches = np.linspace(0, dataloader_len, total_batches)
        active_batches = active_batches.astype(int)

    res = []
    pb = tqdm(total=dataloader_len, desc='batch', position=0)
    for i, batch in (enumerate(dataloader)):
        pb.update(1)
        if i not in active_batches:
            continue
        x, y = batch
        x = x.to(device)
        batch_margin = compute_margin_for_batch(
            model,
            x,
            device=device
        )
        res.append(batch_margin)
    return torch.Tensor(res).mean().item()


def compute_results(
    model_paths,
    dataloader,
    cache_path,
    percent=1,
    load_model=get_resnet18_cifar,
    recompute_substring=None
):
    res = {}
    if os.path.exists(cache_path):
        print(f"Loading existing results from {cache_path}")
        res = torch.load(cache_path)
    else:
        res = {}

    dataset_root = dataloader.dataset.dataset.__dict__["root"]
    dataset_name = dataset_root.split(os.sep)[-1]
    for model_path in model_paths:

        path_split = model_path.split(os.sep)[-2:]
        assert len(path_split) == 2
        model_type = path_split[0]
        model_id = path_split[1].split(".")[0]
        df_id = f"{model_type}_{model_id}_{dataset_name}_p{percent}"
        to_recompute = (not df_id in res)
        if not to_recompute and recompute_substring is not None:
            to_recompute = recompute_substring in df_id

        if to_recompute:
            print(f"Computing results for {df_id}")
            res[df_id] = {}
            res[df_id]["margin"] = margin_on_dataloader(
                load_model(model_path),
                dataloader,
                percent=percent,
                device="cuda:0"
            )
            res[df_id]["model_type"] = model_type
            res[df_id]["dataset_name"] = dataset_name
            res[df_id]["percent"] = percent
            res[df_id]["model_id"] = model_id
        else:
            print(f"Reusing results for {df_id}")
    torch.save(res, cache_path)
    return res


def plot_dict(
    res,
    to_show=True,
    filter_dict={},
    drop_columns=[],
    group_aggregate_list=[],
    query=None
):

    df = pd.DataFrame.from_dict(res, orient='index')

    for column_name, allowed_values in filter_dict.items():
        df = df[df[column_name] in allowed_values]

    if query is not None:
        df = df.query(query)

    for column_name in drop_columns:
        if column_name == "index":
            df.reset_index(drop=True, inplace=True)
        else:
            df = df.drop(column_name, axis=1)

    for new_index in group_aggregate_list:
        df = df.groupby(new_index).agg([(f'mean', 'mean'), (f'std', 'std')])

    if to_show:
        display(df)
    return df

# Prepare cifar100

In [177]:
get_link("https://drive.google.com/file/d/1zTH5_YBM9h4AYfeQWhkJyqXW_14xS38B/view?usp=drive_link")

'https://drive.google.com/uc?id=1zTH5_YBM9h4AYfeQWhkJyqXW_14xS38B'

In [178]:
!gdown https://drive.google.com/uc?id=1zTH5_YBM9h4AYfeQWhkJyqXW_14xS38B

Downloading...
From (original): https://drive.google.com/uc?id=1zTH5_YBM9h4AYfeQWhkJyqXW_14xS38B
From (redirected): https://drive.google.com/uc?id=1zTH5_YBM9h4AYfeQWhkJyqXW_14xS38B&confirm=t&uuid=51724791-d6b0-4e41-9a83-bae8a9e60ecc
To: /scratch_local/arubinstein17-163577/starlight/notebooks/cifar100_resnet18.zip
100%|██████████████████████████████████████| 4.38G/4.38G [00:57<00:00, 76.1MB/s]


In [182]:
!unzip cifar100_resnet18.zip -d ./

Archive:  cifar100_resnet18.zip
   creating: ./cifar100_resnet18/
   creating: ./cifar100_resnet18/anchors/
  inflating: ./cifar100_resnet18/anchors/0eqk0gla.pt  
  inflating: ./cifar100_resnet18/anchors/384qnuxm.pt  
  inflating: ./cifar100_resnet18/anchors/3n8ftowe.pt  
  inflating: ./cifar100_resnet18/anchors/4808726s.pt  
  inflating: ./cifar100_resnet18/anchors/4rjdi4pa.pt  
  inflating: ./cifar100_resnet18/anchors/5bfkmro6.pt  
  inflating: ./cifar100_resnet18/anchors/65cr1zyr.pt  
  inflating: ./cifar100_resnet18/anchors/6hapim4a.pt  
  inflating: ./cifar100_resnet18/anchors/73s47o0m.pt  
  inflating: ./cifar100_resnet18/anchors/75eeo9bw.pt  
  inflating: ./cifar100_resnet18/anchors/7wgab2a1.pt  
  inflating: ./cifar100_resnet18/anchors/82yrfx4p.pt  
  inflating: ./cifar100_resnet18/anchors/85izljav.pt  
  inflating: ./cifar100_resnet18/anchors/8pppxxoh.pt  
  inflating: ./cifar100_resnet18/anchors/9ansq593.pt  
  inflating: ./cifar100_resnet18/anchors/9lzove03.pt  
  inflating:

In [186]:
anchors_cifar100_resnet18, stars_cifar100_resnet18 = get_paths("./cifar100_resnet18/")

cifar100_train_dataloader, _, _ = load_cifar100(
    batch_size=256,
    num_workers=8,
    img_size=32,
    normalize=True,
    resize=False,
    horizontal_flip=False,
    vertical_flip=False,
    random_crop_resize=False,
    random_resize_crop=False,
    color_jitter=False,
    rotation_range=0,
    pad_random_crop=False,
    random_one_aug=False,
    train_set_fraction=1.0,
    return_ds=False,
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /mnt/qb/work/oh/arubinstein17/datasets/cifar100/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:07<00:00, 23051054.87it/s]


Extracting /mnt/qb/work/oh/arubinstein17/datasets/cifar100/cifar-100-python.tar.gz to /mnt/qb/work/oh/arubinstein17/datasets/cifar100
Files already downloaded and verified
Files already downloaded and verified
Using 45000 images for training
Using 5000 images for validation
Using train transform Compose(
    IdentityTransform()
    ToTensor()
    Normalize(mean=[0.485 0.456 0.406], std=[0.229 0.224 0.225])
)


# Prepare cifar10

In [None]:
get_link("https://drive.google.com/file/d/1g-TxEGbORtHmxVEefoJtk2yxSf_mHL28/view?usp=drive_link")

In [None]:
!gdown https://drive.google.com/uc?id=1g-TxEGbORtHmxVEefoJtk2yxSf_mHL28

In [None]:
!unzip cifar10_resnet18.zip -d ./

In [16]:
anchors_resnet18, stars_resnet18 = get_paths("./cifar10_resnet18/")

cifar10_train_dataloader, _, _ = load_cifar10(
    batch_size=256,
    num_workers=8,
    img_size=32,
    normalize=True,
    resize=False,
    horizontal_flip=False,
    vertical_flip=False,
    random_crop_resize=False,
    random_resize_crop=False,
    color_jitter=False,
    rotation_range=0,
    pad_random_crop=False,
    random_one_aug=False,
    train_set_fraction=1.0,
    return_ds=False,
)

# Compute margin

## Resnet18 Cifar100

In [190]:
cifar100_resnet18_res_path = "./cifar100_resnet18_results.pt"

cifar100_resnet18_results = compute_results(
    anchors_cifar100_resnet18 + stars_cifar100_resnet18,
    cifar100_train_dataloader,
    cifar100_resnet18_res_path,
    percent=0.05,
    load_model=get_resnet18_cifar100,
    # recompute_substring="p0"
    recompute_substring=None
)

Computing results for anchors_0eqk0gla_cifar100_p0.05


batch: 100%|██████████| 176/176 [00:22<00:00,  7.76it/s]


Computing results for anchors_384qnuxm_cifar100_p0.05


batch: 100%|██████████| 176/176 [00:22<00:00,  7.66it/s]


Computing results for stars_01fj1cav_cifar100_p0.05


batch: 100%|██████████| 176/176 [00:22<00:00,  7.69it/s]


Computing results for stars_212q5op8_cifar100_p0.05


batch: 100%|██████████| 176/176 [00:22<00:00,  7.68it/s]


In [192]:
df = plot_dict(cifar100_resnet18_results, to_show=True)

                                   margin model_type dataset_name  percent
anchors_0eqk0gla_cifar100_p0.05  0.000035    anchors     cifar100     0.05
anchors_384qnuxm_cifar100_p0.05  0.000035    anchors     cifar100     0.05
stars_01fj1cav_cifar100_p0.05    0.000034      stars     cifar100     0.05
stars_212q5op8_cifar100_p0.05    0.000033      stars     cifar100     0.05


Unnamed: 0,margin,model_type,dataset_name,percent
anchors_0eqk0gla_cifar100_p0.05,3.5e-05,anchors,cifar100,0.05
anchors_384qnuxm_cifar100_p0.05,3.5e-05,anchors,cifar100,0.05
stars_01fj1cav_cifar100_p0.05,3.4e-05,stars,cifar100,0.05
stars_212q5op8_cifar100_p0.05,3.3e-05,stars,cifar100,0.05


In [207]:
df = plot_dict(cifar100_resnet18_results, to_show=True, filter_dict={}, drop_columns=["index", "dataset_name"], group_aggregate_list=["model_type"])

Unnamed: 0_level_0,margin,margin,percent,percent
Unnamed: 0_level_1,mean,std,mean,std
model_type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
anchors,3.5e-05,2.278152e-07,0.05,0.0
stars,3.4e-05,6.369489e-07,0.05,0.0


## Resnet18 Cifar10

In [None]:
margin_on_dataloader(get_resnet18_cifar(anchors_resnet18[0]), get_resnet18_cifar(anchors_resnet18[1]), cifar10_train_dataloader, percent=0.10, device="cuda:0")

In [None]:
cifar10_resnet18_res_path = "./cifar10_resnet18_results.pt"

cifar10_resnet18_results = compute_results(
    anchors_resnet18 + stars_resnet18,
    cifar10_train_dataloader,
    cifar10_resnet18_res_path,
    percent=1,
    load_model=get_resnet18_cifar,
    # recompute_substring="p0"
    recompute_substring=None
)

In [205]:
df = plot_dict(cifar10_resnet18_results, to_show=True)

Unnamed: 0,margin,model_type,dataset_name,percent,model_id
anchors_2oo0upw1_cifar10_p0.05,8.1e-05,anchors,cifar10,0.05,2oo0upw1
anchors_3m2c3guk_cifar10_p0.05,8e-05,anchors,cifar10,0.05,3m2c3guk
stars_0eiwx0xx_cifar10_p0.05,7.1e-05,stars,cifar10,0.05,0eiwx0xx
stars_4s24tjof_cifar10_p0.05,7e-05,stars,cifar10,0.05,4s24tjof


In [209]:
df = plot_dict(cifar10_resnet18_results, to_show=True, drop_columns=["index", "dataset_name", "model_id"], group_aggregate_list=["model_type"], query="model_type == 'anchors' | model_id in ['0eiwx0xx']")

Unnamed: 0_level_0,margin,margin,percent,percent
Unnamed: 0_level_1,mean,std,mean,std
model_type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
anchors,8.1e-05,7.595591e-07,0.05,0.0
stars,7.1e-05,,0.05,


# Save df

In [None]:
df.to_csv("tmp.csv")