# Create datasets

Use pre-prepared image downloads to create datasets with train/test splits for model

## Config

In [365]:
import pandas as pd
import sqlalchemy as sqa
from sklearn.model_selection import train_test_split
import json
from typing import Dict, Tuple
from dataclasses import dataclass, field, asdict
from enum import Enum, auto
from numbers import Number
from pathlib import Path
import os
import shutil
import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.propagate = False

In [2]:
!python --version

Python 3.6.9


In [13]:
# PARAMETERS
# Database parameters
db_container = "metadata_db"
db_user = "pguser"
db_password = "pgpassword"
db_port = 5432
db_database = "metadata"
db_prefix = "postgresql"

metadata_tbl = "base_images"
datasets_tbl = "datasets"
dataset_img_tbl = "dataset_images"

In [18]:
db_con_str = f"{db_prefix}://{db_user}:{db_password}@{db_container}:{db_port}/{db_database}"
db_engine = sqa.create_engine(db_con_str)

In [145]:
class Image_Label_Method(Enum):
    BINARY: 'Image_Label_Method' = auto()
    MULTI: 'Image_Label_Method' = auto()

@dataclass
class Image_Dataset_Config:
    """
    Specifies config for creating an image dataset:
     - name: name of the dataset
     - target_dir: directory for storing the dataset
     - description
     - validation frac: frac of TRAINING data to use as training dataset (if 0, no validation set created)
     - test_frac: frac of ALL data to use as testing data (if 0, no test set created)
     - label_method: different labelling methods implemented:
         - BINARY: All labels just positive or negative
         - MULTI: Each datapoint is labelled by its source
     - label_counts: Dictionary of ((label str, source_name), count) pairs - how many data points from each data source to use
        (if count is -1 or is greater than total available, just use all of them)
        
    """
    name: str
    target_dir: str
    label_counts: Dict[Tuple[str, str], int] = field(default_factory=dict)
    class_names: Dict[Tuple[str, str], str] = field(default_factory=dict)
    description: str = ""
    validation_frac: float = 0.2
    test_frac: float = float(0)
    label_method: Image_Label_Method = Image_Label_Method.BINARY
    
    def __post_init__(self):
        assert len(self.label_counts) > 0, \
            "Must include counts for at least one source + label"
        assert self.label_counts.keys() == self.class_names.keys(), \
            "Must include class name to match each class label count"

In [372]:
class Image_Dataset_Builder():
    """
    Class for creating a dataset from a metadata table and configs
    db_engine: database engine attached to metadata database
    dataset_config: specifies how to create dataset (see Image_Dataset_Config class)
    """
    
    
    def __init__(self,
                 db_engine: sqa.engine.Engine,
                 config: Image_Dataset_Config,
                 random_state: int = 101,
                 db_metadata_tbl: str = metadata_tbl,
                 db_datasets_tbl: str = datasets_tbl,
                 db_dataset_img_tbl: str = dataset_img_tbl,
                 ):
        self.db_metadata_tbl = db_metadata_tbl
        self.db_datasets_tbl = db_datasets_tbl
        self.db_dataset_img_tbl = db_dataset_img_tbl
        self.config = config
        self.engine = db_engine
        self.random_state = random_state   # For sampling
        with self.engine.connect() as con:
            self.df_img: pd.DataFrame = pd.read_sql(self.db_metadata_tbl, con, index_col='image_name')
            self.df_img = self.df_img[self.df_img['read'] == True]  # We only care about usable images now
        self._create_dataset_dfs()
    
    
    # Use config to create dataframe of desired image outputs
    def _create_dataset_dfs(self):
        # Datasets dataframe
        self.df_ds = pd.DataFrame([
            asdict(self.config)
        ]).set_index("name").drop(['label_counts', 'class_names'], axis=1)
        self.df_ds['label_method'] = self.df_ds['label_method'].map(lambda lm: lm.name)
        
        # Dataset Images dataframe
        self.df_ds_img = self.df_img \
            .groupby(['label_str', 'source']) \
            .apply(lambda df: df.sample(0 if df.name not in self.config.label_counts.keys()
                                        else len(df.index) if (self.config.label_counts[df.name] == -1)
                                        else min(len(df.index), self.config.label_counts[df.name]),
                                        replace=False, random_state=self.random_state)) \
            [[]].reset_index().set_index('image_name')
        self.df_ds_img['class_name'] = self.df_ds_img.apply(
            lambda row: self.config.class_names[(row['label_str'], row['source'])], axis=1)
        
        if self.config.label_method == Image_Label_Method.BINARY:
            self.df_ds_img['class_label'] = self.df_ds_img['label_str']
        elif self.config.label_method == Image_Label_Method.MULTI:
            self.df_ds_img['class_label'] = self.df_ds_img['class_name']
        else:
            self.df_ds_img['class_label'] = self.df_ds_img['label_str']
        label_value_dict = {label: i for i, label in enumerate(self.df_ds_img['class_label'].unique())}
        self.df_ds_img['class_value'] = self.df_ds_img['class_label'].map(label_value_dict)
        self.df_ds_img['dataset_name'] = self.config.name
        self.df_ds_img['dataset_img_path'] = None
        self.df_ds_img['split'] = None
        self.df_ds_img = self.df_ds_img.drop(['label_str', 'source'], axis=1)
        
    
    # Assign each data point to a train/test/validation split
    def _assign_splits(self):
        train_X = self.df_ds_img.index
        val_X = []
        test_X = []
        
        if self.config.test_frac > 0:
            train_X, test_X = train_test_split(self.df_ds_img.index,
                                               test_size=self.config.test_frac,
                                               stratify=self.df_ds_img['class_label'],
                                               random_state=self.random_state)
        if self.config.validation_frac > 0:
            train_X, val_X = train_test_split(self.df_ds_img.index,
                                              test_size=self.config.validation_frac,
                                              stratify=self.df_ds_img['class_label'],
                                              random_state=self.random_state)
        
        self.df_ds_img.loc[train_X, 'split'] = 'train'
        self.df_ds_img.loc[test_X, 'split'] = 'test'
        self.df_ds_img.loc[val_X, 'split'] = 'validation'
        
        self.df_ds_img['dataset_img_path'] = self.df_ds_img.apply(
            lambda row: str(Path(self.config.target_dir,
                                 row['split'],
                                 row['class_label'],
                                 row.name + '.jpg')),
        axis=1)
        
    # Clear data associated with this dataset from dataset (or create tables if they don't exist)
    def _clear_db_data(self):
        with self.engine.connect() as con:
            db_ds_tbl_exists = con.execute(self._query_check_tbl_exists(
                self.db_datasets_tbl)).fetchall()[0][0]
            db_ds_img_tbl_exists = con.execute(self._query_check_tbl_exists(
                self.db_dataset_img_tbl)).fetchall()[0][0]
                
            
            if db_ds_tbl_exists:
                con.execute(self._query_drop_datasets_rows)
            else:
                con.execute(self._query_create_dataset_tbl)
            if not db_ds_img_tbl_exists:
                con.execute(self._query_create_dataset_img_tbl)
    
    # Create (if not exists) or clear (if exists) data from target directory
    def _clear_target_dir(self):
        if os.path.exists(self.config.target_dir):
            for root, dirs, files in os.walk(self.config.target_dir):
                for f in files:
                    os.unlink(os.path.join(root, f))
                for d in dirs:
                    shutil.rmtree(os.path.join(root, d))
        else:
            Path(self.config.target_dir).mkdir(parents=True)
        
        for split in self.df_ds_img['split'].unique():
            for class_label in self.df_ds_img['class_label']:
                Path(self.config.target_dir, split, class_label).mkdir(exist_ok=True, parents=True)
            
    
    # Copy files from image dataframe from source to target directory
    def _copy_image_files(self):
        for img_name, row in self.df_ds_img.iterrows():
            shutil.copyfile(self.df_img.loc[img_name, 'final_loc'],
                            row['dataset_img_path'])
    
    # Store associated data to the metadata db (dataset table and dataset images table)
    def _store_db_metadata(self):
        with self.engine.connect() as con:
            self.df_ds.to_sql(self.db_datasets_tbl, con, if_exists='append')
            self.df_ds_img.to_sql(self.db_dataset_img_tbl, con, if_exists='append')
    
    
    # Builds the dataset by running the other methods
    def build_dataset(self):
        success_flag = True
        self._assign_splits()
        logging.info("Splits assigned")
        self._clear_db_data()
        logging.info("Database cleared")
        self._clear_target_dir()
        logging.info("Directory cleared")
        try:
            self._copy_image_files()
            logging.info("Images copied")
            self._store_db_metadata()
            logging.info("Data stored to db")
        except:
            logging.error("Dataset copy unsuccessful, removing data from db and directory")
            success_flag = False
            self._clear_db_data()
            self._clear_target_dir()
        return(success_flag)
            
    
    
    # QUERY HELPERS
    @property
    def _query_create_dataset_tbl(self):
        return(f"""
               CREATE TABLE {self.db_datasets_tbl} (
                   name VARCHAR(100) UNIQUE NOT NULL,
                   target_dir VARCHAR(300) NOT NULL,
                   description TEXT,
                   validation_frac NUMERIC NOT NULL,
                   test_frac NUMERIC NOT NULL,
                   label_method VARCHAR(50) NOT NULL,
                   PRIMARY KEY(name)
               )
               """)
    @property
    def _query_create_dataset_img_tbl(self):
        return(f"""
               CREATE TABLE {self.db_dataset_img_tbl} (
                   image_name VARCHAR(300) NOT NULL,
                   dataset_name VARCHAR(100) NOT NULL,
                   class_name VARCHAR(100) NOT NULL,
                   class_label VARCHAR(100) NOT NULL,
                   class_value INT NOT NULL,
                   split VARCHAR(30) NOT NULL,
                   dataset_img_path VARCHAR(500) NOT NULL,
                   PRIMARY KEY (image_name, dataset_name),
                   FOREIGN KEY (image_name) REFERENCES {self.db_metadata_tbl}(image_name),
                   FOREIGN KEY (dataset_name) REFERENCES {self.db_datasets_tbl}(name) ON DELETE CASCADE
               )
               """)
    
    @staticmethod
    def _query_check_tbl_exists(tbl_name):
        return(f"""
                SELECT EXISTS(
                    SELECT * FROM information_schema.tables
                    WHERE table_name = '{tbl_name}'
                )
               """)
    @property 
    def _query_datasets_exists(self):
        return(self._query_check_tbl_exists(self.db_datasets_tbl))
    @property 
    def _query_dataset_imgs_exists(self):
        return(self._query_check_tbl_exists(self.db_dataset_img_tbl))
    
    @staticmethod
    def _query_drop_col_values_from_tbl(tbl_name, col_name, col_value):
        return(f"""
               DELETE FROM {tbl_name}
               WHERE {col_name} = {("'" + col_value + "'") if isinstance(col_value, str) else col_value}
               """)
    @property 
    def _query_drop_datasets_rows(self):
        return(self._query_drop_col_values_from_tbl(self.db_datasets_tbl, "name", self.config.name))
    @property 
    def _query_drop_dataset_imgs_rows(self):
        return(self._query_drop_col_values_from_tbl(self.db_dataset_img_tbl,
                                                    "image_name",
                                                    self.config.name))
    
        

In [373]:
df = pd.DataFrame([
    (1,2,3),
    (4,5,6)
], columns=['a', 'b', 'c']).set_index('a')
df.apply(lambda row: (row.name, row['c']), axis=1)

a
1    (1, 3)
4    (4, 6)
dtype: object

In [374]:
with db_engine.connect() as con:
    cols = pd.read_sql_query(f"SELECT source, label_str FROM {metadata_tbl}", con)
    unique_src_label_pairs = set(
        tuple(record) for record in pd.read_sql_query(f"SELECT label_str, source FROM {metadata_tbl}",
                                                      con).to_records(index=False)
    )
    # sources = [*pd.read_sql_query(f"SELECT source FROM {metadata_tbl}", con)['source'].unique()]
unique_src_label_pairs

{('negative', 'Google Images'),
 ('negative', 'Imagenet'),
 ('negative', 'Plantnet'),
 ('positive', 'Google Images')}

In [383]:
test_bin_label_counts = {pair: 50 for pair in unique_src_label_pairs}
test_bin_label_names = {
    ('negative', 'Google Images'): "negative_similar_plant",
    ('negative', 'Imagenet'): "negative_random_picture",
    ('negative', 'Plantnet'): "negative_general_plant",
    ('positive', 'Google Images'): "positive",
}
test_bin_config = Image_Dataset_Config(
    name="test_bin_example",
    target_dir="../datasets/pipeline_v1/test_bin_dataset",
    label_counts=test_bin_label_counts,
    class_names=test_bin_label_names,
    label_method=Image_Label_Method.BINARY,
    validation_frac=0.2,
    test_frac=0,
    description="First example with 50 values from each class and binary labeling"
)

test_bin_ds = Image_Dataset_Builder(db_engine, test_bin_config)
test_bin_ds.build_dataset()

INFO:root:Splits assigned
INFO:root:Database cleared
INFO:root:Directory cleared
INFO:root:Images copied
INFO:root:Data stored to db


True

In [385]:
test_multi_label_counts = {pair: 50 for pair in unique_src_label_pairs}
test_multi_label_names = {
    ('negative', 'Google Images'): "negative_similar_plant",
    ('negative', 'Imagenet'): "negative_random_picture",
    ('negative', 'Plantnet'): "negative_general_plant",
    ('positive', 'Google Images'): "positive",
}
test_multi_config = Image_Dataset_Config(
    name="test_multi_example",
    target_dir="../datasets/pipeline_v1/test_multi_dataset",
    label_counts=test_multi_label_counts,
    class_names=test_multi_label_names,
    label_method=Image_Label_Method.MULTI,
    validation_frac=0.2,
    test_frac=0,
    description="First example with 50 values from each class and multiclass labeling"
)

test_multi_ds = Image_Dataset_Builder(db_engine, test_multi_config)
test_multi_ds.build_dataset()

INFO:root:Splits assigned
INFO:root:Database cleared
INFO:root:Directory cleared
INFO:root:Images copied
INFO:root:Data stored to db


True

In [377]:
all_multi_label_counts = {pair: -1 for pair in unique_src_label_pairs}
all_multi_label_names = {
    ('negative', 'Google Images'): "negative_similar_plant",
    ('negative', 'Imagenet'): "negative_random_picture",
    ('negative', 'Plantnet'): "negative_general_plant",
    ('positive', 'Google Images'): "positive",
}
all_multi_config = Image_Dataset_Config(
    name="all_v1_multiclass",
    target_dir="../datasets/pipeline_v1/all_v1_multiclass",
    label_counts=all_multi_label_counts,
    class_names=all_multi_label_names,
    label_method=Image_Label_Method.MULTI,
    validation_frac=0.2,
    test_frac=0,
    description="All images in v1 dataset with multiclass labels"
)

all_multi_ds = Image_Dataset_Builder(db_engine, all_multi_config)
all_multi_ds.build_dataset()

INFO:root:Splits assigned
INFO:root:Database cleared
INFO:root:Directory cleared
INFO:root:Images copied
INFO:root:Data stored to db


True

In [381]:
print(len(os.listdir("../datasets/pipeline_v1/all_v1_multiclass/train/positive")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_multiclass/train/negative_general_plant")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_multiclass/train/negative_similar_plant")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_multiclass/train/negative_random_picture")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_multiclass/validation/positive")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_multiclass/validation/negative_general_plant")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_multiclass/validation/negative_similar_plant")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_multiclass/validation/negative_random_picture")))

1764
1934
1762
3138
441
483
441
785


In [378]:
all_binary_label_counts = {pair: -1 for pair in unique_src_label_pairs}
all_binary_label_names = {
    ('negative', 'Google Images'): "negative_similar_plant",
    ('negative', 'Imagenet'): "negative_random_picture",
    ('negative', 'Plantnet'): "negative_general_plant",
    ('positive', 'Google Images'): "positive",
}
all_binary_config = Image_Dataset_Config(
    name="all_v1_binary",
    target_dir="../datasets/pipeline_v1/all_v1_binary",
    label_counts=all_binary_label_counts,
    class_names=all_binary_label_names,
    label_method=Image_Label_Method.BINARY,
    validation_frac=0.2,
    test_frac=0,
    description="All images in v1 dataset with binary"
)

all_binary_ds = Image_Dataset_Builder(db_engine, all_binary_config)
all_binary_ds.build_dataset()

INFO:root:Splits assigned
INFO:root:Database cleared
INFO:root:Directory cleared
INFO:root:Images copied
INFO:root:Data stored to db


True

In [380]:
print(len(os.listdir("../datasets/pipeline_v1/all_v1_binary/train/positive")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_binary/train/negative")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_binary/validation/positive")))
print(len(os.listdir("../datasets/pipeline_v1/all_v1_binary/validation/negative")))

1764
6834
441
1709
