In [2]:
import os
import psycopg2
import shutil

import configparser

from pathlib import Path

In [3]:
# Path to the sql script used for getting images and details from database.
SELECT_IMAGE_SCRIPT_PATH = "database/select_aircraft_images.sql"

TRAINING_STORE_PATH = "training_data/training/"
VALIDATION_STORE_PATH = "training_data/validation/"

In [4]:
class DBConfig:
    def __init__(self, filename, host, port, db_name, user, pwd):
        self.filename = filename
        self.host = host
        self.port = port
        self.db_name = db_name
        self.user = user
        self.pwd = pwd

In [5]:
def load_sql(filename):
    with open(filename, "r") as sql_script_file:
        return sql_script_file.read()

In [6]:
def load_config(filename = 'database.ini', section = 'aircraft_postgres_db'):
    parser = configparser.ConfigParser()
    parser.read(filename)
    cfg = parser[section]
    return DBConfig(filename, cfg['Host'], int(cfg['Port']), cfg['Db_name'], cfg['User'], cfg['Pass'])

In [21]:
variant_dic = {}

def folderize_file(path, variant_name):
    naive_cleaned_name = variant_name.replace("/", "_")
    
    training_dir_path = os.path.join(TRAINING_STORE_PATH, naive_cleaned_name)
    
    Path(training_dir_path).mkdir(parents=True, exist_ok=True)
    
    validation_dir_path = os.path.join(VALIDATION_STORE_PATH, naive_cleaned_name)
    
    Path(validation_dir_path).mkdir(parents=True, exist_ok=True)
    
    image_name = os.path.basename(path)
    
    if naive_cleaned_name in variant_dic:
        variant_dic[naive_cleaned_name] = variant_dic[naive_cleaned_name] + 1
    else:
        variant_dic[naive_cleaned_name] = 1
    
    new_path = ""
    
    # Take every fifth image to place in validation
    if (variant_dic[naive_cleaned_name] % 5) == 0:
        new_path = os.path.join(validation_dir_path, image_name)
    else:
        new_path = os.path.join(training_dir_path, image_name)
        
    shutil.copy(path, new_path)

In [22]:
db_cfg = load_config()

In [23]:
conn = psycopg2.connect(dbname=db_cfg.db_name, user=db_cfg.user, password=db_cfg.pwd, host=db_cfg.host, port=db_cfg.port)

In [24]:
cur = conn.cursor()
cur.execute('SELECT version()')
print("Connected to database, version: " + str(cur.fetchone()))

Connected to database, version: ('PostgreSQL 12.3, compiled by Visual C++ build 1914, 64-bit',)


In [25]:
select_image_sql = load_sql(SELECT_IMAGE_SCRIPT_PATH)

In [26]:
cur.execute(select_image_sql)
training_data = cur.fetchall()

In [27]:
for img in training_data:
    path = img[0]
    variant_name = img[1]
    folderize_file(path, variant_name)
#     print("Path: {}".format(path))
#     print("Name: {}".format(variant_name))