In [3]:
import os
import torch
import numpy as np
import argparse

from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    classifier_defaults,
    create_model_and_diffusion,
    create_classifier,
    add_dict_to_argparser,
    args_to_dict,
    ResNet18_32x32,
)


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        use_ddjm=False,
        model_path="",
        log_dir="tmp",
        classifier_path="",
        guide_mode="None",
        classifier_scale=0.0,
        positive_label="None",
        progress=False,
        eta=0.0,
        ref_batch=None,
        test_classifier_path="",
        model_id=None,
        iteration=10,
        shrink_cond_x0=True,    # whether to shrink the score of x0 by at
        faceid_loss_type='cosine',
        face_image1_id='00000',
        face_image2_id='00000',
        face_image3_id='00000',
        plot_args=False,
        score_norm=1e09,
        plot_traj=False,
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(classifier_defaults())
    
    return defaults
    # parser = argparse.ArgumentParser()
    # add_dict_to_argparser(parser, defaults)
    # return parser

args = create_argparser()
# change dict to args, write your own code
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

args = Args(**args)
args.image_size = 32

In [5]:
npath = './ckpts/cifar_classifier/not_best_epoch96_acc0.9470.ckpt'
nmodel = ResNet18_32x32()
# load state dict
nmodel.load_state_dict(torch.load(npath))


class TModel(torch.nn.Module):
    def __init__(self, model):
        super(TModel, self).__init__()
        self.model = model
    
    def forward(self, x):
        return self.model(x, torch.tensor([0] * x.shape[0], device=x.device))

tpath = './ckpts/cifar_classifier/model099999.pt'
tmodel = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
# load state dict
tmodel.load_state_dict(torch.load(tpath,))
tmodel = TModel(tmodel)

nmodel.eval()
tmodel.eval()
# to cuda
nmodel.cuda()
tmodel.cuda()


TModel(
  (model): EncoderUNetModel(
    (time_embed): Sequential(
      (0): Linear(in_features=128, out_features=512, bias=True)
      (1): SiLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
    )
    (input_blocks): ModuleList(
      (0): TimestepEmbedSequential(
        (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=512, out_features=256, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(

In [14]:
# load data
label = 8
# data_path = f'./evaluations/ref/cifar_test_{label}.npz'
data_path = "./working/witht_classifier_cifar/steps=200+pipe=ddim+iter=1+mode=manifold+scale=20.0+shrink=True/samples_1024x32x32x3.npz"
try:
    data = np.load(data_path)
except:
    data_path = data_path.replace("samples_1024", "samples_64")
    data = np.load(data_path)
# data.shape
data = data['arr_0']
data = data.transpose(0, 3, 1, 2) / 127.5 - 1
data.min(), data.max()

(-1.0, 1.0)

In [15]:
x = torch.tensor(data, dtype=torch.float32)
x = x.cuda()
x.shape

torch.Size([64, 3, 32, 32])

In [16]:
btz = 64
print(
    (tmodel(x[:btz]).argmax(1) == label).float().mean()
)
print(
    (nmodel(x[:btz]).argmax(1) == label).float().mean()
)

tensor(1., device='cuda:0')
tensor(0.5938, device='cuda:0')


In [45]:
import torch.nn.functional as F

def get_grad(model, x, label):
    x_in = x.detach().requires_grad_(True)
    logits = model(x_in)
    log_probs = F.log_softmax(logits, dim=-1)
    selected = log_probs[:, label]
    return torch.autograd.grad(selected.sum(), x_in)[0]



In [16]:
ngrad = get_grad(nmodel, x[:btz], label)
tgrad = get_grad(tmodel, x[:btz], label)

ngrad.shape, tgrad.shape


(torch.Size([64, 3, 32, 32]), torch.Size([64, 3, 32, 32]))

In [17]:
# cosine similarity of two tensors, where the first dimension is the batch dimension
def cosine_similarity(a, b):
    a, b = a.reshape(a.shape[0], -1), b.reshape(b.shape[0], -1)

    return (a * b).sum(dim=-1) / (a.norm(dim=-1) * b.norm(dim=-1))

In [18]:
cosine_similarity(ngrad, tgrad).mean()

tensor(0.0084, device='cuda:0')

In [21]:
ngrad.reshape(ngrad.shape[0], -1).norm(dim=-1).mean(), tgrad.reshape(tgrad.shape[0], -1).norm(dim=-1).mean()

(tensor(1.9445, device='cuda:0'), tensor(0.7266, device='cuda:0'))