In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class CustomSSDDataset(Dataset):
    def __init__(self, annotations_file, transform=None):
        self.annotations_file = annotations_file
        self.transform = transform
        self.data = []
        
        with open(annotations_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                img_path = parts[0]
                boxes = []
                labels = []
                distances = []
                
                for box_info in parts[1:]:
                    x1, y1, x2, y2, label, distance = map(int, box_info.split(','))
                    boxes.append([x1, y1, x2, y2])
                    labels.append(label)
                    distances.append(distance)
                
                self.data.append((img_path, boxes, labels, distances))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, boxes, labels, distances = self.data[idx]
        image = Image.open(img_path).convert("RGB")
        
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)
        distances = torch.tensor(distances, dtype=torch.float32)
        
        if self.transform:
            image, boxes = self.transform(image, boxes)
        
        return image, boxes, labels, distances



In [2]:
# 示例用法
transform = None  # 你可以根據需要添加轉換
dataset = CustomSSDDataset("with_dist/car_truck_train.txt", transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

# 測試數據加載器
dataset[2]

(<PIL.Image.Image image mode=RGB size=1238x374>,
 tensor([[548., 171., 572., 194.],
         [505., 168., 575., 209.],
         [ 49., 185., 227., 246.],
         [328., 170., 397., 204.]]),
 tensor([0, 0, 0, 0]),
 tensor([48., 31., 19., 38.]))

In [9]:
def collate_fn(batch):
    images = []
    boxes = []
    labels = []
    distances = []
    
    for sample in batch:
        images.append(sample[0])
        boxes.append(torch.tensor(sample[1], dtype=torch.float32))
        labels.append(torch.tensor(sample[2], dtype=torch.int64))
        distances.append(torch.tensor(sample[3], dtype=torch.float32))
    
    images = torch.stack(images, dim=0)
    
    return images, boxes, labels, distances

In [10]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

# 測試數據加載器
for images, boxes, labels, distances in dataloader:
    print(images.shape)
    print(boxes)
    print(labels)
    print(distances)
    break

  boxes.append(torch.tensor(sample[1], dtype=torch.float32))
  labels.append(torch.tensor(sample[2], dtype=torch.int64))
  distances.append(torch.tensor(sample[3], dtype=torch.float32))


TypeError: expected Tensor as element 0 in argument 0, but got Image