In [None]:
import tensorflow as tf
import numpy as np


In [None]:
def guided_filter(data, guide, radius, eps):

    if data.dtype == tf.uint8:
        data = tf.cast(data/255,tf.float32)

    while len(data.shape) < 4:
        data = tf.expand_dims(data,axis=0)

    # 获取数据的形状信息
    batch, height, width, channel = tf.shape(data)[0], tf.shape(data)[1], tf.shape(data)[2], tf.shape(data)[3]

    # 对数据和引导图像进行平均值滤波
    mean_data = tf.nn.avg_pool(data, ksize=[1, radius, radius, 1], strides=[1, 1, 1, 1], padding='SAME')
    mean_guide = tf.nn.avg_pool(guide, ksize=[1, radius, radius, 1], strides=[1, 1, 1, 1], padding='SAME')

    # 计算方差和协方差
    mean_data_guide = tf.nn.avg_pool(data * guide, ksize=[1, radius, radius, 1], strides=[1, 1, 1, 1], padding='SAME')
    cov_data_guide = mean_data_guide - mean_data * mean_guide

    mean_data_sq = mean_data ** 2
    mean_guide_sq = mean_guide ** 2

    # 计算引导图像的方差
    var_guide = tf.nn.avg_pool(guide ** 2, ksize=[1, radius, radius, 1], strides=[1, 1, 1, 1], padding='SAME') - mean_guide_sq

    # 计算a和b的系数
    a = cov_data_guide / (var_guide + eps)
    b = mean_data - a * mean_guide

    # 对a和b进行平均值滤波
    mean_a = tf.nn.avg_pool(a, ksize=[1, radius, radius, 1], strides=[1, 1, 1, 1], padding='SAME')
    mean_b = tf.nn.avg_pool(b, ksize=[1, radius, radius, 1], strides=[1, 1, 1, 1], padding='SAME')

    # 计算输出图像
    output = mean_a * guide + mean_b

    return output