In [1]:
print("Jai Shree Ganesha")

Jai Shree Ganesha


In [2]:
import torch 
device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


**IMPORTING LIBRARIES**

In [3]:
from abc import abstractmethod
import math
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import functools

**CHECKPOINT** 
`For less memory on the cost of computations`

In [None]:
"""
Various utilities for neural networks.
"""

import math

import torch as th
import torch.nn as nn


# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
    def forward(self, x):
        return x * th.sigmoid(x)


class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)


def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def scale_module(module, scale):
    """
    Scale the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().mul_(scale)
    return module


def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


def normalization(channels):
    """
    Make a standard normalization layer.

    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(32, channels)


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.

    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)


class CheckpointFunction(th.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])
        with th.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        with th.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = th.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads

**CONV ND**

In [5]:
import torch.nn as nn

def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.

    Args:
        dims (int): The number of dimensions of the convolution.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        nn.Module: The convolution module.

    Raises:
        ValueError: If `dims` is not one of 1, 2, or 3.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

**AVG_POOL_ND**

In [6]:
import torch.nn as nn

def avg_pool_nd(
    dims: int,
    *args,
    **kwargs,
) -> nn.Module:
    """
    Create a 1D, 2D, or 3D average pooling module.

    Args:
        dims: The number of dimensions of the average pooling module.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        nn.Module: The average pooling module.

    Raises:
        ValueError: If `dims` is not one of 1, 2, or 3.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError("unsupported dimensions: {}".format(dims))

def max_pool_nd(dims: int, *args, **kwargs) -> nn.Module:
    """
    Create a 1D, 2D, or 3D max pooling module.

    Args:
        dims: The number of dimensions of the max pooling module.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        nn.Module: The max pooling module.

    Raises:
        ValueError: If `dims` is not one of 1, 2, or 3.
    """
    # TODO: Allow specifying the `padding` argument for 3D max pooling
    if dims == 1:
        return nn.MaxPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.MaxPool2d(*args, **kwargs)
    elif dims == 3:
        if 'padding' in kwargs:
            padding = kwargs['padding']
            del kwargs['padding']
        else:
            padding = 0
        return nn.MaxPool3d(*args, padding=padding, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

**LINEAR**

In [7]:
import torch.nn as nn
def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)

**NORMALIZATION**

In [8]:
import torch.nn as nn

class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)
    
def normalization(channels):
    """
    Make a standard normalization layer.

    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(32, channels)

**TIMESTEP EMBEDDING**

In [None]:
import torch.nn as nn
import torch as th
import math

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

**ZERO MODULE**

In [9]:
def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

**CHANGE_I/P---O/P**

In [10]:
import torch.nn as nn

def change_input_output_unet(model, in_channels=4, out_channels=8):
    """

    :param model: unet model from guided diffusion code, for 256x256 image input
    :param in_channels:
    :param out_channels:
    :return: the model with the change
    """

    # change the input
    kernel_size = model.input_blocks[0][0].kernel_size
    stride = model.input_blocks[0][0].stride
    padding = model.input_blocks[0][0].padding
    out_channels_in = model.input_blocks[0][0].out_channels
    model.input_blocks[0][0] = nn.Conv2d(in_channels, out_channels_in, kernel_size, stride, padding)

    # change the input
    kernel_size = model.out[-1].kernel_size
    stride = model.out[-1].stride
    padding = model.out[-1].padding
    in_channels_out = model.out[-1].in_channels
    model.out[-1] = nn.Conv2d(in_channels_out, out_channels, kernel_size, stride, padding)

    return model

**CONVERSION_MODULE**

In [13]:
import numpy as np
import torch as th
import torch.nn as nn

def convert_module_to_f16(l):
    """
    Convert primitive modules to float16.
    """
    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        l.weight.data = l.weight.data.half()
        if l.bias is not None:
            l.bias.data = l.bias.data.half()


def convert_module_to_f32(l):
    """
    Convert primitive modules to float32, undoing convert_module_to_f16().
    """
    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        l.weight.data = l.weight.data.float()
        if l.bias is not None:
            l.bias.data = l.bias.data.float()

**UNET**

In [14]:
NUM_CLASSES = 1000
def create_model(
        image_size,
        num_channels,
        num_res_blocks,
        channel_mult="",
        learn_sigma=False,
        class_cond=False,
        use_checkpoint=False,
        attention_resolutions="16",
        num_heads=1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        dropout=0,
        resblock_updown=False,
        use_fp16=False,
        use_new_attention_order=False,
        model_path='',
        pretrain_model='',
):
    if channel_mult == "":
        if image_size == 512:
            channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
        elif image_size == 256:
            channel_mult = (1, 1, 2, 2, 4, 4)
        elif image_size == 128:
            channel_mult = (1, 1, 2, 3, 4)
        elif image_size == 64:
            channel_mult = (1, 2, 3, 4)
        else:
            raise ValueError(f"unsupported image size: {image_size}")
    else:
        channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))

    attention_ds = []
    if isinstance(attention_resolutions, int):
        attention_ds.append(image_size // attention_resolutions)
    elif isinstance(attention_resolutions, str):
        for res in attention_resolutions.split(","):
            attention_ds.append(image_size // int(res))
    else:
        raise NotImplementedError

    model = UNetModel(
        image_size=image_size,
        in_channels=3,
        model_channels=num_channels,
        out_channels=(3 if not learn_sigma else 6),
        num_res_blocks=num_res_blocks,
        attention_resolutions=tuple(attention_ds),
        dropout=dropout,
        channel_mult=channel_mult,
        num_classes=(NUM_CLASSES if class_cond else None),
        use_checkpoint=use_checkpoint,
        use_fp16=use_fp16,
        num_heads=num_heads,
        num_head_channels=num_head_channels,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        resblock_updown=resblock_updown,
        use_new_attention_order=use_new_attention_order,
    )

    # update number of channels according the pretrained model
    if pretrain_model == "osmosis":
        model = change_input_output_unet(model, in_channels=4, out_channels=8)

    try:
        model.load_state_dict(th.load(model_path, map_location='cpu'))
    except Exception as e:
        print(f"Got exception: {e} / Randomly initialize")
    return model


class AttentionPool2d(nn.Module):
    """
    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
    """

    def __init__(
            self,
            spacial_dim: int,
            embed_dim: int,
            num_heads_channels: int,
            output_dim: int = None,
    ):
        super().__init__()
        self.positional_embedding = nn.Parameter(
            th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
        )
        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
        self.num_heads = embed_dim // num_heads_channels
        self.attention = QKVAttention(self.num_heads)

    def forward(self, x):
        b, c, *_spatial = x.shape
        x = x.reshape(b, c, -1)  # NC(HW)
        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
        x = self.qkv_proj(x)
        x = self.attention(x)
        x = self.c_proj(x)
        return x[:, :, 0]


class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            else:
                x = layer(x)
        return x


class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 upsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2, out_channels=None):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        if use_conv:
            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.dims == 3:
            x = F.interpolate(
                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
            )
        else:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x


class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2, out_channels=None):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        stride = 2 if dims != 3 else (1, 2, 2)
        if use_conv:
            self.op = conv_nd(
                dims, self.channels, self.out_channels, 3, stride=stride, padding=1
            )
        else:
            assert self.channels == self.out_channels
            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)


class ResBlock(TimestepBlock):
    """
    A residual block that can optionally change the number of channels.

    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    :param up: if True, use this block for upsampling.
    :param down: if True, use this block for downsampling.
    """

    def __init__(
            self,
            channels,
            emb_channels,
            dropout,
            out_channels=None,
            use_conv=False,
            use_scale_shift_norm=False,
            dims=2,
            use_checkpoint=False,
            up=False,
            down=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            nn.SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )

        self.updown = up or down

        if up:
            self.h_upd = Upsample(channels, False, dims)
            self.x_upd = Upsample(channels, False, dims)
        elif down:
            self.h_upd = Downsample(channels, False, dims)
            self.x_upd = Downsample(channels, False, dims)
        else:
            self.h_upd = self.x_upd = nn.Identity()

        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = conv_nd(
                dims, channels, self.out_channels, 3, padding=1
            )
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.

        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        return checkpoint(
            self._forward, (x, emb), self.parameters(), self.use_checkpoint
        )

    def _forward(self, x, emb):
        if self.updown:
            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
            h = in_rest(x)
            h = self.h_upd(h)
            x = self.x_upd(x)
            h = in_conv(h)
        else:
            h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h


class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.

    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    def __init__(
            self,
            channels,
            num_heads=1,
            num_head_channels=-1,
            use_checkpoint=False,
            use_new_attention_order=False,
    ):
        super().__init__()
        self.channels = channels
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert (
                    channels % num_head_channels == 0
            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = channels // num_head_channels
        self.use_checkpoint = use_checkpoint
        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        if use_new_attention_order:
            # split qkv before split heads
            self.attention = QKVAttention(self.num_heads)
        else:
            # split heads before split qkv
            self.attention = QKVAttentionLegacy(self.num_heads)

        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    def forward(self, x):
        return checkpoint(self._forward, (x,), self.parameters(), True)

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        h = self.attention(qkv)
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)


def count_flops_attn(model, _x, y):
    """
    A counter for the `thop` package to count the operations in an
    attention operation.
    Meant to be used like:
        macs, params = thop.profile(
            model,
            inputs=(inputs, timestamps),
            custom_ops={QKVAttention: QKVAttention.count_flops},
        )
    """
    b, c, *spatial = y[0].shape
    num_spatial = int(np.prod(spatial))
    # We perform two matmuls with the same number of ops.
    # The first computes the weight matrix, the second computes
    # the combination of the value vectors.
    matmul_ops = 2 * b * (num_spatial ** 2) * c
    model.total_ops += th.DoubleTensor([matmul_ops])


class QKVAttentionLegacy(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.

        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)


class QKVAttention(nn.Module):
    """
    A module which performs QKV attention and splits in a different order.
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.

        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.chunk(3, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts",
            (q * scale).view(bs * self.n_heads, ch, length),
            (k * scale).view(bs * self.n_heads, ch, length),
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)


class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.

    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
     a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def __init__(
            self,
            image_size,
            in_channels,
            model_channels,
            out_channels,
            num_res_blocks,
            attention_resolutions,
            dropout=0,
            channel_mult=(1, 2, 4, 8),
            conv_resample=True,
            dims=2,
            num_classes=None,
            use_checkpoint=False,
            use_fp16=False,
            num_heads=1,
            num_head_channels=-1,
            num_heads_upsample=-1,
            use_scale_shift_norm=False,
            resblock_updown=False,
            use_new_attention_order=False,
    ):
        super().__init__()

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)

        ch = input_ch = int(channel_mult[0] * model_channels)
        self.input_blocks = nn.ModuleList(
            [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
        )
        self._feature_size = ch
        input_block_chans = [ch]
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=int(mult * model_channels),
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = int(mult * model_channels)
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=num_head_channels,
                            use_new_attention_order=use_new_attention_order,
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=num_head_channels,
                use_new_attention_order=use_new_attention_order,
            ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                layers = [
                    ResBlock(
                        ch + ich,
                        time_embed_dim,
                        dropout,
                        out_channels=int(model_channels * mult),
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = int(model_channels * mult)
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                            num_head_channels=num_head_channels,
                            use_new_attention_order=use_new_attention_order,
                        )
                    )
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            up=True,
                        )
                        if resblock_updown
                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    def forward(self, x, timesteps, y=None):
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
                self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"

        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
        h = self.middle_block(h, emb)
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb)
        h = h.type(x.dtype)
        return self.out(h)


class SuperResModel(UNetModel):
    """
    A UNetModel that performs super-resolution.

    Expects an extra kwarg `low_res` to condition on a low-resolution image.
    """

    def __init__(self, image_size, in_channels, *args, **kwargs):
        super().__init__(image_size, in_channels * 2, *args, **kwargs)

    def forward(self, x, timesteps, low_res=None, **kwargs):
        _, _, new_height, new_width = x.shape
        upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
        x = th.cat([x, upsampled], dim=1)
        return super().forward(x, timesteps, **kwargs)


class EncoderUNetModel(nn.Module):
    """
    The half UNet model with attention and timestep embedding.

    For usage, see UNet.
    """

    def __init__(
            self,
            image_size,
            in_channels,
            model_channels,
            out_channels,
            num_res_blocks,
            attention_resolutions,
            dropout=0,
            channel_mult=(1, 2, 4, 8),
            conv_resample=True,
            dims=2,
            use_checkpoint=False,
            use_fp16=False,
            num_heads=1,
            num_head_channels=-1,
            num_heads_upsample=-1,
            use_scale_shift_norm=False,
            resblock_updown=False,
            use_new_attention_order=False,
            pool="adaptive",
    ):
        super().__init__()

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        ch = int(channel_mult[0] * model_channels)
        self.input_blocks = nn.ModuleList(
            [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
        )
        self._feature_size = ch
        input_block_chans = [ch]
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=int(mult * model_channels),
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = int(mult * model_channels)
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=num_head_channels,
                            use_new_attention_order=use_new_attention_order,
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=num_head_channels,
                use_new_attention_order=use_new_attention_order,
            ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch
        self.pool = pool
        if pool == "adaptive":
            self.out = nn.Sequential(
                normalization(ch),
                nn.SiLU(),
                nn.AdaptiveAvgPool2d((1, 1)),
                zero_module(conv_nd(dims, ch, out_channels, 1)),
                nn.Flatten(),
            )
        elif pool == "attention":
            assert num_head_channels != -1
            self.out = nn.Sequential(
                normalization(ch),
                nn.SiLU(),
                AttentionPool2d(
                    (image_size // ds), ch, num_head_channels, out_channels
                ),
            )
        elif pool == "spatial":
            self.out = nn.Sequential(
                nn.Linear(self._feature_size, 2048),
                nn.ReLU(),
                nn.Linear(2048, self.out_channels),
            )
        elif pool == "spatial_v2":
            self.out = nn.Sequential(
                nn.Linear(self._feature_size, 2048),
                normalization(2048),
                nn.SiLU(),
                nn.Linear(2048, self.out_channels),
            )
        else:
            raise NotImplementedError(f"Unexpected {pool} pooling")

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)

    def forward(self, x, timesteps):
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :return: an [N x K] Tensor of outputs.
        """
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        results = []
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            if self.pool.startswith("spatial"):
                results.append(h.type(x.dtype).mean(dim=(2, 3)))
        h = self.middle_block(h, emb)
        if self.pool.startswith("spatial"):
            results.append(h.type(x.dtype).mean(dim=(2, 3)))
            h = th.cat(results, axis=-1)
            return self.out(h)
        else:
            h = h.type(x.dtype)
            return self.out(h)


class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=2, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=2, padding=padw)] + [nn.Dropout(0.5)]
        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)


class GANLoss(nn.Module):
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', th.tensor(target_real_label))
        self.register_buffer('fake_label', th.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """

        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        return loss


def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
    """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028

    Arguments:
        netD (network)              -- discriminator network
        real_data (tensor array)    -- real images
        fake_data (tensor array)    -- generated images from the generator
        device (str)                -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        type (str)                  -- if we mix real and fake data or not [real | fake | mixed].
        constant (float)            -- the constant used in formula ( | |gradient||_2 - constant)^2
        lambda_gp (float)           -- weight for this loss

    Returns the gradient penalty loss
    """
    if lambda_gp > 0.0:
        if type == 'real':  # either use real images, fake images, or a linear interpolation of two.
            interpolatesv = real_data
        elif type == 'fake':
            interpolatesv = fake_data
        elif type == 'mixed':
            alpha = th.rand(real_data.shape[0], 1, device=device)
            alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(
                *real_data.shape)
            interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
        else:
            raise NotImplementedError('{} not implemented'.format(type))
        interpolatesv.requires_grad_(True)
        disc_interpolates = netD(interpolatesv)
        gradients = th.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
                                     grad_outputs=th.ones(disc_interpolates.size()).to(device),
                                     create_graph=True, retain_graph=True, only_inputs=True)
        gradients = gradients[0].view(real_data.size(0), -1)  # flat the data
        gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp  # added eps
        return gradient_penalty, gradients
    else:
        return 0.0, None

**DYNAMIC THRESHOLDING**

In [15]:
import torch 
"""
Helper functions for new types of inverse problems
"""
def normalize(img, s=0.95):
    scaling = torch.quantile(img.abs(), s)
    return img * scaling

def dynamic_thresholding(img, s=0.95):
    img = normalize(img, s=s)
    return torch.clip(img, -1., 1.)

**POSTERIOR MEAN VARIANCE**

In [16]:
from abc import ABC, abstractmethod

import numpy as np
import torch


# ====================
# Model Mean Processor
# ====================

__MODEL_MEAN_PROCESSOR__ = {}


def register_mean_processor(name: str):
    def wrapper(cls):
        if __MODEL_MEAN_PROCESSOR__.get(name, None):
            raise NameError(f"Name {name} is already registerd.")
        __MODEL_MEAN_PROCESSOR__[name] = cls
        return cls

    return wrapper


def get_mean_processor(name: str, **kwargs):
    if __MODEL_MEAN_PROCESSOR__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined.")
    return __MODEL_MEAN_PROCESSOR__[name](**kwargs)


class MeanProcessor(ABC):
    """Predict x_start and calculate mean value"""

    @abstractmethod
    def __init__(self, betas, dynamic_threshold, clip_denoised):
        self.dynamic_threshold = dynamic_threshold
        self.clip_denoised = clip_denoised

    @abstractmethod
    def get_mean_and_xstart(self, x, t, model_output):
        pass

    def process_xstart(self, x):
        if self.dynamic_threshold:
            x = dynamic_thresholding(x, s=0.98)

        if self.clip_denoised:
            x = x.clamp(-1, 1)

        return x


@register_mean_processor(name='previous_x')
class PreviousXMeanProcessor(MeanProcessor):
    def __init__(self, betas, dynamic_threshold, clip_denoised):
        super().__init__(betas, dynamic_threshold, clip_denoised)
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

        self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)

    def predict_xstart(self, x_t, t, x_prev):
        coef1 = extract_and_expand(1.0 / self.posterior_mean_coef1, t, x_t)
        coef2 = extract_and_expand(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t)
        return coef1 * x_prev - coef2 * x_t

    def get_mean_and_xstart(self, x, t, model_output):
        mean = model_output
        pred_xstart = self.process_xstart(self.predict_xstart(x, t, model_output))
        return mean, pred_xstart


@register_mean_processor(name='start_x')
class StartXMeanProcessor(MeanProcessor):
    def __init__(self, betas, dynamic_threshold, clip_denoised):
        super().__init__(betas, dynamic_threshold, clip_denoised)
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

        self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)

    def q_posterior_mean(self, x_start, x_t, t):
        """
        Compute the mean of the diffusion posteriro:
            q(x_{t-1} | x_t, x_0)
        """
        assert x_start.shape == x_t.shape
        coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
        coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)

        return coef1 * x_start + coef2 * x_t

    def get_mean_and_xstart(self, x, t, model_output):
        pred_xstart = self.process_xstart(model_output)
        mean = self.q_posterior_mean(x_start=pred_xstart, x_t=x, t=t)

        return mean, pred_xstart


@register_mean_processor(name='epsilon')
class EpsilonXMeanProcessor(MeanProcessor):
    def __init__(self, betas, dynamic_threshold, clip_denoised):
        super().__init__(betas, dynamic_threshold, clip_denoised)
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
        self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)

    def q_posterior_mean(self, x_start, x_t, t):
        """
        Compute the mean of the diffusion posteriro:
            q(x_{t-1} | x_t, x_0)
        """
        assert x_start.shape == x_t.shape
        coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
        coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
        return coef1 * x_start + coef2 * x_t

    def predict_xstart(self, x_t, t, eps):
        coef1 = extract_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t)
        coef2 = extract_and_expand(self.sqrt_recipm1_alphas_cumprod, t, eps)
        return coef1 * x_t - coef2 * eps

    def get_mean_and_xstart(self, x, t, model_output):
        pred_xstart = self.process_xstart(self.predict_xstart(x, t, model_output))
        mean = self.q_posterior_mean(pred_xstart, x, t)

        return mean, pred_xstart


# =========================
# Model Variance Processor
# =========================

__MODEL_VAR_PROCESSOR__ = {}


def register_var_processor(name: str):
    def wrapper(cls):
        if __MODEL_VAR_PROCESSOR__.get(name, None):
            raise NameError(f"Name {name} is already registerd.")
        __MODEL_VAR_PROCESSOR__[name] = cls
        return cls

    return wrapper


def get_var_processor(name: str, **kwargs):
    if __MODEL_VAR_PROCESSOR__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined.")
    return __MODEL_VAR_PROCESSOR__[name](**kwargs)


class VarianceProcessor(ABC):
    @abstractmethod
    def __init__(self, betas):
        pass

    @abstractmethod
    def get_variance(self, x, t):
        pass


@register_var_processor(name='fixed_small')
class FixedSmallVarianceProcessor(VarianceProcessor):
    def __init__(self, betas):
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
                betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )

    def get_variance(self, x, t):
        model_variance = self.posterior_variance
        model_log_variance = np.log(model_variance)

        model_variance = extract_and_expand(model_variance, t, x)
        model_log_variance = extract_and_expand(model_log_variance, t, x)

        return model_variance, model_log_variance


@register_var_processor(name='fixed_large')
class FixedLargeVarianceProcessor(VarianceProcessor):
    def __init__(self, betas):
        self.betas = betas

        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
                betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )

    def get_variance(self, x, t):
        model_variance = np.append(self.posterior_variance[1], self.betas[1:])
        model_log_variance = np.log(model_variance)

        model_variance = extract_and_expand(model_variance, t, x)
        model_log_variance = extract_and_expand(model_log_variance, t, x)

        return model_variance, model_log_variance


@register_var_processor(name='learned')
class LearnedVarianceProcessor(VarianceProcessor):
    def __init__(self, betas):
        pass

    def get_variance(self, x, t):
        model_log_variance = x
        model_variance = torch.exp(model_log_variance)
        return model_variance, model_log_variance


@register_var_processor(name='learned_range')
class LearnedRangeVarianceProcessor(VarianceProcessor):
    def __init__(self, betas):
        self.betas = betas

        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (
                betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        # log calculation clipped because the posterior variance is 0 at the
        # beginning of the diffusion chain.
        self.posterior_log_variance_clipped = np.log(
            np.append(posterior_variance[1], posterior_variance[1:])
        )

    def get_variance(self, x, t):
        model_var_values = x
        min_log = self.posterior_log_variance_clipped
        max_log = np.log(self.betas)

        min_log = extract_and_expand(min_log, t, x)
        max_log = extract_and_expand(max_log, t, x)

        # The model_var_values is [-1, 1] for [min_var, max_var]
        frac = (model_var_values + 1.0) / 2.0
        model_log_variance = frac * max_log + (1 - frac) * min_log
        model_variance = torch.exp(model_log_variance)
        return model_variance, model_log_variance


# ================
# Helper function
# ================

def extract_and_expand(array, time, target):
    array = torch.from_numpy(array).to(target.device)[time].float()
    while array.ndim < target.ndim:
        array = array.unsqueeze(-1)
    return array.expand_as(target)


def expand_as(array, target):
    if isinstance(array, np.ndarray):
        array = torch.from_numpy(array)
    elif isinstance(array, np.float):
        array = torch.tensor([array])

    while array.ndim < target.ndim:
        array = array.unsqueeze(-1)

    return array.expand_as(target).to(target.device)

**utils**

In [18]:

import sys
import os
from os.path import join as pjoin
import numpy as np
import yaml
import argparse
import datetime
import re
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
import torch.optim as optim
import torchvision.transforms.functional as tvtf


# %% image functions

def min_max_norm(img, global_norm=True, is_uint8=True):
    """
    assume input is a torch tensor [3,h,w]
    """
    if global_norm:
        img_norm = img - img.min()
        img_norm /= img_norm.max()
    else:
        img_norm = torch.zeros_like(img)

        img_norm[0, :, :] = img[0, :, :] - img[0, :, :].min()
        img_norm[1, :, :] = img[1, :, :] - img[1, :, :].min()
        img_norm[2, :, :] = img[2, :, :] - img[2, :, :].min()

        img_norm[0, :, :] = img_norm[0, :, :] / img_norm[0, :, :].max()
        img_norm[1, :, :] = img_norm[1, :, :] / img_norm[1, :, :].max()
        img_norm[2, :, :] = img_norm[2, :, :] / img_norm[2, :, :].max()

    if is_uint8:
        img_norm *= 255
        img_norm = img_norm.to(torch.uint8)

    return img_norm


def min_max_norm_range(img, vmin=0, vmax=1, is_uint8=False):
    """
    assume input is a torch tensor [3/1,h,w] or [Batch,3/1,h,w]
    """

    vmin = float(vmin)
    vmax = float(vmax)

    # Compute the minimum and maximum values for each image in the batch separately
    if len(img.shape) == 4:
        # support a batch
        img_min = img.view(img.size(0), -1).min(dim=1, keepdim=True)[0].view(-1, 1, 1, 1)
        img_max = img.view(img.size(0), -1).max(dim=1, keepdim=True)[0].view(-1, 1, 1, 1)

    elif len(img.shape) == 3:
        img_max = img.max()
        img_min = img.min()

    else:
        raise NotImplementedError

    if img_min == img_max:
        img_norm = torch.zeros_like(img)
    else:
        scale = (vmax - vmin) / (img_max - img_min)
        img_norm = (img - img_min) * scale + vmin

    if is_uint8:
        img_norm = (255 * img_norm).to(torch.uint8)

    return img_norm


def min_max_norm_range_percentile(img, vmin=0, vmax=1, percent_low=0., percent_high=1., is_uint8=False):
    """
    assume input is a torch tensor [3/1,h,w]
    """

    # first clip into percentile values
    img_min = torch.quantile(img, q=percent_low)
    img_max = torch.quantile(img, q=percent_high)
    img_clip = torch.clamp(img, img_min, img_max)

    vmin = float(vmin)
    vmax = float(vmax)

    # Compute the minimum and maximum values for each image in the batch separately
    if len(img_clip.shape) == 4:
        # support a batch
        img_min = img_clip.view(img.size(0), -1).min(dim=1, keepdim=True)[0].view(-1, 1, 1, 1)
        img_max = img_clip.view(img.size(0), -1).max(dim=1, keepdim=True)[0].view(-1, 1, 1, 1)

    elif len(img.shape) == 3:
        img_max = img_clip.max()
        img_min = img_clip.min()

    else:
        raise NotImplementedError

    if img_min == img_max:
        img_norm = torch.zeros_like(img_clip)
    else:
        scale = (vmax - vmin) / (img_max - img_min)
        img_norm = (img_clip - img_min) * scale + vmin

    if is_uint8:
        img_norm = (255 * img_norm).to(torch.uint8)

    return img_norm


def max_norm(img, global_norm=True, is_uint8=True):
    """
    assume input is a torch tensor [3,h,w]
    """

    if global_norm:
        img_norm = img / img.max()

    else:
        img_norm = torch.zeros_like(img)
        img_norm[0, :, :] = img[0, :, :] / img[0, :, :].max()
        img_norm[1, :, :] = img[1, :, :] / img[1, :, :].max()
        img_norm[2, :, :] = img[2, :, :] / img[2, :, :].max()

    if is_uint8:
        img_norm *= 255
        img_norm = img_norm.to(torch.uint8)

    return img_norm


def clip_image(img, scale=True, move=True, is_uint8=True):
    """
    assume input is a torch tensor [ch,h,w]
    ch can be 3/1
    """

    # fix in case the image is only [imagesize, imagesize]
    if len(img.shape) == 2:
        img = img.unsqueeze(0)

    if move:
        img = img + 1
    if scale:
        img = 0.5 * img

    if is_uint8:
        img *= 255
        img = img.clamp(0, 255).to(torch.uint8)
    else:
        img = img.clamp(0, 1)

    return img


def gaussian_kernel(kernel_size, sigma=1., muu=0.):
    # Initializing value of x,y as grid of kernel size
    # in the range of kernel size

    x, y = np.meshgrid(np.linspace(0, kernel_size, kernel_size),
                       np.linspace(0, kernel_size, kernel_size))

    x -= kernel_size // 2
    y -= kernel_size // 2

    dst = np.sqrt(x ** 2 + y ** 2)

    # lower normal part of gaussian
    # normal = 1 / (2 * np.pi * sigma ** 2)
    normal = 1

    # Calculating Gaussian filter
    # gauss = normal * np.exp(-((dst - muu) ** 2 / (2.0 * sigma ** 2)))
    gauss = normal * np.exp(-((dst - muu) ** 2 / (sigma ** 2)))

    return gauss


def create_image_text_to_grid(image, image_size=[256, 256], info_str="light factor", norm=True):
    """
    input is an image (1 or 3 channels) or a scalar (1 or 3 channels) as pytorch tensor
    outputs are:
    1. a tensor image [3, image_size, image_size]
    2. a text for logger
    """
    shape = list(image.detach().cpu().shape)
    image = image.detach().cpu()

    # 1 channel scalar
    if len(image.detach().cpu().shape) == 1:
        text = f"{info_str} = {image.item():.3f}"
        out_image = image * torch.ones(size=[3] + image_size)
        # cast to uint8
        out_image = (255 * out_image).to(torch.uint8)

    # 3 channels scalar
    elif len(shape) == 3 and shape[-3] == 3 and (not shape[-2::] == image_size):
        text = f"{info_str} = " \
               f"[{image.squeeze().numpy()[0]:.3f}, " \
               f"{image.squeeze().numpy()[1]:.3f}, " \
               f"{image.squeeze().numpy()[2]:.3f}]"
        out_image = torch.zeros(size=[3, image_size[0], image_size[1]], dtype=torch.float32)
        out_image[0] = image[0] * torch.ones(size=image_size)
        out_image[1] = image[1] * torch.ones(size=image_size)
        out_image[2] = image[2] * torch.ones(size=image_size)
        # cast to uint8
        out_image = (255 * out_image).to(torch.uint8)

    # 1 channel image
    elif len(shape) == 2 and shape[-2::] == image_size:
        text = f"{info_str} mean = {image.mean():.3f}\n" \
               f"{info_str} std = {image.std():.3f}\n" \
               f"{info_str} min = {image.min():.3f}\n" \
               f"{info_str} max = {image.max():.3f}"
        out_image = image.unsqueeze(0).repeat(3, 1, 1)
        out_image = min_max_norm(out_image, is_uint8=True) if norm else (255 * out_image).to(torch.uint8)

    # 3 channels image

    elif len(shape) == 3 and shape[-3] == 3 and shape[-2::] == image_size:
        text = f"Red mean={image[0].mean():.3f}, std={image[0].std():.3f}, min={image[0].min():.3f}, max={image[0].max():.3f}\n" \
               f"Green mean={image[1].mean():.3f}, std={image[1].std():.3f}, min={image[1].min():.3f}, max={image[1].max():.3f}\n" \
               f"Blue mean={image[2].mean():.3f}, std={image[2].std():.3f}, min={image[2].min():.3f}, max={image[2].max():.3f}\n"
        out_image = min_max_norm(image, is_uint8=True) if norm else (255 * image).to(torch.uint8)

    else:
        ValueError(f"Image dimensions are not recognized - shape={shape}")

    return out_image, text


def add_text_torch_img(img, text, font_size=15):
    """

    :param img: torch image shape [3,h,w]
    :param text: text to insert
    :return: torch image shape [3,h,w]
    """

    # print betas and b_inf on b_inf image
    img_pil = tvtf.to_pil_image(img)
    I_text = ImageDraw.Draw(img_pil)

    if sys.platform.startswith("win"):
        I_text.font = ImageFont.truetype("arial.ttf", font_size)
    elif sys.platform.startswith("linux"):
        I_text.font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", font_size,
                                         encoding="unic")
    else:
        raise NotImplementedError
    I_text.multiline_text((5, 30), text, fill=(0, 0, 0))
    b_inf_image = tvtf.to_tensor(img_pil)

    return b_inf_image


# %% change input and outputs of the unet

def change_input_output_unet(model, in_channels=4, out_channels=8):
    """

    :param model: unet model from guided diffusion code, for 256x256 image input
    :param in_channels:
    :param out_channels:
    :return: the model with the change
    """

    # change the input
    kernel_size = model.input_blocks[0][0].kernel_size
    stride = model.input_blocks[0][0].stride
    padding = model.input_blocks[0][0].padding
    out_channels_in = model.input_blocks[0][0].out_channels
    model.input_blocks[0][0] = nn.Conv2d(in_channels, out_channels_in, kernel_size, stride, padding)

    # change the input
    kernel_size = model.out[-1].kernel_size
    stride = model.out[-1].stride
    padding = model.out[-1].padding
    in_channels_out = model.out[-1].in_channels
    model.out[-1] = nn.Conv2d(in_channels_out, out_channels, kernel_size, stride, padding)

    return model


# %% masked mse loss

class MaskedMSELoss(_Loss):
    """

    masked version of MSE loss

    """

    __constants__ = ['reduction']

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(MaskedMSELoss, self).__init__(size_average, reduce, reduction)

    def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        masked_base_loss = mask * F.mse_loss(input, target, reduction='none')

        if self.reduction == 'sum':
            masked_mse_loss = masked_base_loss.sum()
        elif self.reduction == 'mean':
            # the number of channel of mask is 1, and the image the RGB/RGBD therefore a multiplication is required
            num_channels = input.shape[1]
            num_non_zero_elements = num_channels * mask.sum()
            masked_mse_loss = masked_base_loss.sum() / num_non_zero_elements
        elif self.reduction == 'none':
            masked_mse_loss = masked_base_loss
        else:
            ValueError(f"Unknown reduction input: {self.reduction}")

        return masked_mse_loss


# %% masked L1 loss

class MaskedL1Loss(_Loss):
    """

    masked version of L1 loss

    """

    __constants__ = ['reduction']

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(MaskedL1Loss, self).__init__(size_average, reduce, reduction)

    def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        masked_base_loss = mask * F.l1_loss(input, target, reduction='none')

        if self.reduction == 'sum':
            masked_l1_loss = masked_base_loss.sum()
        elif self.reduction == 'mean':
            # the number of channel of mask is 1, and the image the RGB/RGBD therefore a multiplication is required
            num_channels = input.shape[1]
            num_non_zero_elements = num_channels * mask.sum()
            masked_l1_loss = masked_base_loss.sum() / num_non_zero_elements
        elif self.reduction == 'none':
            masked_l1_loss = masked_base_loss
        else:
            ValueError(f"Unknown reduction input: {self.reduction}")

        return masked_l1_loss


# %% read yaml config file and parser functions

def load_yaml(file_path: str) -> dict:
    with open(file_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


# read yaml file (config file and write the content into txt file)

def yaml_to_txt(yaml_file_path, txt_file_path):
    # Read YAML file
    with open(yaml_file_path, 'r') as yaml_file:
        yaml_data = yaml.load(yaml_file, Loader=yaml.FullLoader)

    # Convert YAML data to a string
    yaml_text = yaml.dump(yaml_data, default_flow_style=False)

    # Write YAML data to a text file
    with open(txt_file_path, 'w') as txt_file:
        txt_file.write(yaml_text)


# dictionary and argparser functions

def args_to_dict(args, keys):
    return {k: getattr(args, k) for k in keys}


def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("boolean value expected")


def add_dict_to_argparser(parser, default_dict):
    """
    function from guided diffusion code
    """

    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v, type=v_type)


def add_dict_to_namespace(namespace, args_dict):
    for key, value in args_dict.items():
        setattr(namespace, key, value)


# save directory using date
def update_save_dir_date(arguments_save_dir: str) -> str:
    today = datetime.date.today()
    today = f"{today.day}-{today.month}-{today.year % 2000}"
    run_description = "run1"
    save_dir = pjoin(arguments_save_dir, f"{today}", run_description)

    # check if this path is already exist
    while True:
        if os.path.exists(save_dir):

            digits = re.findall(r'\d+$', save_dir)[0]
            digits_len = len(str(digits))
            save_dir = f"{save_dir[0:-digits_len]}{int(digits) + 1}"
        else:
            break
    os.makedirs(save_dir, exist_ok=True)

    return save_dir


# checkpoint path update

def update_checkpoint_path(save_dir_path: str) -> str:
    checkpoint_path = os.path.join(save_dir_path, "checkpoint")
    os.makedirs(checkpoint_path, exist_ok=True)
    return os.path.join(checkpoint_path, "checkpoint.pt")


# update_relevant_arguments function

def update_relevant_arguments(args, save_dir_path: str):
    # cast to float relevant inputs
    args.lr, args.fp16_scale_growth = float(args.lr), float(args.fp16_scale_growth)
    args.save_dir = update_save_dir_date(args.save_dir_main)
    args.checkpoint_path = update_checkpoint_path(args.save_dir) if args.save_checkpoint else ""

    # specify number of input and output channels according to pretrained model
    if args.pretrain_model == "debka":
        args.unet_in_channels = 4
        args.unet_out_channels = (4 if not args.learn_sigma else 8)
    else:
        args.unet_in_channels = 3
        args.unet_out_channels = (3 if not args.learn_sigma else 6)

    return args


# arguments parser functions
def arguments_from_file(config_file_path: str) -> argparse.Namespace:
    # read config file
    args_dict = load_yaml(config_file_path)

    # create argparse Namspace object
    args = argparse.Namespace()

    # add config dictionary into argparse namespace
    add_dict_to_namespace(args, args_dict)

    return args


# os run

def get_os():
    if sys.platform.startswith('linux'):
        os_run = 'linux'
    elif sys.platform.startswith('win'):
        os_run = 'win'
    else:
        print("Running on a different platform")

    return os_run


# %% return torch optimizer by name

def get_optimizer(optimizer_name, model_parameters, **kwargs):
    optimizer_name = optimizer_name.lower()

    if optimizer_name is None or optimizer_name == "gd" or optimizer_name == "":
        return None
    elif optimizer_name == 'adam':
        return optim.Adam(model_parameters, **kwargs)
    elif optimizer_name == 'sgd':
        return optim.SGD(model_parameters, **kwargs)
    elif optimizer_name == 'rmsprop':
        return optim.RMSprop(model_parameters, **kwargs)
    elif optimizer_name == 'adagrad':
        return optim.Adagrad(model_parameters, **kwargs)
    elif optimizer_name == 'adadelta':
        return optim.Adadelta(model_parameters, **kwargs)
    elif optimizer_name == 'adamw':
        return optim.AdamW(model_parameters, **kwargs)
    elif optimizer_name == 'sparseadam':
        return optim.SparseAdam(model_parameters, **kwargs)
    elif optimizer_name == 'adamax':
        return optim.Adamax(model_parameters, **kwargs)
    elif optimizer_name == 'asgd':
        return optim.ASGD(model_parameters, **kwargs)
    elif optimizer_name == 'lbfgs':
        return optim.LBFGS(model_parameters, **kwargs)
    elif optimizer_name == 'rprop':
        return optim.Rprop(model_parameters, **kwargs)
    elif optimizer_name == 'rprop':
        return optim.Rprop(model_parameters, **kwargs)
    else:
        raise ValueError(f"Optimizer '{optimizer_name}' is not supported.")


# %% change depth function according to the input of depth type

def get_depth_value(value_raw, **kwargs):
    if isinstance(value_raw, float):
        value = value_raw
    elif isinstance(value_raw, int):
        value = float(value_raw)
    elif isinstance(value_raw, str):
        value = np.fromstring(value_raw, dtype=float, sep=',')
    elif isinstance(value_raw, (np.ndarray, np.generic)):
        value = value_raw
    else:
        raise NotImplementedError

    return value


def convert_depth(depth, depth_type, **kwargs):
    """

    :param depth: expected to get the depth as it gets out from the unet model
    :param depth_type: the type of conversion
    :return: converted out depth
    """
    tmp_value = kwargs.get("value", None)
    value = get_depth_value(tmp_value)

    if depth_type == "move":
        depth_out = depth + value

    elif depth_type == "gamma":
        depth_out = torch.pow((depth + value[0]) * value[1], value[2])

    elif depth_type is None or depth_type == "original":
        depth_out = 0.5 * (depth + 1.0)

    else:
        raise NotImplementedError

    return depth_out


# %% when pattern sampling - check if freezing phi is required

def is_freeze_phi(sample_pattern, time_index, num_timesteps):
    # original sampling (no freezing phi required at all)
    if (sample_pattern is None) or (sample_pattern["pattern"] == "original"):
        freeze_phi = False

        # in case of non guidance for that time index, no alternating happens
    elif time_index > sample_pattern['start_guidance'] * num_timesteps or \
            time_index < sample_pattern['stop_guidance'] * num_timesteps:
        freeze_phi = True

    # gibbsDDRM pattern sampling - but before starting update phi
    elif time_index > sample_pattern["update_start"] * num_timesteps or time_index < sample_pattern[
        "update_end"] * num_timesteps:
        freeze_phi = True

    # otherwise not freezing phi
    else:
        freeze_phi = False

    return freeze_phi


# %% when pattern sampling - set alternating length

def set_alternate_length(sample_pattern, time_index, num_timesteps):
    # check correction of the values
    if (sample_pattern["pattern"] != "original") and (sample_pattern is not None):

        assert sample_pattern["update_start"] > sample_pattern["update_end"]
        assert sample_pattern["s_start"] > sample_pattern["s_end"]

        if sample_pattern['local_M'] > 1:
            assert sample_pattern["update_start"] >= sample_pattern["s_start"]
            assert sample_pattern["s_end"] >= sample_pattern["update_end"]

        # this is the original - non pattern case
    if (sample_pattern is None) or (sample_pattern["pattern"] == "original"):
        alternate_length = 1

    # in case of non guidance for that time index, no alternating happens
    elif time_index > sample_pattern['start_guidance'] * num_timesteps or \
            time_index < sample_pattern['stop_guidance'] * num_timesteps:
        alternate_length = 1

    # Until start update there is no optimization of phi's - This is mentioned in the gibbsDDRM paper
    elif time_index > sample_pattern["update_start"] * num_timesteps or \
            time_index < sample_pattern["update_end"] * num_timesteps:
        alternate_length = 1

    # PGDiff paper - S_start and S_end - time indices which the alternate optimization is happened
    # in this case S_start should be smaller than update start
    # s_start should be smaller than update_start and s_end larger than update_end
    elif time_index > sample_pattern["s_start"] * num_timesteps or \
            time_index < sample_pattern["s_end"] * num_timesteps:
        alternate_length = 1

    else:
        alternate_length = sample_pattern["local_M"]

    return alternate_length


# %% logging text

def log_text(args):
    log_txt_tmp = f"\n\nGuidance Scale: {args.conditioning['params']['scale']}" \
                  f"\nLoss Function: {args.conditioning['params']['loss_function']}" \
                  f"\nweight: {args.conditioning['params']['loss_weight']}, " \
                  f"weight_function: {args.conditioning['params']['weight_function']}" \
                  f"\nAuxiliary Loss: {args.aux_loss['aux_loss']}" \
                  f"\nUnderwater model: {args.measurement['operator']['name']}" \
                  f"\nOptimize w.r.t: {'x_prev' if args.conditioning['params']['gradient_x_prev'] else 'x0'}" \
                  f"\nOptimizer model: {args.measurement['operator']['optimizer'] if 'optimizer' in list(args.measurement['operator'].keys()) else 'none'}, " \
                  f"\nManual seed: {args.manual_seed}" \
                  f"\nDepth type: {args.measurement['operator']['depth_type']}, value: {args.measurement['operator']['value']}"

    log_noise_txt = f"\nNoise: {args.measurement['noise']['name']}"
    if 'sigma' in list(args.measurement['noise'].keys()):
        log_noise_txt += f", sigma: {args.measurement['noise']['sigma']}"
    log_txt_tmp += log_noise_txt

    gradient_clip_tmp = args.conditioning['params']['gradient_clip']
    gradient_clip_tmp = [num_str for num_str in gradient_clip_tmp.split(',')]
    log_grad_clip_txt = f"\nGradient Clipping: {gradient_clip_tmp[0]}"
    gradient_clip = str2bool(gradient_clip_tmp[0])
    if gradient_clip:
        log_grad_clip_txt += f", min value: -{gradient_clip_tmp[1]}, max value: {gradient_clip_tmp[1]}"
    log_txt_tmp += log_grad_clip_txt

    if args.sample_pattern['pattern'] == 'original':
        log_txt_tmp += f"\nSample Pattern: original"
    else:
        log_txt_tmp += f"\nSample Pattern: {args.sample_pattern['pattern']}, " \
                       f"\n     Guidance start: {args.sample_pattern['start_guidance']} ,end: {args.sample_pattern['stop_guidance']}" \
                       f"\n     Optimizations iters: {args.sample_pattern['n_iter']}, " \
                       f"\n     Update start from: {args.sample_pattern['update_start']}, end: {args.sample_pattern['update_end']}" \
                       f"\n     M: {args.sample_pattern['local_M']}, start: {args.sample_pattern['s_start']}, end: {args.sample_pattern['s_end']}"

    return log_txt_tmp


# %% loss_weight - factor the difference between the  measurement to the degraded image

def set_loss_weight(loss_weight_type, weight_function=None, degraded_image=None, x_0_hat=None):
    # weight function is a string divided into "function,value0,value1,..."
    if isinstance(weight_function, str):
        str_parts = weight_function.split(",")
        function_str = str_parts[0]

        if len(str_parts) > 1:
            value = np.asarray(str_parts[1:]).astype(float)
            value = value.item() if value.shape[0] == 1 else value

    else:
        function_str = 'none'

    if loss_weight_type == 'none' or loss_weight_type is None:
        loss_weight = 1

    # try to multiply by the depth, the reason is to make the gradients of the far area larger since the
    # prediction from the u-net got close to zero at those areas
    elif loss_weight_type == 'depth':

        depth_tmp = x_0_hat.detach()[:, 3, :, :].unsqueeze(1)
        loss_weight = convert_depth(depth=depth_tmp, depth_type=function_str, value=value)

    else:
        raise NotImplementedError

    return loss_weight


# %% create histogram image

def color_histogram(img, title=None):
    """
    :param img: image should be tensor (c, h, w) between values [0.,1.]
    :return: tensor image of histogram (c, h, w) between values [0.,1.]
    """

    img = torch.clamp(img, min=0., max=1.)

    colors = ("red", "green", "blue")
    img_np = (img * 255).to(torch.uint8).permute(1, 2, 0).numpy()
    # get the dimensions
    ypixels, xpixels, bands = img_np.shape

    # get the size in inches
    dpi = plt.rcParams['figure.dpi']
    xinch = xpixels / dpi
    yinch = ypixels / dpi

    fig = plt.figure(figsize=(xinch, yinch))
    plt.xlim([-5, 260])

    for channel_id, color in enumerate(colors):
        histogram, bin_edges = np.histogram(img_np[:, :, channel_id], bins=256, range=(0, 256))
        plt.plot(bin_edges[0:-1], histogram, color=color)

    plt.grid()
    plt.yticks(rotation=45, ha='right', fontsize=7)
    if title is not None:
        plt.title(str(title))

    canvas = fig.canvas
    canvas.draw()  # Draw the canvas, cache the renderer
    hist_image_flat = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')  # (H * W * 3,)
    # NOTE: reversed converts (W, H) from get_width_height to (H, W)
    hist_np = hist_image_flat.reshape(*reversed(canvas.get_width_height()), 3)  # (H, W, 3)
    hist_tensor = tvtf.to_tensor(Image.fromarray(hist_np))
    plt.close(fig)

    return hist_tensor


# %% save depth tensor into rgb with colormap (instead of grayscale)

def depth_tensor_to_color_image(tensor_image, colormap='viridis'):
    cm = plt.get_cmap(colormap)

    if len(tensor_image.shape) == 4:
        tensor_image = tensor_image.squeeze()

    if len(tensor_image.shape) == 3:
        tensor_image = tensor_image[0]

    assert len(tensor_image.shape) == 2

    # color the gray scale image
    im_np = cm(tensor_image.numpy())
    depth_im_ii = torch.tensor(im_np[:, :, 0:3]).permute(2, 0, 1)

    return depth_im_ii


**GAUSSIAN DIFFUSION**

In [19]:
import math
import os
from os.path import join as pjoin
from functools import partial
import matplotlib.pyplot as plt

import numpy as np
from PIL import Image
from tqdm.auto import tqdm

import torch
from torchvision.utils import make_grid
import torchvision.transforms.functional as tvtf

__SAMPLER__ = {}


def register_sampler(name: str):
    def wrapper(cls):
        if __SAMPLER__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __SAMPLER__[name] = cls
        return cls

    return wrapper


def get_sampler(name: str):
    if __SAMPLER__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined!")
    return __SAMPLER__[name]


def create_sampler(sampler,
                   steps,
                   noise_schedule,
                   model_mean_type,
                   model_var_type,
                   dynamic_threshold,
                   clip_denoised,
                   rescale_timesteps,
                   timestep_respacing="",
                   **kwargs):
    sampler = get_sampler(name=sampler)

    annealing_time = kwargs.get('annealing_time', False)
    betas = get_named_beta_schedule(noise_schedule, steps)
    if not timestep_respacing:
        timestep_respacing = [steps]

    return sampler(use_timesteps=space_timesteps(steps, timestep_respacing),
                   betas=betas,
                   model_mean_type=model_mean_type,
                   model_var_type=model_var_type,
                   dynamic_threshold=dynamic_threshold,
                   clip_denoised=clip_denoised,
                   rescale_timesteps=rescale_timesteps,
                   annealing_time=annealing_time)


class GaussianDiffusion:
    def __init__(self,
                 betas,
                 model_mean_type,
                 model_var_type,
                 dynamic_threshold,
                 clip_denoised,
                 rescale_timesteps,
                 **kwargs):

        # use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert self.betas.ndim == 1, "betas must be 1-D"
        assert (0 < self.betas).all() and (self.betas <= 1).all(), "betas must be in (0..1]"

        self.num_timesteps = int(self.betas.shape[0])
        self.rescale_timesteps = rescale_timesteps

        alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
                betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        # log calculation clipped because the posterior variance is 0 at the
        # beginning of the diffusion chain.
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
                betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
                (1.0 - self.alphas_cumprod_prev)
                * np.sqrt(alphas)
                / (1.0 - self.alphas_cumprod)
        )

        self.mean_processor = get_mean_processor(model_mean_type,
                                                 betas=betas,
                                                 dynamic_threshold=dynamic_threshold,
                                                 clip_denoised=clip_denoised)

        self.var_processor = get_var_processor(model_var_type,
                                               betas=betas)

    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).

        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """

        mean = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start) * x_start
        variance = extract_and_expand(1.0 - self.alphas_cumprod, t, x_start)
        log_variance = extract_and_expand(self.log_one_minus_alphas_cumprod, t, x_start)

        return mean, variance, log_variance

    def q_sample(self, x_start, t):
        """
        Diffuse the data for a given number of diffusion steps.

        In other words, sample from q(x_t | x_0).

        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        noise = torch.randn_like(x_start)
        assert noise.shape == x_start.shape

        coef1 = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start)
        coef2 = extract_and_expand(self.sqrt_one_minus_alphas_cumprod, t, x_start)

        return coef1 * x_start + coef2 * noise

    def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior:

            q(x_{t-1} | x_t, x_0)

        """
        assert x_start.shape == x_t.shape
        coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
        coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
        posterior_mean = coef1 * x_start + coef2 * x_t
        posterior_variance = extract_and_expand(self.posterior_variance, t, x_t)
        posterior_log_variance_clipped = extract_and_expand(self.posterior_log_variance_clipped, t, x_t)

        assert (
                posterior_mean.shape[0]
                == posterior_variance.shape[0]
                == posterior_log_variance_clipped.shape[0]
                == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_sample_loop(self,
                      model,
                      x_start,
                      measurement,
                      measurement_cond_fn,
                      record,
                      save_root,
                      pretrain_model=None,
                      image_idx=None,
                      record_every=150,
                      rgb_guidance=False,
                      sample_pattern=None,
                      **kwargs):
        """
        The function used for sampling from noise.
        """

        img = x_start
        device = x_start.device
        global_iteration = kwargs.get("global_iteration", False)
        original_file_name = kwargs.get("original_file_name", "image_0")
        save_grids_path = kwargs.get("save_grids_path", None)

        time_val_list = []
        loss_process = []

        if record:
            rgb_record_list = []
            depth_record_list = []

        total_steps = self.num_timesteps
        pbar = tqdm(list(range(total_steps))[::-1])

        # loop over the timestep
        for idx in pbar:

            time = torch.tensor([idx] * img.shape[0], device=device)
            time_val_list.append(time.cpu().item())

            # flag (bool) for non guidance
            guidance_flag = (sample_pattern['pattern'] == 'original') or \
                            (sample_pattern['pattern'] is None) or \
                            (sample_pattern['start_guidance'] * self.num_timesteps >= time >= sample_pattern[
                                'stop_guidance'] * self.num_timesteps)

            # setting the alternate len (M from the gibbsDDRM paper)
            alternate_len = set_alternate_length(sample_pattern, idx, self.num_timesteps)

            # for osmosis use alternate_len=1, means - no alternating
            for alternate_ii in range(alternate_len):

                img.requires_grad = True if guidance_flag else False

                if rgb_guidance:
                    out = self.p_sample(x=img, t=time, model=model)

                else:
                    # "clean" the noise with the unet
                    out = self.p_mean_variance(model=model, x=img, t=time)
                    out['sample'] = out['mean']

                # there is no use of the noisy measurement, do we need it? I don't know yet
                noisy_measurement = self.q_sample(measurement, t=time)

                # Give condition. -> guiding
                if pretrain_model == 'osmosis' and not rgb_guidance:

                    # check if there is a sampling method and check the idx to check if to freeze phis
                    freeze_phi = is_freeze_phi(sample_pattern, idx, self.num_timesteps)

                    if guidance_flag:

                        # conditioning function (guidance)
                        img, loss, variable_dict, gradients, aux_loss = \
                            measurement_cond_fn(x_t=out['sample'],
                                                measurement=measurement,
                                                noisy_measurement=noisy_measurement,
                                                x_prev=img,
                                                x_0_hat=out['pred_xstart'],
                                                freeze_phi=freeze_phi,
                                                time_index=float(idx) / self.num_timesteps)

                    else:
                        # no guidance
                        img = out['sample']

                    # sampling new img after guidance
                    noise = torch.randn_like(img, device=img.device)
                    if time != 0:  # no noise when t == 0
                        img += torch.exp(0.5 * out['log_variance']) * noise

                    # detach result from graph, for the next iteration
                    img.detach_()

                    # update pbar for the last alternating process
                    if alternate_ii == (alternate_len - 1):

                        loss_process.append(loss[0].item())
                        # print and log values
                        pbar_print_dictionary = {}
                        pbar_print_dictionary['time'] = time.cpu().tolist()
                        pbar_print_dictionary['loss'] = loss
                        # print auxiliary loss to the pbar
                        if aux_loss is not None:
                            pbar_print_dictionary['aux'] = np.round(
                                [ii.item() for ii in list(aux_loss.values())], decimals=4)

                        # print variables to pbar
                        for key_ii, value_ii in variable_dict.items():
                            current_var_value = np.round(value_ii.cpu().detach().squeeze().tolist(), decimals=3)
                            # in case the variable is a matrix
                            if len(current_var_value.shape) > 1:
                                current_var_value = \
                                    np.round([current_var_value.mean(), current_var_value.std()], decimals=3)
                            pbar_print_dictionary[key_ii] = current_var_value

                        # print the pbar
                        pbar.set_postfix(pbar_print_dictionary, refresh=False)

                # almost original dps code - rgb_guidance
                else:
                    img, loss = measurement_cond_fn(x_t=out['sample'],
                                                    measurement=measurement,
                                                    noisy_measurement=noisy_measurement,
                                                    x_prev=img,
                                                    x_0_hat=out['pred_xstart'])
                    img = img.detach_()
                    pbar.set_postfix({'loss': loss.detach().cpu().item()}, refresh=False)

                # save the images during the diffusion process
                if record and (alternate_ii == (alternate_len - 1)) and \
                        ((idx % record_every == 0) or (idx == 0) or (idx == 999)):
                    # the RGBD image
                    mid_x_0_pred_tmp = out['pred_xstart'].detach().cpu()

                    # split into RGB and Depth images
                    rgb_record_tmp = 0.5 * (mid_x_0_pred_tmp[0, 0:3, :, :] + 1)
                    rgb_record_tmp_clip = torch.clamp(rgb_record_tmp, 0, 1)

                    # Depth
                    depth_record_tmp = mid_x_0_pred_tmp[:, 3, :, :]
                    # percentile + min max norm for the depth image
                    depth_record_tmp_pmm = min_max_norm_range_percentile(depth_record_tmp, percent_low=0.05,
                                                                                percent_high=0.99)
                    depth_record_tmp_pmm_color = depth_tensor_to_color_image(depth_record_tmp_pmm)

                    rgb_record_list.append(rgb_record_tmp_clip)
                    depth_record_list.append(depth_record_tmp_pmm_color)

        # save the recorded images
        if record and (save_grids_path is not None):
            # save rgb and depth information - images are clipped, depth is percentiled + min-max normalized
            mid_grid = make_grid(rgb_record_list + depth_record_list, nrow=len(rgb_record_list))
            mid_grid_pil = tvtf.to_pil_image(mid_grid)
            mid_grid_pil.save(pjoin(save_grids_path, f'{original_file_name}_process.png'))

        # return the relevant things
        if pretrain_model == 'osmosis' and not rgb_guidance:
            return img, variable_dict, loss, out['pred_xstart'].detach().cpu()

        else:
            return img

    def p_sample(self, model, x, t):
        raise NotImplementedError

    def p_mean_variance(self, model, x, t):
        model_output = model(x, self._scale_timesteps(t))

        # In the case of "learned" variance, model will give twice channels.
        if model_output.shape[1] == 2 * x.shape[1]:
            model_output, model_var_values = torch.split(model_output, x.shape[1], dim=1)
        else:
            # The name of variable is wrong. 
            # This will just provide shape information, and 
            # will not be used for calculating something important in variance.
            model_var_values = model_output

        model_mean, pred_xstart = self.mean_processor.get_mean_and_xstart(x, t, model_output)
        model_variance, model_log_variance = self.var_processor.get_variance(model_var_values, t)

        assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape

        return {'mean': model_mean,
                'variance': model_variance,
                'log_variance': model_log_variance,
                'pred_xstart': pred_xstart}

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * (1000.0 / self.num_timesteps)
        return t


def space_timesteps(num_timesteps, section_counts):
    """
    Create a list of timesteps to use from an original diffusion process,
    given the number of timesteps we want to take from equally-sized portions
    of the original process.
    For example, if there's 300 timesteps and the section counts are [10,15,20]
    then the first 100 timesteps are strided to be 10 timesteps, the second 100
    are strided to be 15 timesteps, and the final 100 are strided to be 20.
    If the stride is a string starting with "ddim", then the fixed striding
    from the DDIM paper is used, and only one section is allowed.
    :param num_timesteps: the number of diffusion steps in the original
                          process to divide up.
    :param section_counts: either a list of numbers, or a string containing
                           comma-separated numbers, indicating the step count
                           per section. As a special case, use "ddimN" where N
                           is a number of steps to use the striding from the
                           DDIM paper.
    :return: a set of diffusion steps from the original process to use.
    """
    if isinstance(section_counts, str):
        if section_counts.startswith("ddim"):
            desired_count = int(section_counts[len("ddim"):])
            for i in range(1, num_timesteps):
                if len(range(0, num_timesteps, i)) == desired_count:
                    return set(range(0, num_timesteps, i))
            raise ValueError(
                f"cannot create exactly {num_timesteps} steps with an integer stride"
            )
        section_counts = [int(x) for x in section_counts.split(",")]
    elif isinstance(section_counts, int):
        section_counts = [section_counts]

    size_per = num_timesteps // len(section_counts)
    extra = num_timesteps % len(section_counts)
    start_idx = 0
    all_steps = []
    for i, section_count in enumerate(section_counts):
        size = size_per + (1 if i < extra else 0)
        if size < section_count:
            raise ValueError(
                f"cannot divide section of {size} steps into {section_count}"
            )
        if section_count <= 1:
            frac_stride = 1
        else:
            frac_stride = (size - 1) / (section_count - 1)
        cur_idx = 0.0
        taken_steps = []
        for _ in range(section_count):
            taken_steps.append(start_idx + round(cur_idx))
            cur_idx += frac_stride
        all_steps += taken_steps
        start_idx += size
    return set(all_steps)


class SpacedDiffusion(GaussianDiffusion):
    """
    A diffusion process which can skip steps in a base diffusion process.
    :param use_timesteps: a collection (sequence or set) of timesteps from the
                          original diffusion process to retain.
    :param kwargs: the kwargs to create the base diffusion process.
    """

    def __init__(self, use_timesteps, **kwargs):
        self.use_timesteps = set(use_timesteps)
        self.timestep_map = []
        self.original_num_steps = len(kwargs["betas"])

        base_diffusion = GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa
        last_alpha_cumprod = 1.0
        new_betas = []
        for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
            if i in self.use_timesteps:
                new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
                last_alpha_cumprod = alpha_cumprod
                self.timestep_map.append(i)
        kwargs["betas"] = np.array(new_betas)
        super().__init__(**kwargs)

    def p_mean_variance(self, model, *args, **kwargs):  # pylint: disable=signature-differs
        return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)

    def training_losses(self, model, *args, **kwargs):  # pylint: disable=signature-differs
        return super().training_losses(self._wrap_model(model), *args, **kwargs)

    def condition_mean(self, cond_fn, *args, **kwargs):
        return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)

    def condition_score(self, cond_fn, *args, **kwargs):
        return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)

    def _wrap_model(self, model):
        if isinstance(model, _WrappedModel):
            return model
        return _WrappedModel(
            model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
        )

    def _scale_timesteps(self, t):
        # Scaling is done by the wrapped model.
        return t


class _WrappedModel:
    def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
        self.model = model
        self.timestep_map = timestep_map
        self.rescale_timesteps = rescale_timesteps
        self.original_num_steps = original_num_steps

    def __call__(self, x, ts, **kwargs):
        map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
        new_ts = map_tensor[ts]
        if self.rescale_timesteps:
            new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
        return self.model(x, new_ts, **kwargs)


@register_sampler(name='ddpm')
class DDPM(SpacedDiffusion):
    def p_sample(self, model, x, t):
        out = self.p_mean_variance(model, x, t)
        sample = out['mean']

        noise = torch.randn_like(x)
        if t[0] != 0:  # no noise when t == 0
            sample += torch.exp(0.5 * out['log_variance']) * noise

        return {'sample': sample, 'pred_xstart': out['pred_xstart']}


@register_sampler(name='ddim')
class DDIM(SpacedDiffusion):
    def p_sample(self, model, x, t, eta=0.0):
        out = self.p_mean_variance(model, x, t)

        eps = self.predict_eps_from_x_start(x, t, out['pred_xstart'])

        alpha_bar = extract_and_expand(self.alphas_cumprod, t, x)
        alpha_bar_prev = extract_and_expand(self.alphas_cumprod_prev, t, x)
        sigma = (
                eta
                * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
                * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
        )
        # Equation 12.
        noise = torch.randn_like(x)
        mean_pred = (
                out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
                + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
        )

        sample = mean_pred
        if t != 0:
            sample += sigma * noise

        return {"sample": sample, "pred_xstart": out["pred_xstart"]}

    def predict_eps_from_x_start(self, x_t, t, pred_xstart):
        coef1 = extract_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t)
        coef2 = extract_and_expand(self.sqrt_recipm1_alphas_cumprod, t, x_t)
        return (coef1 * x_t - pred_xstart) / coef2


# =================
# Helper functions
# =================

def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    """
    Get a pre-defined beta schedule for the given name.

    The beta schedule library consists of beta schedules which remain similar
    in the limit of num_diffusion_timesteps.
    Beta schedules may be added, but should not be removed or changed once
    they are committed to maintain backwards compatibility.
    """
    if schedule_name == "linear":
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps.
        scale = 1000 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


# ================
# Helper function
# ================

def extract_and_expand(array, time, target):
    array = torch.from_numpy(array).to(target.device)[time].float()
    while array.ndim < target.ndim:
        array = array.unsqueeze(-1)
    return array.expand_as(target)


def expand_as(array, target):
    if isinstance(array, np.ndarray):
        array = torch.from_numpy(array)
    elif isinstance(array, np.float):
        array = torch.tensor([array])

    while array.ndim < target.ndim:
        array = array.unsqueeze(-1)

    return array.expand_as(target).to(target.device)


def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)

In [20]:
"""
This module handles task-dependent operations
"""

from abc import ABC, abstractmethod

import numpy as np
from torchvision import torch

# =================
# Operation classes
# =================

__OPERATOR__ = {}


def register_operator(name: str):
    def wrapper(cls):
        if __OPERATOR__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __OPERATOR__[name] = cls

        return cls

    return wrapper


def get_operator(name: str, **kwargs):
    if __OPERATOR__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined.")

    operator = __OPERATOR__[name](**kwargs)
    operator.__name__ = name

    # return __OPERATOR__[name](**kwargs)
    return operator


class LinearOperator(ABC):
    @abstractmethod
    def forward(self, data, **kwargs):
        # calculate A * X
        pass

    @abstractmethod
    def transpose(self, data, **kwargs):
        # calculate A^T * X
        pass

    def ortho_project(self, data, **kwargs):
        # calculate (I - A^T * A)X
        return data - self.transpose(self.forward(data, **kwargs), **kwargs)

    def project(self, data, measurement, **kwargs):
        # calculate (I - A^T * A)Y - AX
        return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs)


@register_operator(name='noise')
class DenoiseOperator(LinearOperator):
    def __init__(self, device, batch_size=1, **kargs):
        self.device = device
        self.batch_size = batch_size

    def forward(self, data, **kargs):
        return data

    def transpose(self, data):
        return data

    def ortho_project(self, data):
        return data

    def project(self, data):
        return data


@register_operator(name='rgb_guidance')
class RGBGuidanceOperator(LinearOperator):
    def __init__(self, device, batch_size=1, **kargs):
        self.device = device
        self.batch_size = batch_size

    def forward(self, data, **kargs):
        return data

    def transpose(self, data):
        return data

    def ortho_project(self, data):
        return data

    def project(self, data):
        return data


# osmosis - learnable Operator
class LearnableOperator(ABC):

    @abstractmethod
    def forward(self, data, **kwargs):
        pass


@register_operator(name='haze_physical')
class HazePhysicalOperator(LearnableOperator):
    def __init__(self, device, phi_ab, phi_inf, phi_ab_eta=1e-5, phi_inf_eta=1e-5,
                 phi_ab_learn_flag=True, phi_inf_learn_flag=True,
                 batch_size=1, **kwargs):

        self.device = device
        self.depth_type = kwargs.get("depth_type", None)
        tmp_value = kwargs.get("value", None)
        self.value = get_depth_value(tmp_value)

        # initialization values
        self.phi_ab = torch.tensor(float(phi_ab)).to(device)
        self.phi_ab = self.phi_ab.repeat(batch_size, 1).unsqueeze(-1).unsqueeze(-1)

        self.phi_inf = torch.tensor(np.fromstring(phi_inf, dtype=float, sep=','), dtype=torch.float, device=device)
        self.phi_inf = self.phi_inf.repeat(batch_size, 1).unsqueeze(-1).unsqueeze(-1)

        self.phi_ab_learn_flag = phi_ab_learn_flag
        self.phi_inf_learn_flag = phi_inf_learn_flag

        # coefficients for the Gradient descend step size
        self.phi_ab_eta = float(phi_ab_eta) if phi_ab_learn_flag else float(0)
        self.phi_inf_eta = float(phi_inf_eta) if phi_inf_learn_flag else float(0)

        # set optimizer
        optimizer = kwargs.get("optimizer", None)
        self.optimizer = get_optimizer(optimizer_name=optimizer,
                                              model_parameters=[{'params': self.phi_ab, "lr": self.phi_ab_eta},
                                                                {'params': self.phi_inf, "lr": self.phi_inf_eta}])

    def forward(self, data, **kwargs):

        # split into rgb and depth
        rgb = data[:, 0:-1, :, :]
        rgb_norm = 0.5 * (rgb + 1)
        depth_tmp = data[:, -1, :, :].unsqueeze(1)

        # convert depth to relevant coordinates
        depth = convert_depth(depth=depth_tmp, depth_type=self.depth_type, value=self.value)

        # the underwater image formation model
        uw_image = rgb_norm * torch.exp(-self.phi_ab * depth) + self.phi_inf * (1 - torch.exp(-self.phi_ab * depth))

        return uw_image

    def optimize(self, **kwargs):

        freeze_phi = kwargs.get("freeze_phi", False)

        # update only part of the variables - in this case: self.optimizer == "GD"
        update_phi_ab = self.phi_ab.requires_grad
        update_phi_inf = self.phi_inf.requires_grad

        # when freeze_phi is True that means no optimization is required
        if not freeze_phi:

            # no optimizer was specified - GD is the default
            if self.optimizer is None or self.optimizer == "GD" or self.optimizer == "":

                # classic gradient descend
                with torch.no_grad():
                    if update_phi_ab:
                        self.phi_ab.add_(self.phi_ab.grad, alpha=-self.phi_ab_eta)
                    if update_phi_inf:
                        self.phi_inf.add_(self.phi_inf.grad, alpha=-self.phi_inf_eta)
                # zero the gradients so they will not accumulate
                if update_phi_ab:
                    self.phi_ab.grad.zero_()
                if update_phi_inf:
                    self.phi_inf.grad.zero_()

            # optimizer was specified
            else:
                self.optimizer.step()
                self.optimizer.zero_grad()

        # return self.beta.detach(), self.b_inf.detach()
        return {'phi_ab': self.phi_ab.detach(), 'phi_inf': self.phi_inf.detach()}

    def get_variable_gradients(self, **kwargs):

        grad_enable_dict = {"phi_ab": self.phi_ab.requires_grad,
                            "phi_inf": self.phi_inf.requires_grad}

        return grad_enable_dict

    def set_variable_gradients(self, value=None, **kwargs):

        if value is None:
            raise ValueError("A value should be specified (True or False for general or dictionary)")

        if isinstance(value, dict):
            self.phi_ab.requires_grad_(value["phi_ab"])
            self.phi_inf.requires_grad_(value["phi_inf"])
        else:
            self.phi_ab.requires_grad_(value)
            self.phi_inf.requires_grad_(value)

    def get_variable_list(self, **kwargs):

        return [self.phi_ab, self.phi_inf]


@register_operator(name='underwater_physical_revised')
class UnderWaterPhysicalRevisedOperator(LearnableOperator):
    def __init__(self, device, phi_a, phi_b, phi_inf,
                 phi_a_eta=1e-5, phi_b_eta=1e-5, phi_inf_eta=1e-5,
                 phi_a_learn_flag=True, phi_b_learn_flag=True, phi_inf_learn_flag=True,
                 batch_size=1, **kwargs):

        self.device = device

        self.depth_type = kwargs.get("depth_type", None)
        tmp_value = kwargs.get("value", None)
        self.value = get_depth_value(tmp_value)

        # initialization values
        self.phi_a = torch.tensor(np.fromstring(phi_a, dtype=float, sep=','), dtype=torch.float, device=device)
        self.phi_a = self.phi_a.repeat(batch_size, 1).unsqueeze(-1).unsqueeze(-1)

        self.phi_b = torch.tensor(np.fromstring(phi_b, dtype=float, sep=','), dtype=torch.float, device=device)
        self.phi_b = self.phi_b.repeat(batch_size, 1).unsqueeze(-1).unsqueeze(-1)

        self.phi_inf = torch.tensor(np.fromstring(phi_inf, dtype=float, sep=','), dtype=torch.float, device=device)
        self.phi_inf = self.phi_inf.repeat(batch_size, 1).unsqueeze(-1).unsqueeze(-1)

        # learning flags
        self.phi_a_learn_flag = phi_a_learn_flag
        self.phi_b_learn_flag = phi_b_learn_flag
        self.phi_inf_learn_flag = phi_inf_learn_flag

        # coefficients for the Gradient descend step size
        self.phi_a_eta = float(phi_a_eta) if phi_a_learn_flag else float(0)
        self.phi_b_eta = float(phi_b_eta) if phi_b_learn_flag else float(0)
        self.phi_inf_eta = float(phi_inf_eta) if phi_inf_learn_flag else float(0)

        # set optimizer
        optimizer = kwargs.get("optimizer", None)
        self.optimizer = get_optimizer(optimizer_name=optimizer,
                                              model_parameters=[{'params': self.phi_a, "lr": self.phi_a_eta},
                                                                {'params': self.phi_b, "lr": self.phi_b_eta},
                                                                {'params': self.phi_inf, "lr": self.phi_inf_eta}])

    def forward(self, data, **kwargs):

        # split into rgb and depth
        rgb = data[:, 0:-1, :, :]
        rgb_norm = 0.5 * (rgb + 1)
        depth_tmp = data[:, -1, :, :].unsqueeze(1)

        # convert depth to relevant coordinates
        depth = convert_depth(depth=depth_tmp, depth_type=self.depth_type, value=self.value)

        # the underwater image formation model
        uw_image = rgb_norm * torch.exp(-self.phi_a * depth) + self.phi_inf * (1 - torch.exp(-self.phi_b * depth))

        return uw_image

    def optimize(self, **kwargs):

        freeze_phi = kwargs.get("freeze_phi", False)

        # update only part of the variables - in this case: self.optimizer == "GD"
        update_phi_a = self.phi_a.requires_grad
        update_phi_b = self.phi_b.requires_grad
        update_phi_inf = self.phi_inf.requires_grad

        # when freeze_phi is True that means no optimization is required
        if not freeze_phi:

            # no optimizer was specified - GD is the default
            if self.optimizer is None or self.optimizer == "GD" or self.optimizer == "":

                # classic gradient descend
                with torch.no_grad():
                    if update_phi_a:
                        self.phi_a.add_(self.phi_a.grad, alpha=-self.phi_a_eta)
                    if update_phi_b:
                        self.phi_b.add_(self.phi_b.grad, alpha=-self.phi_b_eta)
                    if update_phi_inf:
                        self.phi_inf.add_(self.phi_inf.grad, alpha=-self.phi_inf_eta)

                # zero the gradients so they will not accumulate
                if update_phi_a:
                    self.phi_a.grad.zero_()
                if update_phi_b:
                    self.phi_b.grad.zero_()
                if update_phi_inf:
                    self.phi_inf.grad.zero_()

            else:

                self.optimizer.step()
                self.optimizer.zero_grad()

        return {'phi_a': self.phi_a.detach(), 'phi_b': self.phi_b.detach(), 'phi_inf': self.phi_inf.detach()}

    def get_variable_gradients(self, **kwargs):

        grad_enable_dict = {"phi_a": self.phi_a.requires_grad,
                            "phi_b": self.phi_b.requires_grad,
                            "phi_inf": self.phi_inf.requires_grad}

        return grad_enable_dict

    def set_variable_gradients(self, value=None, **kwargs):

        if value is None:
            raise ValueError("A value should be specified (True or False for general or dictionary)")

        if isinstance(value, dict):
            self.phi_a.requires_grad_(value["phi_a"])
            self.phi_b.requires_grad_(value["phi_b"])
            self.phi_inf.requires_grad_(value["phi_inf"])
        else:
            self.phi_a.requires_grad_(value)
            self.phi_b.requires_grad_(value)
            self.phi_inf.requires_grad_(value)

    def get_variable_list(self, **kwargs):

        return [self.phi_a, self.phi_b, self.phi_inf]


@register_operator(name='underwater_physical')
class UnderWaterPhysicalOperator(LearnableOperator):
    def __init__(self, device, phi_ab, phi_inf, phi_ab_eta=1e-5, phi_inf_eta=1e-5,
                 phi_ab_learn_flag=True, phi_inf_learn_flag=True,
                 batch_size=1, **kwargs):

        self.device = device
        self.depth_type = kwargs.get("depth_type", None)
        tmp_value = kwargs.get("value", None)
        self.value = get_depth_value(tmp_value)

        # initialization values
        self.phi_ab = torch.tensor(np.fromstring(phi_ab, dtype=float, sep=','), dtype=torch.float, device=device)
        self.phi_ab = self.phi_ab.repeat(batch_size, 1).unsqueeze(-1).unsqueeze(-1)

        self.phi_inf = torch.tensor(np.fromstring(phi_inf, dtype=float, sep=','), dtype=torch.float, device=device)
        self.phi_inf = self.phi_inf.repeat(batch_size, 1).unsqueeze(-1).unsqueeze(-1)

        self.phi_ab_learn_flag = phi_ab_learn_flag
        self.phi_inf_learn_flag = phi_inf_learn_flag

        # coefficients for the Gradient descend step size
        self.phi_ab_eta = float(phi_ab_eta) if phi_ab_learn_flag else float(0)
        self.phi_inf_eta = float(phi_inf_eta) if phi_inf_learn_flag else float(0)

        # set optimizer
        optimizer = kwargs.get("optimizer", None)
        self.optimizer = get_optimizer(optimizer_name=optimizer,
                                              model_parameters=[{'params': self.phi_ab, "lr": self.phi_ab_eta},
                                                                {'params': self.phi_inf, "lr": self.phi_inf_eta}])

    def forward(self, data, **kwargs):

        # split into rgb and depth
        rgb = data[:, 0:-1, :, :]
        rgb_norm = 0.5 * (rgb + 1)
        depth_tmp = data[:, -1, :, :].unsqueeze(1)

        # convert depth to relevant coordinates
        depth = convert_depth(depth=depth_tmp, depth_type=self.depth_type, value=self.value)

        # the underwater image formation model
        uw_image = rgb_norm * torch.exp(-self.phi_ab * depth) + self.phi_inf * (1 - torch.exp(-self.phi_ab * depth))

        return uw_image

    def optimize(self, **kwargs):

        freeze_phi = kwargs.get("freeze_phi", False)

        # update only part of the variables - in this case: self.optimizer == "GD"
        update_phi_ab = self.phi_ab.requires_grad
        update_phi_inf = self.phi_inf.requires_grad

        # when freeze_phi is True that means no optimization is required
        if not freeze_phi:

            # no optimizer was specified - GD is the default
            if self.optimizer is None or self.optimizer == "GD" or self.optimizer == "":

                # classic gradient descend
                with torch.no_grad():
                    if update_phi_ab:
                        self.phi_ab.add_(self.phi_ab.grad, alpha=-self.phi_ab_eta)
                    if update_phi_inf:
                        self.phi_inf.add_(self.phi_inf.grad, alpha=-self.phi_inf_eta)
                # zero the gradients so they will not accumulate
                if update_phi_ab:
                    self.phi_ab.grad.zero_()
                if update_phi_inf:
                    self.phi_inf.grad.zero_()

            # optimizer was specified
            else:
                self.optimizer.step()
                self.optimizer.zero_grad()

        # return self.beta.detach(), self.b_inf.detach()
        return {'phi_ab': self.phi_ab.detach(), 'phi_inf': self.phi_inf.detach()}

    def get_variable_gradients(self, **kwargs):

        grad_enable_dict = {"phi_ab": self.phi_ab.requires_grad,
                            "phi_inf": self.phi_inf.requires_grad}

        return grad_enable_dict

    def set_variable_gradients(self, value=None, **kwargs):

        if value is None:
            raise ValueError("A value should be specified (True or False for general or dictionary)")

        if isinstance(value, dict):
            self.phi_ab.requires_grad_(value["phi_ab"])
            self.phi_inf.requires_grad_(value["phi_inf"])
        else:
            self.phi_ab.requires_grad_(value)
            self.phi_inf.requires_grad_(value)

    def get_variable_list(self, **kwargs):

        return [self.phi_ab, self.phi_inf]


# =============
# Noise classes
# =============


__NOISE__ = {}


def register_noise(name: str):
    def wrapper(cls):
        if __NOISE__.get(name, None):
            raise NameError(f"Name {name} is already defined!")
        __NOISE__[name] = cls
        return cls

    return wrapper


def get_noise(name: str, **kwargs):
    if __NOISE__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined.")
    noiser = __NOISE__[name](**kwargs)
    noiser.__name__ = name
    return noiser


class Noise(ABC):
    def __call__(self, data):
        return self.forward(data)

    @abstractmethod
    def forward(self, data):
        pass


@register_noise(name='clean')
class Clean(Noise):
    def forward(self, data):
        return data


@register_noise(name='gaussian')
class GaussianNoise(Noise):
    def __init__(self, sigma):
        self.sigma = sigma

    def forward(self, data):
        return data + torch.randn_like(data, device=data.device) * self.sigma


@register_noise(name='poisson')
class PoissonNoise(Noise):
    def __init__(self, rate):
        self.rate = rate

    def forward(self, data):
        '''
        Follow skimage.util.random_noise.
        '''

        # TODO: fix the addional Poission noise - osmosis_utils - adaption for debka

        # version 3 (stack-overflow)

        data = (data + 1.0) / 2.0
        data = data.clamp(0, 1)
        device = data.device
        data = data.detach().cpu()
        data = torch.from_numpy(np.random.poisson(data * 255.0 * self.rate) / 255.0 / self.rate)
        data = data * 2.0 - 1.0
        data = data.clamp(-1, 1)
        return data.to(device)

        # version 2 (skimage)
        # if data.min() < 0:
        #     low_clip = -1
        # else:
        #     low_clip = 0

        # # Determine unique values in iamge & calculate the next power of two
        # vals = torch.Tensor([len(torch.unique(data))])
        # vals = 2 ** torch.ceil(torch.log2(vals))
        # vals = vals.to(data.device)

        # if low_clip == -1:
        #     old_max = data.max()
        #     data = (data + 1.0) / (old_max + 1.0)

        # data = torch.poisson(data * vals) / float(vals)

        # if low_clip == -1:
        #     data = data * (old_max + 1.0) - 1.0

        # return data.clamp(low_clip, 1.0)

In [21]:
import os
import numpy as np
import torch
import torch.nn as nn

# %% base functions for getting loss

__LOSS__ = {}


def register_loss(name: str):
    def wrapper(cls):
        if __LOSS__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __LOSS__[name] = cls
        return cls

    return wrapper


def get_loss(name: str, **kwargs):
    if __LOSS__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined.")
    return __LOSS__[name](**kwargs)


# %% global exposure loss

@register_loss(name='avrg_loss')
class Average_Loss(nn.Module):
    """
    Global Exposure Control Loss
    """

    def __init__(self):
        super(Average_Loss, self).__init__()

    def forward(self, x):
        #  only color data (rgb) is required, depth is not required here - value should be [-1,1]

        x_norm = x[:, 0:3, :, :]
        mean = torch.mean(x_norm, dim=(2, 3))
        avrg_loss = torch.sum(torch.abs(mean))

        return avrg_loss


# %% Value loss

@register_loss(name='val_loss')
class Value_Loss(nn.Module):

    def __init__(self, device=torch.device("cuda:0"), **kwargs):
        super(Value_Loss, self).__init__()
        self.device = torch.device(device)

    def forward(self, rgbd, **kwargs):
        rgb = (rgbd[:, 0:3, :, :])
        value = kwargs.get("value", 0.7)
        val_loss = (torch.maximum(rgb.abs() - value, torch.zeros_like(rgb)) ** 2).mean()

        return val_loss


# %% Auxiliary loss class which includes all the quality losses and their coefficients

class AuxiliaryLoss(nn.Module):
    def __init__(self, losses_dictionary):
        super(AuxiliaryLoss, self).__init__()

        self.losses_dictionary = losses_dictionary
        self.losses_list = [get_loss(key_ii) for key_ii in losses_dictionary.keys()]
        self.loss_gammas = [torch.tensor(value_ii) for value_ii in losses_dictionary.values()]

    def forward(self, x):
        aux_loss = 0
        aux_loss_dict = {}
        # summing the losses according to their gammas
        for gamma_ii, loss_ii, loss_name_ii in zip(self.loss_gammas, self.losses_list, self.losses_dictionary):
            cur_loss = loss_ii.forward(x)
            aux_loss += gamma_ii.to(x.device) * cur_loss
            aux_loss_dict[loss_name_ii] = cur_loss.detach().cpu()
        return aux_loss, aux_loss_dict

In [22]:
from abc import ABC, abstractmethod
import torch
import numpy as np
# import losses as losseso
# import utils as utilso
import copy

__CONDITIONING_METHOD__ = {}


def register_conditioning_method(name: str):
    def wrapper(cls):
        if __CONDITIONING_METHOD__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __CONDITIONING_METHOD__[name] = cls
        return cls

    return wrapper


def get_conditioning_method(name: str, operator, noiser, **kwargs):
    if __CONDITIONING_METHOD__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined!")
    return __CONDITIONING_METHOD__[name](operator=operator, noiser=noiser, **kwargs)


class ConditioningMethod(ABC):
    def __init__(self, operator, noiser, **kwargs):
        self.operator = operator
        self.noiser = noiser

    def project(self, data, noisy_measurement, **kwargs):
        return self.operator.project(data=data, measurement=noisy_measurement, **kwargs)

    def grad_and_value(self, x_prev, x_0_hat, measurement, **kwargs):
        if self.noiser.__name__ == 'gaussian':

            difference = measurement - self.operator.forward(x_0_hat[:, 0:3], **kwargs)
            loss = torch.linalg.norm(difference)
            loss_grad = torch.autograd.grad(outputs=loss, inputs=x_prev)[0]
            # loss_grad = torch.autograd.grad(outputs=loss, inputs=x_0_hat)[0]

        elif self.noiser.__name__ == 'poisson':
            Ax = self.operator.forward(x_0_hat, **kwargs)
            difference = measurement - Ax
            loss = torch.linalg.norm(difference) / measurement.abs()
            loss = loss.mean()
            loss_grad = torch.autograd.grad(outputs=loss, inputs=x_prev)[0]

        else:
            raise NotImplementedError

        return loss_grad, loss

    @abstractmethod
    # def conditioning(self, x_t, measurement, noisy_measurement=None, **kwargs):
    def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs):
        pass


@register_conditioning_method(name='osmosis')
class PosteriorSamplingOsmosis(ConditioningMethod):
    def __init__(self, operator, noiser, **kwargs):
        super().__init__(operator, noiser)
        input_scale_str = kwargs.get('scale', 1.0)

        # in case scale is single for all channels
        try:
            self.scale = torch.tensor([float(input_scale_str)])

        # in case there is a scale value for each channel
        except ValueError:
            self.scale = torch.tensor([float(num_str.strip()) for num_str in input_scale_str.split(',')])

        self.gradient_x_prev = kwargs.get('gradient_x_prev', False)

        # sample pattern parameters
        self.pattern_name = kwargs.get('pattern', 'original')
        self.global_N = kwargs.get('global_N', 1)
        self.local_M = kwargs.get('local_M', 1)
        self.n_iter = kwargs.get('n_iter', 1)
        self.update_start = kwargs.get('update_start', 1.0)

        # Auxiliary loss information
        aux_loss_dict = kwargs.get("aux_loss", None)
        if aux_loss_dict is not None:
            aux_loss_dict = {key_ii: float(value_ii) for key_ii, value_ii in aux_loss_dict.items()}
            # Quality loss object
            self.aux_loss = AuxiliaryLoss(aux_loss_dict)
        else:
            self.aux_loss = None

        # guiding loss function, loss weight (depth or none), is depth - what function and values
        self.loss_function = kwargs.get("loss_function", "norm")
        self.loss_weight = kwargs.get("loss_weight", None)
        self.weight_function = kwargs.get("weight_function", None)

        # use gradient clipping (or image clipping)
        gradient_clip_tmp = kwargs.get("gradient_clip", "False")
        gradient_clip_tmp = [num_str for num_str in gradient_clip_tmp.split(',')]
        self.gradient_clip = str2bool(gradient_clip_tmp[0])

        # if true - what values
        if self.gradient_clip:
            self.gradient_clip_value = float(gradient_clip_tmp[1].strip())
        else:
            self.gradient_clip_value = None

    def grad_and_value(self, x_prev, x_0_hat, measurement, **kwargs):

        # compute the degraded image on the unet prediction (operator) - in measurement file
        degraded_image_tmp = self.operator.forward(x_0_hat, **kwargs)

        # back to [-1,1]
        degraded_image = 2 * degraded_image_tmp - 1

        differance = (measurement - degraded_image)

        # create the loss weights - multiply the differences
        loss_weight = set_loss_weight(loss_weight_type=self.loss_weight,
                                             weight_function=self.weight_function,
                                             degraded_image=degraded_image_tmp.detach(),
                                             x_0_hat=x_0_hat.detach())
        differance = differance * loss_weight

        # loss function - norm2
        if self.loss_function == 'norm':
            loss = torch.linalg.norm(differance)
            # calculated for visualization
            sep_loss = torch.norm(differance.detach().cpu(), p=2, dim=[1, 2, 3]).numpy()

        # Mean square error
        elif self.loss_function == "mse":
            mse = differance ** 2
            mse = mse.mean(dim=(1, 2, 3))
            loss = mse.sum()
            # calculated for visualization
            sep_loss = mse.detach().cpu().numpy()

        # No other loss
        else:
            raise NotImplementedError

        return sep_loss, loss, degraded_image_tmp.detach()

    def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs):

        freeze_phi = kwargs.get("freeze_phi", False)
        time_index = kwargs.get("time_index", None)

        # when the gradient is w.r.t x0, the x_prev gradients and history of the x0 prediction are not required
        if not self.gradient_x_prev:
            x_0_hat = x_0_hat.detach().to(x_0_hat.device)
            x_0_hat.requires_grad_(True)
            x_0_hat = x_0_hat.to(x_0_hat.device)
            x_prev.requires_grad_(False)

        # calculate the losses
        with torch.set_grad_enabled(True):

            # phi's require gradients when we update them, hence when freeze_phi is False
            self.operator.set_variable_gradients(value=not freeze_phi)

            # the number of inner optimization num of steps should be 1 if freezing phi,
            # since there is no optimizing at all in this case
            inner_optimize_length = 1 if freeze_phi else self.n_iter

            for optimize_ii in range(inner_optimize_length):

                # compute the loss after applying the operator, sep_loss is relevant for multiple images
                sep_loss, loss, degraded_image_01 = self.grad_and_value(x_prev=x_prev,
                                                                        x_0_hat=x_0_hat,
                                                                        measurement=measurement,
                                                                        time_index=time_index)

                # total loss refers to the original loss or to the loss of the x
                if self.aux_loss is not None:
                    aux_loss, aux_loss_dict = self.aux_loss.forward(x_0_hat)
                    total_loss = loss + aux_loss
                else:
                    aux_loss_dict = None
                    total_loss = loss

                # calculate the backward graph
                if optimize_ii == (inner_optimize_length - 1):
                    if freeze_phi:
                        # calculate graph w.r.t x_prev
                        total_loss.backward(inputs=[x_prev])
                    else:
                        # calculate graph w.r.t x_prev and phi's
                        total_loss.backward(inputs=[x_prev] + self.operator.get_variable_list())
                else:
                    # when optimize only the phi's, we specify it for faster run time
                    total_loss.backward(inputs=self.operator.get_variable_list())

                # optimize phi's, in case of freeze phi true - optimization is not done
                variables_dict = self.operator.optimize(freeze_phi=freeze_phi)

            # update x_t
            with torch.no_grad():

                # remove - opher
                # # update guidance scale
                # scale_norm = utilso.set_guidance_scale_norm(norm_type=self.scale_norm, x_0_hat=x_0_hat.detach(),
                #                                             x_t=x_t.detach(), x_prev=x_prev,
                #                                             sample_added_noise=sample_added_noise)
                # # reshape the scale according to [b,c,h,w]
                # guidance_scale = scale_norm * self.scale[None, ..., None, None].to(x_prev.device)

                # reshape the scale according to [b,c,h,w]
                guidance_scale = self.scale[None, ..., None, None].to(x_prev.device)

                # update x_t - gradient w.r.t x_t
                if self.gradient_x_prev:

                    if self.gradient_clip:
                        grads = torch.clamp(x_prev.grad,
                                            min=-self.gradient_clip_value,
                                            max=self.gradient_clip_value)
                    else:
                        grads = x_prev.grad

                    x_t -= guidance_scale * grads
                    gradients = x_prev.grad.cpu()

                # update x_t - gradient w.r.t x_0_pred
                else:
                    x_t -= guidance_scale * x_0_hat.grad
                    gradients = x_0_hat.grad.cpu()

        return x_t, sep_loss, variables_dict, gradients, aux_loss_dict


@register_conditioning_method(name='ps')
class PosteriorSampling(ConditioningMethod):
    def __init__(self, operator, noiser, **kwargs):
        super().__init__(operator, noiser)

        input_scale_str = kwargs.get('scale', 1.0)
        # in case scale is single for all channels
        try:
            self.scale = torch.tensor([float(input_scale_str)])
        # in case there is a scale value for each channel
        except ValueError:
            self.scale = torch.tensor([float(num_str.strip()) for num_str in input_scale_str.split(',')])

    def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs):
        norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs)
        x_t -= norm_grad * self.scale[None, ..., None, None].to(x_prev.device)

        return x_t, norm

In [23]:
"""
Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
"""

import os
import sys
import shutil
import os.path as osp
import json
import time
import datetime
import tempfile
import warnings
from collections import defaultdict
from contextlib import contextmanager

DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40

DISABLED = 50


class KVWriter(object):
    def writekvs(self, kvs):
        raise NotImplementedError


class SeqWriter(object):
    def writeseq(self, seq):
        raise NotImplementedError


class HumanOutputFormat(KVWriter, SeqWriter):
    def __init__(self, filename_or_file):
        if isinstance(filename_or_file, str):
            self.file = open(filename_or_file, "wt")
            self.own_file = True
        else:
            assert hasattr(filename_or_file, "read"), (
                    "expected file or str, got %s" % filename_or_file
            )
            self.file = filename_or_file
            self.own_file = False

    def writekvs(self, kvs):
        # Create strings for printing
        key2str = {}
        for (key, val) in sorted(kvs.items()):
            if hasattr(val, "__float__"):
                valstr = "%-8.3g" % val
            else:
                valstr = str(val)
            key2str[self._truncate(key)] = self._truncate(valstr)

        # Find max widths
        if len(key2str) == 0:
            print("WARNING: tried to write empty key-value dict")
            return
        else:
            keywidth = max(map(len, key2str.keys()))
            valwidth = max(map(len, key2str.values()))

        # Write out the data
        dashes = "-" * (keywidth + valwidth + 7)
        lines = [dashes]
        for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
            lines.append(
                "| %s%s | %s%s |"
                % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
            )
        lines.append(dashes)
        self.file.write("\n".join(lines) + "\n")

        # Flush the output to the file
        self.file.flush()

    def _truncate(self, s):
        maxlen = 30
        return s[: maxlen - 3] + "..." if len(s) > maxlen else s

    def writeseq(self, seq):
        seq = list(seq)
        for (i, elem) in enumerate(seq):
            self.file.write(elem)
            if i < len(seq) - 1:  # add space unless this is the last one
                self.file.write(" ")
        self.file.write("\n")
        self.file.flush()

    def close(self):
        if self.own_file:
            self.file.close()


class JSONOutputFormat(KVWriter):
    def __init__(self, filename):
        self.file = open(filename, "wt")

    def writekvs(self, kvs):
        for k, v in sorted(kvs.items()):
            if hasattr(v, "dtype"):
                kvs[k] = float(v)
        self.file.write(json.dumps(kvs) + "\n")
        self.file.flush()

    def close(self):
        self.file.close()


class CSVOutputFormat(KVWriter):
    def __init__(self, filename):
        self.file = open(filename, "w+t")
        self.keys = []
        self.sep = ","

    def writekvs(self, kvs):
        # Add our current row to the history
        extra_keys = list(kvs.keys() - self.keys)
        extra_keys.sort()
        if extra_keys:
            self.keys.extend(extra_keys)
            self.file.seek(0)
            lines = self.file.readlines()
            self.file.seek(0)
            for (i, k) in enumerate(self.keys):
                if i > 0:
                    self.file.write(",")
                self.file.write(k)
            self.file.write("\n")
            for line in lines[1:]:
                self.file.write(line[:-1])
                self.file.write(self.sep * len(extra_keys))
                self.file.write("\n")
        for (i, k) in enumerate(self.keys):
            if i > 0:
                self.file.write(",")
            v = kvs.get(k)
            if v is not None:
                self.file.write(str(v))
        self.file.write("\n")
        self.file.flush()

    def close(self):
        self.file.close()


class TensorBoardOutputFormat(KVWriter):
    """
    Dumps key/value pairs into TensorBoard's numeric format.
    """

    def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = "events"
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat

        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))

    def writekvs(self, kvs):
        def summary_val(k, v):
            kwargs = {"tag": k, "simple_value": float(v)}
            return self.tf.Summary.Value(**kwargs)

        summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
        event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
        event.step = (
            self.step
        )  # is there any reason why you'd want to specify the step?
        self.writer.WriteEvent(event)
        self.writer.Flush()
        self.step += 1

    def close(self):
        if self.writer:
            self.writer.Close()
            self.writer = None


def make_output_format(format, ev_dir, log_suffix=""):
    os.makedirs(ev_dir, exist_ok=True)
    if format == "stdout":
        return HumanOutputFormat(sys.stdout)
    elif format == "log":
        return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
    elif format == "json":
        return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
    elif format == "csv":
        return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
    elif format == "tensorboard":
        return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
    else:
        raise ValueError("Unknown format specified: %s" % (format,))


# ================================================================
# API
# ================================================================


def logkv(key, val):
    """
    Log a value of some diagnostic
    Call this once for each diagnostic quantity, each iteration
    If called many times, last value will be used.
    """
    get_current().logkv(key, val)


def logkv_mean(key, val):
    """
    The same as logkv(), but if called many times, values averaged.
    """
    get_current().logkv_mean(key, val)


def logkvs(d):
    """
    Log a dictionary of key-value pairs
    """
    for (k, v) in d.items():
        logkv(k, v)


def dumpkvs():
    """
    Write all of the diagnostics from the current iteration
    """
    return get_current().dumpkvs()


def getkvs():
    return get_current().name2val


def log(*args, level=INFO):
    """
    Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
    """
    get_current().log(*args, level=level)


def debug(*args):
    log(*args, level=DEBUG)


def info(*args):
    log(*args, level=INFO)


def warn(*args):
    log(*args, level=WARN)


def error(*args):
    log(*args, level=ERROR)


def set_level(level):
    """
    Set logging threshold on current logger.
    """
    get_current().set_level(level)


def set_comm(comm):
    get_current().set_comm(comm)


def get_dir():
    """
    Get directory that log files are being written to.
    will be None if there is no output directory (i.e., if you didn't call start)
    """
    return get_current().get_dir()


record_tabular = logkv
dump_tabular = dumpkvs


@contextmanager
def profile_kv(scopename):
    logkey = "wait_" + scopename
    tstart = time.time()
    try:
        yield
    finally:
        get_current().name2val[logkey] += time.time() - tstart


def profile(n):
    """
    Usage:
    @profile("my_func")
    def my_func(): code
    """

    def decorator_with_name(func):
        def func_wrapper(*args, **kwargs):
            with profile_kv(n):
                return func(*args, **kwargs)

        return func_wrapper

    return decorator_with_name


# ================================================================
# Backend
# ================================================================


def get_current():
    if Logger.CURRENT is None:
        _configure_default_logger()

    return Logger.CURRENT


class Logger(object):
    DEFAULT = None  # A logger with no output files. (See right below class definition)
    # So that you can still log to the terminal without setting up any output files
    CURRENT = None  # Current logger being used by the free functions above

    def __init__(self, dir, output_formats, comm=None):
        self.name2val = defaultdict(float)  # values this iteration
        self.name2cnt = defaultdict(int)
        self.level = INFO
        self.dir = dir
        self.output_formats = output_formats
        self.comm = comm

    # Logging API, forwarded
    # ----------------------------------------
    def logkv(self, key, val):
        self.name2val[key] = val

    def logkv_mean(self, key, val):
        oldval, cnt = self.name2val[key], self.name2cnt[key]
        self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
        self.name2cnt[key] = cnt + 1

    def dumpkvs(self):
        if self.comm is None:
            d = self.name2val
        else:
            d = mpi_weighted_mean(
                self.comm,
                {
                    name: (val, self.name2cnt.get(name, 1))
                    for (name, val) in self.name2val.items()
                },
            )
            if self.comm.rank != 0:
                d["dummy"] = 1  # so we don't get a warning about empty dict
        out = d.copy()  # Return the dict for unit testing purposes
        for fmt in self.output_formats:
            if isinstance(fmt, KVWriter):
                fmt.writekvs(d)
        self.name2val.clear()
        self.name2cnt.clear()
        return out

    def log(self, *args, level=INFO):
        if self.level <= level:
            self._do_log(args)

    # Configuration
    # ----------------------------------------
    def set_level(self, level):
        self.level = level

    def set_comm(self, comm):
        self.comm = comm

    def get_dir(self):
        return self.dir

    def close(self):
        for fmt in self.output_formats:
            fmt.close()

    # Misc
    # ----------------------------------------
    def _do_log(self, args):
        for fmt in self.output_formats:
            if isinstance(fmt, SeqWriter):
                fmt.writeseq(map(str, args))


def get_rank_without_mpi_import():
    # check environment variables here instead of importing mpi4py
    # to avoid calling MPI_Init() when this module is imported
    for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
        if varname in os.environ:
            return int(os.environ[varname])
    return 0


def mpi_weighted_mean(comm, local_name2valcount):
    """
    Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
    Perform a weighted average over dicts that are each on a different node
    Input: local_name2valcount: dict mapping key -> (value, count)
    Returns: key -> mean
    """
    all_name2valcount = comm.gather(local_name2valcount)
    if comm.rank == 0:
        name2sum = defaultdict(float)
        name2count = defaultdict(float)
        for n2vc in all_name2valcount:
            for (name, (val, count)) in n2vc.items():
                try:
                    val = float(val)
                except ValueError:
                    if comm.rank == 0:
                        warnings.warn(
                            "WARNING: tried to compute mean on non-float {}={}".format(
                                name, val
                            )
                        )
                else:
                    name2sum[name] += val * count
                    name2count[name] += count
        return {name: name2sum[name] / name2count[name] for name in name2sum}
    else:
        return {}


def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
    """
    If comm is provided, average all numerical stats across that comm
    """
    if dir is None:
        dir = os.getenv("OPENAI_LOGDIR")
    if dir is None:
        dir = osp.join(
            tempfile.gettempdir(),
            datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
        )
    assert isinstance(dir, str)
    dir = os.path.expanduser(dir)
    os.makedirs(os.path.expanduser(dir), exist_ok=True)

    rank = get_rank_without_mpi_import()
    if rank > 0:
        log_suffix = log_suffix + "-rank%03i" % rank

    if format_strs is None:
        if rank == 0:
            # format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
            format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log").split(",")
        else:
            format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
    format_strs = filter(None, format_strs)
    output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]

    Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
    if output_formats:
        log("Logging to %s" % dir)


def _configure_default_logger():
    configure()
    Logger.DEFAULT = Logger.CURRENT


def reset():
    if Logger.CURRENT is not Logger.DEFAULT:
        Logger.CURRENT.close()
        Logger.CURRENT = Logger.DEFAULT
        log("Reset logger")


@contextmanager
def scoped_configure(dir=None, format_strs=None, comm=None):
    prevlogger = Logger.CURRENT
    configure(dir=dir, format_strs=format_strs, comm=comm)
    try:
        yield
    finally:
        Logger.CURRENT.close()
        Logger.CURRENT = prevlogger

In [36]:
import cv2
import os
from os.path import join as pjoin
import numpy as np
import glob
from PIL import Image
from natsort import natsorted

import torch
from torch.utils.data import Dataset


# %% ImageFolder Dataset

class ImagesFolder(Dataset):

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.images_list = natsorted(os.listdir(root_dir))
        self.transform = transform

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, idx):
        try:
            image = Image.open(os.path.join(self.root_dir, self.images_list[idx]))
        except:
            print("\n**************\nexpect\n**************\n")
            image = cv2.imread(os.path.join(self.root_dir, self.images_list[idx]), cv2.IMREAD_UNCHANGED)
            image = image // 255

        if self.transform is not None:
            image = self.transform(image)

        return image, self.images_list[idx]


# %% ImageFolder Dataset with gt (simulation)

class ImagesFolder_GT_results(Dataset):

    def __init__(self, gt_dir, results_dir, transform=None):
        self.gt_dir = gt_dir
        self.results_dir = results_dir

        self.gt_list = natsorted(glob.glob(pjoin(gt_dir, "*.*")))
        self.simulate_list = natsorted(glob.glob(pjoin(results_dir, "*ref.png")))
        self.rgb_list = natsorted(glob.glob(pjoin(results_dir, "*rgb.png")))
        self.depth_list = natsorted(glob.glob(pjoin(results_dir, "*depth.png")))
        self.transform = transform

    def __len__(self):
        return len(self.gt_list)

    def __getitem__(self, idx):
        image_name = os.path.splitext(os.path.basename(self.gt_list[idx]))[0]
        gt = Image.open(self.gt_list[idx])
        simulate = Image.open(self.simulate_list[idx])
        rgb = Image.open(self.rgb_list[idx])
        depth = Image.open(self.depth_list[idx])

        if self.transform is not None:
            gt = self.transform(gt)
            simulate = self.transform(simulate)
            rgb = self.transform(rgb)
            depth = self.transform(depth)

        return gt, simulate, rgb, depth, image_name


# %% ImageFolder Dataset with gt
class ImagesFolder_GT(Dataset):

    def __init__(self, root_dir, gt_rgb_dir, gt_depth_dir, transform=None):
        self.gt_rgb_dir = gt_rgb_dir
        self.gt_depth_dir = gt_depth_dir
        self.root_dir = root_dir

        self.gt_rgb_list = natsorted(glob.glob(pjoin(gt_rgb_dir, "*.*")))
        self.gt_depth_list = natsorted(glob.glob(pjoin(gt_depth_dir, "*.*")))
        self.images_list = natsorted(glob.glob(pjoin(root_dir, "*.*")))
        self.transform = transform

    def __len__(self):
        return len(self.gt_rgb_list)

    def __getitem__(self, idx):
        image_name = os.path.basename(self.images_list[idx])
        image = Image.open(self.images_list[idx])
        gt_rgb_image = Image.open(self.gt_rgb_list[idx])

        gt_depth_image_tmp = cv2.imread(self.gt_depth_list[idx], cv2.IMREAD_UNCHANGED)
        if gt_depth_image_tmp.dtype == 'uint16':
            gt_depth_image = Image.fromarray((gt_depth_image_tmp//256).astype(np.uint8))
        else:
            gt_depth_image = Image.fromarray(gt_depth_image_tmp)
            # gt_depth_image = Image.open(self.gt_depth_list[idx])

        if self.transform is not None:
            image = self.transform(image)
            gt_rgb_image = self.transform(gt_rgb_image)

            # it is a single channel image (only depth), so preprocess is required
            # gt_depth_image = Image.merge("RGB", (gt_depth_image,gt_depth_image,gt_depth_image))
            gt_depth_image = gt_depth_image.convert(mode="RGB")
            gt_depth_image = self.transform(gt_depth_image)

        return [image, gt_rgb_image, gt_depth_image], image_name

In [38]:
"""
Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
"""

import os
import sys
import shutil
import os.path as osp
import json
import time
import datetime
import tempfile
import warnings
from collections import defaultdict
from contextlib import contextmanager

DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40

DISABLED = 50


class KVWriter(object):
    def writekvs(self, kvs):
        raise NotImplementedError


class SeqWriter(object):
    def writeseq(self, seq):
        raise NotImplementedError


class HumanOutputFormat(KVWriter, SeqWriter):
    def __init__(self, filename_or_file):
        if isinstance(filename_or_file, str):
            self.file = open(filename_or_file, "wt")
            self.own_file = True
        else:
            assert hasattr(filename_or_file, "read"), (
                    "expected file or str, got %s" % filename_or_file
            )
            self.file = filename_or_file
            self.own_file = False

    def writekvs(self, kvs):
        # Create strings for printing
        key2str = {}
        for (key, val) in sorted(kvs.items()):
            if hasattr(val, "__float__"):
                valstr = "%-8.3g" % val
            else:
                valstr = str(val)
            key2str[self._truncate(key)] = self._truncate(valstr)

        # Find max widths
        if len(key2str) == 0:
            print("WARNING: tried to write empty key-value dict")
            return
        else:
            keywidth = max(map(len, key2str.keys()))
            valwidth = max(map(len, key2str.values()))

        # Write out the data
        dashes = "-" * (keywidth + valwidth + 7)
        lines = [dashes]
        for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
            lines.append(
                "| %s%s | %s%s |"
                % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
            )
        lines.append(dashes)
        self.file.write("\n".join(lines) + "\n")

        # Flush the output to the file
        self.file.flush()

    def _truncate(self, s):
        maxlen = 30
        return s[: maxlen - 3] + "..." if len(s) > maxlen else s

    def writeseq(self, seq):
        seq = list(seq)
        for (i, elem) in enumerate(seq):
            self.file.write(elem)
            if i < len(seq) - 1:  # add space unless this is the last one
                self.file.write(" ")
        self.file.write("\n")
        self.file.flush()

    def close(self):
        if self.own_file:
            self.file.close()


class JSONOutputFormat(KVWriter):
    def __init__(self, filename):
        self.file = open(filename, "wt")

    def writekvs(self, kvs):
        for k, v in sorted(kvs.items()):
            if hasattr(v, "dtype"):
                kvs[k] = float(v)
        self.file.write(json.dumps(kvs) + "\n")
        self.file.flush()

    def close(self):
        self.file.close()


class CSVOutputFormat(KVWriter):
    def __init__(self, filename):
        self.file = open(filename, "w+t")
        self.keys = []
        self.sep = ","

    def writekvs(self, kvs):
        # Add our current row to the history
        extra_keys = list(kvs.keys() - self.keys)
        extra_keys.sort()
        if extra_keys:
            self.keys.extend(extra_keys)
            self.file.seek(0)
            lines = self.file.readlines()
            self.file.seek(0)
            for (i, k) in enumerate(self.keys):
                if i > 0:
                    self.file.write(",")
                self.file.write(k)
            self.file.write("\n")
            for line in lines[1:]:
                self.file.write(line[:-1])
                self.file.write(self.sep * len(extra_keys))
                self.file.write("\n")
        for (i, k) in enumerate(self.keys):
            if i > 0:
                self.file.write(",")
            v = kvs.get(k)
            if v is not None:
                self.file.write(str(v))
        self.file.write("\n")
        self.file.flush()

    def close(self):
        self.file.close()


class TensorBoardOutputFormat(KVWriter):
    """
    Dumps key/value pairs into TensorBoard's numeric format.
    """

    def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = "events"
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat

        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))

    def writekvs(self, kvs):
        def summary_val(k, v):
            kwargs = {"tag": k, "simple_value": float(v)}
            return self.tf.Summary.Value(**kwargs)

        summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
        event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
        event.step = (
            self.step
        )  # is there any reason why you'd want to specify the step?
        self.writer.WriteEvent(event)
        self.writer.Flush()
        self.step += 1

    def close(self):
        if self.writer:
            self.writer.Close()
            self.writer = None


def make_output_format(format, ev_dir, log_suffix=""):
    os.makedirs(ev_dir, exist_ok=True)
    if format == "stdout":
        return HumanOutputFormat(sys.stdout)
    elif format == "log":
        return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
    elif format == "json":
        return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
    elif format == "csv":
        return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
    elif format == "tensorboard":
        return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
    else:
        raise ValueError("Unknown format specified: %s" % (format,))


# ================================================================
# API
# ================================================================


def logkv(key, val):
    """
    Log a value of some diagnostic
    Call this once for each diagnostic quantity, each iteration
    If called many times, last value will be used.
    """
    get_current().logkv(key, val)


def logkv_mean(key, val):
    """
    The same as logkv(), but if called many times, values averaged.
    """
    get_current().logkv_mean(key, val)


def logkvs(d):
    """
    Log a dictionary of key-value pairs
    """
    for (k, v) in d.items():
        logkv(k, v)


def dumpkvs():
    """
    Write all of the diagnostics from the current iteration
    """
    return get_current().dumpkvs()


def getkvs():
    return get_current().name2val


def log(*args, level=INFO):
    """
    Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
    """
    get_current().log(*args, level=level)


def debug(*args):
    log(*args, level=DEBUG)


def info(*args):
    log(*args, level=INFO)


def warn(*args):
    log(*args, level=WARN)


def error(*args):
    log(*args, level=ERROR)


def set_level(level):
    """
    Set logging threshold on current logger.
    """
    get_current().set_level(level)


def set_comm(comm):
    get_current().set_comm(comm)


def get_dir():
    """
    Get directory that log files are being written to.
    will be None if there is no output directory (i.e., if you didn't call start)
    """
    return get_current().get_dir()


record_tabular = logkv
dump_tabular = dumpkvs


@contextmanager
def profile_kv(scopename):
    logkey = "wait_" + scopename
    tstart = time.time()
    try:
        yield
    finally:
        get_current().name2val[logkey] += time.time() - tstart


def profile(n):
    """
    Usage:
    @profile("my_func")
    def my_func(): code
    """

    def decorator_with_name(func):
        def func_wrapper(*args, **kwargs):
            with profile_kv(n):
                return func(*args, **kwargs)

        return func_wrapper

    return decorator_with_name


# ================================================================
# Backend
# ================================================================


def get_current():
    if Logger.CURRENT is None:
        _configure_default_logger()

    return Logger.CURRENT


class Logger(object):
    DEFAULT = None  # A logger with no output files. (See right below class definition)
    # So that you can still log to the terminal without setting up any output files
    CURRENT = None  # Current logger being used by the free functions above

    def __init__(self, dir, output_formats, comm=None):
        self.name2val = defaultdict(float)  # values this iteration
        self.name2cnt = defaultdict(int)
        self.level = INFO
        self.dir = dir
        self.output_formats = output_formats
        self.comm = comm

    # Logging API, forwarded
    # ----------------------------------------
    def logkv(self, key, val):
        self.name2val[key] = val

    def logkv_mean(self, key, val):
        oldval, cnt = self.name2val[key], self.name2cnt[key]
        self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
        self.name2cnt[key] = cnt + 1

    def dumpkvs(self):
        if self.comm is None:
            d = self.name2val
        else:
            d = mpi_weighted_mean(
                self.comm,
                {
                    name: (val, self.name2cnt.get(name, 1))
                    for (name, val) in self.name2val.items()
                },
            )
            if self.comm.rank != 0:
                d["dummy"] = 1  # so we don't get a warning about empty dict
        out = d.copy()  # Return the dict for unit testing purposes
        for fmt in self.output_formats:
            if isinstance(fmt, KVWriter):
                fmt.writekvs(d)
        self.name2val.clear()
        self.name2cnt.clear()
        return out

    def log(self, *args, level=INFO):
        if self.level <= level:
            self._do_log(args)

    # Configuration
    # ----------------------------------------
    def set_level(self, level):
        self.level = level

    def set_comm(self, comm):
        self.comm = comm

    def get_dir(self):
        return self.dir

    def close(self):
        for fmt in self.output_formats:
            fmt.close()

    # Misc
    # ----------------------------------------
    def _do_log(self, args):
        for fmt in self.output_formats:
            if isinstance(fmt, SeqWriter):
                fmt.writeseq(map(str, args))


def get_rank_without_mpi_import():
    # check environment variables here instead of importing mpi4py
    # to avoid calling MPI_Init() when this module is imported
    for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
        if varname in os.environ:
            return int(os.environ[varname])
    return 0


def mpi_weighted_mean(comm, local_name2valcount):
    """
    Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
    Perform a weighted average over dicts that are each on a different node
    Input: local_name2valcount: dict mapping key -> (value, count)
    Returns: key -> mean
    """
    all_name2valcount = comm.gather(local_name2valcount)
    if comm.rank == 0:
        name2sum = defaultdict(float)
        name2count = defaultdict(float)
        for n2vc in all_name2valcount:
            for (name, (val, count)) in n2vc.items():
                try:
                    val = float(val)
                except ValueError:
                    if comm.rank == 0:
                        warnings.warn(
                            "WARNING: tried to compute mean on non-float {}={}".format(
                                name, val
                            )
                        )
                else:
                    name2sum[name] += val * count
                    name2count[name] += count
        return {name: name2sum[name] / name2count[name] for name in name2sum}
    else:
        return {}


def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
    """
    If comm is provided, average all numerical stats across that comm
    """
    if dir is None:
        dir = os.getenv("OPENAI_LOGDIR")
    if dir is None:
        dir = osp.join(
            tempfile.gettempdir(),
            datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
        )
    assert isinstance(dir, str)
    dir = os.path.expanduser(dir)
    os.makedirs(os.path.expanduser(dir), exist_ok=True)

    rank = get_rank_without_mpi_import()
    if rank > 0:
        log_suffix = log_suffix + "-rank%03i" % rank

    if format_strs is None:
        if rank == 0:
            # format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
            format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log").split(",")
        else:
            format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
    format_strs = filter(None, format_strs)
    output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]

    Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
    if output_formats:
        log("Logging to %s" % dir)


def _configure_default_logger():
    configure()
    Logger.DEFAULT = Logger.CURRENT


def reset():
    if Logger.CURRENT is not Logger.DEFAULT:
        Logger.CURRENT.close()
        Logger.CURRENT = Logger.DEFAULT
        log("Reset logger")


@contextmanager
def scoped_configure(dir=None, format_strs=None, comm=None):
    prevlogger = Logger.CURRENT
    configure(dir=dir, format_strs=format_strs, comm=comm)
    try:
        yield
    finally:
        Logger.CURRENT.close()
        Logger.CURRENT = prevlogger

In [43]:
import torch.nn as nn
import torch as th
import math

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

In [46]:
import sys

import numpy as np
from functools import partial
import os
from os.path import join as pjoin
from argparse import ArgumentParser
from PIL import Image
import datetime

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as tvtf
from torchvision.utils import make_grid

# from noise import get_noise, get_operator
# from condition import get_conditioning_method   
# from unet import create_model
# from gaussian_diffusion import create_sampler
import logger
# import utils as utilso
# import data as datao


def main():
    args = arguments_from_file(CONFIG_FILE)
    args.image_size = args.unet_model['image_size']
    args.unet_model['model_path'] = os.path.abspath("/kaggle/input/image-enhancer-rgbd/osmosis_outdoor.pt")
    # print(f"\nArguments from inside main:\n{args}\n")
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
    # print(args.unet_model)
    # Prepare dataloader
    data_config = args.data
    # resize small side to be 256px, center cropping 256x256, normalizing to [-1,1]
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize(size=256),
                                    transforms.CenterCrop(size=[256, 256]),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # For the case of any data with ground truth (simulation in our case)
    if data_config['ground_truth']:
        gt_flag = True
        dataset =ImagesFolder_GT(root_dir=data_config['root'], gt_rgb_dir=data_config['gt_rgb'],
                                        gt_depth_dir=data_config['gt_depth'], transform=transform)
        loader = DataLoader(dataset, batch_size=data_config['batch_size'], shuffle=False)

    # for non ground truth dataset (underwater and haze for our case)
    else:
        gt_flag = False
        data_config['root']="/kaggle/input/uw-image/data"
        dataset =ImagesFolder(data_config['root'], transform)
        loader = DataLoader(dataset, batch_size=data_config['batch_size'], shuffle=False)

    print(f"\nDataset size: {len(dataset)}\n")

    #View content of Dataset like view image from dataloader
    # for i in range(1):
    #     sample = dataset[i]
    #     print(sample)
    #     image = sample[0]
    #     print(image.shape)
    #     image = tvtf.to_pil_image(image)
    #     image.show()

    model = create_model(**args.unet_model)
    model = model.to(device)
    model.eval()
    measure_config = args.measurement
    cond_config = args.conditioning
    diffusion_config = args.diffusion
    sample_pattern_config = args.sample_pattern
    aux_loss_config = args.aux_loss


    measurement_name = measure_config['operator']['name']
    out_path = os.path.abspath(pjoin(args.save_dir, measurement_name, args.data['name']))
    out_path = update_save_dir_date(out_path)

    # create txt file with the configurations
    yaml_to_txt(CONFIG_FILE, pjoin(out_path, f"configurations.txt"))


    # directory for saving single results
    if args.save_singles:
        save_singles_path = pjoin(out_path, f"single_images")
        os.makedirs(save_singles_path)

        save_input_path = pjoin(save_singles_path, "input")
        os.makedirs(save_input_path)
        save_rgb_path = pjoin(save_singles_path, "rgb")
        os.makedirs(save_rgb_path)
        save_depth_pmm_color_path = pjoin(save_singles_path, "depth_color")
        os.makedirs(save_depth_pmm_color_path)
        save_depth_mm_path = pjoin(save_singles_path, "depth_raw")
        os.makedirs(save_depth_mm_path)
    else:
        save_singles_path = None

    # directory for the results a grid
    if args.save_grids:
        save_grids_path = pjoin(out_path, f"grid_results")
        os.makedirs(save_grids_path)
    else:
        save_grids_path = None
    
    #Logging
    logger.configure(dir=out_path)
    logger.log(f"pretrained model file: {args.unet_model['model_path']}")
    
    if (not args.rgb_guidance):
        log_txt_tmp = log_text(args=args)
        logger.log(log_txt_tmp)

    
    for i, (ref_img, ref_img_name) in enumerate(loader):
        # in case there is a GT image (if ground truth is used)
        if gt_flag:
            gt_rgb_img = ref_img[1].squeeze()
            gt_rgb_img_01 = 0.5 * (gt_rgb_img + 1)

            gt_depth_img = ref_img[2].squeeze()
            gt_depth_img_01 = 0.5 * (gt_depth_img + 1)
            gt_depth_img_01 = depth_tensor_to_color_image(gt_depth_img_01)

            ref_img = ref_img[0]

        start_run_time_ii = datetime.datetime.now()

        # prepare reference image for visualization
        ref_img_01 = 0.5 * (ref_img.detach().cpu()[0] + 1)
        ref_img_name = ref_img_name[0]
        orig_file_name = os.path.splitext(ref_img_name)[0]

        # stop the run before getting to the last image
        if i == args.data['stop_after']:
            break
        
        # prepare operator for noise 
        measure_config['operator']['batch_size'] = args.data['batch_size']
        operator = get_operator(device=device, **measure_config['operator'])
        noiser = get_noise(**measure_config['noise'])


        # Prepare conditioning - guidance method
        cond_method = get_conditioning_method(cond_config['method'], operator, noiser, **cond_config['params'],
                                              **sample_pattern_config, **aux_loss_config)
        measurement_cond_fn = cond_method.conditioning

        # Load diffusion sampler and pass the required arguments
        sampler = create_sampler(**diffusion_config)
        # passing the "stable" arguments with the partial method
        sample_fn = partial(sampler.p_sample_loop, model=model, measurement_cond_fn=measurement_cond_fn,
                            pretrain_model=args.unet_model['pretrain_model'], rgb_guidance=args.rgb_guidance,
                            sample_pattern=args.sample_pattern,
                            record=args.record_process,
                            save_root=out_path, image_idx=i,
                            record_every=args.record_every,
                            original_file_name=orig_file_name,
                            save_grids_path=save_grids_path)
        
        logger.log(f"\nInference image {i}: {ref_img_name}\n")
        ref_img = ref_img.to(device)

        # add noise to the image
        y_n = noiser(ref_img)

        # degamma the input image - use it for haze
        if args.degamma_input:
            y_n_tmp = 0.5 * (y_n + 1)
            y_n = 2 * torch.pow(y_n_tmp, 2.2) - 1

        # Sampling
        x_start_shape = list(ref_img.shape)
        # in case of sampling for osmosis the input model channel is 4 (RGBD)
        x_start_shape[1] = 4 if (args.unet_model["pretrain_model"] == 'osmosis') else x_start_shape[1]

        # sampling noise for the begging of the diffusion model
        if args.sample_pattern['pattern'] == "original":
            global_N = 1
        elif args.sample_pattern['pattern'] == "pcgs":
            global_N = args.sample_pattern['global_N']
        else:
            raise ValueError(f"Unrecognized sample pattern: {args.sample_pattern['pattern']}")

        # loop according the value of global N (from gibbsDDRM)
        for global_ii in range(global_N):

            logger.log(f"global iteration: {global_ii}\n")
            torch.manual_seed(args.manual_seed)

            # the x_T - Gaussian Noise
            x_start = torch.randn(x_start_shape, device=device).requires_grad_()

            # this is the osmosis project additional code
            if args.unet_model["pretrain_model"] == 'osmosis' and not args.rgb_guidance:

                # sampling function which adapted to osmosis project

                sample, variable_dict, loss, out_xstart = sample_fn(x_start=x_start, measurement=y_n,
                                                                    global_iteration=global_ii)

                # output from the network without guidance - split into rgb and depth image
                sample_rgb = out_xstart[0, 0:-1, :, :]
                sample_depth_tmp = out_xstart[0, -1, :, :].unsqueeze(0)
                sample_depth_tmp_rep = sample_depth_tmp.repeat(3, 1, 1)

                # "move" the rgb predicted image to start from 0
                sample_rgb_01 = 0.5 * (sample_rgb + 1)
                sample_rgb_01_clip = torch.clamp(sample_rgb_01, min=0, max=1)

                # "move" the depth predicted image to start from 0
                sample_depth_mm = min_max_norm_range(sample_depth_tmp[0].unsqueeze(0))
                sample_depth_vis_pmm = min_max_norm_range_percentile(sample_depth_tmp,
                                                                            vmin=0, vmax=1,
                                                                            percent_low=0.03,
                                                                            percent_high=0.99,
                                                                            is_uint8=False)
                sample_depth_vis_pmm_color = depth_tensor_to_color_image(sample_depth_vis_pmm)

                # depth for calculations
                sample_depth_calc = convert_depth(sample_depth_tmp_rep,
                                                         depth_type=args.measurement['operator']['depth_type'],
                                                         value=args.measurement['operator']['value'])

                # phi inf image - relevant for both underwater and haze
                phi_inf = variable_dict['phi_inf'].cpu().squeeze(0)
                phi_inf_image = phi_inf * torch.ones_like(sample_rgb, device=torch.device('cpu'))

                # underwater model
                if 'underwater_physical_revised' in args.measurement['operator']['name']:

                    # create the ingredients for the underwater image
                    phi_a = variable_dict['phi_a'].cpu().squeeze(0)
                    phi_a_image = phi_a * torch.ones_like(sample_rgb, device=torch.device('cpu'))
                    phi_b = variable_dict['phi_b'].cpu().squeeze(0)
                    phi_b_image = phi_b * torch.ones_like(sample_rgb, device=torch.device('cpu'))

                    # calculate the underwater parts
                    backscatter_image = phi_inf_image * (1 - torch.exp(-phi_b_image * sample_depth_calc))
                    attenuation_image = torch.exp(-phi_a_image * sample_depth_calc)
                    forward_predicted_image = sample_rgb_01 * attenuation_image + backscatter_image

                    # calculate norm lost for visualization - degraded_images and ref_img values should be [-1,1]
                    degraded_image = 2 * forward_predicted_image - 1
                    norm_loss_final = np.round([torch.linalg.norm(
                        degraded_image - ref_img.detach().cpu()).numpy()], decimals=3)

                    # calculate the "clean" image from the predicted phi's and ref image
                    attenuation_flip_image = torch.exp(phi_a_image * sample_depth_calc)
                    sample_rgb_recon = attenuation_flip_image * (ref_img_01 - backscatter_image)

                    # logging values of phi's
                    print_phi_a = [np.round(i, decimals=3) for i in phi_a.cpu().squeeze().tolist()]
                    print_phi_b = [np.round(i, decimals=3) for i in phi_b.cpu().squeeze().tolist()]
                    print_phi_inf = [np.round(i, decimals=3) for i in phi_inf.cpu().squeeze().tolist()]

                    log_value_txt = f"\nInitialized values: " \
                                    f"\nphi_a: [{measure_config['operator']['phi_a']}], lr: {measure_config['operator']['phi_a_eta']}" \
                                    f"\nphi_b: [{measure_config['operator']['phi_b']}], lr: {measure_config['operator']['phi_b_eta']}" \
                                    f"\nphi_inf: [{measure_config['operator']['phi_inf']}], lr: {measure_config['operator']['phi_inf_eta']}" \
                                    f"\n\nResults values: " \
                                    f"\nphi_a: {print_phi_a}" \
                                    f"\nphi_b: {print_phi_b}" \
                                    f"\nphi_inf: {print_phi_inf}" \
                                    f"\n\nNorm loss: {norm_loss_final}" \
                                    f"\nFinal loss: {np.round(np.array(loss), decimals=3)}"

                    # log results for parameters
                    logger.log(log_value_txt)

                # haze model
                elif ('haze' in args.measurement['operator']['name']) or (
                        'underwater_physical' in args.measurement['operator']['name']):

                    # create the ingredients for the hazed image
                    phi_ab = variable_dict['phi_ab'].cpu().squeeze(0)
                    phi_ab_image = phi_ab * torch.ones_like(sample_rgb, device=torch.device('cpu'))
                    backscatter_image = phi_inf_image * (1 - torch.exp(-phi_ab_image * sample_depth_calc))
                    attenuation_image = torch.exp(-phi_ab_image * sample_depth_calc)
                    forward_predicted_image = sample_rgb_01 * attenuation_image + backscatter_image

                    # calculate the "clean" image from the predicted phis, phi_inf and ref image
                    attenuation_flip_image = torch.exp(phi_ab_image * sample_depth_calc)
                    sample_rgb_recon = attenuation_flip_image * (ref_img_01 - backscatter_image)

                    # calculate norm lost for visualization - both degraded_images and ref_img values should be [-1,1]
                    degraded_image = 2 * forward_predicted_image - 1
                    norm_loss_final = np.round(
                        [torch.linalg.norm(degraded_image.cpu() - ref_img.detach().cpu()).numpy()],
                        decimals=3)

                    # logging values of phi and phi_inf
                    print_phi_ab = np.round(phi_ab.cpu().squeeze(), decimals=3)
                    print_phi_inf = np.round(phi_inf.cpu().squeeze(), decimals=3)
                    log_value_txt = f"\nInitialized values: " \
                                    f"\nphi_ab: [{measure_config['operator']['phi_ab']}], lr: {measure_config['operator']['phi_ab_eta']}" \
                                    f"\nphi_inf: [{measure_config['operator']['phi_inf']}], lr: {measure_config['operator']['phi_inf_eta']}" \
                                    f"\n\nResults values: " \
                                    f"\nphi_ab: {print_phi_ab}" \
                                    f"\nphi_inf: {print_phi_inf}" \
                                    f"\n\nNorm loss: {norm_loss_final}" \
                                    f"\nFinal loss: {np.round(np.array(loss), decimals=5)}"

                    # log results for parameters
                    logger.log(log_value_txt)

                else:
                    raise NotImplementedError("Operator can be for 'underwater' or 'haze' ")

                # saving single images (reference (input), rgb (restored image), depth (depth estimation))
                if args.save_singles:
                    # input - reference image
                    ref_im_pil = tvtf.to_pil_image(ref_img_01)
                    # ref_im_pil.save(pjoin(save_singles_path, f'{orig_file_name}_g{global_ii}_ref.png'))
                    ref_im_pil.save(pjoin(save_input_path, f'{orig_file_name}.png'))

                    # rgb clip - sample_rgb_01_clip
                    sample_rgb_01_clip_pil = tvtf.to_pil_image(sample_rgb_01_clip)
                    # sample_rgb_01_clip_pil.save(pjoin(save_singles_path, f'{orig_file_name}_g{global_ii}_rgb.png'))
                    sample_rgb_01_clip_pil.save(pjoin(save_rgb_path, f'{orig_file_name}.png'))

                    # depth percentile min-max - sample_depth_vis_percentile_norm
                    sample_depth_vis_pmm_color_pil = tvtf.to_pil_image(sample_depth_vis_pmm_color)
                    # sample_depth_vis_pmm_color_pil.save(pjoin(save_singles_path, f'{orig_file_name}_g{global_ii}_depth.png'))
                    sample_depth_vis_pmm_color_pil.save(pjoin(save_depth_pmm_color_path, f'{orig_file_name}.png'))

                    # depth percentile min-max - sample_depth_vis_percentile_norm
                    sample_depth_vis_mm_pil = tvtf.to_pil_image(sample_depth_mm)
                    # sample_depth_vis_mm_pil.save(pjoin(save_singles_path, f'{orig_file_name}_g{global_ii}_depth_raw.png'))
                    sample_depth_vis_mm_pil.save(pjoin(save_depth_mm_path, f'{orig_file_name}.png'))

                # save extended results in the grid
                if args.save_grids:

                    grid_list = [ref_img_01, sample_rgb_01_clip, sample_depth_vis_pmm_color]

                    # there is ground truth in the case of simulation
                    if gt_flag:
                        grid_list += [torch.zeros_like(sample_rgb_01, device=torch.device('cpu')),
                                      gt_rgb_img_01, gt_depth_img_01]

                    results_grid = make_grid(grid_list, nrow=3, pad_value=1.)
                    results_grid = clip_image(results_grid, scale=False, move=False, is_uint8=True) \
                        .permute(1, 2, 0).numpy()
                    results_pil = Image.fromarray(results_grid, mode="RGB")

                    # save the image
                    results_pil.save(pjoin(save_grids_path, f'{orig_file_name}_g{global_ii}_grid.png'))

                if args.save_singles or args.save_grids:
                    logger.log(f"result images was saved into: {out_path}")

                logger.log(f"Run time: {datetime.datetime.now() - start_run_time_ii}")

            # no osmosis - rgb guidance
            else:

                sample = sample_fn(x_start=x_start, measurement=y_n)

                # split into rgb and depth image - not handling results save for a batch of images
                sample_rgb = sample.cpu()[0, 0:-1, :, :]
                sample_depth_tmp = sample.cpu()[0, -1, :, :].repeat(3, 1, 1)

                # "move" the rgb predicted image to start from 0 (the values "sample_rgb" should be between [-1, 1])
                sample_rgb_01 = 0.5 * (sample_rgb + 1)
                sample_rgb_01_clip = torch.clamp(sample_rgb_01, 0., 1.)

                # used for visualization
                sample_depth_mm = min_max_norm_range(sample_depth_tmp, vmin=0, vmax=1, is_uint8=False)
                sample_depth_vis_pmm = min_max_norm_range_percentile(sample_depth_tmp,
                                                                            percent_low=0.05, percent_high=0.99)
                sample_depth_vis_pmm_color = depth_tensor_to_color_image(sample_depth_vis_pmm)

                # saving seperated images
                if args.save_singles:
                    ref_im_pil = tvtf.to_pil_image(ref_img_01)
                    ref_im_pil.save(pjoin(save_input_path, f'{orig_file_name}.png'))

                    sample_rgb_pil = tvtf.to_pil_image(sample_rgb_01_clip)
                    sample_rgb_pil.save(pjoin(save_rgb_path, f'{orig_file_name}.png'))

                    sample_depth_vis_pil = tvtf.to_pil_image(sample_depth_vis_pmm_color)
                    sample_depth_vis_pil.save(pjoin(save_depth_pmm_color_path, f'{orig_file_name}.png'))

                    sample_depth_mm_pil = tvtf.to_pil_image(sample_depth_mm)
                    sample_depth_mm_pil.save(pjoin(save_depth_mm_path, f'{orig_file_name}.png'))

                # create images grid
                if args.save_grids:
                    grid_list = [ref_img_01, sample_rgb_01_clip, sample_depth_vis_pmm_color]
                    results_grid = make_grid(grid_list, nrow=3, pad_value=1.)
                    results_grid = clip_image(results_grid, scale=False, move=False, is_uint8=True)
                    results_pil = tvtf.to_pil_image(results_grid)

                    # save the image
                    results_pil.save(pjoin(save_grids_path, f'{orig_file_name}.png'))

                if args.save_singles or args.save_grids:
                    logger.log(f"result images was saved into: {out_path}")

                logger.log(f"Run time: {datetime.datetime.now() - start_run_time_ii}")

    # close the logger txt file
    logger.get_current().close()
        

if __name__ == "__main__":
    # parser = ArgumentParser()
    # parser.add_argument("-c", "--config_file", default="/kaggle/input/yaml-file/osmosis_sample.yaml", help="Configurations file")
    # parser.add_argument("-d", "--device", default=0, help="GPU Device", type=int)
    # print(parser.parse_args())
    args = {"config_file":"/kaggle/input/yaml-file/osmosis_sample.yaml","device":0}
    '''
    vars is a function that converts an argument parser object into a dictionary of key-value pairs.
    '''
    # print(f"\nArguments from outside main:\n{args}\n")

    CONFIG_FILE = os.path.abspath(args["config_file"])
    '''
    abspath is a function that returns the absolute path of a file or directory.
    '''
    DEVICE = args["device"]

    # print(f"\nConfiguration file:\n{CONFIG_FILE}\n")

    main()
    print(f"\nFINISH!")
    sys.exit(0)


Dataset size: 1



  model.load_state_dict(th.load(model_path, map_location='cpu'))


Logging to /kaggle/working/results/underwater_physical_revised/osmosis/13-11-24/run5
pretrained model file: /kaggle/input/image-enhancer-rgbd/osmosis_outdoor.pt


Guidance Scale: 7,7,7,0.9
Loss Function: norm
weight: depth, weight_function: gamma,1.4,1.4,1
Auxiliary Loss: {'avrg_loss': 0.5, 'val_loss': 20}
Underwater model: underwater_physical_revised
Optimize w.r.t: x_prev
Optimizer model: sgd, 
Manual seed: 0
Depth type: gamma, value: 1.4,1.4,1
Noise: clean
Gradient Clipping: True, min value: -0.005, max value: 0.005
Sample Pattern: pcgs, 
     Guidance start: 1 ,end: 0
     Optimizations iters: 20, 
     Update start from: 0.7, end: 0
     M: 1, start: 1, end: 0

Inference image 0: UW_Image.png

global iteration: 0



  0%|          | 0/1000 [00:00<?, ?it/s]


Initialized values: 
phi_a: [1.1,0.95,0.95], lr: 1e-5
phi_b: [0.95, 0.8, 0.8], lr: 1e-5
phi_inf: [0.14, 0.29, 0.49], lr: 1e-5

Results values: 
phi_a: [1.109, 0.968, 0.738]
phi_b: [0.946, 0.774, 0.693]
phi_inf: [0.066, 0.202, 0.366]

Norm loss: [7.019]
Final loss: [3.113]
result images was saved into: /kaggle/working/results/underwater_physical_revised/osmosis/13-11-24/run5
Run time: 0:18:30.570036

FINISH!


SystemExit: 0

In [40]:
import sys
sys.path.append('/kaggle/input/logger')

In [41]:
import logger 

In [45]:
"""
Various utilities for neural networks.
"""

import math

import torch as th
import torch.nn as nn


# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
    def forward(self, x):
        return x * th.sigmoid(x)


class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)


def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def scale_module(module, scale):
    """
    Scale the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().mul_(scale)
    return module


def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


def normalization(channels):
    """
    Make a standard normalization layer.

    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(32, channels)


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.

    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)


class CheckpointFunction(th.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])
        with th.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        with th.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = th.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads