In [1]:
import os
import lmdb
import tkinter as tk
from tkinter import filedialog
import glob
import cv2
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import pickle
from PIL import Image
import random

In [2]:
# BASE_DIR = os.getcwd()
# root = tk.Tk()
# root.withdraw()

# FOLDER_PATH = filedialog.askdirectory(title="Select directory containing images")
IMAGES_PATH = "/mnt/d/Datasets/PVDN/images/"
LABELS_PATH = "/mnt/d/Datasets/PVDN/labels/"

In [3]:
PKL_PATH = "/".join(str(LABELS_PATH).split("/")[:-1]) + "/labels_test.pkl"
# print(pkl_path)
DB_PATH = Path("/".join(str(IMAGES_PATH).split("/")[:-1]) + "/images_lmdb_test")
# print(lmdb_path)
DB_PATH.mkdir(parents=True, exist_ok=True)
DB_PATH = DB_PATH.as_posix()

print("Reading pickel labels from %s" % PKL_PATH)
print("Reading LMDB images from %s" % DB_PATH)

Reading pickel labels from /mnt/d/Datasets/PVDN/labels/labels_test.pkl
Reading LMDB images from /mnt/d/Datasets/PVDN/images/images_lmdb_test


In [8]:
def resize(image, size):
    image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
    return image

In [9]:
class my_dataset_LMDB(Dataset):
    def __init__(self, list_path, img_size=416, multiscale=True, transform=None):
        self.db_path = DB_PATH
        self.pkl_path = PKL_PATH
        self.key_labels = {}

        with open(list_path, "r") as file:
            self.img_files = file.readlines()
        self.image_labels = [path.split("/")[-1].split(".")[0] for path in self.img_files]

        # Delay loading LMDB data until after initialization to avoid "can't pickle Environment Object error"
        self._init_db()
        self._init_pkl()

        self.img_size = img_size
        self.max_objects = 100
        self.multiscale = multiscale
        self.min_size = self.img_size - 3 * 32
        self.max_size = self.img_size + 3 * 32
        self.batch_count = 0
        self.transform = transform
        

    def _init_db(self):
        self.env = lmdb.open(self.db_path, subdir=os.path.isdir(self.db_path),
            readonly=True, lock=False,
            readahead=False, meminit=False)
        self.txn = self.env.begin()
        
    
    def _init_pkl(self):
        with open(self.pkl_path, 'rb') as handle:
            self.key_labels = pickle.load(handle)

    def read_lmdb(self, key):
        lmdb_data = self.txn.get(key.encode("ascii"))
        lmdb_data = np.frombuffer(lmdb_data, dtype=np.uint8)
        lmdb_data = cv2.imdecode(lmdb_data, cv2.IMREAD_COLOR)
        pil_image = np.array(Image.fromarray(lmdb_data).convert('RGB'), dtype=np.uint8)
        return pil_image

    def read_pkl(self, key):
        data_str = self.key_labels[key].decode('utf-8').split('\n')
        result = np.array([list(map(float, s.split())) for s in data_str if s])
        return result

    def __getitem__(self, index):
        img = self.read_lmdb(self.image_labels[index])
        boxes = self.read_pkl(self.image_labels[index])
        print(img, boxes)
        # -----------
        #  Transform
        # -----------
        if self.transform:
            try:
                img, bb_targets = self.transform((img, boxes))
            except Exception:
                print("Could not apply transform.")
                return
        return None, img, bb_targets

    def collate_fn(self, batch):
        self.batch_count += 1

        # Drop invalid images
        batch = [data for data in batch if data is not None]

        __, imgs, bb_targets = list(zip(*batch))

        # Selects new image size every tenth batch
        if self.multiscale and self.batch_count % 10 == 0:
            self.img_size = random.choice(
                range(self.min_size, self.max_size + 1, 32))

        # Resize images to input shape
        imgs = torch.stack([resize(img, self.img_size) for img in imgs])

        # Add sample index to targets
        for i, boxes in enumerate(bb_targets):
            boxes[:, 0] = i
        bb_targets = torch.cat(bb_targets, 0)

        return __, imgs, bb_targets
    
    def __len__(self):
        return len(self.image_labels)