# Create datasets

Use pre-prepared image downloads to create datasets with train/test splits for model

## Config

In [239]:
import pandas as pd
import sqlalchemy as sqa
from sklearn.model_selection import train_test_split
import json
from typing import Dict, Tuple
from dataclasses import dataclass, field, asdict
from enum import Enum, auto
from numbers import Number
from pathlib import Path

In [2]:
!python --version

Python 3.6.9


In [13]:
# PARAMETERS
# Database parameters
db_container = "metadata_db"
db_user = "pguser"
db_password = "pgpassword"
db_port = 5432
db_database = "metadata"
db_prefix = "postgresql"

metadata_tbl = "base_images"
datasets_tbl = "datasets"
dataset_img_tbl = "dataset_images"

In [18]:
db_con_str = f"{db_prefix}://{db_user}:{db_password}@{db_container}:{db_port}/{db_database}"
db_engine = sqa.create_engine(db_con_str)

In [145]:
class Image_Label_Method(Enum):
    BINARY: 'Image_Label_Method' = auto()
    MULTI: 'Image_Label_Method' = auto()

@dataclass
class Image_Dataset_Config:
    """
    Specifies config for creating an image dataset:
     - name: name of the dataset
     - target_dir: directory for storing the dataset
     - description
     - validation frac: frac of TRAINING data to use as training dataset (if 0, no validation set created)
     - test_frac: frac of ALL data to use as testing data (if 0, no test set created)
     - label_method: different labelling methods implemented:
         - BINARY: All labels just positive or negative
         - MULTI: Each datapoint is labelled by its source
     - label_counts: Dictionary of ((label str, source_name), count) pairs - how many data points from each data source to use
        (if count is -1 or is greater than total available, just use all of them)
        
    """
    name: str
    target_dir: str
    label_counts: Dict[Tuple[str, str], int] = field(default_factory=dict)
    class_names: Dict[Tuple[str, str], str] = field(default_factory=dict)
    description: str = ""
    validation_frac: float = 0.2
    test_frac: float = float(0)
    label_method: Image_Label_Method = Image_Label_Method.BINARY
    
    def __post_init__(self):
        assert len(self.label_counts) > 0, \
            "Must include counts for at least one source + label"
        assert self.label_counts.keys() == self.class_names.keys(), \
            "Must include class name to match each class label count"

## TEMP DS CLASS

In [249]:
class Image_Dataset_Builder():
    """
    Class for creating a dataset from a metadata table and configs
    db_engine: database engine attached to metadata database
    dataset_config: specifies how to create dataset (see Image_Dataset_Config class)
    """
    
    
    def __init__(self,
                 db_engine: sqa.engine.Engine,
                 config: Image_Dataset_Config,
                 random_state: int = 101,
                 db_metadata_tbl: str = metadata_tbl,
                 db_datasets_tbl: str = datasets_tbl,
                 db_dataset_img_tbl: str = dataset_img_tbl,
                 ):
        self.db_metadata_tbl = db_metadata_tbl
        self.db_datasets_tbl = db_datasets_tbl
        self.db_dataset_img_tbl = db_dataset_img_tbl
        self.config = config
        self.engine = db_engine
        self.random_state = random_state   # For sampling
        with self.engine.connect() as con:
            self.df_img: pd.DataFrame = pd.read_sql(self.db_metadata_tbl, con, index_col='image_name')
            self.df_img = self.df_img[self.df_img['read'] == True]  # We only care about usable images now
        self._create_dataset_dfs()
    
    
    # Use config to create dataframe of desired image outputs
    def _create_dataset_dfs(self):
        # Datasets dataframe
        self.df_ds = pd.DataFrame([
            asdict(self.config)
        ]).set_index("name").drop(['label_counts', 'class_names'], axis=1)
        self.df_ds['label_method'] = self.df_ds['label_method'].map(lambda lm: lm.name)
        
        # Dataset Images dataframe
        self.df_ds_img = self.df_img \
            .groupby(['label_str', 'source']) \
            .apply(lambda df: df.sample(0 if df.name not in self.config.label_counts.keys()
                                        else min(len(df.index), self.config.label_counts[df.name]),
                                        replace=False, random_state=self.random_state)) \
            [[]].reset_index().set_index('image_name')
        self.df_ds_img['class_name'] = self.df_ds_img.apply(
            lambda row: self.config.class_names[(row['label_str'], row['source'])], axis=1)
        
        if self.config.label_method == Image_Label_Method.BINARY:
            self.df_ds_img['class_label'] = self.df_ds_img['label_str']
        elif self.config.label_method == Image_Label_Method.MULTI:
            self.df_ds_img['class_label'] = self.df_ds_img['class_name']
        else:
            self.df_ds_img['class_label'] = self.df_ds_img['label_str']
        label_value_dict = {label: i for i, label in enumerate(self.df_ds_img['class_label'].unique())}
        self.df_ds_img['class_value'] = self.df_ds_img['class_label'].map(label_value_dict)
        self.df_ds_img['dataset_name'] = self.config.name
        self.df_ds_img['dataset_img_path'] = None
        self.df_ds_img = self.df_ds_img.drop(['label_str', 'source'], axis=1)
            
        
    # Clear data associated with this dataset from dataset (or create tables if they don't exist)
    def _clear_db_data(self):
        with self.engine.connect() as con:
            db_ds_tbl_exists = con.execute(self._query_check_tbl_exists(
                self.db_datasets_tbl)).fetchall()[0][0]
            db_ds_img_tbl_exists = con.execute(self._query_check_tbl_exists(
                self.db_dataset_img_tbl)).fetchall()[0][0]
            
            if db_ds_tbl_exists:
                con.execute(self._query_drop_datasets_rows)
            else:
                con.execute(self._query_create_dataset_tbl)
                
            if db_ds_img_tbl_exists:
                con.execute(self._query_drop_dataset_imgs_rows)
            else:
                con.execute(self._query_create_dataset_img_tbl)
            
    
    # TODO
    # Create (if not exists) or clear (if exists) data from target directory
    def _clear_target_dir(self):
        pass
    
    # TODO
    # Copy files from image dataframe from source to target directory
    def _copy_image_files(self):
        pass
    
    # TODO
    # Store associated data to the metadata db (dataset table and dataset images table)
    def _store_db_metadata(self):
        pass
    
    
    # TODO
    # Builds the dataset by running the other methods
    def build_dataset(self):
        pass
    
    
    # QUERY HELPERS
    @property
    def _query_create_dataset_tbl(self):
        return(f"""
               CREATE TABLE {self.db_datasets_tbl} (
                   name VARCHAR(100) UNIQUE NOT NULL,
                   target_dir VARCHAR(300) NOT NULL,
                   description TEXT,
                   validation_frac NUMERIC NOT NULL,
                   test_frac NUMERIC NOT NULL,
                   label_method VARCHAR(50) NOT NULL,
                   PRIMARY KEY(name)
               )
               """)
    @property
    def _query_create_dataset_img_tbl(self):
        return(f"""
               CREATE TABLE {self.db_datasets_img_tbl} (
                   image_name VARCHAR(300) NOT NULL,
                   dataset_name VARCHAR(100) NOT NULL,
                   class_name VARCHAR(100) NOT NULL,
                   class_label VARCHAR(100) NOT NULL,
                   class_value INT NOT NULL,
                   dataset_img_path VARCHAR(500) NOT NULL,
                   PRIMARY KEY (image_name, dataset_name),
                   FOREIGN KEY (image_name) REFERENCES {self.db_metadata_tbl}.image_name
                   FOREIGN KEY (dataset_name) REFERENCES {self.db_datasets_tbl}.name
               )
               """)
    
    @staticmethod
    def _query_check_tbl_exists(tbl_name):
        return(f"""
                SELECT EXISTS(
                    SELECT * FROM information_schema.tables
                    WHERE table_name = '{tbl_name}'
                )
               """)
    @property 
    def _query_datasets_exists(self):
        return(self._query_check_tbl_exists(self.db_datasets_tbl))
    @property 
    def _query_dataset_imgs_exists(self):
        return(self._query_check_tbl_exists(self.db_dataset_img_tbl))
    
    @staticmethod
    def _query_drop_col_values_from_tbl(tbl_name, col_name, col_value):
        return(f"""
               DELETE FROM {tbl_name}
               WHERE {col_name} == {("'" + col_value + "'") if isinstance(col_value, str) else col_value}
               """)
    @property 
    def _query_drop_datasets_rows(self):
        return(self._query_drop_col_values_from_tbl(self.db_datasets_tbl, "name", self.config.name))
    @property 
    def _query_drop_dataset_imgs_rows(self):
        return(self._query_drop_col_values_from_tbl(self.db_dataset_img_tbl,
                                                    "image_name",
                                                    self.config.name))
    
        

In [250]:
df = pd.DataFrame([
    (1,2,3),
    (4,5,6)
], columns=['a', 'b', 'c']).set_index('a')
df.apply(lambda row: (row.name, row['c']), axis=1)

a
1    (1, 3)
4    (4, 6)
dtype: object

with db_engine.connect() as con:
    cols = pd.read_sql_query(f"SELECT source, label_str FROM {metadata_tbl}", con)
    unique_src_label_pairs = set(
        tuple(record) for record in pd.read_sql_query(f"SELECT label_str, source FROM {metadata_tbl}",
                                                      con).to_records(index=False)
    )
    # sources = [*pd.read_sql_query(f"SELECT source FROM {metadata_tbl}", con)['source'].unique()]
unique_src_label_pairs

In [251]:
test_label_counts = {pair: 50 for pair in unique_src_label_pairs}
test_label_names = {
    ('negative', 'Google Images'): "negative_similar_plant",
    ('negative', 'Imagenet'): "negative_random_picture",
    ('negative', 'Plantnet'): "negative_general_plant",
    ('positive', 'Google Images'): "positive",
}
test_config = Image_Dataset_Config(
    name="test_example",
    target_dir="",
    label_counts=test_label_counts,
    class_names=test_label_names,
    label_method=Image_Label_Method.BINARY,
    validation_frac=0.2,
    test_frac=0,
    description="First example with 50 values from each class and binary labeling"
)

## TEMP DS TESTS MARKER

In [252]:
test_ds = Image_Dataset_Builder(db_engine, test_config)

# Used for validating the Dataset Builder class
test_ds.df_img.head()

In [253]:
test_ds.df_ds.head()

Unnamed: 0_level_0,target_dir,description,validation_frac,test_frac,label_method
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
test_example,,First example with 50 values from each class a...,0.2,0,BINARY


In [254]:
test_ds.df_ds_img

Unnamed: 0_level_0,class_name,class_label,class_value,dataset_name,dataset_img_path
image_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Poison_sumac_plant_12,negative_similar_plant,negative,0,test_example,
Box_elder_plant_54,negative_similar_plant,negative,0,test_example,
Virginia_creeper_plant_93,negative_similar_plant,negative,0,test_example,
Blackberries_plant_249,negative_similar_plant,negative,0,test_example,
Hoptree_plant_49,negative_similar_plant,negative,0,test_example,
...,...,...,...,...,...
poison_ivy_plant_2348,positive,positive,1,test_example,
poison_ivy_plant_2475,positive,positive,1,test_example,
poison_ivy_plant_784,positive,positive,1,test_example,
poison_ivy_plant_1934,positive,positive,1,test_example,


In [177]:
df_test = test_ds.df_img.copy()

In [130]:
just_imgs = {
    ('negative', 'Google Images'): 50,
    ('positive', 'Google Images'): 50
}

In [135]:
dfg = df_test.groupby(['label_str', 'source'])

In [136]:
dfg.apply(lambda df: print(df.name, df.name in just_imgs.keys()))

('negative', 'Google Images') True
('negative', 'Imagenet') False
('negative', 'Plantnet') False
('positive', 'Google Images') True


In [143]:
dfg.apply(lambda df: df.sample(0 if df.name not in just_imgs.keys()
                               else min(just_imgs[df.name], len(df.index))))

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,file_name,download_loc,final_loc,full_path,search_term,source,read,orig_width,orig_height,width,height,label,label_str,download_name
label_str,source,image_name,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
negative,Google Images,Fragrant_sumac_plant_74,Fragrant_sumac_plant_74.jpg,simple_images/Fragrant sumac plant/Fragrant su...,../datasets/pipeline_v1/downloaded_images/nega...,/home/code/datasets/pipeline_v1/downloaded_ima...,Fragrant sumac plant,Google Images,True,500,375,500,375,0,negative,
negative,Google Images,Kudzu_plant_67,Kudzu_plant_67.jpg,simple_images/Kudzu plant/Kudzu plant_67.jpg,../datasets/pipeline_v1/downloaded_images/nega...,/home/code/datasets/pipeline_v1/downloaded_ima...,Kudzu plant,Google Images,True,500,334,500,334,0,negative,
negative,Google Images,Virgin's_bower_plant_21,Virgin's_bower_plant_21.jpg,simple_images/Virgin's bower plant/Virgin's bo...,../datasets/pipeline_v1/downloaded_images/nega...,/home/code/datasets/pipeline_v1/downloaded_ima...,Virgin's bower plant,Google Images,True,510,510,500,500,0,negative,
negative,Google Images,Western_poison_oak_plant_224,Western_poison_oak_plant_224.jpg,simple_images/Western poison oak plant/Western...,../datasets/pipeline_v1/downloaded_images/nega...,/home/code/datasets/pipeline_v1/downloaded_ima...,Western poison oak plant,Google Images,True,1024,679,754,500,0,negative,
negative,Google Images,Virgin's_bower_plant_126,Virgin's_bower_plant_126.jpg,simple_images/Virgin's bower plant/Virgin's bo...,../datasets/pipeline_v1/downloaded_images/nega...,/home/code/datasets/pipeline_v1/downloaded_ima...,Virgin's bower plant,Google Images,True,1300,957,679,500,0,negative,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
positive,Google Images,poison_ivy_plant_1330,poison_ivy_plant_1330.jpg,simple_images/poison ivy plant/poison ivy plan...,../datasets/pipeline_v1/downloaded_images/posi...,/home/code/datasets/pipeline_v1/downloaded_ima...,poison ivy plant,Google Images,True,1662,1246,666,500,1,positive,
positive,Google Images,poison_ivy_plant_79,poison_ivy_plant_79.jpg,simple_images/poison ivy plant/poison ivy plan...,../datasets/pipeline_v1/downloaded_images/posi...,/home/code/datasets/pipeline_v1/downloaded_ima...,poison ivy plant,Google Images,True,1200,630,952,500,1,positive,
positive,Google Images,poison_ivy_plant_730,poison_ivy_plant_730.jpg,simple_images/poison ivy plant/poison ivy plan...,../datasets/pipeline_v1/downloaded_images/posi...,/home/code/datasets/pipeline_v1/downloaded_ima...,poison ivy plant,Google Images,True,1254,836,750,500,1,positive,
positive,Google Images,poison_ivy_plant_499,poison_ivy_plant_499.jpg,simple_images/poison ivy plant/poison ivy plan...,../datasets/pipeline_v1/downloaded_images/posi...,/home/code/datasets/pipeline_v1/downloaded_ima...,poison ivy plant,Google Images,True,732,549,666,500,1,positive,
