## CSV格式的dataloader

In [2]:
import torch
import cv2
from torch.utils.data import DataLoader, Dataset,RandomSampler,SequentialSampler

### 单类别

### csv 的格式

![](./img/1.png)

In [3]:
'''
csv 数据读取器
'''
class CSVDataset(Dataset):
    # 初始化函数 获取图像路径，标签路径，
    def __init__(self, csv, img_root, image_ids, transforms=None, test=False):
        '''
        csv_path: 训练集或测试集全路径
        image_ids: 真正参与训练或测试的图像，往往用于交叉验证等
        transforms: 图像的增广函数集合
        '''
        # 图片的标签和基本信息
        self._csv = csv
        # 图像路径
        self._img_root = img_root
        # 图片的 ID 列表
        self._image_ids = image_ids
        # 图像增强
        self._transforms = transforms
        # 测试集
        self._test = test

    # 基于 index 获取图像以及标签
    def __getitem__(self, index):
        image_id = self.image_ids[index]
        
        # get image and normalize
        image_path = self._root + '/' + image_id + '.jpg'
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0
        
        # 从csv 读取标签信息
        records = self._csv[self._csv['image_id'] == image_id]
        # 获取 bbox
        boxes = records[['x', 'y', 'w', 'h']].values

        # 转换成模型输入需要的格式 [x0,y0,w,h] -> [x0,y0,x1,y1]
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        
        # 这里只有一类的目标定位问题，标签数量就是 bbox 的数量
        labels = torch.ones((boxes.shape[0],), dtype=torch.int64)
        
         # 多做几次图像增强，防止有图像增强失败，如果成功，则直接返回。
        if self.transforms:
            for i in range(10):
                sample = self.transforms(**{
                    'image': image,
                    'bboxes': boxes,
                    'labels': labels
                })
                if len(sample['bboxes']) > 0:
                    image = sample['image']
                    boxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
                    boxes[:, [0, 1, 2, 3]] = boxes[:, [1, 0, 3, 2]]  # yxyx: be warning
                    break
        
        return image, boxes, labels, image_id

### example

In [4]:
import pandas as pd

In [5]:
csv = pd.read_csv('D:/Dataset/wheat_detection/train.csv')
csv.head(5)

Unnamed: 0,image_id,width,height,bbox,source
0,b6ab77fd7,1024,1024,"[834.0, 222.0, 56.0, 36.0]",usask_1
1,b6ab77fd7,1024,1024,"[226.0, 548.0, 130.0, 58.0]",usask_1
2,b6ab77fd7,1024,1024,"[377.0, 504.0, 74.0, 160.0]",usask_1
3,b6ab77fd7,1024,1024,"[834.0, 95.0, 109.0, 107.0]",usask_1
4,b6ab77fd7,1024,1024,"[26.0, 144.0, 124.0, 117.0]",usask_1


In [None]:
# bbox 字符串转 ndarray