# Create datasets

Use pre-prepared image downloads to create datasets with train/test splits for model

## Config

In [48]:
import pandas as pd
import sqlalchemy as sqa
from sklearn.model_selection import train_test_split
import json
from typing import Dict, Tuple
from dataclasses import dataclass, field
from enum import Enum, auto

In [2]:
!python --version

Python 3.6.9


In [13]:
# PARAMETERS
# Database parameters
db_container = "metadata_db"
db_user = "pguser"
db_password = "pgpassword"
db_port = 5432
db_database = "metadata"
db_prefix = "postgresql"

metadata_tbl = "base_images"
datasets_tbl = "datasets"
dataset_img_tbl = "dataset_images"

In [18]:
db_con_str = f"{db_prefix}://{db_user}:{db_password}@{db_container}:{db_port}/{db_database}"
db_engine = sqa.create_engine(db_con_str)

In [51]:
class Image_Label_Method(Enum):
    BINARY: 'Image_Label_Method' = auto()
    MULTI: 'Image_Label_Method' = auto()

@dataclass
class Image_Dataset_Config:
    """
    Specifies config for creating an image dataset:
     - name: name of the dataset
     - target_dir: directory for storing the dataset
     - description
     - validation frac: frac of TRAINING data to use as training dataset (if 0, no validation set created)
     - test_frac: frac of ALL data to use as testing data (if 0, no test set created)
     - label_method: different labelling methods implemented:
         - BINARY: All labels just positive or negative
         - MULTI: Each datapoint is labelled by its source
     - label_counts: Dictionary of ((label str, source_name), count) pairs - how many data points from each data source to use
        (if count is -1 or is greater than total available, just use all of them)
        
    """
    name: str
    target_dir: str
    label_counts: Dict[Tuple[str, str], int] = field(default_factory=dict)
    description: str = ""
    validation_frac: float = 0.2
    test_frac: float = float(0)
    label_method: Image_Label_Method = Image_Label_Method.BINARY
    
    def __post_init__(self):
        assert len(self.label_counts) > 0, "Must include counts for at least one source + label"

In [62]:
class Image_Dataset_Builder():
    """
    Class for creating a dataset from a metadata table and configs
    db_engine: database engine attached to metadata database
    dataset_config: specifies how to create dataset (see Image_Dataset_Config class)
    """
    db_metadata_tbl = metadata_tbl
    db_datasets_tbl = datasets_tbl
    db_dataset_img_tbl = dataset_img_tbl
    def __init__(self,
                 db_engine: sqa.engine.Engine,
                 config: Image_Dataset_Config
                 ):
        self.config = config
        self.engine = db_engine
        with self.engine.connect() as con:
            self.df_img: pd.DataFrame = pd.read_sql(self.db_metadata_tbl, con)
            self.df_img = self.df_img[self.df_img['read'] == True]  # We only care about usable images now
    
    # TODO
    # Clear data associated with this dataset from dataset
    def clear_db_data(self):
        pass
    
    # TODO
    # Create (if not exists) or clear (if exists) data from target directory
    def clear_target_dir(self):
        pass
    
    # TODO
    # Use config to create dataframe of desired image outputs
    def create_image_df(self):
        pass
    
    # TODO
    # Copy files from image dataframe from source to target directory
    def copy_image_files(self):
        pass
    
    # TODO
    # Store associated data to the metadata db (dataset table and dataset images table)
    def store_db_metadata(self):
        pass
    
    
    # TODO
    # Builds the dataset by running the other methods
    def build_dataset(self):
        pass

In [47]:
with db_engine.connect() as con:
    cols = pd.read_sql_query(f"SELECT source, label_str FROM {metadata_tbl}", con)
    unique_src_label_pairs = set(
        tuple(record) for record in pd.read_sql_query(f"SELECT source, label_str FROM {metadata_tbl}",
                                                      con).to_records(index=False)
    )
    # sources = [*pd.read_sql_query(f"SELECT source FROM {metadata_tbl}", con)['source'].unique()]
unique_src_label_pairs

{('Google Images', 'negative'),
 ('Google Images', 'positive'),
 ('Imagenet', 'negative'),
 ('Plantnet', 'negative')}

In [54]:
test_label_counts = {pair: 50 for pair in unique_src_label_pairs}
test_config = Image_Dataset_Config(
    name="test_example",
    target_dir="",
    label_counts=test_label_counts,
    label_method=Image_Label_Method.BINARY,
    validation_frac=0.2,
    test_frac=0,
    description="First example with 50 values from each class and binary labelling"
)

In [63]:
test_ds = Image_Dataset_Builder(db_engine, test_config)

In [64]:
# Used for validating the Dataset Builder class
test_ds.df_img.head()

Unnamed: 0,image_name,file_name,download_loc,final_loc,full_path,search_term,source,read,orig_width,orig_height,width,height,label,label_str,download_name
0,poison_ivy_plant_1095,poison_ivy_plant_1095.jpg,simple_images/poison ivy plant/poison ivy plan...,../datasets/pipeline_v1/downloaded_images/posi...,/home/code/datasets/pipeline_v1/downloaded_ima...,poison ivy plant,Google Images,True,1024,575,890,500,1,positive,
1,poison_ivy_plant_1556,poison_ivy_plant_1556.jpg,simple_images/poison ivy plant/poison ivy plan...,../datasets/pipeline_v1/downloaded_images/posi...,/home/code/datasets/pipeline_v1/downloaded_ima...,poison ivy plant,Google Images,True,750,500,750,500,1,positive,
2,poison_ivy_plant_1294,poison_ivy_plant_1294.jpg,simple_images/poison ivy plant/poison ivy plan...,../datasets/pipeline_v1/downloaded_images/posi...,/home/code/datasets/pipeline_v1/downloaded_ima...,poison ivy plant,Google Images,True,1800,1013,888,500,1,positive,
3,poison_ivy_plant_964,poison_ivy_plant_964.jpg,simple_images/poison ivy plant/poison ivy plan...,../datasets/pipeline_v1/downloaded_images/posi...,/home/code/datasets/pipeline_v1/downloaded_ima...,poison ivy plant,Google Images,True,480,310,480,310,1,positive,
4,poison_ivy_plant_1526,poison_ivy_plant_1526.jpg,simple_images/poison ivy plant/poison ivy plan...,../datasets/pipeline_v1/downloaded_images/posi...,/home/code/datasets/pipeline_v1/downloaded_ima...,poison ivy plant,Google Images,True,1600,1031,775,500,1,positive,
