In [1]:
import os
os.environ['MPLCONFIGDIR'] = '/data/vision/torralba/naturally_robust_models/matplotlib'

%matplotlib inline

The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed')).History will not be written to the database.


In [2]:
import torch
import torch.nn as nn
# import timm
import numpy as np
import torchvision
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
# import robustbench
from tqdm.auto import tqdm
from pathlib import Path
import os
import plotly.express as px

from zz_gradnorm_statistics_utils import *

from swin_transformer_timm_version import *

device = torch.device('cuda', 5)
# device = 'cpu'

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


In [3]:
xs, ys, ds, dataloader = get_data(N=10000, batch_size=32)

In [4]:
xs_viz, ys_viz, ds_viz, dataloader_viz = get_viz_fig_data(batch_size=32)



In [5]:
def get_model_statistics_viz(model_path, ema, dataloader_viz, device='cpu'):
  model = load_model(model_path, ema=ema)
  stats = get_statistics(model, dataloader_viz, device=device)
  return stats

In [6]:
def get_model_attack_statistics_viz(model_path, ema, dataloader_viz, device='cpu'):
  model = load_model(model_path, ema=ema)
  stats = get_attack_statistics(model, dataloader_viz, device=device)
  return stats

## Fig 1: Loss-input gradients

In [7]:
fig1 = False

In [8]:
if fig1:
  resnet_nat_relu = get_model_statistics_viz(f'outputs/nattrain_resnet_relu/2024-02-16_16-59-02/snapshots/snapshot-0-100.pth.tar', True, dataloader_viz, device=device)
  swinb_nat = get_model_statistics_viz(f'outputs/advtrain_swinb/2024-02-18_14-42-14/snapshots/snapshot-0-100.pth.tar', True, dataloader_viz, device=device)

  resnet_at_relu = get_model_statistics_viz(f'outputs/advtrain_resnet_relu/2024-02-18_20-06-25/last.pth.tar', True, dataloader_viz, device=device)
  swinb_at = get_model_statistics_viz(f'outputs/advtrain_swinb_orig/last.pth.tar', True, dataloader_viz, device=device)

  resnet_gn_gelu = get_model_statistics_viz(f'outputs/gradnorm_resnet_gelu/2024-02-03_22-07-28/last.pth.tar', True, dataloader_viz, device=device)
  swinb_gn = get_model_statistics_viz(f'outputs/gradnorm_swinb_variant/2024-02-14_11-30-41/last.pth.tar', True, dataloader_viz, device=device)

In [9]:
if fig1:
  fig1_stats = [resnet_nat_relu, swinb_nat, resnet_at_relu, swinb_at, resnet_gn_gelu, swinb_gn]

In [10]:
if fig1:
  xs_viz_idx = range(10)
  xs_viz_idx = [2, 4, 5, 6, 7]
  plt.rcParams['axes.grid'] = False
  plt.rcParams['xtick.bottom'] = False
  plt.rcParams['xtick.labelbottom'] = False
  plt.rcParams['ytick.left'] = False
  plt.rcParams['ytick.labelleft'] = False
  for idx in xs_viz_idx:
    print(*[stats['grad_class_x'][idx].abs().sum().item() for stats in fig1_stats])
    print(*[stats['grad_class_x'][idx].abs().sum().item() for stats in fig1_stats])
    plot_side_by_side_normalize(xs_viz[idx], *[stats['grad_class_x'][idx] for stats in fig1_stats], normalize=(False,)+(True,)*len(fig1_stats))

## Tab1: Public model stats

In [11]:
tab1 = False

In [12]:
import robustbench

def load_public_model(model_name):
  if model_name[0].isupper():
    model = robustbench.utils.load_model(model_name, dataset='imagenet', threat_model='Linf').eval().cpu()
  else:
    if 'random' in model_name:
      model = timm.create_model(model_name[:-len('_random')], pretrained=False)
    else:
      model = timm.create_model(model_name, pretrained=True)
    model = add_imagenet_normalization(model)
  print(model)
  return model.eval().cpu()

In [13]:
def return_tab1_metrics(xs, ys, stats, atkstats):
  robacc_flag_pgd10 = (atkstats['out'].argmax(-1) == ys)
  robacc_pgd10 = 100 * robacc_flag_pgd10.float().mean().item()

  l1norm_unconditional = stats['grad_loss_x'].flatten(1).abs().sum(1).mean().item()
  l1norm_pgd10_vulnerable = stats['grad_loss_x'][~robacc_flag_pgd10].flatten(1).abs().sum(1).mean().item()
  l1norm_pgd10_robust = stats['grad_loss_x'][robacc_flag_pgd10].flatten(1).abs().sum(1).mean().item()

  return robacc_pgd10, l1norm_unconditional, l1norm_pgd10_robust, l1norm_pgd10_vulnerable

### Resnet 50 Nat

In [14]:
if tab1:
  stats_public_resnet_nat = get_statistics(load_public_model('resnet50'), dataloader, device=device)

In [15]:
if tab1:
  atkstats_public_resnet_nat = get_attack_statistics(load_public_model('resnet50'), dataloader, device=device)

In [16]:
if tab1:
  print(return_tab1_metrics(xs, ys, stats_public_resnet_nat, atkstats_public_resnet_nat))

(0.24999999441206455, 3013.835693359375, 0.03041829541325569, 3021.388916015625)

### Resnet 50 AdvTrain

In [17]:
if tab1:
  stats_public_resnet_at = get_statistics(load_public_model('Salman2020Do_R50'), dataloader, device=device)

In [18]:
if tab1:
  atkstats_public_resnet_at = get_attack_statistics(load_public_model('Salman2020Do_R50'), dataloader, device=device)

In [19]:
if tab1:
  print(return_tab1_metrics(xs, ys, stats_public_resnet_at, atkstats_public_resnet_at))

(39.980000257492065, 56.819786071777344, 11.783616065979004, 86.81889343261719)

### SwinB Nat

In [20]:
if tab1:
  stats_public_swinb_nat = get_statistics(load_public_model('swin_base_patch4_window7_224'), dataloader, device=device)

In [21]:
if tab1:
  atkstats_public_swinb_nat = get_attack_statistics(load_public_model('swin_base_patch4_window7_224'), dataloader, device=device)

In [22]:
if tab1:
  print(return_tab1_metrics(xs, ys, stats_public_swinb_nat, atkstats_public_swinb_nat))

(2.6499999687075615, 3891.702392578125, 546.509033203125, 3982.7626953125)

### SwinB AT - Liu2023Comprehensive_Swin-B

In [23]:
if tab1:
  stats_public_swinb_at = get_statistics(load_public_model('Liu2023Comprehensive_Swin-B'), dataloader, device=device)

In [24]:
if tab1:
  atkstats_public_swinb_at = get_attack_statistics(load_public_model('Liu2023Comprehensive_Swin-B'), dataloader, device=device)

In [25]:
if tab1:
  print(return_tab1_metrics(xs, ys, stats_public_swinb_at, atkstats_public_swinb_at))

(59.24999713897705, 37.427303314208984, 10.967606544494629, 75.89938354492188)

### SwinL Nat

In [26]:
if tab1:
  stats_public_swinl_nat = get_statistics(load_public_model('swin_large_patch4_window7_224'), dataloader, device=device)

In [27]:
if tab1:
  atkstats_public_swinl_nat = get_attack_statistics(load_public_model('swin_large_patch4_window7_224'), dataloader, device=device)

In [28]:
if tab1:
  print(return_tab1_metrics(xs, ys, stats_public_swinl_nat, atkstats_public_swinl_nat))

(.070000022649765, 2408.36279296875, 285.92877197265625, 2453.225830078125)

### SwinL AT - Liu2023Comprehensive_Swin-L

In [29]:
if tab1:
  stats_public_swinl_at = get_statistics(load_public_model('Liu2023Comprehensive_Swin-L'), dataloader, device=device)

In [30]:
if tab1:
  atkstats_public_swinl_at = get_attack_statistics(load_public_model('Liu2023Comprehensive_Swin-L'), dataloader, device=device)

In [31]:
if tab1:
  print(return_tab1_metrics(xs, ys, stats_public_swinl_at, atkstats_public_swinl_at))

(61.330002546310425, 33.73237228393555, 7.308244228363037, 75.6406021118164)

### Fig2 PGD100 vs epsilon

In [32]:
fig2 = False

In [33]:
import numpy as np
import pandas as pd

def get_robustness_curve(dirname):
  result = []
  for file in os.listdir(dirname):
    if file.endswith(".npy"):
        result.append(np.load(os.path.join(dirname, file)))
  print(len(result))
  result = np.concatenate(result)
  
  accs = []
  for eps in np.linspace(0, 32, 100):
    accs.append((eps, 100 * np.greater(result, eps/255.).mean()))
  return accs

In [34]:
if fig2:
  swinb_at = get_robustness_curve('./outputs/advtrain_swinb_orig/robust_curve/pgd100')
  swinb_at = pd.DataFrame(swinb_at, columns=['eps', 'acc'])
  swinb_at['model'] = 'AdvTrain'

In [35]:
if fig2:
  swinb_gn_variant = get_robustness_curve('outputs/gradnorm_swinb_variant/2024-02-14_11-30-41/robust_curve/pgd100')
  swinb_gn_variant = pd.DataFrame(swinb_gn_variant, columns=['eps', 'acc'])
  swinb_gn_variant['model'] = 'GN'

In [36]:
if fig2:
  fig2_data = pd.concat([swinb_at, swinb_gn_variant])

In [37]:
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

if fig2:
  fig = plt.figure(figsize=(12, 6))
  # mpl.pyplot.figure(figsize=(250, 350))
  sns.set(font_scale=1.3)
  ax = sns.lineplot(data=fig2_data, x="eps", y="acc", hue='model')
  ax.set(xlabel=r'Adversarial strength $\epsilon$ (in units of 1/255)', ylabel='PGD100 robust accuracy (%)')
  ax.lines[0].set_color('red')
  ax.lines[1].set_color(mpl.rcParams['lines.color'])
  plt.legend([],[], frameon=False)

  fig.savefig("zzz_arxiv_figures/pgd100_epsilon.pdf")

### Fig 3 Perturbations

In [38]:
fig3 = False

In [39]:
if fig3:
  resnet_nat_relu = get_model_attack_statistics_viz(f'outputs/nattrain_resnet_relu/2024-02-16_16-59-02/snapshots/snapshot-0-100.pth.tar', True, dataloader_viz, device=device)
  swinb_nat = get_model_attack_statistics_viz(f'outputs/advtrain_swinb/2024-02-18_14-42-14/snapshots/snapshot-0-100.pth.tar', True, dataloader_viz, device=device)

  resnet_at_relu = get_model_attack_statistics_viz(f'outputs/advtrain_resnet_relu/2024-02-18_20-06-25/last.pth.tar', True, dataloader_viz, device=device)
  swinb_at = get_model_attack_statistics_viz(f'outputs/advtrain_swinb_orig/last.pth.tar', True, dataloader_viz, device=device)

  resnet_gn_gelu = get_model_attack_statistics_viz(f'outputs/gradnorm_resnet_gelu/2024-02-03_22-07-28/last.pth.tar', True, dataloader_viz, device=device)
  swinb_gn = get_model_attack_statistics_viz(f'outputs/gradnorm_swinb_variant/2024-02-14_11-30-41/last.pth.tar', True, dataloader_viz, device=device)

In [40]:
if fig3:
  fig3_stats = [resnet_nat_relu, swinb_nat, resnet_at_relu, swinb_at, resnet_gn_gelu, swinb_gn]

In [41]:
if fig3:
  resnet_nat_relu.keys()

In [42]:
if fig3:
  # xs_viz_idx = range(10)
  xs_viz_idx = [2, 4, 5, 6, 7]
  # xs_viz_idx = [0, 1, 3, 8, 9]
  plt.rcParams['axes.grid'] = False
  plt.rcParams['xtick.bottom'] = False
  plt.rcParams['xtick.labelbottom'] = False
  plt.rcParams['ytick.left'] = False
  plt.rcParams['ytick.labelleft'] = False
  for idx in xs_viz_idx:
    print(*[(stats['out'][idx].argmax(-1) == ys_viz[idx]).item() for stats in fig3_stats])
    plot_side_by_side_normalize(xs_viz[idx], *[stats['atk'][idx] - xs_viz[idx] for stats in fig3_stats], normalize=(False,)+(True,)*len(fig3_stats))

## Fig 5 GradNorm lambda

In [43]:
fig5 = False

In [44]:
data_gradnorm = [
  (1.2, 0.8, 79.9564, 54.9949, 79.8600, 46.2238),
  (1.0, 1.0, 79.0584, 56.6449, 78.9011, 49.6703), # arxiv_outputs/gradnorm_swinb_finetuning_pareto_10/2024-04-03_18-17-15/log_rank0.txt
  (0.9, 1.1, 78.5004, 57.4469, 78.5200, 50.7600), # arxiv_outputs/gradnorm_swinb_finetuning_pareto_11/2024-04-04_16-10-23/log_rank0.txt
  (0.8, 1.2, 77.9424, 57.9408, 77.9400, 51.4875), # arxiv_outputs/gradnorm_swinb_finetuning_control/2024-04-02_19-15-31/log_rank0.txt
  (0.7, 1.3, 77.2365, 58.3288, 77.1000, 51.9600), # arxiv_outputs/gradnorm_swinb_finetuning_pareto_13/2024-04-03_18-59-19/log_rank0.txt
  (0.6, 1.4, 76.0825, 58.7368, 75.8200, 52.5200), # arxiv_outputs/gradnorm_swinb_finetuning_pareto_14/2024-04-03_18-58-08/log_rank0.txt
  (0.5, 1.5, 74.3765, 58.7568, 74.2000, 52.4200), # arxiv_outputs/gradnorm_swinb_finetuning_pareto_15/2024-04-03_18-28-59/log_rank0.txt
  (0.4, 1.6, 71.4926, 58.1588, 70.5600, 51.1800), # arxiv_outputs/gradnorm_swinb_finetuning_pareto_16/2024-04-03_18-14-56/log_rank0.txt
  # PLACEHOLDER for 1.7
  (0.3, 1.7, 63.6167, 54.7469, 62.9200, 46.2540), # PLACE HOLDER arxiv_outputs/gradnorm_swinb_finetuning_pareto_17/2024-04-09_23-09-16/log_rank0.txt
  (0.2, 1.8, 40.9452, 39.0212, 40.0200, 29.0200), # arxiv_outputs/gradnorm_swinb_finetuning_pareto_18/2024-04-09_21-57-15/log_rank0.txt
  (0.1, 1.9, 0.1000, 0.1000, 0.1, 0.1) # arxiv_outputs/gradnorm_swinb_finetuning_pareto_19/2024-04-14_21-44-28/log_rank0.txt
]
columns_gradnorm = ['weight_ce', 'weight_gradnorm', 'pgd10_clean', 'pgd10_robust', 'autoattack_clean', 'autoattack_robust']

df_gradnorm = pd.DataFrame(data=data_gradnorm, columns=columns_gradnorm)

In [45]:
data_at = [
  # (False, 77.7384, 57.8328, 77.7600, 51.6600),
  (True, 77.2260, 59.3940, 77.2028, 56.1239)
]
columns_at = ['AT', 'pgd10_clean', 'pgd10_robust', 'autoattack_clean', 'autoattack_robust']

df_at = pd.DataFrame(data=data_at, columns=columns_at)

In [46]:
data_clean = [
  (84.19, 1.7800, 84.19, 00.0000),
]
columns_clean = ['pgd10_clean', 'pgd10_robust', 'autoattack_clean', 'autoattack_robust']

df_clean = pd.DataFrame(data=data_clean, columns=columns_clean)

In [47]:
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

if fig5:
  fig = plt.figure(figsize=(12, 6))
  sns.set(font_scale=1.3)
  # mpl.pyplot.figure(figsize=(250, 350))
  ax = sns.lineplot(data=df_gradnorm, x="weight_gradnorm", y="autoattack_robust")
  ax.lines[0].set_linestyle("--")
  ax.lines[0].set_color(mpl.rcParams['lines.color'])
  sns.lineplot(data=df_gradnorm, x="weight_gradnorm", y="autoattack_clean", ax=ax)
  ax.lines[1].set_color(mpl.rcParams['lines.color'])

  ax.axhline(df_at.loc[0, 'autoattack_robust'], ls='--', color='red')
  ax.axhline(df_at.loc[0, 'autoattack_clean'], ls='-', color='red')

  ax.axhline(df_clean.loc[0, 'autoattack_robust'], ls='--', color='orange')
  ax.axhline(df_clean.loc[0, 'autoattack_clean'], ls='-', color='orange')

  fig.savefig("zzz_arxiv_figures/pareto_lines.pdf")

## Fig 6 GradNorm distribution

In [48]:
fig6 = False

In [49]:
if fig6:
  stats_at = torch.load('/var/datasets/adrianr/input_norm/outputs/advtrain_swinb_orig/last_arxiv_stats/stats.pth.tar')
  stats_gradnorm_14 = torch.load('/var/datasets/adrianr/input_norm/arxiv_outputs/gradnorm_swinb_finetuning_pareto_14/2024-04-03_18-58-08/last_arxiv_stats/stats.pth.tar')

In [50]:
if fig6:
  stats_dict = {
    'Adversarial':stats_at,
    'GradNorm(0.6,1.4)':stats_gradnorm_14,
  }

In [51]:
def plot_fn(name, stats_dict, field, fn, *, index=Ellipsis, split=None, kde=False, **kwargs):
    fig = plt.figure(figsize=(12, 6))
    
    grads_dict = {k:v[field][index] for k,v in stats_dict.items()}
    if split is None:
        split = list(grads_dict.values())[0].shape[0]
    grads_fn_v = {k:torch.cat([fn(_x) for _x in v.split(split, dim=0)], 0) for k,v in grads_dict.items()}

    grads_fn_v_shapes = [v.shape for v in grads_fn_v.values()]
    print(grads_fn_v_shapes)

    columns = [[f'{k} (mean={v.mean():.2g}, std={v.std():.2g})'] * v.shape[0] for k,v in grads_fn_v.items()]    
    ax = sns.displot(data=pd.DataFrame(
        {
            name: torch.cat(list(grads_fn_v.values()), 0),
            "Training": sum(columns, []),
        }), kde=kde, x=name, hue='Training', **kwargs)
    plt.gca().set_ylabel(f'Count')

    sns.move_legend(
        ax, "lower center",
        bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False, ncols=1
    )

    return ax

sns.set(font_scale=1.1)

In [52]:
def plot_condhist_fn(name, stats_dict, field_v, fn_v, field_c, fn_c, *, index=Ellipsis, split=None, kde=False, **kwargs):
    fig = plt.figure(figsize=(12, 6))
    
    grads_dict_v = {k:v[field_v][index] for k,v in stats_dict.items()}
    grads_dict_c = {k:v[field_c][index] for k,v in stats_dict.items()}
    if split is None:
        split = list(grads_dict_v.values())[0].shape[0]
    grads_fn_v = {k:torch.cat([fn_v(_x) for _x in v.split(split, dim=0)], 0) for k,v in grads_dict_v.items()}
    grads_fn_c = {k:torch.cat([fn_c(_x) for _x in v.split(split, dim=0)], 0) for k,v in grads_dict_c.items()}

    grads_fn_v0 = {k:v[~c] for (k,v),(_,c) in zip(grads_fn_v.items(), grads_fn_c.items())}
    grads_fn_v1 = {k:v[c] for (k,v),(_,c) in zip(grads_fn_v.items(), grads_fn_c.items())}

    grads_fn_v_shapes = [v.shape for v in grads_fn_v.values()]
    print(grads_fn_v_shapes)
    grads_fn_v0_shapes = [v.shape for v in grads_fn_v0.values()]
    print(grads_fn_v0_shapes)
    grads_fn_v1_shapes = [v.shape for v in grads_fn_v1.values()]
    print(grads_fn_v1_shapes)

    columns0 = [[f'{k} (mean={v.mean():.2g},std={v.std():.2g})'] * v.shape[0] for k,v in grads_fn_v0.items()]    
    ax0 = sns.displot(data=pd.DataFrame(
        {
            name: torch.cat(list(grads_fn_v0.values()), 0),
            "Training": sum(columns0, []),
        }), kde=kde, x=name, hue='Training', **kwargs)
    plt.gca().set_ylabel(f'Count')

    sns.move_legend(
        ax0, "lower center",
        bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,
    )

    columns1 = [[f'{k} (mean={v.mean():.2g},std={v.std():.2g})'] * v.shape[0] for k,v in grads_fn_v1.items()]    
    ax1 = sns.displot(data=pd.DataFrame(
        {
            name: torch.cat(list(grads_fn_v1.values()), 0),
            "Training": sum(columns1, []),
        }), kde=kde, x=name, hue='Training', **kwargs)
    plt.gca().set_ylabel(f'Count')

    sns.move_legend(
        ax1, "lower center",
        bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,
    )

    return ax0, ax1

sns.set(font_scale=1.1)

In [53]:
import seaborn as sns
import matplotlib as mpl

if fig6:
        fig = plot_fn(r'Distribution of $|\nabla_x \mathcal{L}|_1$ per image', stats_dict, 'grad_loss_x', lambda x: x.abs().flatten(1).sum(1), 
                index=torch.arange(10000), kde=True, log_scale=(True, False), palette=sns.color_palette(palette=['red', mpl.rcParams['lines.color']], n_colors=2))
        # fig.savefig("arxiv_figures/transformers/l1_norm_distribution_per_image_grad_loss_x_10k.pdf") 
        fig.fig.savefig(f"zzz_arxiv_figures/l1_norm_distribution_per_image_grad_loss_x_10k.pdf")

In [54]:
if fig6:
        fig0, fig1 = plot_condhist_fn(r'Distribution of $|\nabla_x \mathcal{L}|_1$ per image', stats_dict, 
                'grad_loss_x', lambda x: x.abs().flatten(1).sum(1), 
                'out', lambda x: x.argmax(-1) == ys,
                index=torch.arange(10000), kde=True, log_scale=(True, False), palette=sns.color_palette(palette=['red', mpl.rcParams['lines.color']], n_colors=2))
        fig0.set(xlim=fig.ax.get_xlim(),ylim=fig.ax.get_ylim())
        fig1.set(xlim=fig.ax.get_xlim(),ylim=fig.ax.get_ylim())
        fig0.fig.savefig(f"zzz_arxiv_figures/class_conditional_0_l1_norm_distribution_per_image_grad_loss_x_10k.pdf")
        fig1.fig.savefig(f"zzz_arxiv_figures/class_conditional_1_l1_norm_distribution_per_image_grad_loss_x_10k.pdf")

### Fig 7 Absolute value distribution

In [55]:
fig7 = False

In [56]:
if fig7:
  stats_at = torch.load('/var/datasets/adrianr/input_norm/outputs/advtrain_swinb_orig/last_arxiv_stats/stats.pth.tar')
  stats_gradnorm_14 = torch.load('/var/datasets/adrianr/input_norm/arxiv_outputs/gradnorm_swinb_finetuning_pareto_14/2024-04-03_18-58-08/last_arxiv_stats/stats.pth.tar')

In [57]:
if fig7:
  stats_dict = {
    'Adversarial':stats_at,
    'GradNorm(0.6,1.4)':stats_gradnorm_14,
  }

In [58]:
if fig7:
        fig = plot_fn(r'distribution of $|\nabla_x \mathcal{L}|$', stats_dict, 'grad_loss_x', lambda x: x.abs().flatten(), 
                index=torch.arange(1000), log_scale=(True, False), rug=False, palette=sns.color_palette(palette=['red', mpl.rcParams['lines.color']], n_colors=2))
        fig.savefig(f"zzz_arxiv_figures/abs_distribution_per_pixel_grad_loss_x_advtrain_gradnorm_1k.pdf")

In [59]:
if fig7:
        fig_green = plot_fn(r'distribution of $|\nabla_x \mathcal{L}|_{green}$', stats_dict, 'grad_loss_x', lambda x: x[:, 1].abs().flatten(), 
                index=torch.arange(1000), log_scale=(True, False), rug=False, palette=sns.color_palette(palette=['red', mpl.rcParams['lines.color']], n_colors=2))
        fig_green.savefig(f"zzz_arxiv_figures/abs_distribution_per_pixel_green_grad_loss_x_advtrain_gradnorm_1k.pdf")

In [60]:
if fig7:
        fig_red = plot_fn(r'distribution of $|\nabla_x \mathcal{L}|_{red}$', stats_dict, 'grad_loss_x', lambda x: x[:, 0].abs().flatten(), 
                index=torch.arange(1000), log_scale=(True, False), rug=False, palette=sns.color_palette(palette=['red', mpl.rcParams['lines.color']], n_colors=2))
        fig_red.set(xlim=fig.ax.get_xlim(),ylim=fig_green.ax.get_ylim())
        fig_red.savefig(f"zzz_arxiv_figures/abs_distribution_per_pixel_red_grad_loss_x_advtrain_gradnorm_1k.pdf")

In [61]:
if fig7:
        fig_blue = plot_fn(r'Distribution of $|\nabla_x \mathcal{L}|_{blue}$', stats_dict, 'grad_loss_x', lambda x: x[:, 2].abs().flatten(), 
                index=torch.arange(1000), log_scale=(True, False), rug=False, palette=sns.color_palette(palette=['red', mpl.rcParams['lines.color']], n_colors=2))
        fig_blue.set(xlim=fig.ax.get_xlim(),ylim=fig_green.ax.get_ylim())
        fig_blue.savefig(f"zzz_arxiv_figures/abs_distribution_per_pixel_blue_grad_loss_x_advtrain_gradnorm_1k.pdf")

## Fig 10 Normalized curvature

In [62]:
fig10 = False

In [63]:
from typing import Tuple
import torch.nn.functional as F

def curvature_hessian_estimator(model: torch.nn.Module,
                        image: torch.Tensor,
                        target: torch.Tensor,
                        num_power_iter: int=20) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    model.eval()
    u = torch.randn_like(image)
    u /= torch.norm(u, p=2, dim=(1, 2, 3), keepdim=True)

    with torch.enable_grad():
        image = image.requires_grad_()
        out = model(image)
        # print((out.argmax(-1) == target).float().mean(), target)
        y = F.log_softmax(out, 1)
        output = F.nll_loss(y, target, reduction='none')
        model.zero_grad()
        # Gradients w.r.t. input
        gradients = torch.autograd.grad(outputs=output.sum(),
                                        inputs=image, create_graph=True)[0]
        gnorm = torch.norm(gradients, p=2, dim=(1, 2, 3))
        gnorm_1 = gradients.abs().sum((1, 2, 3))
        assert not gradients.isnan().any()

        # Power method to find singular value of Hessian
        for _ in range(num_power_iter):
            grad_vector_prod = (gradients * u.detach_()).sum()
            hessian_vector_prod = torch.autograd.grad(outputs=grad_vector_prod, inputs=image, retain_graph=True)[0]
            assert not hessian_vector_prod.isnan().any()

            hvp_norm = torch.norm(hessian_vector_prod, p=2, dim=(1, 2, 3), keepdim=True)
            u = hessian_vector_prod.div(hvp_norm + 1e-6) #1e-6 for numerical stability

        grad_vector_prod = (gradients * u.detach_()).sum()
        hessian_vector_prod = torch.autograd.grad(outputs=grad_vector_prod, inputs=image)[0]
        hessian_singular_value = (hessian_vector_prod * u.detach_()).sum((1, 2, 3))
    
    # curvature = hessian_singular_value / (grad_norm + epsilon) by definition
    curvatures = hessian_singular_value.abs().div(gnorm + 1e-6)
    hess = hessian_singular_value.abs()
    grad = gnorm
    grad_1 = gnorm_1
    
    return curvatures, hess, grad, grad_1


def measure_curvature(model: torch.nn.Module,
                      dataloader: torch.utils.data.DataLoader,
                      data_fraction: float=0.1,
                      batch_size: int=64,
                      num_power_iter: int=20,
                      device: torch.device='cpu') -> Tuple[tuple, tuple, tuple]:

    """
    Compute curvature, hessian norm and gradient norm of a subset of the data given by the dataloader.
    These values are computed using the power method, which requires setting the number of power iterations (num_power_iter).
    """

    model = model.eval().to(device)
    datasize = int(data_fraction * len(dataloader.dataset))
    max_batches = int(datasize / batch_size)
    curvature_agg = torch.zeros(size=(datasize,))
    grad_agg = torch.zeros(size=(datasize,))
    hess_agg = torch.zeros(size=(datasize,))
    grad_1_agg = torch.zeros(size=(datasize,))

    for idx, (data, target) in enumerate(tqdm(dataloader)):
        data, target = data.to(device).requires_grad_(), target.to(device)
        with torch.no_grad():
            curvatures, hess, grad, grad_1 = curvature_hessian_estimator(model, data, target, num_power_iter=num_power_iter)
        curvature_agg[idx * batch_size:(idx + 1) * batch_size] = curvatures.detach()
        hess_agg[idx * batch_size:(idx + 1) * batch_size] = hess.detach()
        grad_agg[idx * batch_size:(idx + 1) * batch_size] = grad.detach()
        grad_1_agg[idx * batch_size:(idx + 1) * batch_size] = grad_1.detach()

        avg_curvature, std_curvature = curvature_agg.mean().item(), curvature_agg.std().item()
        avg_hessian, std_hessian = hess_agg.mean().item(), hess_agg.std().item()
        avg_grad, std_grad = grad_agg.mean().item(), grad_agg.std().item()
        avg_grad_1, std_grad_1 = grad_1_agg.mean().item(), grad_1_agg.std().item()

        if idx == (max_batches - 1):
            print('Average Curvature: {:.6f} +/- {:.2f} '.format(avg_curvature, std_curvature))
            print('Average Hessian Spectral Norm: {:.6f} +/- {:.2f} '.format(avg_hessian, std_hessian))
            print('Average Gradient Norm: {:.6f} +/- {:.2f}'.format(avg_grad, std_grad))
            print('Average Gradient L1 Norm: {:.6f} +/- {:.2f}'.format(avg_grad_1, std_grad_1))
            return curvature_agg, hess_agg, grad_agg, grad_1_agg

In [64]:
def launch_measure_curvature(model_path, ema, dataloader, device):
  model = load_model(model_path, ema=ema)
  return measure_curvature(model, dataloader, data_fraction=3200./10000., batch_size=dataloader.batch_size, device=device)

In [65]:
def launch_measure_curvature_public(model_name, dataloader, device):
  model = load_public_model(model_name)
  return measure_curvature(model, dataloader, data_fraction=3200./10000., batch_size=dataloader.batch_size, device=device)

In [66]:
if fig10:
  curvature_at = launch_measure_curvature('outputs/advtrain_swinb_orig/last.pth.tar', True, dataloader, device)

In [67]:
if fig10:
  curvature_gn = launch_measure_curvature('outputs/gradnorm_swinb_variant/2024-02-14_11-30-41/last.pth.tar', True, dataloader, device)

In [68]:
if fig10:
  curvature_nat = launch_measure_curvature_public('swin_large_patch4_window7_224', dataloader, device=device)

In [69]:
if fig10:
  fig10_results = {
    'natural':curvature_nat,
    'gradnorm':curvature_gn,
    'advtrain':curvature_at
  }

In [70]:
def plot_curvature_fn(name, fn, *, index=Ellipsis, split=None, **kwargs):
    _at_linerrs = fig10_results['advtrain'][0][index].detach().cpu()
    _gn_strong_linerrs = fig10_results['gradnorm'][0][index].detach().cpu()
    _nat_linerrs = fig10_results['natural'][0][index].detach().cpu()
    print(_at_linerrs.shape, _gn_strong_linerrs.shape, _nat_linerrs.shape)
    print(_at_linerrs.min(), _gn_strong_linerrs.min(), _nat_linerrs.min())
    if split is None:
        split = _at_linerrs.shape[0]
    at_fn_v = torch.cat([fn(_x) for _x in _at_linerrs.split(split, dim=0)], 0)
    gns_fn_v = torch.cat([fn(_x) for _x in _gn_strong_linerrs.split(split, dim=0)], 0)
    nat_fn_v = torch.cat([fn(_x) for _x in _nat_linerrs.split(split, dim=0)], 0)
    ax = sns.displot(data=pd.DataFrame(
        {
            name: torch.cat([at_fn_v, gns_fn_v, nat_fn_v], 0),
            "Training": (
                ([f'Adversarial Training (PGD-3) (mean={at_fn_v.mean():.2e})'] * _at_linerrs.shape[0]) + 
                ([f'Gradient Norm Regularization (mean={gns_fn_v.mean():.2e})'] * _gn_strong_linerrs.shape[0]) +
                ([f'Natural Training (mean={nat_fn_v.mean():.2e})'] * _nat_linerrs.shape[0])
            ),
        }), kde=True, x=name, hue='Training', 
        palette = {
            f'Adversarial Training (PGD-3) (mean={at_fn_v.mean():.2e})':'red', 
            f'Gradient Norm Regularization (mean={gns_fn_v.mean():.2e})':'blue',
            f'Natural Training (mean={nat_fn_v.mean():.2e})':'orange'
        }, **kwargs)
    plt.gca().set_ylabel(f'Count (n={_at_linerrs.shape[0]})')
    return ax.fig

In [71]:
if fig10:
        fig = plot_curvature_fn(r'Normalized curvature', lambda x: x + 1e-16, 
                index=torch.arange(3200), log_scale=(True, False), alpha=0.4)
        fig.savefig('./zzz_arxiv_figures/curvature.pdf', facecolor='white')

## Fig10 - FGSM Training

In [74]:
import re

def log_to_acc_df(filename):
  log_lines = []
  start_epoch = None
  with open(f'{filename}', 'r') as file:
    # Iterate over each line in the file
    for line in file:
        # Process the line as needed
        if (': [ 195/195]' in line) or ('[  97/97]' in line) or ('[ 187/187]' in line):
          log_lines.append(line)
        if start_epoch is None and ('[   0/2502 (  0%)]' in line or '[   0/5004 (  0%)]' in line or '[   0/2408 (  0%)]' in line):
          start_epoch = int(line.split(' ')[7][1:].split('/')[0])

  acc_df = []
  epoch = start_epoch
  for line in log_lines:
    elems = re.findall(r'\((.*?)\)', line)
    
    if 'EMA' not in line:
      loss = float(elems[2])
      acc1 = float(elems[3])
      acc5 = float(elems[4])
      advloss = float(elems[5])
      advacc1 = float(elems[6])
      advacc5 = float(elems[7])
      ema = False
    else:
      loss = float(elems[3])
      acc1 = float(elems[4])
      acc5 = float(elems[5])
      advloss = float(elems[6])
      advacc1 = float(elems[7])
      advacc5 = float(elems[8])
      ema = True
    # acc_df.append((epoch, ema, loss, acc1, acc5, advloss, advacc1, advacc5))
    acc_df.append((epoch, ema, 'loss', loss))
    acc_df.append((epoch, ema, 'acc1', acc1))
    acc_df.append((epoch, ema, 'acc5', acc5))
    acc_df.append((epoch, ema, 'advacc1', advacc1))
    acc_df.append((epoch, ema, 'advacc5', advacc5))
    
    if ema:
      epoch += 1

  acc_df = pd.DataFrame(acc_df,  columns=['epoch', 'EMA', 'metric', 'value'])

  return log_lines, acc_df

def logs_to_acc_df(*filenames):
  dfs = [log_to_acc_df(filename)[1] for filename in filenames]
  df = pd.concat(dfs)
  df = df.drop_duplicates(['epoch', 'EMA', 'metric'], keep='last')
  df = df.sort_values(by=['epoch', 'EMA', 'metric'])
  return df

In [75]:
advtrain = logs_to_acc_df('outputs/advtrain_swinb_orig/log_rank0.txt')
fgsmtrain = logs_to_acc_df('outputs/advtrain1_swinb/2024-02-23_11-21-46/log_rank0.txt')
gradnorm = logs_to_acc_df('outputs/gradnorm_swinb_variant/2024-02-05_20-47-43/log_rank0.txt', 'outputs/gradnorm_swinb_variant/2024-02-07_16-31-19/log_rank0.txt', 'outputs/gradnorm_swinb_variant/2024-02-14_11-30-41/log_rank0.txt')

In [76]:
advtrain['method'] = 'Adversarial Training'
gradnorm['method'] = 'Gradient Norm'
fgsmtrain['method'] = 'FGSM Training'

In [77]:
fig10_df = pd.concat([advtrain, gradnorm, fgsmtrain])
fig10_df

Unnamed: 0,epoch,EMA,metric,value,method
1,0,False,acc1,84.8940,Adversarial Training
2,0,False,acc5,97.4440,Adversarial Training
3,0,False,advacc1,1.9500,Adversarial Training
4,0,False,advacc5,6.0700,Adversarial Training
0,0,False,loss,0.6873,Adversarial Training
...,...,...,...,...,...
996,99,True,acc1,72.8860,FGSM Training
997,99,True,acc5,91.4080,FGSM Training
998,99,True,advacc1,9.7300,FGSM Training
999,99,True,advacc5,15.4740,FGSM Training


In [78]:
fig10_plot_df = fig10_df.copy()
fig10_plot_df = fig10_plot_df[fig10_plot_df.EMA == True]
fig10_plot_df = fig10_plot_df[fig10_plot_df.metric.apply(lambda x: x in ['advacc1'])]

In [81]:
fig = px.line(fig10_plot_df, 
    x='epoch', y='value', color='method', markers=True,
    labels={
        "metric":"Accuracy",
        "acc1":"Clean Acc",
        "advacc1":"Robust Acc",
        "epoch":"Epoch",
        "value":"PGD10 robust accuracy (%)"
    },
    color_discrete_sequence=["red", "blue", "green"],
)
fig.update_layout(showlegend=False)

fig.for_each_annotation(lambda a: a.update(text=a.text.replace("method=", "")))
# fig.update_xaxes(
#     dtick=1  # Change this value according to your desired tick step
# )

# newnames = {'acc1':'Clean', 'advacc1': 'PGD10(eps=4)'}
# fig.for_each_trace(lambda t: t.update(name = newnames[t.name],
#                                       legendgroup = newnames[t.name],
#                                       hovertemplate = t.hovertemplate.replace(t.name, newnames[t.name])))

fig.update_layout(legend=dict(
    yanchor="top",
    y=1.2,
    xanchor="left",
    x=0
))

fig.update_layout(
    margin=dict(l=0, r=0, t=00, b=0),
)

fig.update_layout(
    width=1200,  # Set the width of the plot
    height=600,  # Set the height of the plot
)

fig.update_layout(
    font_size=18,
    title_font_size=22,
)

fig.show()
fig.write_image("zzz_arxiv_figures/wong_comparison.pdf")
