In [50]:
import glob
import math
import os
import random
import shutil
from pathlib import Path
import torchvision.transforms as transforms

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

In [17]:
# for training and testing

class TrainImgLabels(Dataset):
    def __init__(self, img_file, batch_size=16, augment=False, class_weights=False):
        with open(img_file, "r") as fin:
            self.img_paths = list(fin.read().splitlines())
        n = len(self.img_paths)
        batch_idx = np.floor(np.arange(n) / batch_size).astype(np.int)
        num_batch = batch_idx[-1] + 1
        
        self.n = n
        self.batch = batch_idx
        self.augment = augment
        self.class_weights = class_weights
        
        self.label_files = [x.replace('image', 'bbox').               #################what is this for#######
                        replace('.jpeg', '.txt').
                        replace('.jpg', '.txt').
                        replace('.bmp', '.txt').
                        replace('.png', '.txt') for x in self.img_paths] 
        
        # preload labels
        self.imgs = [None] * n
        self.labels = [np.zeros((0, 5))] * n
        iter = tqdm(self.label_files, desc='Reading labels')
        for i, file in enumerate(iter):
            try:
                with open(file, 'r') as f:
                    l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
                    if l.shape[0]:
                        assert l.shape[1] == 5, '> 5 label columns: %s' % file
                        assert (l >= 0).all(), 'negative labels: %s' % file
                        assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
                        self.labels[i] = l
            except FileNotFoundError:
                pass  # print('Warning: missing labels for %s' % self.img_files[i])  # missing label file
        assert len(np.concatenate(self.labels, 0)) > 0, 'No labels found. Incorrect label paths provided.'
        
        
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, index):
        if self.class_weights:  #####################################CLASS WEIGHT###################
            index = self.indices[index] 
        
        # images
        img_path = self.img_paths[index]
        img = cv2.imread(img_path)
        img = img[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) # BGR to RGB
        ##########################WHY NOT MIN MAX NORMALIZATION###########################
        img = img / 255.0 
        #############################################################################
        assert img is not None, "Image not found" + img_path
        
        # load labels
        label_path = self.label_files[index]
        labels = []
#         if os.path.isfile(label_path):
#             x = self.labels[index]
        nL = len(labels)
            
        
        return torch.from_numpy(img), label
        

In [24]:
tmp = TrainImgLabels('./data/img.txt')



Reading labels:   0%|          | 0/5 [00:00<?, ?it/s][A[A

Reading labels: 100%|██████████| 5/5 [00:00<00:00, 4385.51it/s][A[A

In [29]:
tmp[0][0].shape

torch.Size([3, 360, 640])

In [62]:
class Image(Dataset):
    def __init__(self, img_txt):
        #self.files = sorted(glob.glob("%s/*.*" % folder_path))
        with open(img_txt, "r") as fin:
            self.img_paths = list(fin.read().splitlines())
        #self.img_size = img_size
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, index):
        img_path = self.img_paths[index % len(self.img_paths)]
        # Extract image as PyTorch tensor
        img = cv2.imread(img_path)
        img = img[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) # BGR to RGB
        ############################### /255 for normalization for now################
        img /= 255
        #############################################################################
#         # Pad to square resolution
#         img, _ = pad_to_square(img, 0)
#         # Resize
#         img = resize(img, self.img_size)

        return img_path, img

class Label(Dataset):
    def __init__(self, label_txt):
        with open(label_txt, "r") as fin:
            self.label_paths = list(fin.read().splitlines())
    
    def __len__(self):
        return len(self.label_paths)
    
    def __getitem__(self, index):
        label_path = self.label_paths[index % len(self.label_paths)]
        

In [63]:
result = Image('./data/img.txt')


In [None]:
# for inference

In [None]:
# for web cam