In [1]:
import os
import workspace_path
os.chdir(workspace_path.path)

In [2]:
os.getcwd()

'e:\\Code\\openmmlab\\mmsegmentation'

## 一：数据集配置文件

定义数据集类（各类别名称及配色）

In [3]:
import os
import urllib.request

# 删除文件
file_path = "mmseg/datasets/ZihaoDataset.py"
if os.path.exists(file_path):
    os.remove(file_path)

# 下载文件
url = "https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/configs/ZihaoDataset.py"
save_path = "mmseg/datasets/ZihaoDataset.py"
urllib.request.urlretrieve(url, save_path)

('mmseg/datasets/ZihaoDataset.py', <http.client.HTTPMessage at 0x228aa6e5280>)

In [4]:
with open(save_path, 'r') as file:
    content = file.read()
    print(content)

# 同济子豪兄 2023-6-25
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

@DATASETS.register_module()
class ZihaoDataset(BaseSegDataset):
    # 类别和对应的 RGB配色
    METAINFO = {
        'classes':['background', 'red', 'green', 'white', 'seed-black', 'seed-white'],
        'palette':[[127,127,127], [200,0,0], [0,200,0], [144,238,144], [30,30,30], [251,189,8]]
    }
    
    # 指定图像扩展名、标注扩展名
    def __init__(self,
                 seg_map_suffix='.png',   # 标注mask图像的格式
                 reduce_zero_label=False, # 类别ID为0的类别是否需要除去
                 **kwargs) -> None:
        super().__init__(
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs)


## 二：注册数据集类

在`mmseg/datasets/__init__.py`中注册刚刚定义的`ZihaoDataset`数据集类

In [5]:
# 创建文件副本
import shutil
if os.path.exists("mmseg/datasets/__init__bak.py"):
    shutil.copyfile("mmseg/datasets/__init__bak.py", "mmseg/datasets/__init__.py")
else:
    shutil.copyfile("mmseg/datasets/__init__.py", "mmseg/datasets/__init__bak.py")

# 读取文件内容
with open('mmseg/datasets/__init__.py', 'r') as file:
    content = file.readlines()

# 在 __all__ 上面添加新的 import 语句
new_line = 'from .ZihaoDataset import ZihaoDataset\n'
index = content.index('__all__ = [\n') - 2
content.insert(index, new_line)

# 在 __all__ 的最后追加一个列表元素
new_element = "    ,'ZihaoDataset'\n"
index = content.index(']\n')
content.insert(index, new_element)

# 写入修改后的内容到文件
with open('mmseg/datasets/__init__.py', 'w') as file:
    file.writelines(content)

## 三：pipeline配置文件

数据集路径、预处理、后处理、DataLoader、测试集评估指标

In [6]:
# 删除文件
file_path = "configs/_base_/datasets/ZihaoDataset_pipeline.py"
if os.path.exists(file_path):
    os.remove(file_path)

content = '''
# 数据处理 pipeline
# 同济子豪兄 2023-6-28

# 数据集路径
dataset_type = 'ZihaoDataset' # 数据集类名
data_root = 'data/Watermelon87_Semantic_Seg_Mask/' # 数据集路径（相对于mmsegmentation主目录）

# 输入模型的图像裁剪尺寸，一般是 128 的倍数，越小显存开销越少
crop_size = (512, 512)

# 训练预处理
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(
        type='RandomResize',
        scale=(2048, 1024),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]

# 测试预处理
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]

# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]

# 训练 Dataloader
train_dataloader = dict(
    batch_size=2,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/train', seg_map_path='ann_dir/train'),
        pipeline=train_pipeline))

# 验证 Dataloader
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/val', seg_map_path='ann_dir/val'),
        pipeline=test_pipeline))

# 测试 Dataloader
test_dataloader = val_dataloader

# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])

# 测试 Evaluator
test_evaluator = val_evaluator
'''

# 创建文件并写入文本
with open(file_path, 'w') as file:
    file.write(content)
