In [1]:
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import timm
import numpy as np
import torchvision
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import robustbench

device = torch.device('cuda', 3)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data = torch.load('analysis_data/240206_gen_imagenet_data_10k_noattack.pth', map_location='cpu')
# xs = torch.load('/data/vision/torralba/naturally_robust_models/input_norm/outputs/inputs/train_png/inputs.pth')
# ys = torch.load('/data/vision/torralba/naturally_robust_models/input_norm/outputs/inputs/train_png/targets.pth')
# data = {'xs':xs, 'ys':ys}

In [4]:
def abs_normalize(x, q=None, start_dim=-3):
  s = torch.quantile(x.abs().flatten(start_dim=start_dim), q=max(q, 1-q), dim=-1, keepdim=False)
  x = 0.5 + 0.5 * x/s[(..., ) + (None,)*(-start_dim)]

  x = torch.clamp(x, 0., 1.)
  return x
  
def plot_side_by_side_normalize(*images, normalize):
  plt.figure(figsize=(20,20))
  assert len(images) == len(normalize)
  columns = len(images)
  for i, image in enumerate(images):
      plt.subplot(len(images) // columns + 1, columns, i + 1)
      image = image.detach()
      with torch.no_grad():
        if image.shape[-1] > 3:
          image = image.permute(1, 2, 0)
        if normalize[i]:
          image = abs_normalize(image, q=0.01)
        plt.imshow(image, cmap='gray')

In [5]:
def replace_layers(model, old, new):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            ## compound module, go inside it
            replace_layers(module, old, new)
            
        if isinstance(module, old):
            ## simple module
            setattr(model, n, new)

def load_model(path, ema=False):
    model_kwargs=dict({
            'num_classes': 1000,
            'drop_rate': 0.0,
            'drop_path_rate': 0.0,
            'drop_block_rate': None,
            'global_pool': None,
            'bn_momentum': None,
            'bn_eps': None,
    })
    if '_resnet_' in path:
      model = timm.models.create_model('resnet50', pretrained=False, **model_kwargs)
      if '_gelu' in path:
        replace_layers(model, nn.ReLU, nn.GELU())
    elif '_swinb' in path:
      model = timm.models.create_model('swin_base_patch4_window7_224', pretrained=False, **model_kwargs)
    elif '_swins' in path:
      model = timm.models.create_model('swin_small_patch4_window7_224', pretrained=False, **model_kwargs)
    ckpt = torch.load(path)
    if ema:
      if 'state_dict_ema' in ckpt:
        model.load_state_dict(ckpt['state_dict_ema'])
      else:
        return None
    else:
      model.load_state_dict(ckpt['state_dict'])
    return model.eval()

def load_public_model(model_name):
  if model_name[0].isupper():
    return robustbench.utils.load_model(model_name, dataset='imagenet', threat_model='Linf').cuda()
  else:
    if 'random' in model_name:
      return timm.create_model(model_name[:-len('_random')], pretrained=False).cuda()
    else:
      return timm.create_model(model_name, pretrained=True).cuda()

In [121]:
model_path = f'outputs/advtrain_swinb_orig/last.pth.tar'
model = load_model(model_path).to(device)

# model_path = f'outputs/gradnorm_swinb_variant/2024-02-14_11-30-41/last.pth.tar'
# model = load_model(model_path).to(device)

#model_path = f'outputs/logitsobel_swinb/checkpoint-138.pth.tar'
#model = load_model(model_path).to(device)

# model = load_public_model('swin_base_patch4_window7_224').to(device)
# model = load_public_model('Liu2023Comprehensive_Swin-B').to(device)

In [122]:
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

val_transform = create_transform(
    **resolve_data_config(model.pretrained_cfg, model=model),
    is_training=False,
)

In [123]:
normalize_transform = val_transform.transforms[-1]

In [124]:
model = nn.Sequential(normalize_transform, model)

In [125]:
# ds = torchvision.datasets.ImageNet(
#     '/data/vision/torralba/datasets/imagenet_pytorch_new',
#     'val',
#     transform=val_transform,
# )

In [126]:
xs = data['xs']
ys = data['ys']

In [127]:
if xs.min() < 0:
  xs = xs*normalize_transform.std[None, :, None, None] + normalize_transform.mean[None, :, None, None]

In [128]:
N = xs.size(0)
xs = xs[:N]
ys = ys[:N]
sampler_indices = range(N)
ds = torch.utils.data.TensorDataset(xs, ys)

In [129]:
print(len(sampler_indices))

10000


In [130]:
from tqdm.auto import tqdm

dataloader = torch.utils.data.DataLoader(
    ds,
    batch_size=32,
    sampler=sampler_indices,
)

In [131]:
out = None
logit = None
grad_x = None

In [132]:
outs = []
grads = []

for ii, (x, y) in enumerate(tqdm(dataloader)):
    x, y = x.to(device), y.to(device)
    x.requires_grad_(True)

    # Clean at
    out = model(x)#.cpu()
    loss = torch.nn.functional.cross_entropy(out, y)
    grad_x = torch.autograd.grad(loss, x, create_graph=False, retain_graph=False)[0].detach().cpu().abs().sum(1, keepdim=True)
    #logit = out[torch.arange(out.size(0)), y].sum()
    #grad_x = torch.autograd.grad(logit, x, create_graph=False, retain_graph=False)[0].detach().cpu().abs().max(1, keepdim=True).values

    outs.append(out)
    grads.append(grad_x)
    # if ii == 2:
    #     break

100%|██████████| 313/313 [01:04<00:00,  4.86it/s]


In [133]:
outs = torch.cat(outs, 0)
grads = torch.cat(grads, 0)

In [134]:
from gradient_teachers import ContourEnergy
import torch.nn.functional as F

sobel = ContourEnergy(1., 3)

def laplacian(x):

  h = [[0., -1., 0.],
       [-1., 4, -1.],
       [0., -1., 0.]]
  h = torch.Tensor(h)
  
  N, C, H, W = x.shape
  p2d = tuple([(h.size(0)-1)//2] * 4)
  x = F.pad(x, p2d, "reflect")
  x = F.conv2d(x, h[None, None, :, :].expand((C, -1, -1, -1)), padding='valid', groups=C)
  x = x.abs().mean(1, keepdim=True).expand((-1, C, -1, -1))
  return x

In [135]:
edges_xs = sobel(xs[:grads.size(0)])
# edges_xs = laplacian(xs[:grads.size(0)])
edges_xs.shape

torch.Size([10000, 3, 224, 224])

In [136]:
grads.shape

torch.Size([10000, 1, 224, 224])

In [145]:
u = edges_xs.mean(1, keepdim=True).expand((-1, 1, -1, -1)).flatten(1).clamp(min=1e-3).log()
v = grads.flatten(1).log()
print(u.shape, v.shape)

torch.Size([10000, 50176]) torch.Size([10000, 50176])


In [146]:
corrs = []
for ii, (ui, vi) in enumerate(zip(u, v)):
  corrs.append(torch.corrcoef(torch.stack([ui, vi]))[0, 1])
  # if ii >= 2:
  #   break
corrs = np.array(corrs)
corrs

array([0.6910233, 0.6488463, 0.5819105, ..., 0.8243225, 0.771029 ,
       0.6496512], dtype=float32)

In [147]:
corrs.shape, corrs.mean(), corrs.std(), corrs.min(), corrs.max()

((10000,), 0.56919587, 0.15685329, -0.1884277, 0.9261033)

In [148]:
if xs.size(0) > 100:
  assert False

AssertionError: 

In [None]:
idx = 0
plot_side_by_side_normalize(xs[idx], edges_xs[idx], grads[idx].abs().mean(0, keepdim=True), normalize=(False, True, True))

In [None]:
u = edges_xs[idx].flatten().log()
v = grads[idx].abs().mean(0, keepdim=True).expand((3, -1, -1)).flatten().log()
print(u.shape, v.shape)

ax = sns.scatterplot(x=u, y=v)
corr = torch.corrcoef(torch.stack([u, v]))[0, 1]
print(corr)
ax.set(xlabel=r'$\log(|g_x * x|^2 + |g_y * x|^2)$', ylabel=r'$log(|\nabla_xf(x;\theta)_y|)$')

x = np.linspace(-6, 0, num=100)
m, b = np.polyfit(u, v, 1)
plt.plot(x, m*x + b, color='blue')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from flexitext import flexitext

from matplotlib import lines
from matplotlib import patches
from matplotlib.patheffects import withStroke

BROWN = "#AD8C97"
BROWN_DARKER = "#7d3a46"
GREEN = "#2FC1D3"
BLUE = "#076FA1"
GREY = "#C7C9CB"
GREY_DARKER = "#5C5B5D"
RED = "#E3120B"
BLACK = "#000000"

COLORS = [RED, BLUE]

# Initialize plot ------------------------------------------
fig, ax = plt.subplots(figsize=(6, 6))

# Customize axis -------------------------------------------
# Customize y-axis ticks
#ax.yaxis.set_ticks([i for i in np.linspace(0, 1, 11)])
#ax.yaxis.set_ticklabels([str(round(i,1)) for i in np.linspace(0, 1, 11)])
# ax.yaxis.set_tick_params(labelleft=False, length=0)

# Customize y-axis ticks
#ax.xaxis.set_ticks([2008, 2012, 2016, 2020])
#ax.xaxis.set_ticklabels([2008, 12, 16, 20], fontsize=16, fontfamily="Econ Sans Cnd", fontweight=100)
#ax.xaxis.set_tick_params(length=6, width=1.2)

# Make gridlines be below most artists.
ax.set_axisbelow(True)

# Remove all spines but the one in the bottom
ax.spines["right"].set_visible(True)
ax.spines["top"].set_visible(True)
ax.spines["left"].set_visible(True)
ax.spines["bottom"].set_visible(True)

# Customize bottom spine
ax.spines["bottom"].set_lw(1.2)
ax.spines["bottom"].set_capstyle("butt")


x_axis_title = r'$\log(|g_x * x|^2 + |g_y * x|^2)$'
fig.text(
    0.4, 0.04, x_axis_title, color=BLACK, 
    fontsize=12, fontfamily="sans serif"
)

y_axis_title=r'$log(|\nabla_xf(x;\theta)_y|)$'
fig.text(
    0.02, 0.4, y_axis_title, color=BLACK, 
    fontsize=12, fontfamily="sans serif", rotation=90,
)

x = np.linspace(-6, 0, num=100)
m, b = np.polyfit(u, v, 1)
ax.plot(x, m*x + b, color='black', zorder=1)

# Add lines with dots
# Note the zorder to have dots be on top of the lines
ax.scatter(u, v, fc=RED, s=0.001, lw=1.5, ec="red", marker='o', zorder=0)

ax.set_title(rf'R={corr}')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from flexitext import flexitext

from matplotlib import lines
from matplotlib import patches
from matplotlib.patheffects import withStroke

def save_plot(gradient, input, path):
    
    edges_input = sobel(input[None])[0]
    u = edges_input.flatten().log()
    v = gradient.abs().mean(0, keepdim=True).expand((3, -1, -1)).flatten().log()

    BROWN = "#AD8C97"
    BROWN_DARKER = "#7d3a46"
    GREEN = "#2FC1D3"
    BLUE = "#076FA1"
    GREY = "#C7C9CB"
    GREY_DARKER = "#5C5B5D"
    RED = "#E3120B"
    BLACK = "#000000"

    COLORS = [RED, BLUE]

    # Initialize plot ------------------------------------------
    fig, ax = plt.subplots(figsize=(6, 6))

    # Customize axis -------------------------------------------
    # Customize y-axis ticks
    #ax.yaxis.set_ticks([i for i in np.linspace(0, 1, 11)])
    #ax.yaxis.set_ticklabels([str(round(i,1)) for i in np.linspace(0, 1, 11)])
    # ax.yaxis.set_tick_params(labelleft=False, length=0)

    # Customize y-axis ticks
    #ax.xaxis.set_ticks([2008, 2012, 2016, 2020])
    #ax.xaxis.set_ticklabels([2008, 12, 16, 20], fontsize=16, fontfamily="Econ Sans Cnd", fontweight=100)
    #ax.xaxis.set_tick_params(length=6, width=1.2)

    # Make gridlines be below most artists.
    ax.set_axisbelow(True)

    # Remove all spines but the one in the bottom
    ax.spines["right"].set_visible(True)
    ax.spines["top"].set_visible(True)
    ax.spines["left"].set_visible(True)
    ax.spines["bottom"].set_visible(True)

    # Customize bottom spine
    ax.spines["bottom"].set_lw(1.2)
    ax.spines["bottom"].set_capstyle("butt")


    x_axis_title = r'$\log(|g_x * x|^2 + |g_y * x|^2)$'
    fig.text(
        0.4, 0.04, x_axis_title, color=BLACK, 
        fontsize=12, fontfamily="sans serif"
    )

    y_axis_title=r'$log(|\nabla_xf(x;\theta)_y|)$'
    fig.text(
        0.02, 0.4, y_axis_title, color=BLACK, 
        fontsize=12, fontfamily="sans serif", rotation=90,
    )

    x = np.linspace(round(u.min().item() - 0.5), round(u.max().item()), num=100)
    m, b = np.polyfit(u, v, 1)
    ax.plot(x, m*x + b, color='black', zorder=1)

    # Add lines with dots
    # Note the zorder to have dots be on top of the lines
    ax.scatter(u, v, fc=RED, s=0.001, lw=1.5, ec="red", marker='o', zorder=0)

    ax.set_title(rf'R={corr}')

    fig.savefig(path)

In [None]:
for i in range(xs.size(0)):
  save_plot(grads[i], xs[i], 'test.png')