In [1]:
# !pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/1.9.0/MindSpore/gpu/x86_64/cuda-11.1/mindspore_gpu-1.9.0-cp37-cp37m-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple
!pip install -q -U einops datasets matplotlib tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple
!pip install download -i https://pypi.tuna.tsinghua.edu.cn/simple

[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.1.2[0m[39;49m -> [0m[32;49m22.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
# from einops import rearrange

def rearrange(head, inputs):
    b, hc, x, y = inputs.shape
    c = hc // head

    return inputs.reshape((b, head, c, x*y))

import mindspore as ms
import mindspore.nn as nn
from mindspore import context, ms_function
from mindspore.common.initializer import initializer, HeUniform, Uniform, Normal, _calculate_fan_in_and_fan_out


context.set_context(device_target="GPU", mode=1)

class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='same', padding=0, dilation=1, group=1, has_bias=True):
        super().__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, group, has_bias, weight_init='normal', bias_init='zeros')
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape))
        #self.weight = Parameter(initializer(HeUniform(math.sqrt(5)), self.weight.shape), name='weight')
        if self.has_bias:
            fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape)
            bound = 1 / math.sqrt(fan_in)
            self.bias.set_data(initializer(Uniform(bound), [self.out_channels]))
            
            
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

class Residual(nn.Cell):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def construct(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x
    
# """Upsample"""

from mindspore.ops import constexpr
from mindspore.ops.operations.image_ops import ResizeBilinearV2, ResizeLinear1D
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.common.initializer import initializer, HeUniform, Uniform, Normal, _calculate_fan_in_and_fan_out

def Upsample(dim, dim_out=None):
    @constexpr
    def _check_scale_factor(shape, scale_factor):
        if isinstance(scale_factor, tuple) and len(scale_factor) != len(shape[2:]):
            raise ValueError(f"the number of 'scale_fator' must match to inputs.shape[2:], "
                             f"but get scale_factor={scale_factor}, inputs.shape[2:]={shape[2:]}")

    def _interpolate_output_shape(shape, scales, sizes, mode):
        """calculate output shape"""
        if sizes is not None:
            if mode == "nearest":
                return sizes
            return Tensor(sizes)

        ret = ()        
        for i in range(len(shape[2:])):
            if isinstance(scales, float):
                out_i = int(scales * shape[i+2])
            else:
                out_i = int(scales[i] * shape[i+2])
            ret = ret + (out_i,)
        if mode == "nearest":
            return ret
        return Tensor(ret)

    class Upsample_cls(nn.Cell):
        def __init__(self, size = None, scale_factor = None,
                     mode: str = 'nearest', align_corners = False):
            super().__init__()
            if mode not in ['nearest', 'linear', 'bilinear']:
                raise ValueError(f'do not support mode :{mode}.')
            if size and scale_factor:
                raise ValueError(f"can not set 'size' and 'scale_fator' at the same time.")
            self.size = size
            if isinstance(scale_factor, tuple):
                self.scale_factor = tuple(float(factor) for factor in scale_factor)
            else:
                self.scale_factor = float(scale_factor) if scale_factor else None
            self.mode = mode
            self.align_corners = align_corners

        def construct(self, inputs):
            inputs_shape = inputs.shape
            _check_scale_factor(inputs_shape, self.scale_factor)
            sizes = _interpolate_output_shape(inputs_shape, self.scale_factor, self.size, self.mode)
            if self.mode == 'nearest':
                interpolate = _get_cache_prim(ops.ResizeNearestNeighbor)(sizes, self.align_corners)
                return interpolate(inputs)
            elif self.mode == 'linear':
                interpolate = _get_cache_prim(ResizeLinear1D)('align_corners' if self.align_corners else 'half_pixel')
                return interpolate(inputs, sizes)
            elif self.mode == 'bilinear':
                interpolate = _get_cache_prim(ResizeBilinearV2)(self.align_corners, True if self.align_corners==False else False)
                return interpolate(inputs, sizes)
            return inputs
        
    return nn.SequentialCell(
            Upsample_cls(scale_factor = 2, mode = 'nearest'),
            Conv2d(dim, default(dim_out, dim), 3, padding = 1, pad_mode='pad'))



def Downsample(dim):
    return Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)

class SinusoidalPositionEmbeddings(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = np.exp(np.arange(half_dim) * - emb)
        self.emb = Tensor(emb, mindspore.float32)

    def construct(self, x):
        emb = x[:, None] * self.emb[None, :]
        emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
        return emb
    
class Identity(nn.Cell):
    def construct(self, inputs):
        return inputs

class WeightStandardizedConv2d(Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """
    def construct(self, x):
        eps = 1e-5

        weight = self.weight
        mean = weight.mean((1, 2, 3), keep_dims=True)
        var = weight.var((1, 2, 3), keepdims=True)
        normalized_weight = (weight - mean) * rsqrt((var + eps))

        output = self.conv2d(x, normalized_weight.astype(x.dtype))
        if self.has_bias:
            output = self.bias_add(output, self.bias)
        return output

class Block(nn.Cell):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1, pad_mode='pad')
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def construct(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Cell):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.SequentialCell(
            nn.SiLU(),
            nn.Dense(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = Conv2d(dim, dim_out, 1, pad_mode='valid') if dim != dim_out else Identity()

    def construct(self, x, time_emb = None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = time_emb.expand_dims(-1).expand_dims(-1) 
            scale_shift = time_emb.split(axis=1, output_num=2)
        h = self.block1(x, scale_shift = scale_shift)
        h = self.block2(h)
        h = h + self.res_conv(x)
        return h
    
from mindspore import ops, Parameter
from mindspore.common.initializer import initializer, Normal


class BMM(nn.Cell):
    def __init__(self):
        super().__init__()
        self.bmm = ops.BatchMatMul()

    def construct(self, x, y):
        return self.bmm(x, y)
    


class Attention(nn.Cell):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias = False)
        self.to_out = Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias = True)
        self.map = ops.Map()
        self.partial = ops.Partial()
        self.bmm = BMM()
        self.is_ascend = mindspore.get_context('device_target') == 'Ascend'

    def construct(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).split(1, 3)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

        q = q * self.scale

        # 'b h d i, b h d j -> b h i j'
        if self.is_ascend:
            sim = (q.expand_dims(-1) * k.expand_dims(-2)).sum(2)
        else:
            sim = self.bmm(q.swapaxes(2, 3), k)
        attn = softmax(sim, axis=-1)
        # 'b h i j, b h d j -> b h i d'
        if self.is_ascend:
            out = (attn.expand_dims(3) * v.expand_dims(2)).sum(-1)
        else:
            out = self.bmm(attn, v.swapaxes(2, 3))
        # out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        out = out.swapaxes(-1, -2).reshape((b, -1, h, w))

        return self.to_out(out)

class LayerNorm(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')

    def construct(self, x):
        eps = 1e-5
        var = x.var(1, keepdims=True)
        mean = x.mean(1, keep_dims=True)
        return (x - mean) * rsqrt((var + eps)) * self.g

class LinearAttention(nn.Cell):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias = False)

        self.to_out = nn.SequentialCell(
            Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias = True),
            LayerNorm(dim)
        )

        self.map = ops.Map()
        self.partial = ops.Partial()
        self.bmm = BMM()
        self.is_ascend = mindspore.get_context('device_target') == 'Ascend'

    def construct(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).split(1, 3)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

        q = softmax(q, -2)
        k = softmax(k, -1)

        q = q * self.scale
        v = v / (h * w)

        # 'b h d n, b h e n -> b h d e'
        if self.is_ascend:
            context = (k.expand_dims(3) * v.expand_dims(2)).sum(-1)
        else:
            context = self.bmm(k, v.swapaxes(2, 3))

        # 'b h d e, b h d n -> b h e n'
        if self.is_ascend:
            out = (context.expand_dims(-1) * q.expand_dims(-2)).sum(2)
        else:
            out = self.bmm(context.swapaxes(2, 3), q)

        out = out.reshape((b, -1, h, w))
        return self.to_out(out)

class PreNorm(nn.Cell):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def construct(self, x):
        x = self.norm(x)
        return self.fn(x)
    
class Unet(nn.Cell):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        with_time_emb=True,
        resnet_block_groups=8,
        use_convnext=False,
        convnext_mult=2,
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad",has_bias=True)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        
        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.SequentialCell(
                SinusoidalPositionEmbeddings(dim),
                nn.Dense(dim, time_dim),
                nn.GELU(),
                nn.Dense(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers
        self.downs = nn.CellList([])
        self.ups = nn.CellList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.CellList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.CellList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.SequentialCell(
            block_klass(dim, dim), Conv2d(dim, out_dim, 3)
        )

    def construct(self, x, time):
        x = self.init_conv(x)
        r = x.copy()
        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        len_h = len(h) - 1
        for block1, block2, attn, upsample in self.ups:
            x = ops.concat((x, h[len_h]), 1)
            len_h -= 1
            x = block1(x, t)

            x = ops.concat((x, h[len_h]), 1)
            len_h -= 1
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = ops.concat((x, r), 1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)
    
    
def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = np.linspace(0, timesteps, steps).astype(np.float32)
    alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return np.clip(betas, 0.0001, 0.999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)


def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return np.linspace(beta_start**0.5, beta_end**0.5, timesteps).astype(np.float32) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return np.sigmoid(betas) * (beta_end - beta_start).astype(np.float32) + beta_start

import numpy as np
from mindspore import Tensor
import mindspore

timesteps = 200

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values = 1)

sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

# def extract(a, t, x_shape):
#     return a[t, None, None, None]

def extract(a, t, x_shape):
#     batch_size = t.shape[0]
#     out = a.gather(-1, t)
#     return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
#     b = t.shape[0]
#     out = a.gather_elements(-1, t)
#     print(type(out))
#     return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    return a[t, None, None, None]

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image


from mindspore.dataset.vision import Resize, Inter, CenterCrop, ToTensor, RandomHorizontalFlip, Rescale
from download import download
from mindspore.dataset import ImageFolderDataset
from multiprocessing import cpu_count
from mindspore.dataset.vision import RandomHorizontalFlip


image_size = 128
transforms = [
    Resize(image_size, Inter.BILINEAR),
    CenterCrop(image_size),
    ToTensor()
]

# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# path = download(url, './image_cat/jpg', replace=False)
path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(), extensions=['.jpg', '.jpeg', '.png', '.tiff'],
                             num_shards=1, shard_id=0, shuffle=False, decode=True)
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)


# forward diffusion
from mindspore.ops._primitive_cache import _get_cache_prim

def rsqrt(x):
    rsqrt_op = _get_cache_prim(ops.Rsqrt)()
    return rsqrt_op(x)

def randn_like(x, dtype=None):
    if dtype is None:
        dtype = x.dtype
    normal = _get_cache_prim(ops.StandardNormal)()
    return normal(x.shape).astype(dtype)

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = randn_like(x_start)
    return (
        extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
        extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
    )

def softmax(x, axis=-1):
    if gpu_target:
        softmax_ = _get_cache_prim(ops.Softmax)(axis=axis)
        return softmax_(x)
    exp_ = _get_cache_prim(ops.Exp)()
    reduce_sum_ = _get_cache_prim(ops.ReduceSum)(True)

    x_max = x.max(axis=axis, keepdims=True)
    x_exp = exp_(x - x_max)
    partion = reduce_sum_(x_exp, axis)
    return x_exp / partion


import mindspore.ops as ops


def p_losses(unet_model, x_start, t, noise=None):
    if noise is None:
        noise = randn_like(x_start)
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = unet_model(x_noisy, t)

    loss = nn.SmoothL1Loss()(noise, predicted_noise)

    return loss

from mindspore.dataset import FashionMnistDataset

fashion_mnist_dataset_dir = "./dataset"
# # fashion_mnist_dataset_dir = download(url, fashion_mnist_dataset_dir )
dataset = FashionMnistDataset(dataset_dir=fashion_mnist_dataset_dir, num_parallel_workers=cpu_count(), shuffle=True, 
                            num_shards=1, shard_id=0)
transfroms = [
        RandomHorizontalFlip(),
        ToTensor(),  
        lambda t: (t * 2) - 1
]

dataset = dataset.project('image')
dataset = dataset.shuffle(64)
dataset = dataset.map(transfroms, 'image')
dataset = dataset.batch(128, drop_remainder=True)

x = next(dataset.create_dict_iterator())
print(x.keys())


import numpy as np

@ms_function
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    
    def sqrt(x):
        return ops.sqrt(x.astype(mindspore.float32))

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + sqrt(posterior_variance_t) * noise 

# Algorithm 2 but save all images:
def randn(shape, dtype=None):
    if dtype is None:
        dtype = mindspore.float32
    normal = _get_cache_prim(ops.StandardNormal)()
    return normal(shape).astype(dtype)

def p_sample_loop(model, shape):
    b = shape[0]
    
    # start from pure noise (for each example in the batch)
    img = randn(shape, dtype=None)
    imgs = []
    
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, np.full((b,), i))
        imgs.append(img.asnumpy())
    return imgs

def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))


from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000


from mindspore import Tensor, Parameter, context, ms_class
import mindspore.common.dtype as mstype

@ms_class
class LossScaler():
    """
    Basic LossScaler.
    """
    def __init__(self, scale_value):
        super().__init__()
        self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value")
        self.counter = Parameter(Tensor(0, dtype=mstype.int32), name="counter")

    def scale(self, inputs):
        """scale inputs tensor."""
        raise NotImplementedError

    def unscale(self, inputs):
        """unscale inputs tensor."""
        raise NotImplementedError

    def adjust(self, grads_finite):
        """adjust scale value."""
        raise NotImplementedError

class NoLossScaler(LossScaler):
    """
    No LossScaler
    """
    def __init__(self):
        super().__init__(1)

    def scale(self, inputs):
        return inputs

    def unscale(self, inputs):
        return inputs

    def adjust(self, grads_finite):
        return True
    
loss_scaler = NoLossScaler()


###### device = "cuda" if torch.cuda.is_available() else "cpu"

image_size = 24
channels = 1

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
print(model.parameters_and_names)


name_list = []
for (name, par) in list(model.parameters_and_names()):
    name_list.append(name)
i = 0
for item in list(model.trainable_params()):
    item.name = name_list[i]
    i+=1
    

optimizer = nn.Adam(model.trainable_params(), learning_rate=1e-3)

# For Loss Scaler
ascend_target = (context.get_context("device_target") == "Ascend")
gpu_target = (context.get_context("device_target") == "GPU")
reciprocal = ops.Reciprocal()

gpu_float_status = ops.FloatStatus()
npu_alloc_float_status = ops.NPUAllocFloatStatus()
npu_clear_float_status = ops.NPUClearFloatStatus()
npu_get_float_status = ops.NPUGetFloatStatus()
if ascend_target:
    status = npu_alloc_float_status()
    _ = npu_clear_float_status(status)
else:
    status = None

hypermap = ops.HyperMap()
partial = ops.Partial()


def grad_unscale(scale, grad):
    """grad unscale."""
    return grad * reciprocal(scale).astype(grad.dtype)

def grad_scale(scale, grad):
    """grad scale."""
    return grad * scale.astype(grad.dtype)

def is_finite(inputs):
    """whether input tensor is finite."""
    if gpu_target:
        return gpu_float_status(inputs)[0] == 0
    status = ops.isfinite(inputs)
    return status.all()

def all_finite(inputs):
    """whether all inputs tensor are finite."""
    if ascend_target:
        status = ops.depend(status, inputs)
        get_status = npu_get_float_status(status)
        status = ops.depend(status, get_status)
        status_finite = status.sum() == 0
        _ = npu_clear_float_status(status)
        return status_finite
    outputs = hypermap(partial(is_finite), inputs)
    return ops.stack(outputs).all()

from mindspore.ops import stop_gradient, GradOperation

grad_func = GradOperation(True, False, False)
grad_cell = GradOperation(False, True, False)


def value_and_grad(fn, pos=None, params=None, has_aux=False):
    if params is None:
        grad_ = grad_func
    else:
        grad_ = grad_cell

    def fn_aux(*args):
        outputs = fn(*args)
        no_grad_outputs = (outputs[0],)
        for out in outputs[1:]:
            no_grad_outputs += (stop_gradient(out),)
        return no_grad_outputs

    if has_aux:
        fn_ = fn_aux
    else:
        fn_ = fn

    def value_and_grad_f(*args):
        values = fn_(*args)
        if params is None:
            grads = grad_(fn_)(*args)
        else:
            grads = grad_(fn_, params)(*args)
        return values, grads
    return value_and_grad_f

def grad(fn, pos=None, params=None, has_aux=False):
    value_and_grad_f = value_and_grad(fn, pos, params, has_aux)
    def grad_f(*args):
        _, g = value_and_grad_f(*args)
        return g
    return grad_f



import mindspore
from mindspore import ms_class, Tensor, Parameter, ops

def clip_grad_norm(grads, max_norm: float, norm_type: float = 2.0):
    if isinstance(grads, mindspore.Tensor):
        grads = [grads]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(grads) == 0:
        return [], mindspore.Tensor(0., mindspore.float32)

    if norm_type == inf:
        norms = [grad.abs().max() for grad in grads]
        total_norm = norms[0] if len(norms) == 1 else ops.max(ops.stack(norms))
    else:
        norms = ()
        for grad in grads:
            norms += (norm(grad, norm_type),)
        total_norm = norm(ops.stack(norms), norm_type)

    clip_coef = ops.div(max_norm, (total_norm + ops.scalar_to_tensor(1e-6, mindspore.float32)))
    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
    # when the gradients do not reside in CPU memory.
    clip_coef_clamped = clip_coef.clip(None, 1.0)
    new_grads = ()
    for grad in grads:
        new_grads += (ops.mul(grad, clip_coef_clamped),)
    return new_grads, total_norm

@ms_class
class Accumulator():
    def __init__(self, optimizer, accumulate_step, total_step=None, clip_norm=1.0):
        # super().__init__()
        self.optimizer = optimizer
        self.clip_norm = clip_norm
        self.inner_grads = optimizer.parameters.clone(prefix="accumulate_", init='zeros')
        self.zeros = optimizer.parameters.clone(prefix="zeros_", init='zeros')
        self.counter = Parameter(Tensor(1, mindspore.int32), 'counter_')
        assert accumulate_step > 0
        self.accumulate_step = accumulate_step
        if total_step is not None:
            assert total_step > accumulate_step and total_step > 0
        self.total_step = total_step
        self.map = ops.Map()
        self.partial = ops.Partial()
    
    def __call__(self, grads):
        success = self.map(self.partial(ops.assign_add), self.inner_grads, grads)
        if self.counter % self.accumulate_step == 0:
            clip_grads, _ = clip_grad_norm(self.inner_grads, self.clip_norm)
            self.optimizer(clip_grads)
            success = self.map(self.partial(ops.assign), self.inner_grads, self.zeros)

        ops.assign_add(self.counter, Tensor(1, mindspore.int32))

        return success

    
from mindspore import Tensor, context
from mindspore.ops._primitive_cache import _get_cache_prim


def randint(low, high, size, dtype=mindspore.int32):
    uniform_int = _get_cache_prim(ops.UniformInt)()
    return uniform_int(size, Tensor(low, dtype), Tensor(high, dtype)).astype(dtype)


def forward_fn(data, t, noise=None):
    loss = p_losses(model, data, t, noise)
    return loss

grad_fn = value_and_grad(forward_fn, None, optimizer.parameters_dict())


def train_step(data, t, noise):
    loss, grads = grad_fn(data, t, noise)
    grads = ops.identity(grads)
    status = all_finite(grads)
    if status:
        loss = loss_scaler.unscale(loss)
        grads = loss_scaler.unscale(grads)
        accumulator = Accumulator(optimizer,1)
        loss = ops.depend(loss, accumulator(grads))
        # grads = ops.clip_by_global_norm(grads, 1.0)
        # loss = ops.depend(loss, optimizer(grads))
    loss_scaler.adjust(status)
    optimizer(grads)
    return loss    

train_step = ms_function(train_step)



(1, 3, 128, 128)
dict_keys(['image'])
<bound method Cell.parameters_and_names of Unet<
  (init_conv): Conv2d<input_channels=1, output_channels=16, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=1, has_bias=True, weight_init=normal, bias_init=zeros, format=NCHW>
  (time_mlp): SequentialCell<
    (0): SinusoidalPositionEmbeddings<>
    (1): Dense<input_channels=24, output_channels=96, has_bias=True>
    (2): GELU<>
    (3): Dense<input_channels=96, output_channels=96, has_bias=True>
    >
  (downs): CellList<
    (0): CellList<
      (0): ResnetBlock<
        (mlp): SequentialCell<
          (0): SiLU<>
          (1): Dense<input_channels=96, output_channels=48, has_bias=True>
          >
        (block1): Block<
          (proj): WeightStandardizedConv2d<input_channels=16, output_channels=24, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=True, weight_init=normal, bias_init=zeros, format=NCHW>
          

In [2]:
epochs = 5

for epoch in range(epochs):
    step = 0
    for _, batch in enumerate(dataset.create_tuple_iterator()):
        batch_size = batch[0].shape[0]

        # Algorithm 1 line 3: sample t uniformally for every example in the batch
        t = randint(0, timesteps, (batch_size,), dtype=mindspore.int32)
        print("epoch:",epoch,"start", "batch", type(batch[0]))
        loss = train_step(batch[0], t, noise=None)

        if step % 1 == 0:
            print("Loss:", loss.item())



#         # save generated images
#         if step != 0 and step % save_and_sample_every == 0:
#             milestone = step // save_and_sample_every
#             batches = num_to_groups(4, batch_size)
#             all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
#             all_images = torch.cat(all_images_list, dim=0)
#             all_images = (all_images + 1) * 0.5
#         step+=1
# #         save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)

epoch: 0 start batch <class 'mindspore.common.tensor.Tensor'>


[ERROR] CORE(1060966,7f1e1e202740,python):2022-12-19-09:15:34.965.471 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1060966/705420399.py]
[ERROR] CORE(1060966,7f1e1e202740,python):2022-12-19-09:15:34.967.763 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1060966/705420399.py]
[ERROR] CORE(1060966,7f1e1e202740,python):2022-12-19-09:15:34.987.125 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1060966/705420399.py]
[ERROR] CORE(1060966,7f1e1e202740,python):2022-12-19-09:15:34.989.649 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1060966/705420399.py]
[ERROR] CORE(1060966,7f1e1e202740,python):2022-12-19-09:15:38.181.119 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1060966/705420399.py]
[ERROR] CORE(1060966,7f1e1e202740,python):2022-12-19-09:15:38.181

RuntimeError: For 'Conv2D', 'C_in' of input 'x' shape divide by parameter 'group' must be equal to 'C_in' of input 'weight' shape: 48, but got 'C_in' of input 'x' shape: 144, and 'group': 1.

----------------------------------------------------
- The Traceback of Net Construct Code:
----------------------------------------------------
The function call stack (See file '/home/hujingsong/zhangying/diffusion/rank_0/om/analyze_fail.dat' for more details. Get instructions about `analyze_fail.dat` at https://www.mindspore.cn/search?inputValue=analyze_fail.dat):
# 0 In file /tmp/ipykernel_1060966/705420399.py:902
# 1 In file /tmp/ipykernel_1060966/705420399.py:810
# 2 In file /tmp/ipykernel_1060966/705420399.py:813
# 3 In file /tmp/ipykernel_1060966/705420399.py:895
# 4 In file /tmp/ipykernel_1060966/705420399.py:581
# 5 In file /tmp/ipykernel_1060966/705420399.py:895
# 6 In file /tmp/ipykernel_1060966/705420399.py:409
# 7 In file /tmp/ipykernel_1060966/705420399.py:895
# 8 In file /tmp/ipykernel_1060966/705420399.py:409
# 9 In file /tmp/ipykernel_1060966/705420399.py:895
# 10 In file /tmp/ipykernel_1060966/705420399.py:409
# 11 In file /tmp/ipykernel_1060966/705420399.py:895
# 12 In file /tmp/ipykernel_1060966/705420399.py:409
# 13 In file /tmp/ipykernel_1060966/705420399.py:895
# 14 In file /tmp/ipykernel_1060966/705420399.py:424
# 15 In file /tmp/ipykernel_1060966/705420399.py:431
# 16 In file /tmp/ipykernel_1060966/705420399.py:424
# 17 In file /tmp/ipykernel_1060966/705420399.py:198
# 18 In file /tmp/ipykernel_1060966/705420399.py:173
# 19 In file /tmp/ipykernel_1060966/705420399.py:170
# 20 In file /tmp/ipykernel_1060966/705420399.py:158
# 21 In file /tmp/ipykernel_1060966/705420399.py:157

----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/core/ops/conv2d.cc:214 Conv2dInferShape


<div class="output stream stdout">

    Output:
    ----------------------------------------------------------------------------------------------------
    Loss: 0.46477368474006653
    Loss: 0.12143351882696152
    Loss: 0.08106148988008499
    Loss: 0.0801810547709465
    Loss: 0.06122320517897606
    Loss: 0.06310459971427917
    Loss: 0.05681884288787842
    Loss: 0.05729678273200989
    Loss: 0.05497899278998375
    Loss: 0.04439849033951759
    Loss: 0.05415581166744232
    Loss: 0.06020551547408104
    Loss: 0.046830907464027405
    Loss: 0.051029372960329056
    Loss: 0.0478244312107563
    Loss: 0.046767622232437134
    Loss: 0.04305662214756012
    Loss: 0.05216279625892639
    Loss: 0.04748568311333656
    Loss: 0.05107741802930832
    Loss: 0.04588869959115982
    Loss: 0.043014321476221085
    Loss: 0.046371955424547195
    Loss: 0.04952816292643547
    Loss: 0.04472338408231735

</div>

## Sampling (inference)

To sample from the model, we can just use our sample function defined above:


In [None]:
# sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

In [None]:
# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

<img src="https://drive.google.com/uc?id=1ytnzS7IW7ortC6ub85q7nud1IvXe2QTE" width="300" />

Seems like the model is capable of generating a nice T-shirt! Keep in mind that the dataset we trained on is pretty low-resolution (28x28).

We can also create a gif of the denoising process:

In [None]:
import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

<img src="https://drive.google.com/uc?id=1eyonQWhfmbQsTq8ndsNjw5QSRQ9em9Au" width="500" />

# Follow-up reads

Note that the DDPM paper showed that diffusion models are a promising direction for (un)conditional image generation. This has since then (immensely) been improved, most notably for text-conditional image generation. Below, we list some important (but far from exhaustive) follow-up works:

- Improved Denoising Diffusion Probabilistic Models ([Nichol et al., 2021](https://arxiv.org/abs/2102.09672)): finds that learning the variance of the conditional distribution (besides the mean) helps in improving performance
- Cascaded Diffusion Models for High Fidelity Image Generation ([Ho et al., 2021](https://arxiv.org/abs/2106.15282)): introduce cascaded diffusion, which comprises a pipeline of multiple diffusion models that generate images of increasing resolution for high-fidelity image synthesis
- Diffusion Models Beat GANs on Image Synthesis ([Dhariwal et al., 2021](https://arxiv.org/abs/2105.05233)): show that diffusion models can achieve image sample quality superior to the current state-of-the-art generative models by improving the U-Net architecture, as well as introducing classifier guidance
- Classifier-Free Diffusion Guidance ([Ho et al., 2021](https://openreview.net/pdf?id=qw8AKxfYbI)): shows that you don't need a classifier for guiding a diffusion model by jointly training a conditional and an unconditional diffusion model with a single neural network
- Hierarchical Text-Conditional Image Generation with CLIP Latents (DALL-E 2) ([Ramesh et al., 2022](https://cdn.openai.com/papers/dall-e-2.pdf)): use a prior to turn a text caption into a CLIP image embedding, after which a diffusion model decodes it into an image
- Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding (ImageGen) ([Saharia et al., 2022](https://arxiv.org/abs/2205.11487)): shows that combining a large pre-trained language model (e.g. T5) with cascaded diffusion works well for text-to-image synthesis

Note that this list only includes important works until the time of writing, which is June 7th, 2022.

For now, it seems that the main (perhaps only) disadvantage of diffusion models is that they require multiple forward passes to generate an image (which is not the case for generative models like GANs). However, there's [research going on](https://arxiv.org/abs/2204.13902) that enables high-fidelity generation in as few as 10 denoising steps.