# 检测 CUDA 版本


In [None]:
!nvidia-smi

# 0. 相关库安装

### 1. 使用 Python3.7

### 2. 仅需要在新环境下运行此 cell


In [None]:
%pip install pandas
%pip install tqdm
%pip install opencv-python
%pip install openpyxl
%pip install matplotlib
%pip install mxnet
%pip install gluoncv
%pip install numpy==1.17.0 --force-reinstall

# 1. 准备工作

## 1.1 导入相关库


In [1]:
import os
import json

import numpy as np
import pandas as pd
from PIL import Image

from pathlib import Path
from tqdm import tqdm

import mxnet as mx
from mxnet import image
import gluoncv
from gluoncv.data.transforms.presets.segmentation import test_transform
from gluoncv.utils.viz import get_color_pallete

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] # 使用SimHei字体显示中文
plt.rcParams['axes.unicode_minus']=False   # 解决负号'-'显示为方块的问题

## 1.3 设置主要参数


In [2]:
# 读取与保存位置
PATH_LOCS = {
    "INPUT": Path("input"),  # 输入数据集的位置（指没有经过处理的图片数据）
    "ADE_SPLIT": Path("ade-split-output")  # ADE语义分割后的结果保存位置
}
# 创建没有的文件夹
for _, path in PATH_LOCS.items():
    path.mkdir(parents=True, exist_ok=True)

# 2. ADE 语义分割

## 2.1 准备工作

由于 mxnet 较为老旧，若使用 mxnet 进行语义分割请使用 mxnet_ade.ipynb,并使用 python3.7 版本, numpy 版本不超过 1.17.0


### 2.1.1 读取类别映射


In [10]:
type_f_path = Path("type_f.json")

if type_f_path.exists():
    with open(type_f_path, 'r', encoding="utf8") as file:
        col_map = json.load(file)
        if col_map:
            print(f"读取 type_f.json 成功！获取了 {len(col_map)} 条映射数据.")
else:
    print(f"type_f.json 不存在！请检查文件结构.")

读取 type_f.json 成功！获取了 150 条映射数据.


### 2.1.2 定义运算的相关函数


In [11]:
def overlay_mask_on_image(original_img_path, mask, output_folder, filename):
    original_img = Image.open(original_img_path)
    mask_img = mask.convert("RGBA")
    overlayed_img = Image.blend(original_img.convert("RGBA"), mask_img, 0.5)  # 调整遮罩透明度
    overlay_filename = filename.stem + '_overlay.png'
    overlay_path = output_folder / overlay_filename
    overlayed_img.save(overlay_path)
    return overlay_filename


def calculate_percentage(counts, total_pixels):
    return [count / total_pixels * 100 for count in counts]


def collect_segmentation_data(predict, col_map):
    unique_elements, counts_elements = np.unique(predict, return_counts=True)
    total_pixels = sum(counts_elements)
    counts_percentage = calculate_percentage(counts_elements, total_pixels)
    data = {col_map.get(k, 'Unknown'): v for k, v in zip(unique_elements, counts_percentage)}
    return data


def save_segmentation_stats_to_csv(data, output_folder, filename):
    df = pd.DataFrame([data])
    output_filename_csv = filename.stem + '_segmentation_stats.csv'
    save_path_csv = output_folder / output_filename_csv
    df.to_csv(save_path_csv, index=False)
    return output_filename_csv


def auto_adjust_chart_size(df):
    size = max(10, len(df.columns) / 2)
    return size


def save_visualization_chart(data, output_folder, filename):
    df = pd.DataFrame([data]).rename(columns=col_map)
    df = df.loc[:, df.sum() > 1]
    df['其他'] = 100 - df.sum(axis=1)
    chart_size = auto_adjust_chart_size(df)
    plt.figure(figsize=(chart_size, 6))
    df.T.plot(kind='bar', legend=False, width=0.8)
    plt.ylabel('Percentage (%)')
    plt.title(f'Segmentation Percentages for {filename.name}')
    plt.xticks(rotation=0)  # 确保x轴标签名称横向显示
    chart_filename = filename.stem + '_visualization_chart.png'
    save_path_chart = output_folder / chart_filename
    plt.savefig(save_path_chart)
    plt.cla()
    plt.close('all')


def compile_and_plot_segmentation_trends(segmentation_data, output_folder):
    df_compiled = pd.DataFrame(segmentation_data).T.rename(columns=col_map).fillna(0)
    mean_percentages = df_compiled.mean()
    significant_categories = mean_percentages[mean_percentages > 1].index
    df_compiled = df_compiled[significant_categories]
    plt.figure(figsize=(max(12, len(segmentation_data) * 0.5), 8))
    for column in df_compiled.columns:
        plt.plot(df_compiled.index, df_compiled[column], marker='o', label=column)
    plt.xlabel('图片索引')
    plt.ylabel('像素百分比')
    plt.title('Segmentation Trends Across Images')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    trends_filename = 'segmentation_trends.png'
    plt.savefig(output_folder / trends_filename)
    plt.cla()
    plt.close('all')


def segment_images(input_folder : Path, output_folder, model, ctx) -> int:
    segmentation_data = {}
    for filename in tqdm(sorted(os.listdir(input_folder))):
        if filename.lower().endswith(('.jpg', '.png', '.jpeg')):
            # 读取图片并归一化
            img_path = input_folder / Path(filename)
            img = image.imread(img_path)
            img = test_transform(img, ctx=ctx)
            # 进行预测
            output = model.predict(img)
            predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
            # 将所有判断为房屋的位置改为建筑，保持一致性
            predict[predict == 25] = 1
            # 创建遮罩
            mask = get_color_pallete(predict, 'ade')
            # 保存遮罩
            output_filename = filename.split('.')[0] + '_seg_ade_101.png'
            save_path = output_folder / output_filename
            mask.save(save_path)
            # 透明化遮罩并覆盖到原图
            overlay_mask_on_image(img_path, mask, output_folder, img_path)
            # 统计图片内色块数据
            data = collect_segmentation_data(predict, col_map)
            save_segmentation_stats_to_csv(data, output_folder, img_path)
            save_visualization_chart(data, output_folder, img_path)
            segmentation_data[filename] = data
    # 保存像素百分比的条带
    compile_and_plot_segmentation_trends(segmentation_data, output_folder)
    return len(segmentation_data)

### 2.1.3 载入模型


In [12]:
ctx = mx.cpu(0)
ade_model = gluoncv.model_zoo.get_model('deeplab_resnest101_ade', pretrained=True)
ade_model.collect_params().reset_ctx(ctx)

### 2.1.4 使用 ADE 数据集对图片进行分割并处理相关数据


In [13]:
print("使用 MXNET 进行图片语义分割.")
count = segment_images(PATH_LOCS["INPUT"], PATH_LOCS["ADE_SPLIT"], ade_model, ctx)
print("图片语义分割完成!", "共分割", count, "张图片.")

使用 MXNET 进行图片语义分割.


  0%|          | 3/1243 [01:08<7:54:29, 22.96s/it]


KeyboardInterrupt: 