In [1]:
"""
Train a noised image classifier on ImageNet.
"""

'\nTrain a noised image classifier on ImageNet.\n'

In [2]:


import argparse
import os
import sys
from torch.autograd import Variable
sys.path.append("..")
sys.path.append(".")
from guided_diffusion.bratsloader import BRATSDataset
import blobfile as bf
import torch as th
os.environ['OMP_NUM_THREADS'] = '8'
os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "gloo"
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
from visdom import Visdom
import numpy as np
viz = Visdom(port=8097)
loss_window = viz.line( Y=th.zeros((1)).cpu(), X=th.zeros((1)).cpu(), opts=dict(xlabel='epoch', ylabel='Loss', title='classification loss'))
val_window = viz.line( Y=th.zeros((1)).cpu(), X=th.zeros((1)).cpu(), opts=dict(xlabel='epoch', ylabel='Loss', title='validation loss'))
acc_window= viz.line( Y=th.zeros((1)).cpu(), X=th.zeros((1)).cpu(), opts=dict(xlabel='epoch', ylabel='acc', title='accuracy'))

from guided_diffusion import dist_util, logger
from guided_diffusion.fp16_util import MixedPrecisionTrainer
from guided_diffusion.image_datasets import load_data
from guided_diffusion.train_util import visualize
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    add_dict_to_argparser,
    args_to_dict,
    classifier_and_diffusion_defaults,
    create_classifier_and_diffusion,
)
from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict



Setting up a new session...
Setting up a new session...


In [3]:
args_dict = {
  "image_size": 256,
  "Batch_Size": 8,
  "batch_size": 8,
  "channels": 3,
  "EPOCHS": 4000,
  "iterations":40000,
  "diffusion_steps": 1000,
  "num_channels": 192,\
  "learn_sigma":True,
  "val_data": True,
  "eval_interval": 1000,
  "save_interval": 1000,
  "log_interval": 10,
  "anneal_lr": True,
  "microbatch": -1,
  "use_kl":False,
  "schedule_sampler": "uniform",
  "resume_checkpoint": False,
  "predict_xstart":False,
  "rescale_timesteps":False,
  "rescale_learned_sigmas":False,
  "timestep_respacing": "",
  "noise_schedule": "cosine",
  "channel_mults": "",
  "loss-type": "l2",
  "loss_weight": "none",
  "train_start": True,
  "lr": 1e-4,
  "learn_sigma ": True,
  "random_slice": True,
  "sample_distance": 800,
  "weight_decay": 0.0,
  "save_imgs": True,
  "save_vids": True,
  "class_cond ": True,
  "use_fp16": True,
  "use_scale_shift_norm": True,
  "dropout": 0.1,
  "attention_resolutions": "32,16,8",
  "num_res_blocks": 3,
  "resblock_updown": True,
  "classifier_use_fp16":False,
  "classifier_width":128,
  "classifier_depth": 2,
  "classifier_attention_resolutions": "32,16,8",
  "classifier_use_scale_shift_norm": True,
  "classifier_resblock_updown": True,
  "classifier_pool": "attention" ,
  "num_head_channels": 64,
  "noised": True,
  "noise_fn": "gauss",
  "dataset": "brats"
}

In [5]:
import torch
import dataset

In [None]:
args = argparse.Namespace(**args_dict)
# args

In [None]:
ROOT_DIR= './'


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from tqdm import tqdm

In [None]:
training_dataset, testing_dataset = dataset.init_datasets(ROOT_DIR, args_dict)
datal = dataset.init_dataset_loader(training_dataset, args_dict)

In [None]:
def main():

#     logger.configure()

#     logger.log("creating model and diffusion...")
    model, diffusion = create_classifier_and_diffusion(args
    )
    model.to(device)
    if args.noised:
        schedule_sampler = create_named_schedule_sampler(
            args.schedule_sampler, diffusion, maxt=1000
        )

    resume_step = 0
    if args.resume_checkpoint:
        resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
#         if dist.get_rank() == 0:
#         logger.log(
#             f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
#         )
        model.load_state_dict(
            dist_util.load_state_dict(
                args.resume_checkpoint, map_location=device
            )
        )
    # Needed for creating correct EMAs and fp16 parameters.
#     dist_util.sync_params(model.parameters())

    mp_trainer = MixedPrecisionTrainer(
        model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0
    )


#     logger.log("creating data loader...")
    training_dataset, testing_dataset = dataset.init_datasets(ROOT_DIR, args_dict)
    datal = dataset.init_dataset_loader(training_dataset, args_dict)
    if args.val_data:
        val_data = dataset.init_dataset_loader(testing_dataset, args_dict)
    else:
        val_data = None




    logger.log(f"creating optimizer...")
    opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay)
    if args.resume_checkpoint:
        opt_checkpoint = bf.join(
            bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt"
        )
        logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
        opt.load_state_dict(
            dist_util.load_state_dict(opt_checkpoint, map_location=device)
        )

    logger.log("training classifier model...")


    def forward_backward_log(data_loader, step, prefix="train"):
        extra = next(data_loader)
        batch = extra["image"].to(device)
        labels = extra["label"].to(device)

        if args.noised:
            t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
            batch = diffusion.q_sample(batch, t)
        else:
            t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())

        for i, (sub_batch, sub_labels, sub_t) in enumerate(
            split_microbatches(args.microbatch, batch, labels, t)
        ):
          
            sub_batch = Variable(sub_batch, requires_grad=True)
            logits = model(sub_batch, timesteps=sub_t)
         
            loss = F.cross_entropy(logits, sub_labels, reduction="none")
            losses = {}
            losses[f"{prefix}_loss"] = loss.detach()
            losses[f"{prefix}_acc@1"] = compute_top_k(
                logits, sub_labels, k=1, reduction="none"
            )
            losses[f"{prefix}_acc@2"] = compute_top_k(
                logits, sub_labels, k=2, reduction="none"
            )
#             print('acc', losses[f"{prefix}_acc@1"])
            log_loss_dict(diffusion, sub_t, losses)

            loss = loss.mean()
            if prefix=="train":
                viz.line(X=th.ones((1, 1)).cpu() * step, Y=th.Tensor([loss]).unsqueeze(0).cpu(),
                     win=loss_window, name='loss_cls',
                     update='append')

            else:
                output_idx = logits[0].argmax()
#                 print('outputidx', output_idx)
                output_max = logits[0, output_idx]
#                 print('outmax', output_max, output_max.shape)
                output_max.backward()
                saliency, _ = th.max(sub_batch.grad.data.abs(), dim=1)
#                 print('saliency', saliency.shape)
                viz.heatmap(visualize(saliency[0, ...]))
                viz.image(visualize(sub_batch[0, 0,...]))
                viz.image(visualize(sub_batch[0, 1, ...]))
                th.cuda.empty_cache()


            if loss.requires_grad and prefix=="train":
                if i == 0:
                    mp_trainer.zero_grad()
                mp_trainer.backward(loss * len(sub_batch) / len(batch))

        return losses

    correct=0; total=0
    for step in tqdm(range(args.iterations - resume_step)):
        logger.logkv("step", step + resume_step)
        logger.logkv(
            "samples",
            (step + resume_step + 1) * args.batch_size,
        )
        if args.anneal_lr:
            set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations)
#         print('step', step + resume_step)
        try:
            losses = forward_backward_log(data, step + resume_step)
        except:
            data = iter(datal)
            losses = forward_backward_log(data, step + resume_step)

        correct+=losses["train_acc@1"].sum()
        total+=args.batch_size
        acctrain=correct/total

        mp_trainer.optimize(opt)
          
        if val_data is not None and not step % args.eval_interval:
            model.eval()
            forward_backward_log(val_data, step, prefix="val")
            model.train()
        if not step % args.log_interval:
            logger.dumpkvs()
        if (step + resume_step) % args.save_interval:
            logger.log("saving model...")
            save_model(mp_trainer, opt, step + resume_step)

#     if dist.get_rank() == 0:
    logger.log("saving model...")
    save_model(mp_trainer, opt, step + resume_step)
#     dist.barrier()


def set_annealed_lr(opt, base_lr, frac_done):
    lr = base_lr * (1 - frac_done)
    for param_group in opt.param_groups:
        param_group["lr"] = lr


def save_model(mp_trainer, opt, step):
    th.save(
        mp_trainer.master_params_to_state_dict(mp_trainer.master_params),
        os.path.join(logger.get_dir(), f"model{step:06d}.pt"),
    )
    th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt"))


def compute_top_k(logits, labels, k, reduction="mean"):
    _, top_ks = th.topk(logits, k, dim=-1)
    if reduction == "mean":
        return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
    elif reduction == "none":
        return (top_ks == labels[:, None]).float().sum(dim=-1)


def split_microbatches(microbatch, *args):
    bs = len(args[0])
    if microbatch == -1 or microbatch >= bs:
        yield tuple(args)
    else:
        for i in range(0, bs, microbatch):
            yield tuple(x[i : i + microbatch] if x is not None else None for x in args)

if __name__ == "__main__":
    main()