In [121]:
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import glob
from os.path import join as pjoin
import cv2

In [123]:
class QRCodeDataset(Dataset):
    def __init__(self, annotations_dir, img_dir, predefined_class_file):
        self.annotations_dir = annotations_dir
        self.img_dir = img_dir
        self.predefined_class_file = predefined_class_file
        
        self.defined_class = self.read_predefine_class()
        
        self.img_paths = [fname for fname in glob.glob(pjoin(img_dir,'*'))]
        self.img_labels = self.read_yolo_labels()
        
        assert len(self.img_paths) == len(self.img_labels)
        
    def __len__(self):
        return len(self.img_labels)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.img_paths[idx], cv2.COLOR_BGR2GRAY)
        return self.img_labels[idx], image
    
    def read_yolo_labels(self):
        # """ 根據 annotations_dir 內的資料 load labels資料 """
        
        def read_txt(fpth):
            # 讀取單個 yolo label file
            res = []
            
            with open(fpth) as f:
                lines = f.readlines()
                lines = [ _.rstrip() for _ in lines] 
                
                for bbox in lines:
                    c, x, y, w, h = [float(_) for _ in bbox.split(' ')]
                    res.append((c, x, y, w, h))
            return res
        # ====================================
        image_labels = []
        for lb_path in glob.glob(pjoin(self.annotations_dir, '*')):
            image_labels += [read_txt(lb_path)]
        
        return image_labels
    
    def read_predefine_class(self):
        cdict = dict()
        with open(self.predefined_class_file) as f:
            lines = f.readlines()
            lines = [ _.rstrip() for _ in lines] 
            for idx, cname in enumerate(lines):
                cdict[idx]=cname
        return cdict

In [124]:
class QRCodePatchesDataset(Dataset):
    def __init__(self):
        pass
    
    def __len__(self):
        pass
    
    def __getitem__(self, idx):
        return None

In [125]:
# 讀取 yolo format
with open('./data/paper_qr_label_yolo/File 015.txt') as f:
    lines = f.readlines()
    lines = [ _.rstrip() for _ in lines] 
    
    for bbox in lines:
        c, x, y, w, h = [float(_) for _ in bbox.split(' ')]
        print((c, x, y, w, h))

(0.0, 0.3769946808510638, 0.58125, 0.22739361702127658, 0.35)
(0.0, 0.6555851063829787, 0.3333333333333333, 0.22872340425531915, 0.35833333333333334)
(1.0, 0.14893617021276595, 0.7833333333333333, 0.19148936170212766, 0.325)


In [None]:
# yolo to xywh
x1, y1 = x-w/2, y-h/2
x2, y2 = x+w/2, y+h/2
    

In [126]:
# 讀取 predefined file

cdict = dict()
with open('./predefined_classes.txt') as f:
    lines = f.readlines()
    lines = [ _.rstrip() for _ in lines] 
    for idx, cname in enumerate(lines):
        cdict[idx]=cname

In [127]:
qr_code_dataset = QRCodeDataset(annotations_dir="./data/paper_qr_label_yolo", img_dir="./data/paper_qr",\
                                predefined_class_file="./predefined_classes.txt")

In [101]:
annotations_dir = "./data/paper_qr_label_yolo"

len(glob.glob(pjoin(annotations_dir, '*')))

124

In [116]:
# 
image_labels = []
a = [(1,1,2,3,4)]
b = [(2,1,2,3,4),(3,1,2,3,4)]

image_labels = image_labels + b

In [117]:
image_labels

[(2, 1, 2, 3, 4), (3, 1, 2, 3, 4)]

In [118]:
image_labels = image_labels + a

In [119]:
image_labels

[(2, 1, 2, 3, 4), (3, 1, 2, 3, 4), (1, 1, 2, 3, 4)]