In [None]:
import cv2
import pandas as pd
import os
from torch.utils import data

In [None]:
DATASET_FOLDER = '/Users/DrMatters/Documents/git/vision/data/datasets/market1501/Market-1501-v15.09.15/'

In [None]:
import os
from typing import FrozenSet

import albumentations as albu
import pandas as pd
from torch.utils import data


class MarketDataset(data.Dataset):
    REQUIRED_COLUMNS: FrozenSet[str] = frozenset({'filename', 'pers_id'})

    def __init__(self, folder: str, index_df: pd.DataFrame = None,
                 transforms=albu.Compose([albu.HorizontalFlip()]),
                 preprocessing=None):
        assert set(index_df.columns).issubset(self.REQUIRED_COLUMNS), \
            ('Required columns are not present in dataframe. Expected:'
             f' {self.REQUIRED_COLUMNS}. Got: {set(index_df.columns)}')
        self.folder = folder
        self.index_df = index_df
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __getitem__(self, idx):
        row = self.index_df.loc[idx, :]
        filename = row['filename']
        person_id = row['pers_id']

        full_path = os.path.join(self.folder, filename)

        img = cv2.imread(full_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        augmented = self.transforms(image=img)
        img = augmented['image']
        if self.preprocessing:
            preprocessed = self.preprocessing(image=img)
            img = preprocessed['image']
        return img, person_id

    def __len__(self):
        return len(self.index_df)


In [None]:
def identity_collate(data):
    return data

def get_filenames_in_current(folder: str):
    walk = os.walk(folder)
    for current_catalog, sub_catalogs, files in walk:
        if current_catalog == folder:
            return sorted(files)

def create_index_df(folder: str):
    filenames = pd.Series(get_filenames_in_current(folder))
    filenames = filenames[filenames.str.endswith('jpg')]
    index_df = filenames.str.split('_', expand=True, n=2)
    index_df = index_df.iloc[:, :1]
    index_df = index_df.rename(columns={
        0: 'pers_id', 1: 'env_descr',
        2: 'orig_id'
    })
    index_df['filename'] = filenames

    return index_df

In [None]:
index_df = create_index_df(DATASET_FOLDER + 'bounding_box_test')
index_df.head()

In [None]:
ds = MarketDataset(DATASET_FOLDER + 'bounding_box_test', index_df)

In [None]:
ld = data.DataLoader(ds, drop_last=True, collate_fn=identity_collate)

In [None]:
for i, img in enumerate(ld):
    print(img[0][0].shape)
    if i == 0:
        break