In [1]:
import argparse
import os
import glob
import numpy as np

import blobfile as bf
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW

from guided_diffusion import dist_util, logger
from guided_diffusion.fp16_util import MixedPrecisionTrainer
from guided_diffusion.image_datasets import load_data
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    add_dict_to_argparser,
    args_to_dict,
    CNN_and_diffusion_defaults,
    create_CNN_and_diffusion,
)
from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict

In [2]:
device = "cuda" if th.cuda.is_available else "cpu"
print(device)

cuda


In [3]:
def create_argparser():
    defaults = dict(
        data_dir="",
        val_data_dir="",
        noised=True, #default: True
        iterations=150000,
        lr=3e-4,
        weight_decay=0.0,
        anneal_lr=False,
        batch_size=4,
        microbatch=-1,
        schedule_sampler="uniform",
        resume_checkpoint="",
        log_interval=10,
        eval_interval=5,
        save_interval=10000,
        doc=""
    )
    defaults.update(CNN_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser

In [4]:
# args = create_argparser().parse_args()

In [5]:
args = create_argparser().parse_args(['--data_dir', '/data/yjpak/Dataset/audio_mnist/',
                                      '--iterations', '100'])

In [6]:
args

Namespace(data_dir='/data/yjpak/Dataset/audio_mnist/', val_data_dir='', noised=True, iterations=100, lr=0.0003, weight_decay=0.0, anneal_lr=False, batch_size=4, microbatch=-1, schedule_sampler='uniform', resume_checkpoint='', log_interval=10, eval_interval=5, save_interval=10000, doc='', input_channels=1, num_classes=10, image_size=80, learn_sigma=False, diffusion_steps=1000, noise_schedule='linear', timestep_respacing='', use_kl=False, predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False)

In [13]:
dist_util.setup_dist()
dist_util.sync_params(model.parameters())

setup_dist start
os.environ[CUDA_VISIBLE_DEVICES]: 0
setup_dist


In [7]:
resume_step = 0
model, diffusion = create_CNN_and_diffusion(
    **args_to_dict(args, CNN_and_diffusion_defaults().keys())
)


data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=True,
    )

In [14]:
# model = model.to(device)

model = DDP(
        model,
        device_ids=[dist_util.dev()],
        output_device=dist_util.dev(),
        broadcast_buffers=False,
        bucket_cap_mb=128,
        find_unused_parameters=False,
    )

if args.noised:
    schedule_sampler = create_named_schedule_sampler(
        args.schedule_sampler, diffusion
    )

mp_trainer = MixedPrecisionTrainer(
        model=model, initial_lg_loss_scale=16.0
    )

In [15]:
def set_annealed_lr(opt, base_lr, frac_done):
    lr = base_lr * (1 - frac_done)
    for param_group in opt.param_groups:
        param_group["lr"] = lr

def split_microbatches(microbatch, *args):
    bs = len(args[0])
    if microbatch == -1 or microbatch >= bs:
        yield tuple(args)
    else:
        for i in range(0, bs, microbatch):
            yield tuple(x[i : i + microbatch] if x is not None else None for x in args)

def compute_top_k(logits, labels, k, reduction="mean"):
    _, top_ks = th.topk(logits, k, dim=-1)
    if reduction == "mean":
        return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
    elif reduction == "none":
        return (top_ks == labels[:, None]).float().sum(dim=-1)

In [17]:
logger.log(f"creating optimizer...")
opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay)

logger.log("training classifier model...")

def forward_backward_log(data_loader, prefix="train"):
    batch, extra = next(data_loader)
    labels = extra["y"].to(dist_util.dev())
    batch = batch.to(dist_util.dev())
    
    # Noisy images
    if args.noised:
        t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
        print("noise t: ", t.shape)
        batch = diffusion.q_sample(batch, t)            
    else:
        t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())

    for i, (sub_batch, sub_labels, sub_t) in enumerate(
        split_microbatches(args.microbatch, batch, labels, t)
    ):
        # np.save(f"/data/yjpak/guided-diffusion/logs/batch_check/sub_batch_{dist.get_rank()}", sub_batch.detach().cpu().numpy())
        # np.save(f"/data/yjpak/guided-diffusion/logs/batch_check/sub_labels_{dist.get_rank()}", sub_labels.detach().cpu().numpy())
        logits = model(sub_batch, timesteps=sub_t)
        loss = F.cross_entropy(logits, sub_labels, reduction="none")

        losses = {}
        losses[f"{prefix}_loss"] = loss.detach()
        losses[f"{prefix}_acc@1"] = compute_top_k(
            logits, sub_labels, k=1, reduction="none"
        )

        log_loss_dict(diffusion, sub_t, losses)
        del losses
        loss = loss.mean()
        if loss.requires_grad:
            if i == 0:
                mp_trainer.zero_grad()
            mp_trainer.backward(loss * len(sub_batch) / len(batch))

creating optimizer...
training classifier model...


In [19]:
for step in range(args.iterations - resume_step):
    logger.logkv("step", step + resume_step)
    logger.logkv(
        "samples",
        (step + resume_step + 1) * args.batch_size * dist.get_world_size(),
    )
    if args.anneal_lr:
        set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations)
        
    forward_backward_log(data)
    mp_trainer.optimize(opt)
    
    if not step % args.log_interval:
        logger.dumpkvs()
    if (
        step
        and dist.get_rank() == 0
        and not (step + resume_step) % args.save_interval
    ):
        logger.log("saving model...")
        save_model(mp_trainer, opt, step + resume_step)

if dist.get_rank() == 0:
    logger.log("saving model...")
    save_model(mp_trainer, opt, step + resume_step)
else:
    print("dist.get_rank = 1")
dist.barrier()

noise t:  torch.Size([4])
-----------------------------
| grad_norm      | 3.41     |
| param_norm     | 38.2     |
| samples        | 4        |
| step           | 0        |
| train_acc@1    | 0.0682   |
| train_acc@1_q0 | 0        |
| train_acc@1_q1 | 0        |
| train_acc@1_q2 | 0.2      |
| train_acc@1_q3 | 0.0625   |
| train_loss     | 2.33     |
| train_loss_q0  | 2.28     |
| train_loss_q1  | 2.34     |
| train_loss_q2  | 2.3      |
| train_loss_q3  | 2.38     |
-----------------------------
noise t:  torch.Size([4])
noise t:  torch.Size([4])
noise t:  torch.Size([4])
noise t:  torch.Size([4])
noise t:  torch.Size([4])
noise t:  torch.Size([4])
noise t:  torch.Size([4])
noise t:  torch.Size([4])
noise t:  torch.Size([4])
noise t:  torch.Size([4])
-----------------------------
| grad_norm      | 3.18     |
| param_norm     | 38.3     |
| samples        | 44       |
| step           | 10       |
| train_acc@1    | 0.075    |
| train_acc@1_q0 | 0        |
| train_acc@1_q1 | 0.076

In [11]:
model

CNN_2D(
  (layer1_conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer1_bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer1_bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer2_conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer2_bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer2_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (layer2_conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer2_bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer3_conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer3_bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [None]:
mpiexec -n 2 python model_analysis.py