In [7]:
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torchinfo import summary

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
from models.ncsnpp import *
from configs.subvp import cifar10_ncsnpp_continuous as configs
config = configs.get_config()
config

data:
  centered: true
  dataset: CIFAR10
  image_size: 32
  num_channels: 3
  random_flip: true
  uniform_dequantization: false
device: !!python/object/apply:torch.device
- cuda
- 0
eval:
  batch_size: 1024
  begin_ckpt: 9
  bpd_dataset: test
  enable_bpd: false
  enable_loss: true
  enable_sampling: false
  end_ckpt: 26
  num_samples: 50000
model:
  attention_type: ddpm
  attn_resolutions: !!python/tuple
  - 16
  beta_max: 20.0
  beta_min: 0.1
  ch_mult: !!python/tuple
  - 1
  - 2
  - 2
  - 2
  conditional: true
  conv_size: 3
  dropout: 0.1
  ema_rate: 0.9999
  embedding_type: positional
  fir: true
  fir_kernel:
  - 1
  - 3
  - 3
  - 1
  fourier_scale: 16
  init_scale: 0.0
  name: ncsnpp
  nf: 128
  nonlinearity: swish
  normalization: GroupNorm
  num_res_blocks: 4
  num_scales: 1000
  progressive: none
  progressive_combine: sum
  progressive_input: residual
  resamp_with_conv: true
  resblock_type: biggan
  scale_by_sigma: false
  sigma_max: 50
  sigma_min: 0.01
  skip_rescale: t

In [12]:
import losses, sde_lib

sde = sde_lib.subVPSDE(
            beta_min=config.model.beta_min,
            beta_max=config.model.beta_max,
            N=config.model.num_scales,
        )

loss_fn = losses.get_sde_loss_fn(
    sde,
    train=True,
    reduce_mean=True,
    continuous=True,
    likelihood_weighting=False,
    eps=1e-5,
    masked_marginals=False,
)

In [8]:
model = NCSNpp(config)
summary(model)

Layer (type:depth-idx)                   Param #
NCSNpp                                   --
├─SiLU: 1-1                              --
├─ModuleList: 1-2                        --
│    └─Linear: 2-1                       66,048
│    └─Linear: 2-2                       262,656
│    └─Conv2d: 2-3                       3,584
│    └─ResnetBlockBigGANpp: 2-4          --
│    │    └─GroupNorm: 3-1               256
│    │    └─Conv2d: 3-2                  147,584
│    │    └─Linear: 3-3                  65,664
│    │    └─GroupNorm: 3-4               256
│    │    └─Dropout: 3-5                 --
│    │    └─Conv2d: 3-6                  147,584
│    │    └─SiLU: 3-7                    --
│    └─ResnetBlockBigGANpp: 2-5          --
│    │    └─GroupNorm: 3-8               256
│    │    └─Conv2d: 3-9                  147,584
│    │    └─Linear: 3-10                 65,664
│    │    └─GroupNorm: 3-11              256
│    │    └─Dropout: 3-12                --
│    │    └─Conv2d: 3-13        

In [9]:
x_dummy = torch.zeros(5,3,8,8)
t = torch.ones(x_dummy.shape[0])

In [11]:
model(x_dummy, t)

TypeError: forward() takes 2 positional arguments but 3 were given

In [10]:
loss = loss_fn(model, x_dummy)
loss

TypeError: forward() takes 2 positional arguments but 3 were given

In [None]:
from models import utils as mutils
train=True
reduce_mean=True
continuous=True
likelihood_weighting=True
eps=1e-5

reduce_op = (
    torch.mean
    if reduce_mean
    else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
)

def loss_fn(model, batch):
        """Compute the loss function.

        Args:
          model: A score model.
          batch: A mini-batch of training data.

        Returns:
          loss: A scalar that represents the average loss value across the mini-batch.
        """
        score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
        t = torch.zeros(batch.shape[0], device=batch.device) * (sde.T - eps) + eps

        z = torch.zeros_like(batch)
        mean, std = sde.marginal_prob(batch, t)
        perturbed_data = mean + sde._unsqueeze(std) * z

        score = score_fn(perturbed_data, t)

        if not likelihood_weighting:
            losses = torch.square(score * sde._unsqueeze(std) + z)
            losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
        else:
            g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2
            losses = torch.square(score + z / sde._unsqueeze(std))
            losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2

        loss = torch.mean(losses)
        return loss

In [52]:
mg_model.load_state_dict(unParalled_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['init_conv.weight', 'init_conv.bias', 'init_pool.weight', 'init_pool.bias', 'down_tr64.dense.weight', 'down_tr64.dense.bias', 'down_tr128.dense.weight', 'down_tr128.dense.bias', 'down_tr256.dense.weight', 'down_tr256.dense.bias', 'down_tr512.dense.weight', 'down_tr512.dense.bias', 'up_tr256.dense.weight', 'up_tr256.dense.bias', 'up_tr128.dense.weight', 'up_tr128.dense.bias', 'up_tr64.dense.weight', 'up_tr64.dense.bias', 'out_tr.final_conv_pp.weight', 'out_tr.final_conv_pp.bias'], unexpected_keys=['out_tr.final_conv.weight', 'out_tr.final_conv.bias'])