# Depth Analysis - Batch Image Processing
# 深度分析 - 批量图像处理

This notebook processes raw PNG images to generate:
1. Heatmap visualizations (colormap) saved to `1_red_png/`
2. Depth line plots with fitted curves saved to `2_depth_line/`

本笔记本处理原始PNG图像，生成：
1. 热力图可视化（色图）保存到 `1_red_png/`
2. 深度曲线图（带拟合曲线）保存到 `2_depth_line/`


In [14]:
# Import required libraries / 导入所需库
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.ticker as ticker


In [15]:
# Set matplotlib parameters / 设置matplotlib参数
# Note: All plot labels must be in English / 注意：所有图表标签必须使用英文
plt.rcParams['font.size'] = 18
plt.rcParams['axes.unicode_minus'] = False


In [16]:
# ============================================================================
# Function 1: Generate Heatmap / 函数1: 生成热力图
# ============================================================================
# Function: Extract green channel information and generate heatmap
# 功能: 提取图像绿色通道信息，生成热力图并保存
# ============================================================================

def generate_heatmap(image_path, output_path, threshold=50):
    """
    Generate heatmap from image green channel
    从图像绿色通道生成热力图
    
    Parameters / 参数:
        image_path: Input image path / 输入图像路径
        output_path: Output image path / 输出图像路径
        threshold: Green channel threshold, pixels below this value will be filtered
                   绿色通道阈值，低于此值的像素将被过滤
    """
    # Load image / 读取图像
    img = Image.open(image_path)
    width, height = img.size

    # Initialize green intensity array / 初始化绿色强度数组
    green_intensity = np.zeros((height, width))

    # Extract green channel for each pixel / 遍历每个像素，提取绿色通道信息
    for i in range(height):
        for j in range(width):
            pixel = img.getpixel((j, i))
            r, g, b, a = pixel  # RGBA values
            # Check if pixel is green / 判断是否为绿色像素
            if g > r and g > b:
                green_intensity[i][j] = g * (a / 255.0)  # Consider alpha / 考虑透明度

    # Apply threshold filtering / 应用阈值过滤
    filtered_intensity = np.where(green_intensity >= threshold, green_intensity, 0)

    # Create custom colormap with transparent zero / 创建自定义colormap，0值设为透明
    colors = [(0, 0, 0, 0)] + [(plt.cm.inferno(i)) for i in range(1, 256)]
    custom_cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap', colors, N=256)

    # Plot and save heatmap / 绘制并保存热力图
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.imshow(filtered_intensity, cmap=custom_cmap, interpolation='nearest')
    ax.axis('off')
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close()
    
    print(f"Heatmap saved: {os.path.basename(output_path)}")


In [17]:
# ============================================================================
# Function 2: Generate Depth Line Plot / 函数2: 生成深度曲线图
# ============================================================================
# Function: Extract top and bottom boundaries, fit curves and plot depth analysis
# 功能: 提取图像顶部和底部边界，拟合曲线并绘制深度分析图
# ============================================================================

def format_func(value, tick_number):
    """
    Tick formatter function / 刻度格式化函数
    Convert pixel coordinates to actual scale / 将像素坐标转换为实际尺度
    """
    return r"$\bf{" + str(int(value) * 5) + "}$"

def generate_depth_line(image_path, output_path, threshold=50, x_range=None, y_range=None):
    """
    Generate depth line plot / 生成深度曲线图
    
    Parameters / 参数:
        image_path: Input image path / 输入图像路径
        output_path: Output image path / 输出图像路径
        threshold: Green channel threshold / 绿色通道阈值
        x_range: Optional X-axis range (auto when None) / 可选X轴范围
        y_range: Optional Y-axis range (auto when None) / 可选Y轴范围
    """
    # Load image / 读取图像
    img = Image.open(image_path)
    width, height = img.size
    img_array = np.array(img)

    # Apply threshold filtering / 应用阈值过滤
    img_array[img_array[:, :, 1] < threshold] = 0

    # Find top and bottom points in each column / 找到每一列中色块的最高点和最低点
    top_points = []
    bottom_points = []
    for i in range(width):
        column = img_array[:, i, 1]  # Green channel / 绿色通道
        non_zero_indices = np.nonzero(column)
        if non_zero_indices[0].size != 0:
            top_points.append((i, non_zero_indices[0][0]))
            bottom_points.append((i, non_zero_indices[0][-1]))

    if len(top_points) == 0:
        raise ValueError("No valid green-channel pixels found for depth extraction.")

    # Convert to numpy arrays / 转换为numpy数组
    top_points = np.array(top_points)
    bottom_points = np.array(bottom_points)

    # Auto determine x_range based on image width / 根据图像宽度自动确定x_range
    # Different sizes: 400->80px, 600->120px, 800->160px
    if x_range is None:
        # Estimate the scale factor based on actual data width
        data_width = top_points[:, 0].max() - top_points[:, 0].min()
        if data_width < 100:  # ~400 size
            x_range = (0, 80)
        elif data_width < 140:  # ~600 size
            x_range = (0, 120)
        else:  # ~800 size
            x_range = (0, 160)
    
    # Auto calculate y_range / 自动计算y_range
    if y_range is None:
        min_y = min(top_points[:, 1].min(), bottom_points[:, 1].min())
        max_y = max(top_points[:, 1].max(), bottom_points[:, 1].max())
        span = max(1, max_y - min_y)
        margin = max(5, int(0.05 * span))
        y_range = (min_y - margin, max_y + margin)

    # Polynomial fitting / 多项式拟合
    poly_top = np.polyfit(top_points[:, 0], top_points[:, 1], 10)
    poly_bottom = np.polyfit(bottom_points[:, 0], bottom_points[:, 1], 10)

    # Create polynomial functions / 创建多项式函数
    poly_top_func = np.poly1d(poly_top)
    poly_bottom_func = np.poly1d(poly_bottom)

    # Set figure size based on ranges / 根据范围设置图形大小
    fig_width = max(4, abs(x_range[1] - x_range[0]) / 20)
    fig_height = max(4, abs(y_range[1] - y_range[0]) / 20)
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    # Plot fitted curves / 绘制拟合曲线
    x = np.linspace(x_range[0], x_range[1], 1000)
    ax.plot(x, poly_top_func(x), 'r-', linewidth=2, label='Top boundary')
    ax.plot(x, poly_bottom_func(x), 'b-', linewidth=2, label='Bottom boundary')

    # Plot original data points / 绘制原始数据点
    ax.scatter(top_points[:, 0], top_points[:, 1], color='red', s=3)
    ax.scatter(bottom_points[:, 0], bottom_points[:, 1], color='blue', s=3)

    # Add reference line at y=0 / 添加参考线
    ax.axhline(0, color='gray', linestyle='--', linewidth=1.5)

    # Set axes / 设置坐标轴
    ax.axis('equal')
    ax.set_xlim(x_range)
    ax.set_ylim(y_range[1], y_range[0])  # Invert Y-axis / 反转Y轴
    
    # Set ticks with adaptive step size / 设置自适应步长的刻度
    x_span = max(1, x_range[1] - x_range[0])
    y_span = max(1, y_range[1] - y_range[0])
    x_step = 40 if x_span >= 100 else max(20, int(x_span / 4) or 1)
    y_step = 20 if y_span >= 80 else max(10, int(y_span / 5) or 1)
    x_ticks = np.arange(x_range[0], x_range[1] + 1, x_step)
    y_ticks = np.arange(y_range[0], y_range[1] + 1, y_step)
    if x_ticks.size == 0 or x_ticks[-1] != x_range[1]:
        x_ticks = np.append(x_ticks, x_range[1])
    if y_ticks.size == 0 or y_ticks[-1] != y_range[1]:
        y_ticks = np.append(y_ticks, y_range[1])
    ax.set_xticks(x_ticks)
    ax.set_yticks(y_ticks)
    
    # Format tick labels (multiply by 5) / 格式化刻度标签（乘以5）
    formatter = ticker.FuncFormatter(format_func)
    ax.xaxis.set_major_formatter(formatter)
    ax.yaxis.set_major_formatter(formatter)
    
    # Set tick and border width (bold) / 设置刻度线和边框粗细（加粗）
    ax.xaxis.set_tick_params(width=2)
    ax.yaxis.set_tick_params(width=2)
    for axis_name in ['top', 'bottom', 'left', 'right']:
        ax.spines[axis_name].set_linewidth(2)

    # Save figure / 保存图像
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Depth line plot saved: {os.path.basename(output_path)}")


In [18]:
# ============================================================================
# Function 3: Batch Processing Main Function / 函数3: 批量处理主函数
# ============================================================================
# Function: Batch process all raw images to generate heatmaps and depth line plots
# 功能: 批量处理所有原始图像，生成热力图和深度曲线图
# ============================================================================

def batch_process_images(input_dir, heatmap_dir, depth_line_dir, threshold=50):
    """
    Batch process images / 批量处理图像
    
    Parameters / 参数:
        input_dir: Input directory path / 输入图像文件夹路径
        heatmap_dir: Heatmap output directory path / 热力图输出文件夹路径
        depth_line_dir: Depth line output directory path / 深度曲线图输出文件夹路径
        threshold: Green channel threshold / 绿色通道阈值
    """
    # Create output directories if not exist / 创建输出文件夹
    os.makedirs(heatmap_dir, exist_ok=True)
    os.makedirs(depth_line_dir, exist_ok=True)
    
    # Get all PNG files / 获取所有PNG文件
    image_files = [f for f in os.listdir(input_dir) if f.endswith('.png')]
    image_files.sort()
    
    print(f"Found {len(image_files)} images to process")
    print("=" * 60)
    
    # Process each image / 遍历处理每个图像
    for idx, filename in enumerate(image_files, 1):
        print(f"\n[{idx}/{len(image_files)}] Processing: {filename}")
        
        input_path = os.path.join(input_dir, filename)
        heatmap_path = os.path.join(heatmap_dir, filename)
        depth_line_path = os.path.join(depth_line_dir, filename)
        
        try:
            # Generate heatmap / 生成热力图
            generate_heatmap(input_path, heatmap_path, threshold=threshold)
            
            # Generate depth line plot / 生成深度曲线图
            generate_depth_line(input_path, depth_line_path, threshold=threshold)
            
        except Exception as e:
            print(f"Error processing {filename}: {str(e)}")
    
    print("\n" + "=" * 60)
    print("Batch processing completed!")
    print(f"Heatmaps saved to: {heatmap_dir}")
    print(f"Depth line plots saved to: {depth_line_dir}")


In [19]:
# ============================================================================
# Execute Batch Processing / 执行批量处理
# ============================================================================
# Set paths / 设置路径
# ============================================================================

# Define input/output paths / 定义输入输出路径
base_dir = "/Users/yanyu/Desktop/课题组文章/金晓强老师/AS审稿意见/20251102代码/_others_image_processing/depth_analysis"
input_dir = os.path.join(base_dir, "0_raw_png")
heatmap_dir = os.path.join(base_dir, "1_red_png")
depth_line_dir = os.path.join(base_dir, "2_depth_line")

# Execute batch processing / 执行批量处理
batch_process_images(input_dir, heatmap_dir, depth_line_dir, threshold=50)


Found 6 images to process

[1/6] Processing: 400-1.png
Heatmap saved: 400-1.png
Depth line plot saved: 400-1.png

[2/6] Processing: 400-2.png
Heatmap saved: 400-2.png
Depth line plot saved: 400-2.png

[3/6] Processing: 600-1.png
Heatmap saved: 600-1.png
Depth line plot saved: 600-1.png

[4/6] Processing: 600-2.png
Heatmap saved: 600-2.png
Depth line plot saved: 600-2.png

[5/6] Processing: 800-1.png
Heatmap saved: 800-1.png
Depth line plot saved: 800-1.png

[6/6] Processing: 800-2.png
Heatmap saved: 800-2.png
Depth line plot saved: 800-2.png

Batch processing completed!
Heatmaps saved to: /Users/yanyu/Desktop/课题组文章/金晓强老师/AS审稿意见/20251102代码/_others_image_processing/depth_analysis/1_red_png
Depth line plots saved to: /Users/yanyu/Desktop/课题组文章/金晓强老师/AS审稿意见/20251102代码/_others_image_processing/depth_analysis/2_depth_line
