# Imports

In [134]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline


import os
import sys
import torch
from tqdm import tqdm
import pandas as pd
# import copy
# import h5py
# import numpy as np
# import pandas as pd
from IPython.display import display
# import random
# import matplotlib.ticker as ticker
# import matplotlib.pyplot as plt
import numpy as np
# import itertools
# import sklearn


# local imports
sys.path.insert(
    0,
    os.path.join(
        os.path.dirname(os.path.abspath(''))
    )
)
from dataloaders.cifar10 import (
    load_cifar10
)
from models import models_dict as MODELS_DICT
# models_dict = {
#     "resnet18_cifar": resnet18_cifar,
#     "resnet18_torch": torchvision.models.resnet18,
#     "resnet50_torch": torchvision.models.resnet50,
#     "densenet_cifar": densenet_cifar
# }
sys.path.pop(0)


'/scratch_local/arubinstein17-163577/starlight'

# Functions

In [145]:
def get_link(raw_link):
    id = raw_link.split("/d/")[-1].split("/")[0]
    return f"https://drive.google.com/uc?id={id}"


def get_model(chkpt_path, model_name, model_settings={}):
    model = MODELS_DICT[model_name](**model_settings)
    model.load_state_dict(torch.load(chkpt_path)["state_dict"])
    return model


def get_resnet18_cifar(chkpt_path):
    return get_model(
        chkpt_path,
        "resnet18_cifar",
        model_settings={"num_classes": 10}
)

# @torch.no_grad
# def compute_margin_for_batch(model_1, model_2, x, device="cuda:0"):

#     x.requires_grad = True

#     with torch.enable_grad():
#         output_1 = model_1(x)
#         output_2 = model_2(x)

#     output_1 = output_1.to(device)
#     output_2 = output_2.to(device)

#     num_classes = output_1.shape[-1]
#     margins_for_all_class_pairs = []
#     used_classes = range(num_classes)

#     first_model_grads = {}
#     second_model_grads = {}

#     with torch.enable_grad():
#         for class_i in used_classes:
#             first_model_grads[class_i] = get_grad(
#                 output_1[..., class_i],
#                 x,
#                 retain_graph=True
#             )
#             second_model_grads[class_i] = get_grad(
#                 output_2[..., class_i],
#                 x,
#                 retain_graph=(class_i != num_classes - 1)
#             )

#     cartesian_product = list(itertools.product(used_classes, repeat=2))
#     for i, (class_i, class_j) in (enumerate(cartesian_product)):

#         if class_i == class_j:
#             continue

#         output_1_i = output_1[..., class_i]
#         output_2_j = output_2[..., class_j]

#         grad_1_i = first_model_grads[class_i]
#         grad_2_j = second_model_grads[class_j]
#         batch_margins_for_class_pair = compute_margin(
#             output_1_i,
#             output_2_j,
#             grad_1_i,
#             grad_2_j,
#             norm="l1"
#         )

#         margins_for_all_class_pairs.append(batch_margins_for_class_pair.mean())
#     return torch.Tensor(margins_for_all_class_pairs).mean()


@torch.no_grad
def compute_margin_for_batch(model, x, device="cuda:0"):

    x.requires_grad = True

    with torch.enable_grad():
        output = model(x)

    output = output.to(device)

    num_classes = output.shape[-1]
    margins_for_all_class_pairs = []
    used_classes = range(num_classes)

    model_grads = {}

    with torch.enable_grad():
        for class_i in used_classes:
            model_grads[class_i] = get_grad(
                output[..., class_i],
                x,
                retain_graph=True
            )

    # cartesian_product = list(itertools.product(used_classes, repeat=2))
    # for i, (class_i, class_j) in (enumerate(cartesian_product)):
    for i in range(len(used_classes)):
        for j in range(i + 1, len(used_classes)):

            class_i = used_classes[i]
            class_j = used_classes[j]

            output_i = output[..., class_i]
            output_j = output[..., class_j]

            grad_i = model_grads[class_i]
            grad_j = model_grads[class_j]
            batch_margins_for_class_pair = compute_margin(
                output_i,
                output_j,
                grad_i,
                grad_j,
                norm="l1"
            )

            margins_for_all_class_pairs.append(batch_margins_for_class_pair.mean())
    return torch.Tensor(margins_for_all_class_pairs).mean()


def get_grad(output_i, x, retain_graph):
    assert x.requires_grad == True
    differentiable_value = output_i.sum()
    differentiable_value.backward(inputs=x, retain_graph=retain_graph)
    grad = x.grad
    x.grad = None
    return grad


def compute_margin(output_i, output_j, grad_i, grad_j, norm="l1"):
    # eq. 7 here: https://proceedings.neurips.cc/paper_files/paper/2018/file/42998cf32d552343bc8e460416382dca-Paper.pdf
    # d = |f_i(x) - f_j(x)| / |grad_x f_i(x) - grad_x f_j (x)|_q
    assert norm == "l1"
    if norm == "l1":
        return (
                (output_i - output_j).abs()
            /
                (grad_i - grad_j).abs().sum()
        )
    else:
        raise NotImplementedError("Only L1 norm is supported")


# def margin_on_dataloader(
#     model_1,
#     model_2,
#     dataloader,
#     percent=0.01,
#     device="cuda:0"
# ):

#     model_1.eval()
#     model_2.eval()
#     model_1.to(device)
#     model_2.to(device)
#     dataloader_len = len(dataloader)
#     if percent == 1:
#         active_batches = list(range(dataloader_len))
#     else:
#         total_batches = int(percent * dataloader_len)
#         active_batches = np.linspace(0, dataloader_len, total_batches)
#         active_batches = active_batches.astype(int)

#     res = []
#     pb = tqdm(total=dataloader_len, desc='batch', position=0)
#     for i, batch in (enumerate(dataloader)):
#         pb.update(1)
#         if i not in active_batches:
#             continue
#         x, y = batch
#         x = x.to(device)
#         batch_margin = compute_margin_for_batch(
#             model_1,
#             model_2,
#             x,
#             device=device
#         )
#         res.append(batch_margin)
#     return torch.Tensor(res).mean().item()


def margin_on_dataloader(
    model,
    dataloader,
    percent=0.01,
    device="cuda:0"
):

    model.eval()
    model.to(device)

    dataloader_len = len(dataloader)
    if percent == 1:
        active_batches = list(range(dataloader_len))
    else:
        total_batches = int(percent * dataloader_len)
        active_batches = np.linspace(0, dataloader_len, total_batches)
        active_batches = active_batches.astype(int)

    res = []
    pb = tqdm(total=dataloader_len, desc='batch', position=0)
    for i, batch in (enumerate(dataloader)):
        pb.update(1)
        if i not in active_batches:
            continue
        x, y = batch
        x = x.to(device)
        batch_margin = compute_margin_for_batch(
            model,
            x,
            device=device
        )
        res.append(batch_margin)
    return torch.Tensor(res).mean().item()


# def compute_results(
#     anchors,
#     stars,
#     dataloader,
#     cache_path,
#     percent=1,
#     load_model=get_resnet18_cifar,
#     recompute_substring=None
# ):
#     res = {}
#     if os.path.exists(cache_path):
#         print(f"Loading existing results from {cache_path}")
#         res = torch.load(cache_path)
#     else:
#         res = {}
#     all_model_pairs = list(itertools.product(anchors + stars, repeat=2))
#     dataset_root = dataloader.dataset.dataset.__dict__["root"]
#     for model_1, model_2 in all_model_pairs:
#         if model_1 == model_2:
#             continue

#         df_id = f"{model_1}_{model_2}_{dataset_root}_p{percent}"
#         to_recompute = (not df_id in res)
#         if not to_recompute and recompute_substring is not None:
#             to_recompute = recompute_substring in df_id

#         if to_recompute:
#             print(f"Computing results for {df_id}")
#             res[df_id] = margin_on_dataloader(
#                 load_model(model_1),
#                 load_model(model_2),
#                 dataloader,
#                 percent=percent,
#                 device="cuda:0"
#             )
#         else:
#             print(f"Reusing results for {df_id}")
#     torch.save(res, cache_path)
#     return res

def compute_results(
    model_paths,
    dataloader,
    cache_path,
    percent=1,
    load_model=get_resnet18_cifar,
    recompute_substring=None
):
    res = {}
    if os.path.exists(cache_path):
        print(f"Loading existing results from {cache_path}")
        res = torch.load(cache_path)
    else:
        res = {}
    # all_model_pairs = list(itertools.product(anchors + stars, repeat=2))

    dataset_root = dataloader.dataset.dataset.__dict__["root"]
    dataset_name = dataset_root.split(os.sep)[-1]
    # for model_1, model_2 in all_model_pairs:
    for model_path in model_paths:

        path_split = model_path.split(os.sep)[-2:]
        assert len(path_split) == 2
        model_type = path_split[0]
        model_id = path_split[1].split(".")[0]
        df_id = f"{model_type}_{model_id}_{dataset_name}_p{percent}"
        to_recompute = (not df_id in res)
        if not to_recompute and recompute_substring is not None:
            to_recompute = recompute_substring in df_id

        if to_recompute:
            print(f"Computing results for {df_id}")
            res[df_id] = {}
            res[df_id]["margin"] = margin_on_dataloader(
                load_model(model_path),
                dataloader,
                percent=percent,
                device="cuda:0"
            )
            res[df_id]["model_type"] = model_type
            res[df_id]["dataset_name"] = dataset_name
            res[df_id]["percent"] = percent
        else:
            print(f"Reusing results for {df_id}")
    torch.save(res, cache_path)
    return res

In [13]:
get_link("https://drive.google.com/file/d/1g-TxEGbORtHmxVEefoJtk2yxSf_mHL28/view?usp=drive_link")

'https://drive.google.com/uc?id=1g-TxEGbORtHmxVEefoJtk2yxSf_mHL28'

In [8]:
!gdown https://drive.google.com/uc?id=1g-TxEGbORtHmxVEefoJtk2yxSf_mHL28

Downloading...
From (original): https://drive.google.com/uc?id=1g-TxEGbORtHmxVEefoJtk2yxSf_mHL28
From (redirected): https://drive.google.com/uc?id=1g-TxEGbORtHmxVEefoJtk2yxSf_mHL28&confirm=t&uuid=98788ac5-cd93-4eba-bc9b-97a69c1064f5
To: /scratch_local/arubinstein17-163577/starlight/notebooks/cifar10_resnet18.zip
100%|███████████████████████████████████████| 4.37G/4.37G [00:42<00:00, 104MB/s]


In [10]:
# !tar -xzf /scratch_local/arubinstein17-163577/starlight/notebooks/cifar10_resnet18.zip
!unzip cifar10_resnet18.zip -d ./

Archive:  cifar10_resnet18.zip
   creating: ./cifar10_resnet18/
   creating: ./cifar10_resnet18/anchors/
  inflating: ./cifar10_resnet18/anchors/2oo0upw1.pt  
  inflating: ./cifar10_resnet18/anchors/3m2c3guk.pt  
  inflating: ./cifar10_resnet18/anchors/448kfne2.pt  
  inflating: ./cifar10_resnet18/anchors/52l22gc1.pt  
  inflating: ./cifar10_resnet18/anchors/5791j5is.pt  
  inflating: ./cifar10_resnet18/anchors/5hrwt8sd.pt  
  inflating: ./cifar10_resnet18/anchors/6ids7w54.pt  
  inflating: ./cifar10_resnet18/anchors/6kxd228s.pt  
  inflating: ./cifar10_resnet18/anchors/7sckybft.pt  
  inflating: ./cifar10_resnet18/anchors/7y4wdzel.pt  
  inflating: ./cifar10_resnet18/anchors/8wmcs3t5.pt  
  inflating: ./cifar10_resnet18/anchors/ahszhrke.pt  
  inflating: ./cifar10_resnet18/anchors/atn3ky2k.pt  
  inflating: ./cifar10_resnet18/anchors/ce0idfac.pt  
  inflating: ./cifar10_resnet18/anchors/e4fr64ya.pt  
  inflating: ./cifar10_resnet18/anchors/eq157ykm.pt  
  inflating: ./cifar10_resnet18

In [16]:
cifar_path = "./cifar10_resnet18/"
stars_path = os.path.join(cifar_path, "stars")
anchors_path = os.path.join(cifar_path, "anchors")
# anchors = os.listdir("./cifar10_resnet18/anchors")
anchors = [os.path.join(anchors_path, name) for name in os.listdir(anchors_path)]
stars = [os.path.join(stars_path, name) for name in os.listdir(stars_path)]

# Compute margin

In [19]:
anchor_1 = anchors[0]
anchor_2 = anchors[1]
star_1 = stars[0]
star_2 = stars[1]

In [25]:
cifar10_train_dataloader, _, _ = load_cifar10(
    batch_size=256,
    num_workers=8,
    img_size=32,
    normalize=True,
    resize=False,
    horizontal_flip=False,
    vertical_flip=False,
    random_crop_resize=False,
    random_resize_crop=False,
    color_jitter=False,
    rotation_range=0,
    pad_random_crop=False,
    random_one_aug=False,
    train_set_fraction=1.0,
    return_ds=False,
)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Using 45000 images for training
Using 5000 images for validation
Using train transform Compose(
    IdentityTransform()
    ToTensor()
    Normalize(mean=[0.485 0.456 0.406], std=[0.229 0.224 0.225])
)


In [29]:
for batch in cifar10_train_dataloader:
    x, y = batch
    break

In [86]:
len(cifar10_train_dataloader)

176

In [43]:
model_1.keys()

dict_keys(['state_dict'])

In [45]:
model_1 = get_resnet18_cifar(anchor_1, "resnet18_cifar", model_settings={"num_classes": 10})

In [46]:
output = model_1(x)

## Resnet18 Cifar10

In [58]:
class_i = 0
class_j = 1
model_1 = get_resnet18_cifar(anchor_1)
model_2 = get_resnet18_cifar(anchor_2)
margin = compute_margin(model_1, model_2, class_i, class_j, x, norm="l1")

In [None]:
margin_on_dataloader(get_resnet18_cifar(anchors[0]), get_resnet18_cifar(anchors[1]), cifar10_train_dataloader, percent=0.10, device="cuda:0")

In [122]:
anchors[:2]

['./cifar10_resnet18/anchors/2oo0upw1.pt',
 './cifar10_resnet18/anchors/3m2c3guk.pt']

In [173]:
def plot_dict(res, to_show=True, filter_dict={}, drop_columns=[], group_aggregate_list=[]):

    df = pd.DataFrame.from_dict(res, orient='index')

    for column_name, value in filter_dict.items():
        df = df[df[column_name] == value]

    for column_name in drop_columns:
        if column_name == "index":
            df.reset_index(drop=True, inplace=True)
        else:
            df = df.drop(column_name, axis=1)
    print(df)
    for new_index in group_aggregate_list:
        df = df.groupby(new_index).agg([(f'mean', 'mean'), (f'std', 'std')])

    if to_show:
        display(df)
    return df

In [158]:
plot_dict(cifar10_resnet18_results, to_show=True)

Unnamed: 0,margin,model_type,dataset_root,percent
anchors_2oo0upw1_cifar10_p0.05,8.3e-05,anchors,/mnt/qb/work/oh/arubinstein17/datasets/cifar10,0.05
anchors_3m2c3guk_cifar10_p0.05,8e-05,anchors,/mnt/qb/work/oh/arubinstein17/datasets/cifar10,0.05
stars_0eiwx0xx_cifar10_p0.05,7e-05,stars,/mnt/qb/work/oh/arubinstein17/datasets/cifar10,0.05
stars_4s24tjof_cifar10_p0.05,7.1e-05,stars,/mnt/qb/work/oh/arubinstein17/datasets/cifar10,0.05


Unnamed: 0,margin,model_type,dataset_root,percent
anchors_2oo0upw1_cifar10_p0.05,8.3e-05,anchors,/mnt/qb/work/oh/arubinstein17/datasets/cifar10,0.05
anchors_3m2c3guk_cifar10_p0.05,8e-05,anchors,/mnt/qb/work/oh/arubinstein17/datasets/cifar10,0.05
stars_0eiwx0xx_cifar10_p0.05,7e-05,stars,/mnt/qb/work/oh/arubinstein17/datasets/cifar10,0.05
stars_4s24tjof_cifar10_p0.05,7.1e-05,stars,/mnt/qb/work/oh/arubinstein17/datasets/cifar10,0.05


In [174]:
df = plot_dict(cifar10_resnet18_results, to_show=True, filter_dict={}, drop_columns=["index", "dataset_root"], group_aggregate_list=["model_type"])

     margin model_type  percent
0  0.000083    anchors     0.05
1  0.000080    anchors     0.05
2  0.000070      stars     0.05
3  0.000071      stars     0.05


Unnamed: 0_level_0,margin,margin,percent,percent
Unnamed: 0_level_1,mean,std,mean,std
model_type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
anchors,8.2e-05,2.162845e-06,0.05,0.0
stars,7e-05,7.760381e-07,0.05,0.0


In [171]:
# plot_dict(res, to_show=True, group_by="model_type", filter_by={"percent": 0.05})
# df = plot_dict(cifar10_resnet18_results, to_show=True, filter_dict={}, group_aggregate_dict={"model_type": "mean"})
df = plot_dict(cifar10_resnet18_results, to_show=True, filter_dict={}, drop_columns=["index", "dataset_root", "percent"], group_aggregate_dict={"model_type": "margin"})

     margin model_type
0  0.000083    anchors
1  0.000080    anchors
2  0.000070      stars
3  0.000071      stars


Unnamed: 0_level_0,margin,margin
Unnamed: 0_level_1,margin mean,margin std
model_type,Unnamed: 1_level_2,Unnamed: 2_level_2
anchors,8.2e-05,2.162845e-06
stars,7e-05,7.760381e-07


In [None]:
df.to_csv("tmp.csv")

In [141]:
cifar10_resnet18_results = compute_results(
    anchors[:2] + stars[:2],
    cifar10_train_dataloader,
    "./cifar10_resnet18_results.pt",
    percent=0.05,
    load_model=get_resnet18_cifar,
    # recompute_substring="p0"
    recompute_substring=None
)

Loading existing results from ./cifar10_resnet18_results.pt
Computing results for anchors_2oo0upw1_cifar10_p0.05


batch: 100%|██████████| 176/176 [00:03<00:00, 52.28it/s]


Computing results for anchors_3m2c3guk_cifar10_p0.05


batch: 100%|██████████| 176/176 [00:03<00:00, 50.29it/s]


Computing results for stars_0eiwx0xx_cifar10_p0.05


batch: 100%|██████████| 176/176 [00:03<00:00, 51.10it/s]


Computing results for stars_4s24tjof_cifar10_p0.05


batch: 100%|██████████| 176/176 [00:03<00:00, 50.75it/s]


In [124]:
compute_results(
    anchors[:2],
    stars[:2],
    cifar10_train_dataloader,
    "./cifar10_resnet18_results.pt",
    percent=0.05,
    load_model=get_resnet18_cifar,
    recompute_substring=None
)

Computing results for ./cifar10_resnet18/anchors/2oo0upw1.pt_./cifar10_resnet18/anchors/3m2c3guk.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 30.66it/s]


Computing results for ./cifar10_resnet18/anchors/2oo0upw1.pt_./cifar10_resnet18/stars/0eiwx0xx.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 32.08it/s]


Computing results for ./cifar10_resnet18/anchors/2oo0upw1.pt_./cifar10_resnet18/stars/4s24tjof.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 32.20it/s]


Computing results for ./cifar10_resnet18/anchors/3m2c3guk.pt_./cifar10_resnet18/anchors/2oo0upw1.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 31.94it/s]


Computing results for ./cifar10_resnet18/anchors/3m2c3guk.pt_./cifar10_resnet18/stars/0eiwx0xx.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 31.57it/s]


Computing results for ./cifar10_resnet18/anchors/3m2c3guk.pt_./cifar10_resnet18/stars/4s24tjof.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 31.35it/s]


Computing results for ./cifar10_resnet18/stars/0eiwx0xx.pt_./cifar10_resnet18/anchors/2oo0upw1.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 32.52it/s]


Computing results for ./cifar10_resnet18/stars/0eiwx0xx.pt_./cifar10_resnet18/anchors/3m2c3guk.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 31.92it/s]


Computing results for ./cifar10_resnet18/stars/0eiwx0xx.pt_./cifar10_resnet18/stars/4s24tjof.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 31.17it/s]


Computing results for ./cifar10_resnet18/stars/4s24tjof.pt_./cifar10_resnet18/anchors/2oo0upw1.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 31.63it/s]


Computing results for ./cifar10_resnet18/stars/4s24tjof.pt_./cifar10_resnet18/anchors/3m2c3guk.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 31.77it/s]


Computing results for ./cifar10_resnet18/stars/4s24tjof.pt_./cifar10_resnet18/stars/0eiwx0xx.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05


batch: 100%|██████████| 176/176 [00:05<00:00, 31.61it/s]


{'./cifar10_resnet18/anchors/2oo0upw1.pt_./cifar10_resnet18/anchors/3m2c3guk.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05': tensor(8.3112e-05),
 './cifar10_resnet18/anchors/2oo0upw1.pt_./cifar10_resnet18/stars/0eiwx0xx.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05': tensor(7.4942e-05),
 './cifar10_resnet18/anchors/2oo0upw1.pt_./cifar10_resnet18/stars/4s24tjof.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05': tensor(7.3900e-05),
 './cifar10_resnet18/anchors/3m2c3guk.pt_./cifar10_resnet18/anchors/2oo0upw1.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05': tensor(8.1999e-05),
 './cifar10_resnet18/anchors/3m2c3guk.pt_./cifar10_resnet18/stars/0eiwx0xx.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05': tensor(7.6259e-05),
 './cifar10_resnet18/anchors/3m2c3guk.pt_./cifar10_resnet18/stars/4s24tjof.pt_/mnt/qb/work/oh/arubinstein17/datasets/cifar10_p0.05': tensor(7.6018e-05),
 './cifar10_resnet18/stars/0eiwx0xx.pt_./cifar10_resnet18/anchors/2oo0upw1.pt_

In [106]:
margin_on_dataloader(get_resnet18_cifar(anchors[0]), get_resnet18_cifar(anchors[1]), cifar10_train_dataloader, percent=1.00, device="cuda:0")

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175]


batch: 100%|██████████| 176/176 [01:42<00:00,  1.72it/s]


tensor(8.2205e-05)

In [84]:
%%prun
margin = compute_margin_for_batch(model_1, model_2, x)

100%|██████████| 10/10 [00:15<00:00,  1.50s/it]
100it [00:00, 2172.88it/s]

 




         6317 function calls (6057 primitive calls) in 16.220 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       20   14.986    0.749   14.986    0.749 {method 'run_backward' of 'torch._C._EngineBase' objects}
       40    0.828    0.021    0.828    0.021 {built-in method torch.conv2d}
       40    0.195    0.005    0.195    0.005 {built-in method torch.batch_norm}
       34    0.069    0.002    0.069    0.002 {built-in method torch.relu}
       90    0.026    0.000    0.041    0.000 297611172.py:82(compute_margin)
       16    0.026    0.002    1.090    0.068 resnet_cifar_std.py:36(forward)
        2    0.016    0.008    0.016    0.008 {method 'to' of 'torch._C.TensorBase' objects}
      142    0.015    0.000    0.015    0.000 {method 'acquire' of '_thread.lock' objects}
        1    0.013    0.013   16.219   16.219 <string>:1(<module>)
      180    0.011    0.000    0.011    0.000 {method 'abs' of 'torch._C.TensorBase

In [85]:
print(margin)

tensor(0.0003)
