In [2]:
import torch
import torch.nn as nn

In [22]:
def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        print('pao')
        return val
    return d() if callable(d) else d

In [23]:
class CrossEmbedLayer(nn.Module):
    def __init__(self, dim_in, kernel_sizes, dim_out = None, stride = 2):
        super().__init__()
        assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
        dim_out = default(dim_out, dim_in)

        kernel_sizes = sorted(kernel_sizes)
        num_scales = len(kernel_sizes)

        # calculate the dimension at each scale
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        self.convs = nn.ModuleList([])
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    def forward(self, x):
        fmaps = tuple(map(lambda conv: conv(x), self.convs))
        return torch.cat(fmaps, dim = 1)

In [24]:
CrossEmbedLayer(dim_in=30, kernel_sizes=[8, 6, 10], dim_out = None, stride = 2)

CrossEmbedLayer(
  (convs): ModuleList(
    (0): Conv2d(30, 15, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2))
    (1): Conv2d(30, 7, kernel_size=(8, 8), stride=(2, 2), padding=(3, 3))
    (2): Conv2d(30, 8, kernel_size=(10, 10), stride=(2, 2), padding=(4, 4))
  )
)

In [25]:
cel = CrossEmbedLayer(dim_in=3, kernel_sizes=[6, 8, 10], dim_out = 5, stride = 4)
cel

pao


CrossEmbedLayer(
  (convs): ModuleList(
    (0): Conv2d(3, 2, kernel_size=(6, 6), stride=(4, 4), padding=(1, 1))
    (1): Conv2d(3, 1, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2))
    (2): Conv2d(3, 2, kernel_size=(10, 10), stride=(4, 4), padding=(3, 3))
  )
)

In [21]:
x = torch.rand((1, 3, 250, 250))
y = cel(x)

y.shape

torch.Size([1, 5, 62, 62])

In [26]:
p = [10, 20]
callable(p)

False

In [None]:
class Unet(nn.Module):
    def __init__(
        self,
        *,
        dim,
        image_embed_dim = 1024,
        text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
        num_resnet_blocks = 1,
        cond_dim = None,
        num_image_tokens = 4,
        num_time_tokens = 2,
        learned_sinu_pos_emb = True,
        learned_sinu_pos_emb_dim = 16,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        cond_images_channels = 0,
        channels = 3,
        channels_out = None,
        attn_dim_head = 64,
        attn_heads = 8,
        ff_mult = 2.,
        lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
        layer_attns = True,
        attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
        layer_cross_attns = True,
        use_linear_attn = False,
        use_linear_cross_attn = False,
        cond_on_text = True,
        max_text_len = 256,
        init_dim = None,
        init_conv_kernel_size = 7,
        resnet_groups = 8,
        init_cross_embed_kernel_sizes = (3, 7, 15),
        cross_embed_downsample = False,
        cross_embed_downsample_kernel_sizes = (2, 4),
        attn_pool_text = True,
        attn_pool_num_latents = 32,
        dropout = 0.,
        memory_efficient = False,
        init_conv_to_final_conv_residual = False,
        use_global_context_attn = True,
        scale_skip_connection = True,
        final_resnet_block = True,
        final_conv_kernel_size = 3,
        bilinear_upsample = False,                   # for debugging checkboard artifacts
        antialias_downsample = False,                # for debugging checkboard artifacts
        downsample_concat_hiddens_earlier = False    # for debugging artifacts in memory efficient unet (allows for one to concat the hiddens a bit earlier, right after the downsample)
    ):
        super().__init__()

        # guide researchers

        assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'

        if dim < 128:
            print('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')

        # save locals to take care of some hyperparameters for cascading DDPM

        self._locals = locals()
        self._locals.pop('self', None)
        self._locals.pop('__class__', None)

        # for eventual cascading diffusion

        self.lowres_cond = lowres_cond


        # determine dimensions

        self.channels = channels
        self.channels_out = default(channels_out, channels)

        init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
        init_dim = default(init_dim, dim)

        # optional image conditioning

        self.has_cond_image = cond_images_channels > 0
        self.cond_images_channels = cond_images_channels

        init_channels += cond_images_channels

        # initial convolution

        self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        # time conditioning

        cond_dim = default(cond_dim, dim)
        time_cond_dim = dim * 4 * (2 if lowres_cond else 1)

        # embedding time for discrete gaussian diffusion or log(snr) noise for continuous version

        self.learned_sinu_pos_emb = learned_sinu_pos_emb

        if learned_sinu_pos_emb:
            sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
            sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
        else:
            sinu_pos_emb = SinusoidalPosEmb(dim)
            sinu_pos_emb_input_dim = dim

        self.to_time_hiddens = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
            nn.SiLU()
        )

        self.to_time_cond = nn.Sequential(
            nn.Linear(time_cond_dim, time_cond_dim)
        )

        # project to time tokens as well as time hiddens

        self.to_time_tokens = nn.Sequential(
            nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
            Rearrange('b (r d) -> b r d', r = num_time_tokens)
        )

        # low res aug noise conditioning

        self.lowres_cond = lowres_cond

        if lowres_cond:
            self.to_lowres_time_hiddens = nn.Sequential(
                LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
                nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
                nn.SiLU()
            )

            self.to_lowres_time_cond = nn.Sequential(
                nn.Linear(time_cond_dim, time_cond_dim)
            )

            self.to_lowres_time_tokens = nn.Sequential(
                nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
                Rearrange('b (r d) -> b r d', r = num_time_tokens)
            )

        # normalizations

        self.norm_cond = nn.LayerNorm(cond_dim)

        # text encoding conditioning (optional)

        self.text_to_cond = None

        if cond_on_text:
            assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
            self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)

        # finer control over whether to condition on text encodings

        self.cond_on_text = cond_on_text

        # attention pooling

        self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents) if attn_pool_text else None

        # for classifier free guidance

        self.max_text_len = max_text_len

        self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
        self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))

        # for non-attention based text conditioning at all points in the network where time is also conditioned

        self.to_text_non_attn_cond = None

        if cond_on_text:
            self.to_text_non_attn_cond = nn.Sequential(
                nn.LayerNorm(cond_dim),
                nn.Linear(cond_dim, time_cond_dim),
                nn.SiLU(),
                nn.Linear(time_cond_dim, time_cond_dim)
            )

        # attention related params

        attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)

        num_layers = len(in_out)

        # resnet block klass

        num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
        resnet_groups = cast_tuple(resnet_groups, num_layers)

        layer_attns = cast_tuple(layer_attns, num_layers)
        layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)

        assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])

        # downsample klass

        downsample_klass = DownsampleWithBlur if antialias_downsample else Downsample

        if cross_embed_downsample:
            downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)

        # scale for resnet skip connections

        self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)

        # for debugging purposes

        self.downsample_concat_hiddens_earlier = downsample_concat_hiddens_earlier

        # layers

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_cross_attns]
        reversed_layer_params = list(map(reversed, layer_params))

        # downsampling layers

        for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_cross_attn) in enumerate(zip(in_out, *layer_params)):
            is_last = ind >= (num_resolutions - 1)

            layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
            layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None

            transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else nn.Identity)

            self.downs.append(nn.ModuleList([
                downsample_klass(dim_in) if memory_efficient else None,
                ResnetBlock(dim_in, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
                nn.ModuleList([ResnetBlock(dim_out, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
                transformer_block_klass(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, ff_mult = ff_mult),
                downsample_klass(dim_out) if not memory_efficient and not is_last else None,
            ]))

        # middle layers

        mid_dim = dims[-1]

        self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
        self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
        self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])

        # upsampling klass

        upsample_klass = BilinearUpsample if bilinear_upsample else Upsample

        # upsampling layers

        for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
            is_last = ind == (len(in_out) - 1)
            layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
            layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
            transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else nn.Identity)

            self.ups.append(nn.ModuleList([
                ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
                nn.ModuleList([ResnetBlock(dim_in, dim_in, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
                transformer_block_klass(dim = dim_in, heads = attn_heads, dim_head = attn_dim_head, ff_mult = ff_mult),
                upsample_klass(dim_in) if not is_last or memory_efficient else nn.Identity()
            ]))

        # whether to do a final residual from initial conv to the final resnet block out

        self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
        final_conv_dim = dim * (2 if init_conv_to_final_conv_residual else 1)

        # final optional resnet block and convolution out

        self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None

        final_conv_dim_in = dim if final_resnet_block else final_conv_dim
        self.final_conv = nn.Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)

    # if the current settings for the unet are not correct
    # for cascading DDPM, then reinit the unet with the right settings

In [None]:

    def cast_model_parameters(
        self,
        *,
        lowres_cond,
        text_embed_dim,
        channels,
        channels_out,
        cond_on_text,
        learned_sinu_pos_emb
    ):
        if lowres_cond == self.lowres_cond and \
            channels == self.channels and \
            cond_on_text == self.cond_on_text and \
            text_embed_dim == self._locals['text_embed_dim'] and \
            learned_sinu_pos_emb == self.learned_sinu_pos_emb and \
            channels_out == self.channels_out:
            return self

        updated_kwargs = dict(
            lowres_cond = lowres_cond,
            text_embed_dim = text_embed_dim,
            channels = channels,
            channels_out = channels_out,
            cond_on_text = cond_on_text,
            learned_sinu_pos_emb = learned_sinu_pos_emb
        )

        return self.__class__(**{**self._locals, **updated_kwargs})

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 1.,
        **kwargs
    ):
        logits = self.forward(*args, **kwargs)

        if cond_scale == 1:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

In [None]:
    def forward(
        self,
        x,
        time,
        *,
        lowres_cond_img = None,
        lowres_noise_times = None,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        cond_drop_prob = 0.
    ):
        batch_size, device = x.shape[0], x.device

        # add low resolution conditioning, if present

        assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
        assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'

        if exists(lowres_cond_img):
            x = torch.cat((x, lowres_cond_img), dim = 1)

        # condition on input image

        assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'

        if exists(cond_images):
            assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
            cond_images = resize_image_to(cond_images, x.shape[-1])
            x = torch.cat((cond_images, x), dim = 1)

        # initial convolution

        x = self.init_conv(x)

        # init conv residual

        if self.init_conv_to_final_conv_residual:
            init_conv_residual = x.clone()

        # time conditioning

        time_hiddens = self.to_time_hiddens(time)

        # derive time tokens

        time_tokens = self.to_time_tokens(time_hiddens)
        t = self.to_time_cond(time_hiddens)

        # add lowres time conditioning to time hiddens
        # and add lowres time tokens along sequence dimension for attention

        if self.lowres_cond:
            lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
            lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
            lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)

            t = t + lowres_t
            time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)

        # text conditioning

        text_tokens = None

        if exists(text_embeds) and self.cond_on_text:

            # conditional dropout

            text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)

            text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
            text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')

            # calculate text embeds

            text_tokens = self.text_to_cond(text_embeds)

            text_tokens = text_tokens[:, :self.max_text_len]

            text_tokens_len = text_tokens.shape[1]
            remainder = self.max_text_len - text_tokens_len

            if remainder > 0:
                text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))

            if exists(text_mask):
                if remainder > 0:
                    text_mask = F.pad(text_mask, (0, remainder), value = False)

                text_mask = rearrange(text_mask, 'b n -> b n 1')
                text_keep_mask_embed = text_mask & text_keep_mask_embed

            null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working

            text_tokens = torch.where(
                text_keep_mask_embed,
                text_tokens,
                null_text_embed
            )

            if exists(self.attn_pool):
                text_tokens = self.attn_pool(text_tokens)

            # extra non-attention conditioning by projecting and then summing text embeddings to time
            # termed as text hiddens

            mean_pooled_text_tokens = text_tokens.mean(dim = -2)

            text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)

            null_text_hidden = self.null_text_hidden.to(t.dtype)

            text_hiddens = torch.where(
                text_keep_mask_hidden,
                text_hiddens,
                null_text_hidden
            )

            t = t + text_hiddens

        # main conditioning tokens (c)

        c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)

        # normalize conditioning tokens

        c = self.norm_cond(c)

        # go through the layers of the unet, down and up

        hiddens = []

        for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs:
            if exists(pre_downsample):
                x = pre_downsample(x)

            x = init_block(x, t, c)

            if self.downsample_concat_hiddens_earlier:
                hiddens.append(x)

            for resnet_block in resnet_blocks:
                x = resnet_block(x, t)

            x = attn_block(x)

            if not self.downsample_concat_hiddens_earlier:
                hiddens.append(x)

            if exists(post_downsample):
                x = post_downsample(x)

        x = self.mid_block1(x, t, c)

        if exists(self.mid_attn):
            x = self.mid_attn(x)

        x = self.mid_block2(x, t, c)

        for init_block, resnet_blocks, attn_block, upsample in self.ups:

            skip_connection = hiddens.pop() * self.skip_connect_scale

            x = torch.cat((x, skip_connection), dim = 1)
            x = init_block(x, t, c)

            for resnet_block in resnet_blocks:
                x = resnet_block(x, t)

            x = attn_block(x)
            x = upsample(x)

        # final top-most residual if needed

        if self.init_conv_to_final_conv_residual:
            x = torch.cat((x, init_conv_residual), dim = 1)

        if exists(self.final_res_block):
            x = self.final_res_block(x, t)

        return self.final_conv(x)

# predefined unets, with configs lining up with hyperparameters in appendix of paper

In [37]:
BS = BetaSchedule(10)
BS.cosine()

tensor([0.0279, 0.0755, 0.1244, 0.1772, 0.2373, 0.3099, 0.4040, 0.5370, 0.7438,
        0.9990], dtype=torch.float64)

In [39]:
BS.linear()

tensor([0.0100, 0.2311, 0.4522, 0.6733, 0.8944, 1.1156, 1.3367, 1.5578, 1.7789,
        2.0000], dtype=torch.float64)

In [40]:
linear_beta_schedule(10)

tensor([0.0100, 0.2311, 0.4522, 0.6733, 0.8944, 1.1156, 1.3367, 1.5578, 1.7789,
        2.0000], dtype=torch.float64)

In [10]:
import torch
def cosine_beta_schedule(timesteps, s = 0.008, thres = 0.999):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    
    x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
    
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    
    return torch.clip(betas, 0, thres)

def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

In [14]:
torch.linspace(0, 10, 11, dtype = torch.float64)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
       dtype=torch.float64)

In [9]:
cosine_beta_schedule(100, s = 0.008, thres = 0.999)

tensor([6.3128e-04, 1.1169e-03, 1.6029e-03, 2.0894e-03, 2.5767e-03, 3.0650e-03,
        3.5546e-03, 4.0456e-03, 4.5385e-03, 5.0333e-03, 5.5304e-03, 6.0300e-03,
        6.5323e-03, 7.0378e-03, 7.5465e-03, 8.0589e-03, 8.5751e-03, 9.0955e-03,
        9.6205e-03, 1.0150e-02, 1.0685e-02, 1.1226e-02, 1.1772e-02, 1.2324e-02,
        1.2883e-02, 1.3449e-02, 1.4023e-02, 1.4604e-02, 1.5193e-02, 1.5791e-02,
        1.6399e-02, 1.7016e-02, 1.7643e-02, 1.8282e-02, 1.8931e-02, 1.9593e-02,
        2.0268e-02, 2.0956e-02, 2.1658e-02, 2.2376e-02, 2.3109e-02, 2.3859e-02,
        2.4627e-02, 2.5413e-02, 2.6219e-02, 2.7047e-02, 2.7897e-02, 2.8770e-02,
        2.9668e-02, 3.0593e-02, 3.1546e-02, 3.2530e-02, 3.3545e-02, 3.4594e-02,
        3.5680e-02, 3.6805e-02, 3.7971e-02, 3.9182e-02, 4.0441e-02, 4.1751e-02,
        4.3116e-02, 4.4541e-02, 4.6030e-02, 4.7588e-02, 4.9221e-02, 5.0936e-02,
        5.2740e-02, 5.4640e-02, 5.6646e-02, 5.8768e-02, 6.1018e-02, 6.3407e-02,
        6.5952e-02, 6.8669e-02, 7.1578e-

In [32]:
class BetaSchedule():
    
    def __init__(self, timestep):
        self.timestep = timestep
    
    def cosine(self, s = 0.008, threshold = 0.999): # https://openreview.net/forum?id=-NEXDKk8gZ
                
        x = torch.linspace(0, self.timestep, self.timestep+1, dtype=torch.float64)
        f = ((x / self.timestep) + s) / (1 + s)
        w = 0.5 * torch.pi * f
        
        a = torch.cos(w)**2
        a = a / a[0]
        
        b = 1 - (a[1:]/a[:-1])
        
        return torch.clip(b, 0, threshold)
        
    def linear(self):

        b_start = 0.1  / self.timestep
        b_end   = 20.0 / self.timestep
    
        return torch.linspace(b_start, b_end, self.timestep, dtype=torch.float64)

In [13]:
linear_beta_schedule(4)

tensor([0.0250, 1.6833, 3.3417, 5.0000], dtype=torch.float64)

In [45]:
GaussianDiffusion(noise_type='linear', timesteps=4)

tensor([0.0250, 1.6833, 3.3417, 5.0000], dtype=torch.float64)


GaussianDiffusion()

In [38]:
cosine_beta_schedule(10)

tensor([0.0279, 0.0755, 0.1244, 0.1772, 0.2373, 0.3099, 0.4040, 0.5370, 0.7438,
        0.9990], dtype=torch.float64)

In [46]:
GaussianDiffusion(timesteps=10)

tensor([0.0279, 0.0755, 0.1244, 0.1772, 0.2373, 0.3099, 0.4040, 0.5370, 0.7438,
        0.9990], dtype=torch.float64)


GaussianDiffusion()

In [72]:
batch_size = 10
(batch_size)

10

In [69]:
import torch.nn as nn
import torch.nn.functional as F

class GaussianDiffusion(nn.Module):
    
    def __init__(self, noise_type='cosine', timesteps=1000):
        
        super(GaussianDiffusion, self).__init__()
        
        self.timesteps = timesteps
        
        BS = BetaSchedule(timesteps)
        
        betas  = BS.cosine() if 'cosine'==noise_type else BS.linear()        
        alphas = 1 - betas
        
        alphas_cumprod      = torch.cumprod(alphas, axis = 0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1)
        
        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32), persistent = False)
        
    

In [None]:
class GaussianDiffusionContinuousTimes(nn.Module):
    def __init__(self, *, noise_schedule, timesteps):
        super().__init__()
        if noise_schedule == 'linear':
            self.log_snr = beta_linear_log_snr
        elif noise_schedule == "cosine":
            self.log_snr = alpha_cosine_log_snr
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        self.num_timesteps = timesteps

    def get_times(self, batch_size, noise_level, *, device):
        return torch.full((batch_size,), noise_level, device = device, dtype = torch.long)

    def sample_random_times(self, batch_size, max_thres = 0.999, *, device):
        return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres)

    def get_condition(self, times):
        return maybe(self.log_snr)(times)

    def get_sampling_timesteps(self, batch, *, device):
        times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
        times = repeat(times, 't -> b t', b = batch)
        times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
        times = times.unbind(dim = -1)
        return times

    def q_posterior(self, x_start, x_t, t, *, t_next = None):
        t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))

        """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
        log_snr = self.log_snr(t)
        log_snr_next = self.log_snr(t_next)
        log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

        # c - as defined near eq 33
        c = -expm1(log_snr - log_snr_next)
        posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

        # following (eq. 33)
        posterior_variance = (sigma_next ** 2) * c
        posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        log_snr = self.log_snr(t)
        log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
        alpha, sigma =  log_snr_to_alpha_sigma(log_snr_padded_dim)
        return alpha * x_start + sigma * noise, log_snr

    def predict_start_from_noise(self, x_t, t, noise):
        log_snr = self.log_snr(t)
        log_snr = right_pad_dims_to(x_t, log_snr)
        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)

In [70]:
GaussianDiffusion(timesteps=10)

10
10


GaussianDiffusion()

In [68]:
GaussianDiffusion(timesteps=10)

tensor([1.0000, 0.9721, 0.8987, 0.7869, 0.6475, 0.4938, 0.3408, 0.2031, 0.0940,
        0.0241], dtype=torch.float64)


GaussianDiffusion()

In [50]:
register_buffer = lambda name, val: register_buffer(name, val.to(torch.float32), persistent = False)

register_buffer('betas', torch.tensor(41))
register_buffer('alphas_cumprod', torch.tensor(51))

TypeError: <lambda>() got an unexpected keyword argument 'persistent'

In [None]:
class GaussianDiffusion(nn.Module):
    def __init__(self, *, noise_schedule, timesteps):
        super().__init__()

        if noise_schedule == "cosine":
            betas = cosine_beta_schedule(timesteps)
        elif noise_schedule == "linear":
            betas = linear_beta_schedule(timesteps)
        else:
            raise NotImplementedError()

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis = 0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)

        # register buffer helper function to cast double back to float

        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32), persistent = False)

        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others

        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)

        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)

        register_buffer('posterior_variance', posterior_variance)

        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain

        register_buffer('posterior_log_variance_clipped', log(posterior_variance, eps = 1e-20))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

    def get_times(self, batch_size, noise_level, *, device):
        return torch.full((batch_size,), int(self.num_timesteps * noise_level), device = device, dtype = torch.long)

    def sample_random_times(self, batch_size, *, device):
        return torch.randint(0, self.num_timesteps, (batch_size,), device = device, dtype = torch.long)

    def get_condition(self, times):
        return times

    def get_sampling_timesteps(self, batch, *, device):
        time_transitions = []

        for i in reversed(range(self.num_timesteps)):
            time_transitions.append((torch.full((batch,), i, device = device, dtype = torch.long), None))

        return time_transitions

    def q_posterior(self, x_start, x_t, t, **kwargs):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        noised = (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

        alphas_cumprod = extract(self.alphas_cumprod, t, t.shape)
        log_snr = -log(1. / alphas_cumprod - 1)

        return noised, log_snr

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )


In [3]:
def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def cast_tuple(val, length = None):
    if isinstance(val, list):
        val = tuple(val)

    output = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length):
        assert len(output) == length

    return output

In [4]:
cast_tuple(1000, 2)

(1000, 1000)

In [None]:
class GaussianDiffusionContinuousTimes(nn.Module):
    def __init__(self, *, noise_schedule, timesteps):
        super().__init__()
        if noise_schedule == 'linear':
            self.log_snr = beta_linear_log_snr
        elif noise_schedule == "cosine":
            self.log_snr = alpha_cosine_log_snr
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        self.num_timesteps = timesteps

    def get_times(self, batch_size, noise_level, *, device):
        return torch.full((batch_size,), noise_level, device = device, dtype = torch.long)

    def sample_random_times(self, batch_size, max_thres = 0.999, *, device):
        return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres)

    def get_condition(self, times):
        return maybe(self.log_snr)(times)

    def get_sampling_timesteps(self, batch, *, device):
        times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
        times = repeat(times, 't -> b t', b = batch)
        times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
        times = times.unbind(dim = -1)
        return times

    def q_posterior(self, x_start, x_t, t, *, t_next = None):
        t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))

        """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
        log_snr = self.log_snr(t)
        log_snr_next = self.log_snr(t_next)
        log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

        # c - as defined near eq 33
        c = -expm1(log_snr - log_snr_next)
        posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

        # following (eq. 33)
        posterior_variance = (sigma_next ** 2) * c
        posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        log_snr = self.log_snr(t)
        log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
        alpha, sigma =  log_snr_to_alpha_sigma(log_snr_padded_dim)
        return alpha * x_start + sigma * noise, log_snr

    def predict_start_from_noise(self, x_t, t, noise):
        log_snr = self.log_snr(t)
        log_snr = right_pad_dims_to(x_t, log_snr)
        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)

In [4]:
import torch
from torch.special import expm1
import math

In [46]:
class BetaSchedule():
    
    def __init__(self):
        pass
    
    def cosine(self, t, s: float = 0.008):
                
        f = (t + s) / (1 + s)
        w = f * math.pi * 0.5
        c = (torch.cos(w) ** -2) - 1

        return -torch.log(c.clamp(min = 1e-5))
        
    def linear(self, t):
        
        xi = 1e-4 + 10 * (t ** 2)
        y  = torch.exp(xi) - 1
    
        return -torch.log(y)

In [49]:
BS = BetaSchedule()
BS.linear(torch.tensor(0.1))

tensor(2.2511)

In [51]:
BS = BetaSchedule()
BS.cosine(torch.tensor(0.1))

tensor(3.5450)

In [6]:
@torch.jit.script
def beta_linear_log_snr(t):
    return -torch.log(expm1())

In [8]:
def log(t, eps: float = 1e-12):
    return torch.log(t.clamp(min = eps))

@torch.jit.script
def alpha_cosine_log_snr(t, s: float = 0.008):
    return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version

In [50]:
beta_linear_log_snr(torch.tensor(0.1))

tensor(2.2511)

In [52]:
alpha_cosine_log_snr(torch.tensor(0.1))

tensor(3.5450)

In [21]:
alpha_cosine_log_snr(torch.tensor(11))

tensor(-4.1539)

In [44]:
def teste():
    
    

teste(torch.tensor(11))

tensor(-4.1539)

In [None]:
def q_posterior(self, x_start, x_t, t, *, t_next = None):
        t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))

        """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
        log_snr = self.log_snr(t)
        log_snr_next = self.log_snr(t_next)
        log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

        # c - as defined near eq 33
        c = -expm1(log_snr - log_snr_next)
        posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

        # following (eq. 33)
        posterior_variance = (sigma_next ** 2) * c
        posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

In [25]:
from tqdm import tqdm
timesteps = [1000, 20, 30]

for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps)):
    print(times, times_next)
    break

sampling loop time step:   0%|          | 0/3 [00:00<?, ?it/s]


TypeError: cannot unpack non-iterable int object

In [65]:
torch.full((4,), 3)

tensor([3, 3, 3, 3])

In [79]:
times = torch.linspace(1., 0., 10 + 1)

for i in range(len(times)-1):
    
    j = torch.stack((torch.full((3,), times[i]),torch.full((3,), times[i+1])), dim = 0)
    # print(j)
    
j = tuple([torch.stack((torch.full((3,), times[i]),torch.full((3,), times[i+1])), dim = 0) for i in range(len(times)-1)])
j

(tensor([[1.0000, 1.0000, 1.0000],
         [0.9000, 0.9000, 0.9000]]),
 tensor([[0.9000, 0.9000, 0.9000],
         [0.8000, 0.8000, 0.8000]]),
 tensor([[0.8000, 0.8000, 0.8000],
         [0.7000, 0.7000, 0.7000]]),
 tensor([[0.7000, 0.7000, 0.7000],
         [0.6000, 0.6000, 0.6000]]),
 tensor([[0.6000, 0.6000, 0.6000],
         [0.5000, 0.5000, 0.5000]]),
 tensor([[0.5000, 0.5000, 0.5000],
         [0.4000, 0.4000, 0.4000]]),
 tensor([[0.4000, 0.4000, 0.4000],
         [0.3000, 0.3000, 0.3000]]),
 tensor([[0.3000, 0.3000, 0.3000],
         [0.2000, 0.2000, 0.2000]]),
 tensor([[0.2000, 0.2000, 0.2000],
         [0.1000, 0.1000, 0.1000]]),
 tensor([[0.1000, 0.1000, 0.1000],
         [0.0000, 0.0000, 0.0000]]))

In [82]:
for i in range(10):
    print(torch.eq(j[i], t[i]))

tensor([[True, True, True],
        [True, True, True]])
tensor([[True, True, True],
        [True, True, True]])
tensor([[True, True, True],
        [True, True, True]])
tensor([[True, True, True],
        [True, True, True]])
tensor([[True, True, True],
        [True, True, True]])
tensor([[True, True, True],
        [True, True, True]])
tensor([[True, True, True],
        [True, True, True]])
tensor([[True, True, True],
        [True, True, True]])
tensor([[True, True, True],
        [True, True, True]])
tensor([[True, True, True],
        [True, True, True]])


In [66]:
from einops import rearrange, repeat, reduce
def get_sampling_timesteps(batch):
    times = torch.linspace(1., 0., 10 + 1)
    print(times, '\n')
    times = repeat(times, 't -> b t', b = batch)
    print(times, '\n')
    times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
    print(times, '\n')
    times = times.unbind(dim = -1)
    return times

t = get_sampling_timesteps(3)
t

tensor([1.0000, 0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000, 0.2000,
        0.1000, 0.0000]) 

tensor([[1.0000, 0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000, 0.2000,
         0.1000, 0.0000],
        [1.0000, 0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000, 0.2000,
         0.1000, 0.0000],
        [1.0000, 0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000, 0.2000,
         0.1000, 0.0000]]) 

tensor([[[1.0000, 0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000,
          0.2000, 0.1000],
         [1.0000, 0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000,
          0.2000, 0.1000],
         [1.0000, 0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000,
          0.2000, 0.1000]],

        [[0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000, 0.2000,
          0.1000, 0.0000],
         [0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000, 0.2000,
          0.1000, 0.0000],
         [0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.30

(tensor([[1.0000, 1.0000, 1.0000],
         [0.9000, 0.9000, 0.9000]]),
 tensor([[0.9000, 0.9000, 0.9000],
         [0.8000, 0.8000, 0.8000]]),
 tensor([[0.8000, 0.8000, 0.8000],
         [0.7000, 0.7000, 0.7000]]),
 tensor([[0.7000, 0.7000, 0.7000],
         [0.6000, 0.6000, 0.6000]]),
 tensor([[0.6000, 0.6000, 0.6000],
         [0.5000, 0.5000, 0.5000]]),
 tensor([[0.5000, 0.5000, 0.5000],
         [0.4000, 0.4000, 0.4000]]),
 tensor([[0.4000, 0.4000, 0.4000],
         [0.3000, 0.3000, 0.3000]]),
 tensor([[0.3000, 0.3000, 0.3000],
         [0.2000, 0.2000, 0.2000]]),
 tensor([[0.2000, 0.2000, 0.2000],
         [0.1000, 0.1000, 0.1000]]),
 tensor([[0.1000, 0.1000, 0.1000],
         [0.0000, 0.0000, 0.0000]]))

In [42]:
t[0]

tensor([[1.0000, 1.0000, 1.0000],
        [0.9000, 0.9000, 0.9000]])

In [35]:
t[1]

tensor([[0.9990, 0.9990, 0.9990, 0.9990, 0.9990, 0.9990, 0.9990, 0.9990, 0.9990,
         0.9990],
        [0.9980, 0.9980, 0.9980, 0.9980, 0.9980, 0.9980, 0.9980, 0.9980, 0.9980,
         0.9980]])

In [84]:
import torch.nn as nn
import torch.nn.functional as F

def t_equal_x_dim(x, t):
    padding_dims = x.ndim - t.ndim
    
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

class GaussianDiffusion(nn.Module):
    
    def __init__(self, noise_type='cosine', timesteps=1000, device='cpu'):
        
        super(GaussianDiffusion, self).__init__()
        
        BS = BetaSchedule()   
        
        self.timesteps = timesteps      
        self.snr_func  = BS.cosine if 'cosine'==noise_type else BS.linear    
        self.device    = device
        
    def get_times(self, batch_size, noise_level):
        return torch.full((batch_size,), noise_level, device=self.device, dtype=torch.long)
    
    def sample_random_times(self, batch_size, max_threshold = 0.999):
        return torch.zeros((batch_size,), device=self.device).float().uniform_(0, max_threshold)
    
    def get_condition(self, times):
        return self.beta(times)
    
    def get_sampling_timesteps(self, batch):
        times = torch.linspace(1., 0., self.timesteps + 1, device = self.device)
        return tuple([torch.stack((torch.full((batch,), times[i]),torch.full((batch,), times[i+1])), dim = 0) for i in range(len(times)-1)])
    
    def snr_to_alpha_sigma(self, log_snr):
        return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))

    def q_posterior(self, x_start, x_t, t, t_next):

        snr      = self.snr_func(t)
        snr_next = self.snr_func(t_next)

        log_snr, log_snr_next = map(partial(t_equal_x_dim, x_t), (log_snr, log_snr_next))

        alpha, sigma           = log_snr_to_alpha_sigma(log_snr)
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

        c      = 1 - torch.exp(log_snr - log_snr_next)
        p_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

        p_variance     = (sigma_next ** 2) * c
        p_log_variance = log(p_variance, eps = 1e-20)
        p_log_variance = torch.log(p_variance.clamp(min = 1e-20))

        return p_mean, p_variance, p_log_variance
    
    def q_sample(self, x_start, t, noise):
        
        snr = self.snr_func(t)
        
        snr          = t_equal_x_dim(x_start, snr)
        alpha, sigma = log_snr_to_alpha_sigma(snr)
        
        return alpha * x_start + sigma * noise, snr
    
    def predict_start_from_noise(self, x_t, t, noise):
        
        snr = self.snr_func(t)
        
        snr          = t_equal_x_dim(x_t, snr)
        alpha, sigma = log_snr_to_alpha_sigma(snr)
        
        return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)

In [85]:
g = GaussianDiffusion(noise_type='cosine', timesteps=10, device='cpu')
g.get_sampling_timesteps(3)

(tensor([[1.0000, 1.0000, 1.0000],
         [0.9000, 0.9000, 0.9000]]),
 tensor([[0.9000, 0.9000, 0.9000],
         [0.8000, 0.8000, 0.8000]]),
 tensor([[0.8000, 0.8000, 0.8000],
         [0.7000, 0.7000, 0.7000]]),
 tensor([[0.7000, 0.7000, 0.7000],
         [0.6000, 0.6000, 0.6000]]),
 tensor([[0.6000, 0.6000, 0.6000],
         [0.5000, 0.5000, 0.5000]]),
 tensor([[0.5000, 0.5000, 0.5000],
         [0.4000, 0.4000, 0.4000]]),
 tensor([[0.4000, 0.4000, 0.4000],
         [0.3000, 0.3000, 0.3000]]),
 tensor([[0.3000, 0.3000, 0.3000],
         [0.2000, 0.2000, 0.2000]]),
 tensor([[0.2000, 0.2000, 0.2000],
         [0.1000, 0.1000, 0.1000]]),
 tensor([[0.1000, 0.1000, 0.1000],
         [0.0000, 0.0000, 0.0000]]))

In [55]:


sample_random_times(10)

tensor([0.0015, 0.7292, 0.6209, 0.3540, 0.8060, 0.5618, 0.0910, 0.8585, 0.0087,
        0.6487])

In [63]:
def exists(val):
    return val is not None

def maybe(fn):
    @wraps(fn)
    def inner(x):
        if not exists(x):
            return x
        return fn(x)
    return inner

In [57]:
import torch.nn as nn
class GaussianDiffusionContinuousTimes(nn.Module):
    def __init__(self, *, noise_schedule, timesteps):
        super().__init__()
        if noise_schedule == 'linear':
            self.log_snr = beta_linear_log_snr
        elif noise_schedule == "cosine":
            self.log_snr = alpha_cosine_log_snr
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        self.num_timesteps = timesteps

    def get_times(self, batch_size, noise_level, *, device):
        return torch.full((batch_size,), noise_level, device = device, dtype = torch.long)

    def sample_random_times(self, batch_size, max_thres = 0.999, *, device):
        return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres)

    def get_condition(self, times):
        return maybe(self.log_snr)(times)

    def get_sampling_timesteps(self, batch, *, device):
        times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
        times = repeat(times, 't -> b t', b = batch)
        times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
        times = times.unbind(dim = -1)
        return times

    def q_posterior(self, x_start, x_t, t, *, t_next = None):
        t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))

        """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
        log_snr = self.log_snr(t)
        log_snr_next = self.log_snr(t_next)
        log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

        # c - as defined near eq 33
        c = -expm1(log_snr - log_snr_next)
        posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

        # following (eq. 33)
        posterior_variance = (sigma_next ** 2) * c
        posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        log_snr = self.log_snr(t)
        log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
        alpha, sigma =  log_snr_to_alpha_sigma(log_snr_padded_dim)
        return alpha * x_start + sigma * noise, log_snr

    def predict_start_from_noise(self, x_t, t, noise):
        log_snr = self.log_snr(t)
        log_snr = right_pad_dims_to(x_t, log_snr)
        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)

In [86]:
from functools import partial, wraps
b = GaussianDiffusionContinuousTimes(noise_schedule = 'linear', timesteps = 10)
b.get_condition(torch.tensor(0.1))
b.get_sampling_timesteps(3, device='cpu')

(tensor([[1.0000, 1.0000, 1.0000],
         [0.9000, 0.9000, 0.9000]]),
 tensor([[0.9000, 0.9000, 0.9000],
         [0.8000, 0.8000, 0.8000]]),
 tensor([[0.8000, 0.8000, 0.8000],
         [0.7000, 0.7000, 0.7000]]),
 tensor([[0.7000, 0.7000, 0.7000],
         [0.6000, 0.6000, 0.6000]]),
 tensor([[0.6000, 0.6000, 0.6000],
         [0.5000, 0.5000, 0.5000]]),
 tensor([[0.5000, 0.5000, 0.5000],
         [0.4000, 0.4000, 0.4000]]),
 tensor([[0.4000, 0.4000, 0.4000],
         [0.3000, 0.3000, 0.3000]]),
 tensor([[0.3000, 0.3000, 0.3000],
         [0.2000, 0.2000, 0.2000]]),
 tensor([[0.2000, 0.2000, 0.2000],
         [0.1000, 0.1000, 0.1000]]),
 tensor([[0.1000, 0.1000, 0.1000],
         [0.0000, 0.0000, 0.0000]]))

In [102]:
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    print(x.ndim)
    print(padding_dims)
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

hj = right_pad_dims_to(torch.full((3,4,5), 6), torch.full((3,), 5))
hj.shape

3
2


torch.Size([3, 1, 1])

In [None]:
def t_equal_x_dim(x, t):
    padding_dims = x.ndim - t.ndim
    
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

In [None]:
def q_posterior(self, x_start, x_t, t, *, t_next = None):
        t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))

        """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
        log_snr = self.log_snr(t)
        log_snr_next = self.log_snr(t_next)
        log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

        # c - as defined near eq 33
        c = -expm1(log_snr - log_snr_next)
        posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

        # following (eq. 33)
        posterior_variance = (sigma_next ** 2) * c
        posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

In [None]:
def log_snr_to_alpha_sigma(self, log_snr):
    return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))

def q_posterior(x_start, x_t, t, t_next):
    
    snr      = self.snr_func(t)
    snr_next = self.snr_func(t_next)
    
    log_snr, log_snr_next = map(partial(t_equal_x_dim, x_t), (log_snr, log_snr_next))
    
    alpha, sigma           = log_snr_to_alpha_sigma(log_snr)
    alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
    
    c      = 1 - torch.exp(log_snr - log_snr_next)
    p_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)
    
    p_variance     = (sigma_next ** 2) * c
    p_log_variance = log(p_variance, eps = 1e-20)
    p_log_variance = torch.log(p_variance.clamp(min = 1e-20))
    
    return p_mean, p_variance, p_log_variance
    

In [105]:
c = 1 - torch.exp(torch.tensor(4))
c

tensor(-53.5981)

In [107]:
-expm1(torch.tensor(4))

tensor(-53.5981)

In [88]:
tt = b.get_sampling_timesteps(3, device='cpu')
tt

(tensor([[1.0000, 1.0000, 1.0000],
         [0.9000, 0.9000, 0.9000]]),
 tensor([[0.9000, 0.9000, 0.9000],
         [0.8000, 0.8000, 0.8000]]),
 tensor([[0.8000, 0.8000, 0.8000],
         [0.7000, 0.7000, 0.7000]]),
 tensor([[0.7000, 0.7000, 0.7000],
         [0.6000, 0.6000, 0.6000]]),
 tensor([[0.6000, 0.6000, 0.6000],
         [0.5000, 0.5000, 0.5000]]),
 tensor([[0.5000, 0.5000, 0.5000],
         [0.4000, 0.4000, 0.4000]]),
 tensor([[0.4000, 0.4000, 0.4000],
         [0.3000, 0.3000, 0.3000]]),
 tensor([[0.3000, 0.3000, 0.3000],
         [0.2000, 0.2000, 0.2000]]),
 tensor([[0.2000, 0.2000, 0.2000],
         [0.1000, 0.1000, 0.1000]]),
 tensor([[0.1000, 0.1000, 0.1000],
         [0.0000, 0.0000, 0.0000]]))

In [90]:
for times, times_next in tqdm(tt, desc = 'sampling loop time step', total = len(tt)):
    print(times, times_next)

sampling loop time step: 100%|██████████| 10/10 [00:00<00:00, 434.97it/s]

tensor([1., 1., 1.]) tensor([0.9000, 0.9000, 0.9000])
tensor([0.9000, 0.9000, 0.9000]) tensor([0.8000, 0.8000, 0.8000])
tensor([0.8000, 0.8000, 0.8000]) tensor([0.7000, 0.7000, 0.7000])
tensor([0.7000, 0.7000, 0.7000]) tensor([0.6000, 0.6000, 0.6000])
tensor([0.6000, 0.6000, 0.6000]) tensor([0.5000, 0.5000, 0.5000])
tensor([0.5000, 0.5000, 0.5000]) tensor([0.4000, 0.4000, 0.4000])
tensor([0.4000, 0.4000, 0.4000]) tensor([0.3000, 0.3000, 0.3000])
tensor([0.3000, 0.3000, 0.3000]) tensor([0.2000, 0.2000, 0.2000])
tensor([0.2000, 0.2000, 0.2000]) tensor([0.1000, 0.1000, 0.1000])
tensor([0.1000, 0.1000, 0.1000]) tensor([0., 0., 0.])



