In [1]:
import argparse
import os

import blobfile as bf
import torch as th
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW

from guided_diffusion import dist_util, logger
from guided_diffusion.fp16_util import MixedPrecisionTrainer
from guided_diffusion.image_datasets import load_data
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    add_dict_to_argparser,
    args_to_dict,
    classifier_and_diffusion_defaults,
    create_classifier_and_diffusion,
)
from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict


In [2]:
import collections
import copy
import sys
import time
from random import seed

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from torch import optim
from guided_diffusion.fp16_util import MixedPrecisionTrainer
import dataset
import evaluation
from GaussianDiffusion import GaussianDiffusionModel, get_beta_schedule
from helpers import *
from UNet import UNetModel, update_ema_params
from script_util import (
    add_dict_to_argparser,
    args_to_dict,
    classifier_and_diffusion_defaults,
    create_classifier_and_diffusion,
)
torch.cuda.empty_cache()

# from train_util import  log_loss_dict

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
args_dict = {
  "image_size": 256,
  "Batch_Size": 16,
  "batch_size": 16,
  "channels": 3,
  "EPOCHS": 4000,
  "iterations":40000,
  "diffusion_steps": 1000,
  "num_channels": 192,\
  "learn_sigma":True,
  "val_data": True,
  "eval_interval": 1000,
  "save_interval": 1000,
  "log_interval": 10,
  "anneal_lr": True,
  "microbatch": -1,
  "use_kl":False,
  "schedule_sampler": "uniform",
  "resume_checkpoint": False,
  "predict_xstart":False,
  "rescale_timesteps":False,
  "rescale_learned_sigmas":False,
  "timestep_respacing": "",
  "noise_schedule": "cosine",
  "channel_mults": "",
  "loss-type": "l2",
  "loss_weight": "none",
  "train_start": True,
  "lr": 1e-4,
  "learn_sigma ": True,
  "random_slice": True,
  "sample_distance": 800,
  "weight_decay": 0.0,
  "save_imgs": True,
  "save_vids": True,
  "class_cond ": True,
  "use_fp16": True,
  "use_scale_shift_norm": True,
  "dropout": 0.1,
  "attention_resolutions": "32,16,8",
  "num_res_blocks": 3,
  "resblock_updown": True,
  "classifier_use_fp16":False,
  "classifier_width":128,
  "classifier_depth": 2,
  "classifier_attention_resolutions": "32,16,8",
  "classifier_use_scale_shift_norm": True,
  "classifier_resblock_updown": True,
  "classifier_pool": "attention" ,
  "num_head_channels": 64,
  "noised": True,
  "noise_fn": "gauss",
  "dataset": "mura"
}

In [5]:
args = argparse.Namespace(**args_dict)

In [6]:
# args.

In [1]:
from tqdm import tqdm

In [8]:
ROOT_DIR= './'

In [9]:
def main():
#     args = create_argparser().parse_args()

#     dist_util.setup_dist()
    logger.configure()

    logger.log("creating model and diffusion...")
    model, diffusion = create_classifier_and_diffusion(args_dict
    )
    model.to(device)
    if args.noised:
        schedule_sampler = create_named_schedule_sampler(
            args.schedule_sampler, diffusion
        )

    resume_step = 0
    if args.resume_checkpoint:
        resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
#         if dist.get_rank() == 0:
        logger.log(
            f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
        )
        model.load_state_dict(
            dist_util.load_state_dict(
                args.resume_checkpoint, map_location=device
            )
        )

    # Needed for creating correct EMAs and fp16 parameters.
#     dist_util.sync_params(model.parameters())

    mp_trainer = MixedPrecisionTrainer(
        model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0
    )

#     model = DDP(
#         model,
#         device_ids=[dist_util.dev()],
#         output_device=device(),
#         broadcast_buffers=False,
#         bucket_cap_mb=128,
#         find_unused_parameters=False,
#     )

    logger.log("creating data loader...")
    training_dataset, testing_dataset = dataset.init_datasets(ROOT_DIR, args_dict)
    data = dataset.init_dataset_loader(training_dataset, args_dict)
    if args.val_data:
        val_data = dataset.init_dataset_loader(testing_dataset, args_dict)
    else:
        val_data = None

    logger.log(f"creating optimizer...")
    opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay)
    if args.resume_checkpoint:
        opt_checkpoint = bf.join(
            bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt"
        )
        logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
        opt.load_state_dict(
            dist_util.load_state_dict(opt_checkpoint, map_location=device)
        )

    logger.log("training classifier model...")

    def forward_backward_log(data_loader, prefix="train"):
        extra = next(data_loader)
        batch = extra["image"]
        labels = extra["label"].to(device)

        batch = batch        extra = next(data_loader)
        batch = extra["image"]
        labels = extra["label"].to(device)

        # Noisy images
        if args.noised:
            t, _ = schedule_sampler.sample(batch.shape[0], device)
            batch = diffusion.q_sample(batch, t)
        else:
            t = th.zeros(batch.shape[0], dtype=th.long, device=device)

        for i, (sub_batch, sub_labels, sub_t) in enumerate(
            split_microbatches(args.microbatch, batch, labels, t)
        ):
            logits = model(sub_batch, timesteps=sub_t)
            loss = F.cross_entropy(logits, sub_labels, reduction="none")

            losses = {}
            losses[f"{prefix}_loss"] = loss.detach()
            losses[f"{prefix}_acc@1"] = compute_top_k(
                logits, sub_labels, k=1, reduction="none"
            )
            losses[f"{prefix}_acc@5"] = compute_top_k(
                logits, sub_labels, k=5, reduction="none"
            )
            log_loss_dict(diffusion, sub_t, losses)
            del losses
            loss = loss.mean()
            if loss.requires_grad:
                if i == 0:
                    mp_trainer.zero_grad()
                mp_trainer.backward(loss * len(sub_batch) / len(batch))

    for step in tqdm(range(args.iterations - resume_step)):
        logger.logkv("step", step + resume_step)
        logger.logkv(
            "samples",
            (step + resume_step + 1) * args.batch_size,
        )
        if args.anneal_lr:
            set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations)
        forward_backward_log(data)
        mp_trainer.optimize(opt)
        if val_data is not None and not step % args.eval_interval:
            with th.no_grad():
                model.eval()
                forward_backward_log(val_data, prefix="val")
                model.train()
        if not step % args.log_interval:
            logger.dumpkvs()
        if (step + resume_step) % args.save_interval:
            logger.log("saving model...")
            save_model(mp_trainer, opt, step + resume_step)

#     if dist.get_rank() == 0:
    logger.log("saving model...")
    save_model(mp_trainer, opt, step + resume_step)
#     dist.barrier()


def set_annealed_lr(opt, base_lr, frac_done):
    lr = base_lr * (1 - frac_done)
    for param_group in opt.param_groups:
        param_group["lr"] = lr


def save_model(mp_trainer, opt, step):
    th.save(
        mp_trainer.master_params_to_state_dict(mp_trainer.master_params),
        os.path.join(logger.get_dir(), f"model{step:06d}.pt"),
    )
    th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt"))


def compute_top_k(logits, labels, k, reduction="mean"):
    _, top_ks = th.topk(logits, k, dim=-1)
    if reduction == "mean":
        return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
    elif reduction == "none":
        return (top_ks == labels[:, None]).float().sum(dim=-1)


def split_microbatches(microbatch, *args):
    bs = len(args[0])
    if microbatch == -1 or microbatch >= bs:
        yield tuple(args)
    else:
        for i in range(0, bs, microbatch):
            yield tuple(x[i : i + microbatch] if x is not None else None for x in args)


def create_argparser():
    defaults = dict(
        data_dir="",
        val_data_dir="",
        noised=True,
        iterations=150000,
        lr=3e-4,
        weight_decay=0.0,
        anneal_lr=False,
        batch_size=4,
        microbatch=-1,
        schedule_sampler="uniform",
        resume_checkpoint="",
        log_interval=10,
        eval_interval=5,
        save_interval=10000,
    )
    defaults.update(classifier_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()

Logging to C:\Users\Admin\AppData\Local\Temp\openai-2023-02-15-20-57-48-717999
creating model and diffusion...
creating data loader...
creating optimizer...
training classifier model...




-----------------------------
| grad_norm      | 145      |
| param_norm     | 160      |
| samples        | 16       |
| step           | 0        |
| train_acc@1    | 0        |
| train_acc@1_q0 | 0        |
| train_acc@1_q1 | 0        |
| train_acc@1_q2 | 0        |
| train_acc@1_q3 | 0        |
| train_acc@5    | 0        |
| train_acc@5_q0 | 0        |
| train_acc@5_q1 | 0        |
| train_acc@5_q2 | 0        |
| train_acc@5_q3 | 0        |
| train_loss     | 6.78     |
| train_loss_q0  | 6.75     |
| train_loss_q1  | 6.77     |
| train_loss_q2  | 6.79     |
| train_loss_q3  | 6.84     |
| val_acc@1      | 0.562    |
| val_acc@1_q0   | 0.75     |
| val_acc@1_q1   | 1        |
| val_acc@1_q2   | 0.2      |
| val_acc@1_q3   | 0.5      |
| val_acc@5      | 1        |
| val_acc@5_q0   | 1        |
| val_acc@5_q1   | 1        |
| val_acc@5_q2   | 1        |
| val_acc@5_q3   | 1        |
| val_loss       | 4.89     |
| val_loss_q0    | 4.65     |
| val_loss_q1    | 4.23     |
| val_loss

saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
-----------------------------
| grad_norm      | 6.92     |
| param_norm     | 160      |
| samples        | 1.14e+03 |
| step           | 70       |
| train_acc@1    | 0.613    |
| train_acc@1_q0 | 0.667    |
| train_acc@1_q1 | 0.556    |
| train_acc@1_q2 | 0.571    |
| train_acc@1_q3 | 0.647    |
| train_acc@5    | 1        |
| train_acc@5_q0 | 1        |
| train_acc@5_q1 | 1        |
| train_acc@5_q2 | 1        |
| train_acc@5_q3 | 1        |
| train_loss     | 0.707    |
| train_loss_q0  | 0.666    |
| train_loss_q1  | 0.763    |
| train_loss_q2  | 0.738    |
| train_loss_q3  | 0.669    |
| val_acc@1      | 0.531    |
| val_acc@1_q0   | 0.5      |
| val_acc@1_q1   | 0.455    |
| val_acc@1_q2   | 0.455    |
| val_acc@1_q3   | 1        |
| val_acc@5      | 1        |
| val_acc@5_q0   | 1        |
| val_acc@5_q1   | 1        |


saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
-----------------------------
| grad_norm      | 5.42     |
| param_norm     | 160      |
| samples        | 2.26e+03 |
| step           | 140      |
| train_acc@1    | 0.6      |
| train_acc@1_q0 | 0.55     |
| train_acc@1_q1 | 0.625    |
| train_acc@1_q2 | 0.615    |
| train_acc@1_q3 | 0.609    |
| train_acc@5    | 1        |
| train_acc@5_q0 | 1        |
| train_acc@5_q1 | 1        |
| train_acc@5_q2 | 1        |
| train_acc@5_q3 | 1        |
| train_loss     | 0.706    |
| train_loss_q0  | 0.772    |
| train_loss_q1  | 0.669    |
| train_loss_q2  | 0.668    |
| train_loss_q3  | 0.708    |
| val_acc@1      | 0.406    |
| val_acc@1_q0   | 0.455    |
| val_acc@1_q1   | 0.3      |
| val_acc@1_q2   | 0.5      |
| val_acc@1_q3   | 0.429    |
| val_acc@5      | 1        |
| val_acc@5_q0   | 1        |
| val_acc@5_q1   | 1        |


saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
-----------------------------
| grad_norm      | 6.56     |
| param_norm     | 160      |
| samples        | 3.38e+03 |
| step           | 210      |
| train_acc@1    | 0.487    |
| train_acc@1_q0 | 0.436    |
| train_acc@1_q1 | 0.487    |
| train_acc@1_q2 | 0.625    |
| train_acc@1_q3 | 0.405    |
| train_acc@5    | 1        |
| train_acc@5_q0 | 1        |
| train_acc@5_q1 | 1        |
| train_acc@5_q2 | 1        |
| train_acc@5_q3 | 1        |
| train_loss     | 0.713    |
| train_loss_q0  | 0.714    |
| train_loss_q1  | 0.714    |
| train_loss_q2  | 0.681    |
| train_loss_q3  | 0.743    |
| val_acc@1      | 0.438    |
| val_acc@1_q0   | 0.167    |
| val_acc@1_q1   | 0.429    |
| val_acc@1_q2   | 0.417    |
| val_acc@1_q3   | 0.714    |
| val_acc@5      | 1        |
| val_acc@5_q0   | 1        |
| val_acc@5_q1   | 1        |


saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
saving model...
-----------------------------
| grad_norm      | 6.11     |
| param_norm     | 160      |
| samples        | 4.5e+03  |
| step           | 280      |
| train_acc@1    | 0.612    |
| train_acc@1_q0 | 0.703    |
| train_acc@1_q1 | 0.564    |
| train_acc@1_q2 | 0.548    |
| train_acc@1_q3 | 0.643    |
| train_acc@5    | 1        |
| train_acc@5_q0 | 1        |
| train_acc@5_q1 | 1        |
| train_acc@5_q2 | 1        |
| train_acc@5_q3 | 1        |
| train_loss     | 0.676    |
| train_loss_q0  | 0.586    |
| train_loss_q1  | 0.708    |
| train_loss_q2  | 0.737    |
| train_loss_q3  | 0.667    |
| val_acc@1      | 0.438    |
| val_acc@1_q0   | 0.5      |
| val_acc@1_q1   | 0.571    |
| val_acc@1_q2   | 0.3      |
| val_acc@1_q3   | 0.429    |
| val_acc@5      | 1        |
| val_acc@5_q0   | 1        |
| val_acc@5_q1   | 1        |


RuntimeError: [enforce fail at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\caffe2\serialize\inline_container.cc:325] . unexpected pos 128958528 vs 128958420