In [1]:
import argparse
import os

import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from collections import OrderedDict

from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    CNN_defaults,
    create_model_and_diffusion,
    create_CNN,
    add_dict_to_argparser,
    args_to_dict,
)

In [2]:
def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
        classifier_path="",
        classifier_scale=1.0,
        doc = ""
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(CNN_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser

In [27]:
# args = create_argparser().parse_args()
args = create_argparser().parse_args(['--model_path', '/data/yjpak/guided-diffusion/logs/diffusion01/model045000.pt',
                                      '--classifier_path', '/data/yjpak/guided-diffusion/logs/CNN_classifier02/classifier_model065000.pt'
                                     ])
dist_util.setup_dist()
log_dir="/data/yjpak/guided-diffusion/logs"
logger.configure(dir=os.path.join(log_dir, args.doc))
# logger.configure()

logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
    **args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
    dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model.to(dist_util.dev())
if args.use_fp16:
    model.convert_to_fp16()
model.eval()

setup_dist start
Logging to /data/yjpak/guided-diffusion/logs/
creating model and diffusion...


UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
  )
  (label_emb): Embedding(10, 512)
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-2): 2 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=512, out_features=256, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0.0, inplace=False)
          (3): C

In [28]:
args

Namespace(clip_denoised=True, num_samples=10000, batch_size=16, use_ddim=False, model_path='/data/yjpak/guided-diffusion/logs/diffusion01/model045000.pt', classifier_path='/data/yjpak/guided-diffusion/logs/CNN_classifier02/classifier_model065000.pt', classifier_scale=1.0, doc='', image_size=80, num_channels=128, num_res_blocks=2, num_heads=4, num_heads_upsample=-1, num_head_channels=-1, attention_resolutions='16,8', channel_mult='', dropout=0.0, class_cond=True, use_checkpoint=False, use_scale_shift_norm=True, resblock_updown=False, use_fp16=False, use_new_attention_order=False, learn_sigma=False, diffusion_steps=1000, noise_schedule='linear', timestep_respacing='', use_kl=False, predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False, classifier_use_fp16=False, input_channels=1, num_classes=10)

In [29]:
classifier = create_CNN(**args_to_dict(args, CNN_defaults().keys()))

In [30]:
classifier

CNN_2D(
  (layer1_conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer1_bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer1_bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer2_conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer2_bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer2_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (layer2_conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer2_bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer3_conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer3_bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [31]:
model_state_dict = th.load(args.classifier_path)
model_state_dict

OrderedDict([('module.layer1_conv1.weight',
              Parameter containing:
              tensor([[[[-0.2334, -0.0625,  0.2735],
                        [-0.1974,  0.0436,  0.3831],
                        [-0.0957,  0.2618,  0.8061]]],
              
              
                      [[[ 0.3001,  0.2927,  0.3141],
                        [ 0.2917,  0.2931,  0.3247],
                        [ 0.3013,  0.2982,  0.3497]]],
              
              
                      [[[ 0.2620,  0.2600,  0.3262],
                        [ 0.2818,  0.2869,  0.3640],
                        [ 0.2788,  0.3036,  0.3963]]],
              
              
                      [[[ 0.0023, -0.1310, -0.1273],
                        [-0.0075, -0.0825,  0.2194],
                        [ 0.6314,  0.4004,  0.7965]]],
              
              
                      [[[-0.2209, -0.1791,  0.1665],
                        [-0.1433,  0.0375,  0.4468],
                        [ 0.1704,  0.3654,  0.9158

In [32]:
logger.log("loading classifier...")
classifier = create_CNN(**args_to_dict(args, CNN_defaults().keys()))

loaded_state_dict = th.load(args.classifier_path)
new_state_dict = OrderedDict()
for n, v in loaded_state_dict.items():
    name = n.replace("module.","") # .module이 중간에 포함된 형태라면 (".module","")로 치환
    new_state_dict[name] = v

classifier.load_state_dict(new_state_dict)

classifier.to(dist_util.dev())
if args.classifier_use_fp16:
    classifier.convert_to_fp16()
classifier.eval()

loading classifier...


CNN_2D(
  (layer1_conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer1_bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer1_bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer2_conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer2_bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer2_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (layer2_conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer2_bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer3_conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer3_bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [25]:
def cond_fn(x, t, y=None):
    assert y is not None
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

def model_fn(x, t, y=None):
    assert y is not None
    return model(x, t, y if args.class_cond else None)

logger.log("sampling...")
all_images = []
all_labels = []
while len(all_images) * args.batch_size < args.num_samples:
    model_kwargs = {}
    # classes = th.randint(
    #     low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
    # )
    classes = th.full((args.batch_size,), 5, device=dist_util.dev())
    model_kwargs["y"] = classes
    sample_fn = (
        diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
    )
    sample = sample_fn(
        model_fn,
        (args.batch_size, 1, args.image_size, args.image_size),
        clip_denoised=args.clip_denoised,
        model_kwargs=model_kwargs,
        cond_fn=cond_fn,
        device=dist_util.dev(),
    )

    sample = sample.contiguous()

    gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
    dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
    all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
    gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())]
    dist.all_gather(gathered_labels, classes)
    all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
    logger.log(f"created {len(all_images) * args.batch_size} samples")

arr = np.concatenate(all_images, axis=0)
arr = arr[: args.num_samples]
label_arr = np.concatenate(all_labels, axis=0)
label_arr = label_arr[: args.num_samples]
if dist.get_rank() == 0:
    shape_str = "x".join([str(x) for x in arr.shape])
    out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
    logger.log(f"saving to {out_path}")
    np.savez(out_path, arr, label_arr)

dist.barrier()
logger.log("sampling complete")

sampling...


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fc8e8c91ff0>>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


KeyboardInterrupt: 