In [1]:
from get_args import get_args
args = [
"--dataset","afhq_cat",
"--image_size","256",
"--exp","afhq-cat-kl-f4-256-3",
"--num_channels","3",
"--num_channels_dae","128",
"--num_timesteps","4",
"--num_res_blocks","2",
"--batch_size","32",
"--num_epoch","500",
"--ngf","64",
"--nz","100",
"--z_emb_dim","256",
"--n_mlp","4",
"--embedding_type","positional",
"--use_ema",
"--ema_decay","0.999",
"--r1_gamma","0.02",
"--lr_d","5.0e-5",
"--lr_g","2.0e-4",
"--lazy_reg","10",
"--ch_mult", "1", "2", "2", "2",
"--save_content",
"--datadir","data/afhq",
"--master_port","6087",
"--num_process_per_node","1",
"--save_content_every","1",
"--current_resolution", "64",
"--attn_resolutions", "16",
"--num_disc_layers", "4",
"--scale_factor", "6.0",
"--no_lr_decay", 
"--AutoEncoder_config", "autoencoder/config/afhq-cat-kl-f4.yaml", 
"--AutoEncoder_ckpt", "autoencoder/weight/afhq-cat-kl-f4.ckpt", 
"--rec_loss",
"--sigmoid_learning",
]

args = get_args(args)

In [2]:
import argparse
import os
import shutil

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from datasets_prep.dataset import create_dataset
from diffusion import sample_from_model, sample_posterior, \
    q_sample_pairs, get_time_schedule, \
    Posterior_Coefficients, Diffusion_Coefficients
#from DWT_IDWT.DWT_IDWT_layer import DWT_2D, IDWT_2D
#from pytorch_wavelets import DWTForward, DWTInverse
from torch.multiprocessing import Process
from utils import init_processes, copy_source, broadcast_params
import yaml

from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
import wandb

def load_model_from_config(config_path, ckpt):
    print(f"Loading model from {ckpt}")
    config = OmegaConf.load(config_path)
    pl_sd = torch.load(ckpt, map_location="cpu")
    #global_step = pl_sd["global_step"]
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model = model.first_stage_model
    model.cuda()
    model.eval()
    del m
    del u
    del pl_sd
    return model

def grad_penalty_call(args, D_real, x_t):
    grad_real = torch.autograd.grad(
        outputs=D_real.sum(), inputs=x_t, create_graph=True
    )[0]
    grad_penalty = (
        grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
    ).mean()

    grad_penalty = args.r1_gamma / 2 * grad_penalty
    grad_penalty.backward()


# %%
def train(rank, gpu, args):
    from EMA import EMA
    from score_sde.models.discriminator import Discriminator_large, Discriminator_small
    from score_sde.models.ncsnpp_generator_adagn import NCSNpp, WaveletNCSNpp

    torch.manual_seed(args.seed + rank)
    torch.cuda.manual_seed(args.seed + rank)
    torch.cuda.manual_seed_all(args.seed + rank)
    device = torch.device('cuda:{}'.format(gpu))

    batch_size = args.batch_size

    nz = args.nz  # latent dimension

    dataset = create_dataset(args)
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset,
                                                                    num_replicas=args.world_size,
                                                                    rank=rank)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              sampler=train_sampler,
                                              drop_last=True)
    args.ori_image_size = args.image_size
    args.image_size = args.current_resolution
    G_NET_ZOO = {"normal": NCSNpp, "wavelet": WaveletNCSNpp}
    gen_net = G_NET_ZOO[args.net_type]
    disc_net = [Discriminator_small, Discriminator_large]
    print("GEN: {}, DISC: {}".format(gen_net, disc_net))
    netG = gen_net(args).to(device)

    if args.dataset in ['cifar10', 'stl10']:
        netD = disc_net[0](nc=2 * args.num_channels, ngf=args.ngf,
                           t_emb_dim=args.t_emb_dim,
                           act=nn.LeakyReLU(0.2), num_layers=args.num_disc_layers).to(device)
    else:
        netD = disc_net[1](nc=2 * args.num_channels, ngf=args.ngf,
                           t_emb_dim=args.t_emb_dim,
                           act=nn.LeakyReLU(0.2), num_layers=args.num_disc_layers).to(device)

    broadcast_params(netG.parameters())
    broadcast_params(netD.parameters())

    optimizerD = optim.Adam(filter(lambda p: p.requires_grad, netD.parameters(
    )), lr=args.lr_d, betas=(args.beta1, args.beta2))
    optimizerG = optim.Adam(filter(lambda p: p.requires_grad, netG.parameters(
    )), lr=args.lr_g, betas=(args.beta1, args.beta2))

    if args.use_ema:
        optimizerG = EMA(optimizerG, ema_decay=args.ema_decay)

    schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizerG, args.num_epoch, eta_min=1e-5)
    schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizerD, args.num_epoch, eta_min=1e-5)

    # ddp
    netG = nn.parallel.DistributedDataParallel(
        netG, device_ids=[gpu], find_unused_parameters=True)
    netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])

    """############### DELETE TO AVOID ERROR ###############"""
    # Wavelet Pooling
    #if not args.use_pytorch_wavelet:
    #    dwt = DWT_2D("haar")
    #    iwt = IDWT_2D("haar")
    #else:
    #    dwt = DWTForward(J=1, mode='zero', wave='haar').cuda()
    #    iwt = DWTInverse(mode='zero', wave='haar').cuda()
        
    
    #load encoder and decoder
    config_path = args.AutoEncoder_config 
    ckpt_path = args.AutoEncoder_ckpt 
    
    if args.dataset in ['cifar10', 'stl10', 'afhq_cat']:

        with open(config_path, 'r') as file:
            config = yaml.safe_load(file)
        
        AutoEncoder = instantiate_from_config(config['model'])
        

        checkpoint = torch.load(ckpt_path, map_location=device)
        AutoEncoder.load_state_dict(checkpoint['state_dict'])
        AutoEncoder.eval()
        AutoEncoder.to(device)
    
    else:
        AutoEncoder = load_model_from_config(config_path, ckpt_path)
    """############### END DELETING ###############"""
    
    num_levels = int(np.log2(args.ori_image_size // args.current_resolution))

    exp = args.exp
    parent_dir = "./saved_info/{}".format(args.dataset)

    exp_path = os.path.join(parent_dir, exp)
    if rank == 0:
        if not os.path.exists(exp_path):
            os.makedirs(exp_path)
            copy_source(__file__, exp_path)
            shutil.copytree('score_sde/models',
                            os.path.join(exp_path, 'score_sde/models'))

    coeff = Diffusion_Coefficients(args, device)
    pos_coeff = Posterior_Coefficients(args, device)
    T = get_time_schedule(args, device)

    if args.resume or os.path.exists(os.path.join(exp_path, 'content.pth')):
        checkpoint_file = os.path.join(exp_path, 'content.pth')
        checkpoint = torch.load(checkpoint_file, map_location=device)
        init_epoch = checkpoint['epoch']
        epoch = init_epoch
        # load G
        netG.load_state_dict(checkpoint['netG_dict'])
        #optimizerG.load_state_dict(checkpoint['optimizerG'])
        schedulerG.load_state_dict(checkpoint['schedulerG'])
        # load D
        netD.load_state_dict(checkpoint['netD_dict'])
        #optimizerD.load_state_dict(checkpoint['optimizerD'])
        schedulerD.load_state_dict(checkpoint['schedulerD'])

        global_step = checkpoint['global_step']
        print("=> loaded checkpoint (epoch {})"
              .format(checkpoint['epoch']))
    else:
        global_step, epoch, init_epoch = 0, 0, 0

    '''Sigmoid learning parameter'''
    gamma = 6
    beta = np.linspace(-gamma, gamma, args.num_epoch+1)
    alpha = 1 - 1 / (1+np.exp(-beta))

    for epoch in range(init_epoch, args.num_epoch + 1):
        train_sampler.set_epoch(epoch)

        for iteration, (x, y) in enumerate(data_loader):
            for p in netD.parameters():
                p.requires_grad = True
            netD.zero_grad()

            for p in netG.parameters():
                p.requires_grad = False

            # sample from p(x_0)
            x0 = x.to(device, non_blocking=True)

            """################# Change here: Encoder #################"""
            with torch.no_grad():
                posterior = AutoEncoder.encode(x0)
                real_data = posterior.sample().detach()
            #print("MIN:{}, MAX:{}".format(real_data.min(), real_data.max()))
            real_data = real_data / args.scale_factor #300.0  # [-1, 1]
            
            
            #assert -1 <= real_data.min() < 0
            #assert 0 < real_data.max() <= 1
            """################# End change: Encoder #################"""
            # sample t
            t = torch.randint(0, args.num_timesteps,
                              (real_data.size(0),), device=device)

            x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)
            x_t.requires_grad = True

            # train with real
            D_real = netD(x_t, t, x_tp1.detach()).view(-1)
            errD_real = F.softplus(-D_real).mean()

            errD_real.backward(retain_graph=True)

            if args.lazy_reg is None:
                grad_penalty_call(args, D_real, x_t)
            else:
                if global_step % args.lazy_reg == 0:
                    grad_penalty_call(args, D_real, x_t)

            # train with fake
            latent_z = torch.randn(batch_size, nz, device=device)
            x_0_predict = netG(x_tp1.detach(), t, latent_z)
            x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)

            output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
            errD_fake = F.softplus(output).mean()

            errD_fake.backward()

            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            # update G
            for p in netD.parameters():
                p.requires_grad = False

            for p in netG.parameters():
                p.requires_grad = True
            netG.zero_grad()

            t = torch.randint(0, args.num_timesteps,
                              (real_data.size(0),), device=device)
            x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)

            latent_z = torch.randn(batch_size, nz, device=device)
            x_0_predict = netG(x_tp1.detach(), t, latent_z)
            x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)

            output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
            errG = F.softplus(-output).mean()

            # reconstructior loss
            if args.sigmoid_learning and args.rec_loss:
                ######alpha
                rec_loss = F.l1_loss(x_0_predict, real_data)
                errG = errG + alpha[epoch]*rec_loss

            elif args.rec_loss and not args.sigmoid_learning:
                rec_loss = F.l1_loss(x_0_predict, real_data)
                errG = errG + rec_loss
            

            errG.backward()
            optimizerG.step()

            global_step += 1
            if iteration % 100 == 0:
                if rank == 0:
                    if args.sigmoid_learning:
                        print('epoch {} iteration{}, G Loss: {}, D Loss: {}, alpha: {}'.format(
                            epoch, iteration, errG.item(), errD.item(), alpha[epoch]))
                    elif args.rec_loss:
                        print('epoch {} iteration{}, G Loss: {}, D Loss: {}, rec_loss: {}'.format(
                            epoch, iteration, errG.item(), errD.item(), rec_loss.item()))
                    else:   
                        print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(
                            epoch, iteration, errG.item(), errD.item()))

        if not args.no_lr_decay:

            schedulerG.step()
            schedulerD.step()

        if rank == 0:
            wandb.log({"G_loss": errG.item(), "D_loss": errD.item(), "alpha": alpha[epoch]})
            ########################################
            x_t_1 = torch.randn_like(posterior.sample())
            fake_sample = sample_from_model(
                pos_coeff, netG, args.num_timesteps, x_t_1, T, args)

            """############## CHANGE HERE: DECODER ##############"""
            fake_sample *= args.scale_factor #300
            real_data *= args.scale_factor #300
            with torch.no_grad():
                fake_sample = AutoEncoder.decode(fake_sample)
                real_data = AutoEncoder.decode(real_data)
            
            fake_sample = (torch.clamp(fake_sample, -1, 1) + 1) / 2  # 0-1
            real_data = (torch.clamp(real_data, -1, 1) + 1) / 2  # 0-1 
            
            """############## END HERE: DECODER ##############"""

            torchvision.utils.save_image(fake_sample, os.path.join(
                exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)))
            torchvision.utils.save_image(
                real_data, os.path.join(exp_path, 'real_data.png'))

            if args.save_content:
                if epoch % args.save_content_every == 0:
                    print('Saving content.')
                    content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
                               'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
                               'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
                               'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
                    torch.save(content, os.path.join(exp_path, 'content.pth'))

            if epoch % args.save_ckpt_every == 0:
                if args.use_ema:
                    optimizerG.swap_parameters_with_ema(
                        store_params_in_ema=True)

                torch.save(netG.state_dict(), os.path.join(
                    exp_path, 'netG_{}.pth'.format(epoch)))
                if args.use_ema:
                    optimizerG.swap_parameters_with_ema(
                        store_params_in_ema=True)

In [3]:
wandb.init(
            project="TEST",
            name=args.exp,
            config={
                "dataset": args.dataset,
                "image_size": args.image_size,
                "channels": args.num_channels,
                "channels_dae": args.num_channels_dae,
                "ch_nult": args.ch_mult,
                "timesteps": args.num_timesteps,
                "res_blocks": args.num_res_blocks,
                "nz": args.nz,
                "epochs": args.num_epoch,
                "ngf": args.ngf,
                "lr_g": args.lr_g,
                "lr_d": args.lr_d,
                "batch_size": args.batch_size,
                "r1_gamma": args.r1_gamma,
                "lazy_reg": args.lazy_reg,
                "embedding_type": args.embedding_type,
                "use_ema": args.use_ema,
                "ema_decay": args.ema_decay,
                "no_lr_decay": args.no_lr_decay,
                "z_emb_dim": args.z_emb_dim,
                "attn_resolutions": args.attn_resolutions,
                "use_pytorch_wavelet": args.use_pytorch_wavelet,
                "rec_loss": args.rec_loss,
                "net_type": args.net_type,
                "num_disc_layers": args.num_disc_layers,
                "no_use_fbn": args.no_use_fbn,
                "no_use_freq": args.no_use_freq,
                "no_use_residual": args.no_use_residual,
                "scale_factor": args.scale_factor,
                "AutoEncoder_config": args.AutoEncoder_config,
                "AutoEncoder_ckpt": args.AutoEncoder_ckpt,
                "sigmoid_learning": args.sigmoid_learning,
            }
        )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkotomiya07[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112280376255512, max=1.0…

[34m[1mwandb[0m: Network error resolved after 0:02:17.888630, resuming normal operation.


In [None]:
print('starting in debug mode')
init_processes(0, 1, train, args)

starting in debug mode
GEN: <class 'score_sde.models.ncsnpp_generator_adagn.NCSNpp'>, DISC: [<class 'score_sde.models.discriminator.Discriminator_small'>, <class 'score_sde.models.discriminator.Discriminator_large'>]
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels




loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth




epoch 0 iteration0, G Loss: 0.5996959805488586, D Loss: 1.5086525678634644, alpha: 0.9975273768433652
epoch 0 iteration100, G Loss: 1.0599350929260254, D Loss: 1.414835810661316, alpha: 0.9975273768433652
Saving content.
epoch 1 iteration0, G Loss: 1.1730022430419922, D Loss: 1.3320236206054688, alpha: 0.9974674681467623
epoch 1 iteration100, G Loss: 3.565249443054199, D Loss: 1.5183981657028198, alpha: 0.9974674681467623
Saving content.
epoch 2 iteration0, G Loss: 1.5091724395751953, D Loss: 1.0101717710494995, alpha: 0.997406111708621
epoch 2 iteration100, G Loss: 2.933070659637451, D Loss: 1.7361359596252441, alpha: 0.997406111708621
Saving content.
epoch 3 iteration0, G Loss: 2.40285587310791, D Loss: 0.9868963956832886, alpha: 0.9973432727281952
epoch 3 iteration100, G Loss: 5.039502143859863, D Loss: 0.42012345790863037, alpha: 0.9973432727281952
Saving content.
epoch 4 iteration0, G Loss: 1.3162994384765625, D Loss: 1.4002819061279297, alpha: 0.9972789155772751
epoch 4 iteration

wandb: Network error (ConnectionError), entering retry loop.


Saving content.
epoch 407 iteration0, G Loss: 11.239027976989746, D Loss: 0.0008663555490784347, alpha: 0.022576731385895554
epoch 407 iteration100, G Loss: 8.773289680480957, D Loss: 0.09364686161279678, alpha: 0.022576731385895554
Saving content.
epoch 408 iteration0, G Loss: 7.972092628479004, D Loss: 0.16478705406188965, alpha: 0.022053147285217123
epoch 408 iteration100, G Loss: 7.336418628692627, D Loss: 0.12640242278575897, alpha: 0.022053147285217123
Saving content.
epoch 409 iteration0, G Loss: 10.279129028320312, D Loss: 0.003770428244024515, alpha: 0.02154143817767651
epoch 409 iteration100, G Loss: 3.88248872756958, D Loss: 0.43178337812423706, alpha: 0.02154143817767651
Saving content.
epoch 410 iteration0, G Loss: 6.786442279815674, D Loss: 0.11506091058254242, alpha: 0.021041347020468337
epoch 410 iteration100, G Loss: 10.80633544921875, D Loss: 0.013594258576631546, alpha: 0.021041347020468337
Saving content.
epoch 411 iteration0, G Loss: 8.899704933166504, D Loss: 0.13