In [1]:
import os 
# import cfg
import numpy as np
import torch as t 
import torchvision.transforms as transforms 
import torchvision.transforms.functional as ff
from torch.utils.data import Dataset 
from PIL import Image

In [None]:
class LabelProcessor:
    ''' 对label哈希编码 '''
    def __init__(self, file_path):
        self.colormap = self.read_color_map(file_path)
        self.cm2lbl = self.encode_label_pix(self.colormap)
    
    # 静态方法：定义在类中的普通函数；不可以实例对象，不可调用self.<name>；程序设计需要（简洁代码，高度封装）
    @staticmethod 
    def read_color_map(file_path):
        pd_label_color = pd.read_csv(file_path, sep = ',')
        colormap = []
        for i in range(len(pd_label_color)):
            tmp = pd_label_color.iloc[i]
            color = [tmp['r'], tmp['g'], tmp['b']]
            colormap.append(color)
        return colormap
    
    @staticmethod
    def encode_label_pix(colormap):
        cm2lbl = np.zeros(256 ** 3)
        for i, cm in enumerate(colormap):
            cm2lbl[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i # 对颜色标签创建哈希表
        return cm2lbl

    def encode_label_img(self, img): # 传入label像素点，返回对应的label，哈希查找提升查找效率
        data = np.array(img, dtype = 'int32') # 提升精度
        idx = (data[:,:,0]* 256 + data[:,:,1]) * 256 + data[:,:2])
        return np.array(self.cm2lbl[idx], dtype = 'int64')
        

class CamVidDataset(Dataset): # 继承自torch的Dataset
    def __init__(self, file_path = [], crop_size = None):
        '''
            参数为data和label的路径列表，以及裁切的尺寸
        '''
        # 1. 正确输入数据和标签的路径
        if len(file_path) != 2:
            raise ValueError('需要输入数据和标签文件夹路径列表')
        self.img_path = file_path[0]
        self.label_path = file_path[1]

        # 2. 提取路径中的所有子文件路径
        self.imgs = self.read_file(self.img_path)
        self.labels = self.read_file(self.label_path)

        # 3. 初始化数据处理函数设置
        self.crop_size = crop_size
    
    def __getitem__(self, index):
        # 单个文件路径
        img = self.imags[index] 
        label = self.labels[index]
        
        # 读取数据
        img = Image.open(img)
        label = Image.open(label).convert('RGB')

        # 先裁切，根据网络的架构
        img, label = self.center_crop(img, label)
        
        # 对数据进行处理，归一化，正则化，加速网络收敛
        img, label = self.img_transform(img, label)

        # 放入字典
        sample = {'img':img, 'label':label}
        return sample

    def __len__(self): # 继承Dataset
        return len(self.imgs)

    def read_file(self, path):
        ''' 读取文件夹名称路径列表 '''
        files_list = os.listdir(path)  # 读取文件夹下的文件名称
        files_path_list = [os.path.join(path, img) for img in files_list] # 拼接，生成完整路径
        files_path_list.sort()
        return files_path_list
    
    def center_crop(self, data, label, crop_size):
        ''' 根据网络的结构，裁剪图像大小 '''
        data = ff.center_crop(data, crop_size)
        label = ff.center_crop(label, crop_size)
        return data, label  
    
    def img_transform(self, img, label):
        ''' 对数据和标签，进行数值处理 '''
        transform_img = transforms.Compose(
            [
                transforms.ToTensor(), # 转为tensor
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.22, 0.225])
                # 约定俗成的2d数据集正则化参数
            ]
        )
        img = transform_img(img)
        label = np.array(label) # 以免不是np格式
        label = Image.fromarray(label.astype('unti8')) # 整形
        label = label_processor.encode_label_img(label) # 哈希编码，静态方法
        label = t.from_numpy(label) # 转为tensor
        return img, label 

label_processor = LabelProcessor(cfg.class_dict_path)