##### 构建dataset和dataloader

In [10]:
import csv
def create_palette(csv_file):
    color2class={}
    #newline控制读取内容换行符的处理方式，''表示替换为自适应系统；None表示不做处理；'\n'表示全部替换为\n
    with open(csv_file,newline='') as f: 
    #csv.DictReader表示读取每一行都为一个字典，字典的键为第一行内容
        reader=csv.DictReader(f)
        for idx,item in enumerate(reader):
            class_idx,color=idx,(int(item['r']),int(item['g']),int(item['b']))
            color2class[color]=class_idx
    return color2class

In [11]:
color2class=create_palette(r"/volume/opemmmlab/mmengine/data/CamVid/class_dict.csv")
color2class

{(64, 128, 64): 0,
 (192, 0, 128): 1,
 (0, 128, 192): 2,
 (0, 128, 64): 3,
 (128, 0, 0): 4,
 (64, 0, 128): 5,
 (64, 0, 192): 6,
 (192, 128, 64): 7,
 (192, 192, 128): 8,
 (64, 64, 128): 9,
 (128, 0, 192): 10,
 (192, 0, 64): 11,
 (128, 128, 64): 12,
 (192, 0, 192): 13,
 (128, 64, 64): 14,
 (64, 192, 128): 15,
 (64, 64, 0): 16,
 (128, 64, 128): 17,
 (128, 128, 192): 18,
 (0, 0, 192): 19,
 (192, 128, 128): 20,
 (128, 128, 128): 21,
 (64, 128, 192): 22,
 (0, 0, 64): 23,
 (0, 64, 64): 24,
 (192, 64, 128): 25,
 (128, 128, 0): 26,
 (192, 128, 192): 27,
 (64, 0, 64): 28,
 (192, 192, 0): 29,
 (0, 0, 0): 30,
 (64, 192, 0): 31}

In [12]:
from torchvision.datasets import VisionDataset
import os
from PIL import Image
import numpy as np

class CamVid(VisionDataset):
    def __init__(self,
                 root,
                 img_folder,
                 mask_folder,
                 transform=None,
                 target_transform=None):
        super().__init__(root,transform=transform,target_transform=target_transform)
        self.img_folder=img_folder
        self.mask_folder=mask_folder
        self.images=list(sorted(os.listdir(os.path.join(root,img_folder))))
        self.masks=list(sorted(os.listdir(os.path.join(root,mask_folder))))
        self.color_2_class=create_palette(os.path.join(root,'class_dict.csv'))

    def __getitem__(self,index):
        img_path=os.path.join(self.root,self.img_folder,self.images[index])
        mask_path=os.path.join(self.root,self.mask_folder,self.masks[index])

        img=Image.open(img_path).convert("RGB")
        mask=Image.open(mask_path).convert("RGB")

        if self.transform:
            img=self.transform(img)

        #生成标签图
        mask=np.array(mask)
        mask=mask[:,:,0]*256**2+mask[:,:,1]*256+mask[:,:,2]

        label=np.zeros_like(mask,np.int64)
        for color,class_idx in self.color_2_class.items():
            color=color[0]*256**2+color[1]*256+color[2]
            label[mask==color]=class_idx

        if self.target_transform:
            label=self.target_transform(label)

        data_samples=dict(labels=label,img_path=img_path,mask_path=mask_path)

        return img,data_samples
    
    def __len__(self):
        return len(self.images)


In [13]:
import torch
import torchvision.transforms as transforms

norm_cfg=dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(**norm_cfg)
])

target_transforms=transforms.Lambda(lambda x:torch.tensor(np.array(x),dtype=torch.long))

train_set=CamVid(
    './data/CamVid/',
    img_folder='train',
    mask_folder='train_labels',
    transform=transform,
    target_transform=target_transforms
)
val_set=CamVid(
    './data/CamVid/',
    img_folder='val',
    mask_folder='val_labels',
    transform=transform,
    target_transform=target_transforms
)

train_dataloader=dict(
    batch_size=3,
    dataset=train_set,
    sampler=dict(type="DefaultSampler",shuffle=True),
    collate_fn=dict(type="default_collate")
)

val_dataloader=dict(
    batch_size=1,
    dataset=val_set,
    sampler=dict(type="DefaultSampler",shuffle=False),
    collate_fn=dict(type="default_collate")
)



##### 构建分割模型

In [14]:
import torch.nn.functional as F
from torchvision.models.segmentation import deeplabv3_resnet50
from mmengine.model import BaseModel

class MMDeeplabV3(BaseModel):
    def __init__(self,num_classes):
        super().__init__()
        self.deeplab=deeplabv3_resnet50()
        self.deeplab.classifier[4]=torch.nn.Conv2d(
            256, num_classes, kernel_size=(1, 1), stride=(1, 1)
        )

    def forward(self,imgs,data_samples=None,mode='tensor'):
        x=self.deeplab(imgs)['out']
        if mode=='loss':
            return {'loss':F.cross_entropy(x,data_samples['labels'])}
        elif mode=='predict':
            return x,data_samples
        

##### 定义IOU算子

In [15]:
from mmengine.evaluator import BaseMetric

class IoU(BaseMetric):
    def process(self,data_batch,data_samples):
        #复习一下：pytorch的dataloader有如下作用，从数据集中的getitem函数出来几项，那么dataloader出来的
        #也是相同项数，且每一项里都是bs个子项，如果是字典则键名也不变
        #错啦错啦，这里算子的data_samples是模型的预测输出，只不过模型预测输出包含了gt
        #实际上data_batch才是从dataloader出来的数据源
        preds,labels=data_samples[0],data_samples[1]['labels']
        #preds:bs*c*h*w,labels:bs*h*w
        preds=torch.argmax(preds,dim=1)
        insection=(preds==labels).sum()
        union=torch.logical_or(preds,labels).sum()
        iou=(insection/union).cpu()#注意这里计算的是一个bs所有图像的交并比
        self.results.append(dict(batch_size=len(labels),iou=iou*len(labels)))
    
    #注意这里和我理解的一张图计算IoU不一样
    def compute_metrics(self,results):
        total_iou=sum([result['iou'] for result in results])
        total_imgs=sum([result['batch_size'] for result in results])
        return dict(iou=total_iou/total_imgs)

##### 定义可视化钩子

In [16]:
from mmengine.hooks import Hook
import cv2
import shutil
import os.path as osp

class SegVisHook(Hook):
    def __init__(self,data_root,vis_num=1):
        super().__init__()
        self.palette=create_palette(osp.join(data_root,'class_dict.csv'))
        self.vis_num=vis_num
        self.data_root=data_root

    def after_val_iter(self,
                       runner,
                       batch_idx:int,
                       data_batch=None,
                       outputs=None):
        if batch_idx>self.vis_num:
            return
        
        #取出预测数据
        preds,data_samples=outputs
        img_paths,mask_paths=data_samples['img_path'],data_samples['mask_path']
        _,C,H,W=preds.shape
        #将三通道的预测图转为单通道，值表示每个像素点所属类别
        preds=torch.argmax(preds,dim=1)
        for idx,(pred,img_path,mask_path) in enumerate(zip(preds,img_paths,mask_paths)):
            pred_mask=np.zeros((H,W,3),dtype=np.uint8)
            #设置pred_mask为画布
            runner.visualizer.set_image(pred_mask)
            for color,class_idx in self.palette.items():
                runner.visualizer.draw_binary_masks(
                    pred==class_idx,#在哪些像素点画,pred是H*W的，但pred_mask是H*W*3的
                    colors=[color],
                    alphas=1.0
                )
            #转换为BGR
            pred_mask=runner.visualizer.get_image()[...,::-1] #TODO 这个函数可以获得画完图之后的数组数据，值得借鉴
            #存图，一张图对应三张结果图，存一个文件夹
            save_dir=osp.join(runner.log_dir,'vis_data',str(idx))
            os.makedirs(save_dir,exist_ok=True)
            shutil.copyfile(img_path,osp.join(save_dir,osp.basename(img_path)))
            shutil.copyfile(mask_path,osp.join(save_dir,osp.basename(mask_path)))

            cv2.imwrite(osp.join(save_dir,f"pred_{osp.basename(img_path)}"),pred_mask)


##### 训练

In [17]:
from torch.optim import AdamW
from mmengine.runner import Runner
from mmengine.optim import AmpOptimWrapper

num_classes=32
torch.cuda.empty_cache()
runner=Runner(
    model=MMDeeplabV3(num_classes=num_classes),
    work_dir='./work_dir/deeplabv3',
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    train_cfg=dict(by_epoch=True,max_epochs=10,val_interval=10),
    val_cfg = dict(),
    #type后面可以是1).注册过的字符串，2).也可以直接是类，3).也可以不传入字典，直接传入对象
    optim_wrapper=dict(type=AmpOptimWrapper,optimizer=dict(type=AdamW,lr=2e-4)),
    val_evaluator=dict(type=IoU),
    #自定义Hooks，以列表直接传入对象而非类的字典
    custom_hooks=[SegVisHook(data_root='./data/CamVid/')],
    #默认Hooks，传入类的字典
    default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=10)),
    resume=True,
    load_from="/volume/opemmmlab/mmengine/work_dir/deeplabv3/epoch_10.pth",
)
runner.val()

08/14 02:04:19 - mmengine - [4m[97mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:39:04) [GCC 10.3.0]
    CUDA available: True
    numpy_random_seed: 363155895
    GPU 0,1: GeForce GTX 1080
    CUDA_HOME: /usr/local/cuda
    NVCC: Cuda compilation tools, release 10.2, V10.2.8
    GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
    PyTorch: 1.12.1+cu102
    PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.2
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm

{'iou': tensor(0.9063)}