In [None]:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras import Sequential
import tensorflow.keras.layers as nn

from tensorflow import einsum
from einops import rearrange
from einops.layers.tensorflow import Rearrange

import math
from inspect import isfunction

from functools import partial
from tqdm import tqdm

import numpy as np
import os
import cv2
import pathlib
from glob import glob

import time
from tensorflow.python.data.experimental import AUTOTUNE
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import pandas as pd
import skimage
from skimage.metrics import structural_similarity
import time
from skimage import filters, img_as_ubyte,morphology,measure
from scipy.ndimage import label
import copy
import random
from skimage import data,filters

In [None]:
########################ops#########################
# helpers functions
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def cycle(dl):
    while True:
        for data in dl:
            yield data

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

# small helper modules
class Identity(Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def call(self, x, training=True):
        return tf.identity(x)

class EMA(Layer):
    def __init__(self, beta=0.995):
        super(EMA, self).__init__()
        self.beta = beta

    @tf.function
    def update_model_average(self, old_model, new_model):
        for old_weight, new_weight in zip(old_model.weights, new_model.weights):
            assert old_weight.shape == new_weight.shape

            old_weight.assign(self.update_average(old_weight, new_weight))

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class Residual(Layer):
    def __init__(self, fn):
        super(Residual, self).__init__()
        self.fn = fn

    def call(self, x, training=True):
        return self.fn(x, training=training) + x

class SinusoidalPosEmb(Layer):
    def __init__(self, dim, max_positions=10000):
        super(SinusoidalPosEmb, self).__init__()
        self.dim = dim
        self.max_positions = max_positions

    def call(self, x, training=True):
        x = tf.cast(x, tf.float32)
        half_dim = self.dim // 2
        emb = math.log(self.max_positions) / (half_dim - 1)
        emb = tf.exp(tf.range(half_dim, dtype=tf.float32) * -emb)
        emb = x[:, None] * emb[None, :]

        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)

        return emb

def Upsample(dim):
    return nn.Conv2DTranspose(filters=dim, kernel_size=4, strides=2, padding='SAME')

def Downsample(dim):
    return nn.Conv2D(filters=dim, kernel_size=4, strides=2, padding='SAME')

class LayerNorm(Layer):
    def __init__(self, dim, eps=1e-5):
        super(LayerNorm, self).__init__()
        self.eps = eps

        self.g = tf.Variable(tf.ones([1, 1, 1, dim]))
        self.b = tf.Variable(tf.zeros([1, 1, 1, dim]))

    def call(self, x, training=True):
        var = tf.math.reduce_variance(x, axis=-1, keepdims=True)
        mean = tf.reduce_mean(x, axis=-1, keepdims=True)

        x = (x - mean) / tf.sqrt((var + self.eps)) * self.g + self.b
        return x

class PreNorm(Layer):
    def __init__(self, dim, fn):
        super(PreNorm, self).__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def call(self, x, training=True):
        x = self.norm(x)
        return self.fn(x)

class SiLU(Layer):
    def __init__(self):
        super(SiLU, self).__init__()

    def call(self, x, training=True):
        return x * tf.nn.sigmoid(x)

def gelu(x, approximate=False):
    if approximate:
        coeff = tf.cast(0.044715, x.dtype)
        return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3))))
    else:
        return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype)))

class GELU(Layer):
    def __init__(self, approximate=False):
        super(GELU, self).__init__()
        self.approximate = approximate

    def call(self, x, training=True):
        return gelu(x, self.approximate)

def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02

    return tf.cast(tf.linspace(beta_start, beta_end, timesteps), tf.float32)

def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = tf.cast(tf.linspace(0, timesteps, steps), tf.float32)

    alphas_cumprod = tf.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])

    return tf.clip_by_value(betas, 0, 0.999)

def extract(x, t):
    return tf.gather(x, t)[:, None, None, None]

def Filter(image,model="BLUR"): #均值模糊
    if model == "conv2D":
        kernel = np.ones((3, 3)) / 9
        dst = cv2.filter2D(image, -1, kernel)
    if model == "BLUR":
        dst = cv2.blur(image, (10, 10))
    if model == "Guass":
        dst = cv2.GaussianBlur(image, (0, 0), 2)
    return dst

In [None]:
"""VisionMambaBlock module."""
from einops import rearrange, repeat
from math import ceil
class RMSNorm(tf.keras.layers.Layer):
    def __init__(self, eps = 1e-5):
        super(RMSNorm, self).__init__()
        self.eps = eps
    def build(self, input_shape):
        self.weight = self.add_weight(shape = (input_shape[-1],), dtype = tf.float32, trainable = True, initializer = tf.keras.initializers.Constant(1.), name = 'weight')
    def compute_output_shape(self, input_shape):
        return input_shape
    def call(self, inputs):
        stddev = tf.math.maximum(tf.math.sqrt(tf.math.reduce_mean(inputs ** 2, axis = -1, keepdims = True)), self.eps)
        results = inputs / stddev
        results = results * self.weight
        return results
    def get_config(self):
        config = super(RMSNorm, self).get_config()
        config['eps'] = self.eps
        return config
    @classmethod
    def from_config(cls, config):
        return cls(**config)
    
def selective_scan(u, delta, A, B, C, D):
    dA = tf.einsum('bld,dn->bldn', delta, A) # first step of A_bar = exp(ΔA), i.e., ΔA
    dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)
    dA_cumsum = tf.pad(
        dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip along axis 1
    # Cumulative sum along all the input tokens, parallel prefix sum, calculates dA for all the input tokens parallely
    dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)  
    dA_cumsum = tf.exp(dA_cumsum)  # second step of A_bar = exp(ΔA), i.e., exp(ΔA)
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip back along axis 1
    x = dB_u * dA_cumsum
    x = tf.math.cumsum(x, axis=1)/(dA_cumsum + 1e-12) # 1e-12 to avoid division by 0
    y = tf.einsum('bldn,bln->bld', x, C)
    return y + u * D 

class SSM(tf.keras.layers.Layer):
    def __init__(self, d_model, expand = 2, d_state = 16, bias = False):
        super(SSM, self).__init__()
        self.d_model = d_model
        self.expand = expand
        self.d_state = d_state
        self.bias = bias
        self.dt_rank = ceil(self.d_model / 16)
    def build(self, input_shape):
        self.x_proj_weight = self.add_weight(shape = (self.d_model * self.expand, self.dt_rank + 2 * self.d_state), dtype = tf.float32, trainable = True, name = 'x_proj_weight')
        if self.bias:
            self.x_proj_bias = self.add_weight(shape = (self.dt_rank + 2 * self.d_state), dtype = tf.float32, trainable = True, name = 'x_proj_bias')
        self.dt_proj_weight = self.add_weight(shape = (self.dt_rank, self.expand * self.d_model), dtype = tf.float32, trainable = True, name = 'dt_proj_wei9ght')
        self.dt_proj_bias = self.add_weight(shape = (self.expand * self.d_model,), dtype = tf.float32, trainable = True, name = 'dt_proj_bias')
        self.A_log = self.add_weight(shape = (self.expand * self.d_model, self.d_state), dtype = tf.float32, trainable = True, name = 'A_log')
        self.A_log.assign(tf.math.log(tf.tile(tf.expand_dims(tf.range(1, self.d_state + 1, dtype = tf.float32), axis = 0), (self.expand * self.d_model, 1))))
        self.D = self.add_weight(shape = (self.expand * self.d_model,), dtype = tf.float32, trainable = True, initializer = tf.keras.initializers.Constant(1.), name = 'D')
    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], self.d_model * self.expand)
    def call(self, x):
        # x.shape = (batch, seq_len, d_model * expand)
        x_dbl = tf.linalg.matmul(x, self.x_proj_weight) # x_dbl.shape = (batch, seq_len, dt_rank + 2 * d_state)
        if self.bias:
            x_dbl = x_dbl + self.x_proj_bias
        delta, B, C = tf.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], axis = -1)
        # delta.shape = (batch, seq_len, dt_rank)
        # B.shape = (batch, seq_len, d_state)
        # C.shape = (batch, seq_len, d_state)
        delta = tf.math.softplus(tf.linalg.matmul(delta, self.dt_proj_weight) + self.dt_proj_bias) # delta.shape = (batch, seq_len, expand * d_model)
        # selective scan
        # state(t+1) = A state(t) + B x(t) # B is input gate
        # y(t)   = C state(t) + D x(t) # C is output gate
        A = -tf.exp(self.A_log) # A.shape = (expand * d_model, d_state)
        y = selective_scan(x, delta,A,B,C,self.D)
        return y
    def get_config(self):
        config = super(SSM, self).get_config()
        config['d_model'] = self.d_model
        config['expand'] = self.expand
        config['d_state'] = self.d_state
        config['bias'] = self.bias
        return config
    @classmethod
    def from_config(cls, config):
        return cls(**config)
        
class MambaBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, expand = 2, bias = False, d_conv = 4, conv_bias = True, d_state = 16):
        super(MambaBlock, self).__init__()
        self.d_model = d_model
        self.expand = expand
        self.d_state = d_state
        self.bias = bias
        self.dt_rank = ceil(self.d_model / 16)
        self.d_conv = d_conv
        self.conv_bias = conv_bias
        self.fliter = d_model*expand
    def call(self,x):
        x_and_res = tf.keras.layers.Dense(2 * self.expand * self.d_model, use_bias = self.bias)(x) # results.shape = (batch, seq_len, 2 * expand * d_model)
        x, res = tf.keras.layers.Lambda(lambda x: tf.split(x, 2, axis = -1))(x_and_res) # x.shape = (batch, seq_len, expand * d_model)
        # spatial & channel mixing
        x = tf.keras.layers.Conv1D(self.fliter, kernel_size = self.d_conv, padding = 'same', use_bias = self.conv_bias, activation = tf.keras.activations.swish)(x) # x.shape = (batch, seq_len, expand * d_model)
        # selective state space model
        y = SSM(self.d_model, self.expand, self.d_state, self.bias)(x) # y.shape = (batch, seq_len, d_model * expand)
        # NOTE: borrowing idea of Swish gated linear unit (SwiGLU)
        # this layer gates ssm results with swish layer as well. it can be called as swish gated selective state space model (SwiSSM)
        y = tf.keras.layers.Lambda(lambda x: x[0] * tf.nn.silu(x[1]))([y, res])
        outputs = tf.keras.layers.Dense(self.d_model, use_bias = self.bias)(y)
        return outputs
# Pair
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class VisionEncoderMambaBlock(tf.keras.layers.Layer):
    """
    VisionMambaBlock is a module that implements the Mamba block from the paper
    Vision Mamba: Efficient Visual Representation Learning with Bidirectional
    State Space Model

    Args:
        dim (int): The input dimension of the input tensor.
        dt_rank (int): The rank of the state space model.
        dim_inner (int): The dimension of the inner layer of the
            multi-head attention.
        d_state (int): The dimension of the state space model.

    """
    def __init__(self,
        dim: int,
        dim_inner: int = 4,
        d_state: int = 16,
        patch_size_H: int = 16,
        patch_size_L: int = 16,
    ):
        super(VisionEncoderMambaBlock, self).__init__()
        self.dim = dim
        self.dim_inner = dim_inner
        self.d_state = d_state
        
        self.patch_height = patch_size_H
        self.patch_width  = patch_size_L
        self.x0_conv1d = tf.keras.layers.SeparableConv1D(dim, kernel_size=1)
        self.silu = SiLU()
        self.ssmx =  MambaBlock(d_model = dim, expand = dim_inner, d_state = d_state)

    def call(self, x, training=True):
        # Patch embedding
        b, H, W, c= x.shape
        x = rearrange(x, "b (h p1) (w p2) c -> b (h w) (p1 p2 c)", p1=self.patch_height, p2 =self.patch_width)
        x0 = self.x0_conv1d(x)
        x = self.ssmx(x0)
        x = self.silu(x)+x0
        x = tf.keras.layers.Dense(self.patch_height*self.patch_width*c,activation=None)(x)
        x = rearrange(x, "b (h w) (p1 p2 c)->b (h p1) (w p2) c", h = H//self.patch_height, p1=self.patch_height, p2 =self.patch_width)
        return x

In [None]:
########################util#########################
class Image_data:

    def __init__(self, img_size, dataset_path):
        self.img_size = img_size
        self.dataset_path = dataset_path


    def image_processing(self, filename):

        x = tf.io.read_file(filename)
        x_decode = tf.image.decode_jpeg(x, channels=3, dct_method='INTEGER_ACCURATE')
        img = tf.image.resize(x_decode, [self.img_size, self.img_size], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
        img = preprocess_fit_train_image(img)

        return img

    def preprocess(self):

        self.train_images = glob(os.path.join(self.dataset_path, '*.png')) + glob(os.path.join(self.dataset_path, '*.jpg'))

def adjust_dynamic_range(images, range_in, range_out, out_dtype):
    scale = (range_out[1] - range_out[0]) / (range_in[1] - range_in[0])
    bias = range_out[0] - range_in[0] * scale
    images = images * scale + bias
    images = tf.clip_by_value(images, range_out[0], range_out[1])
    images = tf.cast(images, dtype=out_dtype)
    return images

def random_flip_left_right(images):
    s = tf.shape(images)
    mask = tf.random.uniform([1, 1, 1], 0.0, 1.0)
    mask = tf.tile(mask, [s[0], s[1], s[2]]) # [h, w, c]
    images = tf.where(mask < 0.5, images, tf.reverse(images, axis=[1]))
    return images

def preprocess_fit_train_image(images):
    images = adjust_dynamic_range(images, range_in=(0.0, 255.0), range_out=(-1.0, 1.0), out_dtype=tf.dtypes.float32)
    images = random_flip_left_right(images)
    # images = tf.transpose(images, [2, 0, 1])

    return images

def preprocess_image(images):
    images = adjust_dynamic_range(images, range_in=(0.0, 255.0), range_out=(-1.0, 1.0), out_dtype=tf.dtypes.float32)
    # images = tf.transpose(images, [2, 0, 1])

    return images

def postprocess_images(images):
    images = adjust_dynamic_range(images, range_in=(-1.0, 1.0), range_out=(0.0, 255.0), out_dtype=tf.dtypes.float32)
    # images = tf.transpose(images, [0, 2, 3, 1])
    images = tf.cast(images, dtype=tf.dtypes.uint8)
    return images

def load_images(image_path, img_width, img_height, img_channel):

    # from PIL import Image
    if img_channel == 1 :
        img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE)
    else :
        img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # img = cv2.resize(img, dsize=(img_width, img_height))
    img = tf.image.resize(img, [img_height, img_width], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
    img = preprocess_image(img)

    if img_channel == 1 :
        img = np.expand_dims(img, axis=0)
        img = np.expand_dims(img, axis=-1)
    else :
        img = np.expand_dims(img, axis=0)

    return img

def save_images(images, size, image_path):
    # size = [height, width]
    return imsave(postprocess_images(images), size, image_path)

def imsave(images, size, path):
    images = merge(images, size)
    images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)

    return cv2.imwrite(path, images)

def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[h*j:h*(j+1), w*i:w*(i+1), :] = image

    return img

def str2bool(x):
    return x.lower() in ('true')

def check_folder(log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    return log_dir

def automatic_gpu_usage() :
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

def multi_gpu_loss(x, global_batch_size):
    ndim = len(x.shape)
    no_batch_axis = list(range(1, ndim))
    x = tf.reduce_mean(x, axis=no_batch_axis)
    x = tf.reduce_sum(x) / global_batch_size

    return x

In [None]:
##########################layers############################
# building block modules
class Block(Layer):
    def __init__(self, dim, groups=8):
        super(Block, self).__init__()
        self.proj = nn.Conv2D(dim, kernel_size=3, strides=1, padding='SAME')
        self.norm = tfa.layers.GroupNormalization(groups, epsilon=1e-05)
        self.act = SiLU() # x * sigmoid(x)

    def call(self, x, scale_shift=None, training=True):
        x = self.proj(x)
        x = self.norm(x, training=training)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(Layer):
    def __init__(self, dim, dim_out, time_emb_dim=None, groups=8):
        super(ResnetBlock, self).__init__()

        self.mlp = Sequential([
            SiLU(),
            nn.Dense(units=dim_out * 2)
        ]) if exists(time_emb_dim) else None

        self.block1 = Block(dim_out, groups=groups)
        self.block2 = Block(dim_out, groups=groups)
        self.res_conv = nn.Conv2D(filters=dim_out, kernel_size=1, strides=1) if dim != dim_out else Identity()

    def call(self, x, time_emb=None, training=True):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b 1 1 c')
            scale_shift = tf.split(time_emb, num_or_size_splits=2, axis=-1)

        h = self.block1(x, scale_shift=scale_shift, training=training)
        h = self.block2(h, training=training)

        return h + self.res_conv(x)

class LinearAttention(Layer):
    def __init__(self, dim, heads=4, dim_head=32):
        super(LinearAttention, self).__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.hidden_dim = dim_head * heads

        self.attend = nn.Softmax()
        self.to_qkv = nn.Conv2D(filters=self.hidden_dim * 3, kernel_size=1, strides=1, use_bias=False)

        self.to_out = Sequential([
            nn.Conv2D(filters=dim, kernel_size=1, strides=1),
            LayerNorm(dim)
        ])

    def call(self, x, training=True):
        b, h, w, c = x.shape
        qkv = self.to_qkv(x)
        qkv = tf.split(qkv, num_or_size_splits=3, axis=-1)
        q, k, v = map(lambda t: rearrange(t, 'b x y (h c) -> b h c (x y)', h=self.heads), qkv)

        q = tf.nn.softmax(q, axis=-2)
        k = tf.nn.softmax(k, axis=-1)

        q = q * self.scale
        context = einsum('b h d n, b h e n -> b h d e', k, v)

        out = einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b x y (h c)', h=self.heads, x=h, y=w)
        out = self.to_out(out, training=training)

        return out

class Attention(Layer):
    def __init__(self, dim, heads=4, dim_head=32):
        super(Attention, self).__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv2D(filters=self.hidden_dim * 3, kernel_size=1, strides=1, use_bias=False)
        self.to_out = nn.Conv2D(filters=dim, kernel_size=1, strides=1)

    def call(self, x, training=True):
        b, h, w, c = x.shape
        qkv = self.to_qkv(x)
        qkv = tf.split(qkv, num_or_size_splits=3, axis=-1)
        q, k, v = map(lambda t: rearrange(t, 'b x y (h c) -> b h c (x y)', h=self.heads), qkv)
        q = q * self.scale

        sim = einsum('b h d i, b h d j -> b h i j', q, k)
        sim_max = tf.stop_gradient(tf.expand_dims(tf.argmax(sim, axis=-1), axis=-1))
        sim_max = tf.cast(sim_max, tf.float32)
        sim = sim - sim_max
        attn = tf.nn.softmax(sim, axis=-1)

        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b x y (h d)', x = h, y = w)
        out = self.to_out(out, training=training)

        return out

class MLP(Layer):
    def __init__(self, hidden_dim):
        super(MLP, self).__init__()
        self.net = Sequential([
            Rearrange('... -> ... 1'), # expand_dims(axis=-1)
            nn.Dense(units=hidden_dim),
            GELU(),
            LayerNorm(hidden_dim),
            nn.Dense(units=hidden_dim),
            GELU(),
            LayerNorm(hidden_dim),
            nn.Dense(units=hidden_dim),
        ])

    def call(self, x, training=True):
        return self.net(x, training=training)

In [None]:
#######################network#####################################
class Unet(Model):
    def __init__(self,
                 dim=64,
                 init_dim=None,
                 out_dim=None,
                 dim_mults=(1, 2, 4, 8),
                 channels=1,
                 resnet_block_groups=8,
                 learned_variance=False,
                 sinusoidal_cond_mlp=True
                 ):
        super(Unet, self).__init__()

        # determine dimensions
        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2D(filters=init_dim, kernel_size=7, strides=1, padding='SAME')

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups = resnet_block_groups)

        # time embeddings
        time_dim = dim * 4
        self.sinusoidal_cond_mlp = sinusoidal_cond_mlp

        if sinusoidal_cond_mlp:
            self.time_mlp = Sequential([
                SinusoidalPosEmb(dim),
                nn.Dense(units=time_dim),
                GELU(),
                nn.Dense(units=time_dim)
            ])
        else:
            self.time_mlp = MLP(time_dim)

        # layers
        self.downs = []
        self.ups = []
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append([
                block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                Residual(PreNorm(dim_out,VisionEncoderMambaBlock(dim_out,4,16,patch_size_H=5,patch_size_L=4))),
                Downsample(dim_out) if not is_last else Identity()
            ])

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append([
                block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                Residual(PreNorm(dim_in,VisionEncoderMambaBlock(dim_in,4,16,patch_size_H=5,patch_size_L=4))),
                Upsample(dim_in) if not is_last else Identity()
            ])

        default_out_dim = channels * (1 if not learned_variance else 2)
        self.out_dim = default(out_dim, default_out_dim)

        self.final_conv = Sequential([
            block_klass(dim * 2, dim),
            nn.Conv2D(filters=self.out_dim, kernel_size=1, strides=1)
        ])

    def call(self, x, time=None, training=True, **kwargs):
        x = self.init_conv(x)
        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = tf.concat([x, h.pop()], axis=-1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        x = tf.concat([x, h.pop()], axis=-1)
        x = self.final_conv(x)
        return x
    
    
class InpRailDiffusion(Model): 
    def __init__(self, image_size):
        super(InpRailDiffusion, self).__init__()
#         self.small_size = random.choice([10,20,40])
#         self.timesteps = int(int(200/self.small_size)*int(160/self.small_size)/4)
        self.image_size = image_size
    def sample_timesteps(self, n, small_size,step=1, t_cut=None):
        timesteps = int(int(200/small_size)*int(160/small_size)/2//step)
        if t_cut==None:
            sample_t = tf.random.uniform(shape=[n], minval=1, maxval=timesteps + 1, dtype=tf.int32)
        else:
            sample_t = tf.random.uniform(shape=[n], minval=1, maxval=t_cut, dtype=tf.int32)
        return sample_t

    def noise_images(self, x, t, step=1, size=None, Seed=None, path_number=None, show=False): # forward process q ##########正向disfussion过程###############
        xt =[]
        if size==None:
            small_size=self.small_size 
        else:
            small_size=size
        All_grid = [k for k in range(1,int(int(200/small_size)*int(160/small_size))+1)]
        if Seed!=None:
            np.random.seed(Seed)
        np.random.shuffle(All_grid)
        Noisepath=[]
        for i in range(2):
            Noisepath.append(All_grid[int(len(All_grid)/2)*i:int(len(All_grid)/2)*(i+1)])
        if  path_number==None:
            tem_list = Noisepath[random.choice([0,1])]
        else:
            tem_list = Noisepath[path_number]
        for batch in range(len(x)):
            tem_x = copy.deepcopy(x[batch])
            tem_t= np.max(t[batch])*step #随便取个值
            new_tem_t = np.clip(tem_t,tem_t,len(tem_list)) #限制
            for j in range(new_tem_t//step):
                for s in range(step):
                    tem_L = int(tem_list[j*step+s]/160*small_size)
                    if (tem_list[j*step+s]-tem_L*160/small_size)>0:
                        tem_W = int(tem_list[j*step+s]-tem_L*160/small_size)-1
                    else:
                        tem_L = tem_L-1
                        tem_W = int(160/small_size)-1
                    tem_x[tem_L*small_size:(tem_L+1)*small_size,tem_W*small_size:(tem_W+1)*small_size,:]=0
                    if (batch == 0) & (show==True):
                        plt.imshow(tem_x)
                        plt.show()
            xt.append(tem_x) 
        xt = np.array(xt)
        return xt #t时刻噪声图

    def sample(self, model, n,  step=1, size=None, Seed=None, start_x=None, start_t=None, path_number=None,show=False): # reverse process p ##########反向disfussion过程###############
        if start_x is None:
            x = tf.random.normal(shape=[n, self.image_size[0], self.image_size[1], 1]) 
        else:
            x = start_x
            
        timesteps = int(int(200/size)*int(160/size)/2//step)

        for i in tqdm(reversed(range(1, self.timesteps if start_t is None else (start_t+1))), desc='sampling loop time step', total=timesteps):
            ####这里不随机，按照设计的轨迹加mask#####
            t = tf.ones(n, dtype=tf.int32) * i #从最后时刻倒着往前推
            x0_predicted =  model(x, t) #U-Net输出预测上一帧原始图像
            x0_predicted = np.array(x0_predicted)
            Dt = self.noise_images(x0_predicted,t, step, size,Seed,path_number=path_number)
            if i==1:
                Dt_1 = x0_predicted
            else:
                Dt_1 = self.noise_images(x0_predicted, tf.ones(n, dtype=tf.int32) * (i-1),step, size,Seed, path_number=path_number)
            tem = Dt_1 - Dt 
            x = x + tem
            if show==True:
                plt.imshow(x[0])
                plt.show()
        return x

In [None]:
automatic_gpu_usage()#分配GPU

##########加载数据#############
Type_I_path_Img = r'D:\AI in NTU\Rail data\RSDDs\Type-I RSDDs dataset\Train\Rail surface images'

paths =[]
for p in os.listdir(Type_I_path_Img):
    paths.append(Type_I_path_Img +'\\'+ p)

#print(paths)
#paths_count = len(paths)
BATCH_SIZE = 16

#创建图片路径及其数字标签的dataset
db_train= tf.data.Dataset.from_tensor_slices(paths)
db_train = db_train.shuffle(buffer_size=8,seed=2023)
db_train = db_train.batch(BATCH_SIZE)

def load_image(path):
    path = str(path)[12:-26].replace("\\\\","/")
    image = cv2.imdecode(np.fromfile(path,dtype=np.uint8),-1)
    image =np.array(image)[:,:,np.newaxis]
    return image

""" Network """
unet = Unet(dim=32)
diffusion = InpRailDiffusion((200, 160))

""" Finalize model (build) """
test_images = np.ones([1, 200, 160, 1])
small_size = random.choice([10,20,40])
Step=5
test_t = diffusion.sample_timesteps(n=test_images.shape[0],small_size=small_size,step=Step)
_ = unet(test_images, test_t)

""" Optimizer """
optimizer = tf.keras.optimizers.Adam(1e-4)
#unet.load_weights("./Diffusion-Paints-blur-Mamba-N2")

History=[]
for epoch in range(0,400,1):
    count=0
    db_train = db_train.shuffle(16)
    Average_loss=0
    for batch_size in db_train:
        count+=1
        train_image = []
        for i in range(len(batch_size)):
            tem=load_image(batch_size[i])
            train_image.append(load_image(batch_size[i]))
        train_image=np.array(train_image,dtype=float)

        with tf.GradientTape() as tape:
            small_size = random.choice([10,20,40])
            t = diffusion.sample_timesteps(n=train_image.shape[0],small_size=small_size,step=Step)
            x_t = diffusion.noise_images(train_image, t, ,step=Step, size=small_size)#t时刻噪声图
            
            predicted_noise_image = unet(x_t, t) #t-1时刻的噪声

            #loss = multi_gpu_loss(loss, global_batch_size=BATCH_SIZE)
            loss = tf.keras.losses.mean_absolute_error(train_image,predicted_noise_image)
            loss = tf.reduce_mean(loss)
            
            gradients = tape.gradient(loss, unet.trainable_variables)
            optimizer.apply_gradients(zip(gradients, unet.trainable_variables))
            #tf.print("Loss:%4.2f" %(loss))   
            Average_loss=Average_loss + loss
            
    ###valid###
    if epoch%20==0:
        small_size = random.choice([10,20,40])
        t = diffusion.sample_timesteps(n=train_image.shape[0],small_size=small_size,step=Step)
        x_t = diffusion.noise_images(train_image, t, step=Step, size=small_size)#t时刻噪声图
        plt.imshow(x_t[0])
        plt.show()
        predicted_noise_image = unet(x_t, t) #t-1时刻的噪声
        plt.imshow(predicted_noise_image[0])
        plt.show()

    Average_loss = Average_loss/count
    History.append([epoch, Average_loss])
    tf.print("=>Epoch%4d  Averageloss:%4.2f" %(epoch, Average_loss))   

    unet.save_weights("./Diffusion-Paints-blur-Mamba-N2")