In [None]:
from sklearn.model_selection import ParameterGrid
from utils import *
from evaluation import *
from mergers import compute_descriptors
from configparser import ConfigParser
import mysql.connector as conn
%matplotlib inline

In [None]:
params = ParameterGrid({'color_space': ['rgb', 'yiq', '0iq', 'lab', '0ab'], 
                        'similarity_metric': ['mse', 'inplace_mse', 'cc', 'inplace_cc']})

In [None]:
all_params = [(param['color_space'], param['similarity_metric']) for param in params]

In [None]:
config = ConfigParser()
config.read_file(open('config.ini'))

Get the database credentials from configuration or enter credentials manually if you have defined a different user for training (e.g. with write access)

In [None]:
db_config = config['database']  # get the DB credentials from configuration or enter credentials manually
host_name = db_config.get('host')
db_name = db_config.get('name')
username = db_config.get('username')
password = db_config.get('password')

In [None]:
data_config = config['data']
PATCHES_PATH = data_config.get('patches_path')

In [None]:
db = conn.connect(host=host_name, database=db_name, user=username, passwd=password)
cursor = db.cursor()

This is a good place to filter out costume types that don't have enough samples.

In [None]:
cursor.execute('''
select a.costumeFK, min(num_of_patches) as min_num_of_patches from costumes
join (
    select samples.costumeFK, patches.feature_code, count(patches.patchID) as num_of_patches 
    from patches
    join samples on patches.sampleFK = samples.sampleID
    where samples.is_usable = 1
    group by samples.costumeFK, patches.feature_code
    ) as a on a.costumeFK = costumes.costumeID
group by a.costumeFK
order by min_num_of_patches desc
''')

In [None]:
all_costumes = [c[0] for c in cursor.fetchall()]

In [None]:
costumes = all_costumes[:10]  # or enter the IDs of the costumes you want to train

In [None]:
data = {}
for costume in costumes:
    cursor.execute('''
    SELECT patches.feature_code, GROUP_CONCAT(patches.file_name) FROM patches
    JOIN samples ON samples.sampleID = patches.sampleFK
    WHERE samples.costumeFK = %(costumeID)s
    GROUP BY patches.feature_code
    ''', {'costumeID': costume})
    # GROUP_CONCAT may produce return strings that MySQL doesn't support, you can change that in the settings
    features = {c[0]: c[1].split(',') for c in cursor.fetchall()}
    data[costume] = features

In [None]:
img_data = read_patches_rgb(data)

In [None]:
descriptors = compute_descriptors(img_data, MseMerger)

In [None]:
save_descriptors_as_npy(descriptors, 'mse', 'rgb')