In [159]:
from torch.utils.data import Dataset

In [273]:
def flatten(func):
    """flats a nested list to a normal list if the sublists have the lenght of 1"""
    def wrapper(*args):
        nested_list = func(*args)
        if isinstance(nested_list[0], list) and len(nested_list[0]) == 1:
            return [elm for sublist in nested_list for elm in sublist]
        return nested_list
    return wrapper



class CUBDataset(Dataset):
    def __init__(self, seq_size: int, root: str = "datasets/CUB_200_2011/"):
        self.seq_size = seq_size
        self.root = root
        images = self._read_file("images.txt")
        labels = self._read_file("image_class_labels.txt", True)
        train_test = self._read_file("train_test_split.txt", True)
        bounding_boxes = self._read_file("bounding_boxes.txt", True)
        
        # check if all lists have the same lenght.
        assert all(len(images) == len(l) for l in [labels, train_test, bounding_boxes])
        
        self.train = [(img, bb) for img, x, bb in zip(images, train_test, bounding_boxes) if x]
        self.test = [(img, bb) for img, x, bb in zip(images, train_test, bounding_boxes) if not x]
        
        # check if training and test data is a smaller subset of the dataset.
        assert all(len(images) > len(l) for l in [self.train, self.test])
        
        #TODO add image usage
        
        
    @flatten
    def _read_file(self, file: str, as_int: bool = False):
        data = [line.split()[1:] for line in open(self.root + file)]
        if as_int:
            data = [[int(float(elm)) for elm in sublist] for sublist in data]
        return data
    
    
    def __len__(self):
        return len(self.train) + len(self.test)
    
    
    def __getitem__(self):
        pass

In [274]:
cub = CUBDataset(50)

In [None]:
filepath = "datasets/CUB_200_2011/"
paths_with_col_names = {
    "images.txt": ("image_id", "image_name"),
    "train_test_split.txt": ("image_id", "is_training_image"),
    "image_class_labels.txt": ("image_id", "class_id"),
    "classes.txt": ("class_id", "class_name"),
    "bounding_boxes.txt": ("image_id", "bb_x", "bb_y", "bb_width", "bb_height"),
    "parts/part_locs.txt": ("image_id", "part_id", "p_x", "p_y", "p_visible"),
    "parts/parts.txt": ("part_id", "part_name"),
    "parts/part_click_locs.txt": ("image_id", "part_id", "p_mturk_x", "p_mturk_y", "p_mturk_visible", "p_mturk_time"),
    #"attributes/attributes.txt": ("attribute_id", "attribute_name"),
    #"attributes/certainties.txt": ("certainty_id", "certainty_name"),
    #"attributes/image_attribute_labels.txt": ("image_id", "attribute_id", "att_is_present", "certainty_id", "att_time", "_1", "_2"),
    #"attributes/class_attribute_labels_continuous.txt": ()
}

In [167]:
import pandas as pd
from collections import defaultdict
from functools import reduce

In [183]:
with open(filepath + "attributes/image_attribute_labels.txt", "r") as f:
    x = f.readlines()

In [184]:
pd.DataFrame([i.split() for i in x]).loc[:,5].unique()

KeyboardInterrupt: 

In [None]:
pd.read_csv(filepath + "attributes/image_attribute_labels.txt", sep=r"(?<=\d)\s")

In [187]:
dataframes = [
    pd.read_csv(filepath + path, sep=r"(?<=\d)\s", names=col_names)
        .set_index([col for col in col_names if col.endswith("_id")])
    for path, col_names in paths.items()]

  after removing the cwd from sys.path.


In [None]:
df = reduce(lambda df1, df2: pd.merge(df1, df2, left_index=True, right_index=True), dataframes)

In [None]:
df