In [None]:
from torch.utils.data import Dataset # 导入PyTorch标准数据集基类
import cv2               
import torch
import os      
import matplotlib.pyplot as plt                   

class MyData(Dataset): 

    def __init__(self, root_dir, img_dir, label_dir):
        self.root_dir = root_dir   
        self.img_dir = img_dir
        self.label_dir = label_dir
        
        self.img_path = os.path.join(self.root_dir, self.img_dir) 
        self.label_path = os.path.join(self.root_dir, self.label_dir) 
        self.img_list = os.listdir(self.img_path) 
        self.label_list = os.listdir(self.label_path) 

    def __getitem__(self, idx):
        
        img_name = self.img_list[idx] 
        img_item_path = os.path.join(self.img_path, img_name)
        img = cv2.imread(img_item_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 把后缀 .jpg 替换成 .txt
        label_name = img_name.split('.')[0] + '.txt'
        label_item_path = os.path.join(self.label_path, label_name)
        label = self.label_dir
       
        # 从 txt 文件中读取内容
        # 使用 try-except 是为了防止有些图片忘了写标签导致程序崩溃
        label_content = ""
        try:
            with open(label_item_path, 'r') as f:
                label_content = f.read() # 读取txt里的内容，比如坐标信息
        except FileNotFoundError:
            label_content = "No Label Found"
            
        return img, label_content

    def __len__(self):
        # 返回数据集大小方法：告诉程序总共有多少张图
        return len(self.img_list)


# --- 实例化测试 ---
root_dir = "../test_jpg"          
robots_img_dir = "robots"
cars_img_dir = "cars"              
robots_label_dir = "robots_label"
cars_label_dir = "cars_label"

robots_dataset = MyData(root_dir, robots_img_dir,robots_label_dir) # 实例化 MyData 类，创建一个名为 ants_dataset 的对象
cars_dataset = MyData(root_dir, cars_img_dir, cars_label_dir)

print(len(robots_dataset))

img_robot,label_robot = robots_dataset[0] 
print(f"标签内容: {label_robot}")
plt.imshow(img_robot)         
plt.show()

img_car,label_car = cars_dataset[1] 
print(f"标签内容: {label_car}")
plt.imshow(img_car)          
plt.show()

