## TxT文件的DataLoader

### txt 文件格式

![](./img/2.png)

image_id bbox class

In [1]:
import torch
import torch.utils.data as data

In [2]:
import os
import cv2

class TxTDataset(data.Dataset):
    def __init__(self, img_root, txt_file, train, transform, input_size):
        '''
        Args:
          root: (str) ditectory to images.
          list_file: (str) path to index file.
          train: (boolean) train or test.
          transform: ([transforms]) image transforms.
          input_size: (int) model input size.
        '''
        self._img_root = img_root
        self._txt_file = txt_file
        self._train = train
        self._transform = transform
        self._input_size = input_size
        
        # 数据列表
        self._fnames = [] # 存储图像 id 列表
        self._boxes = [] # 存储 bbox 列表
        self._labels = [] # 存储 labels 列表
        self._num_samples = 0
        
        self._data_accumulate()
    
    def _data_accumulate(self):
        with open(list_file) as f:
            lines = f.readlines()
            self.num_samples = len(lines)

        for line in lines:
            splited = line.strip().split()
            self._fnames.append(splited[0])
            num_boxes = (len(splited) - 1) // 5
            box = []
            label = []
            for i in range(num_boxes):
                xmin = splited[1+5*i]
                ymin = splited[2+5*i]
                xmax = splited[3+5*i]
                ymax = splited[4+5*i]
                c = splited[5+5*i]
                box.append([float(xmin),float(ymin),float(xmax),float(ymax)])
                label.append(int(c))
            self._boxes.append(torch.Tensor(box))
            self._labels.append(torch.LongTensor(label))
    
    def __getitem__(self, idx):
        '''Load image.

        Args:
          idx: (int) image index.

        Returns:
          img: (tensor) image tensor.
          loc_targets: (tensor) location targets.
          cls_targets: (tensor) class label targets.
        '''
        # Load image and boxes.
        fname = self._fnames[idx]
        image = cv2.imread(os.path.join(self._img_root, fname), cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        boxes = self.boxes[idx].clone()
        labels = self.labels[idx]

        # Data augmentation.
        if self.transforms:
            for i in range(10):
                sample = self.transforms(**{
                    'image': image,
                    'bboxes': boxes,
                    'labels': labels
                })
                if len(sample['bboxes']) > 0:
                    image = sample['image']
                    boxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
                    boxes[:, [0, 1, 2, 3]] = boxes[:, [1, 0, 3, 2]]  # yxyx: be warning
                    break

        return img, boxes, labels
    
    def __len__(self):
        return self.num_samples
    
    def collate_fn(self, batch): # 加载一个batch的图片  [[img, boxes, labels],[img, boxes, labels]]
        '''Pad images and encode targets.

        As for images are of different sizes, we need to pad them to the same size.

        Args:
          batch: (list) of images, cls_targets, loc_targets.

        Returns:
          padded images, stacked cls_targets, stacked loc_targets.
        '''
        imgs = [x[0] for x in batch]
        boxes = [x[1] for x in batch]
        labels = [x[2] for x in batch]

        h = w = self._input_size
        num_imgs = len(imgs)
        inputs = torch.zeros(num_imgs, 3, h, w) # 存储一个batch的图片

        loc_targets = []
        cls_targets = []
        for i in range(num_imgs):
            inputs[i] = imgs[i]
            loc_targets.append(boxes[i])
            cls_targets.append(labels[i])
        return inputs, torch.stack(loc_targets), torch.stack(cls_targets)