In [3]:
"""

FILE IO Helpers for galaxy classification/regression project
By James Caldon 2021

"""
import numpy as np
import pandas as pd
import os
import os.path
_PARENT_PATH = r"C:\Users\James_Dev_Account\OneDrive - The University of Western Australia\Documents\Honours - Galaxy Classification\Galaxy-Classification-Research-Project\generated_data"
_IMAGE_DATA = {
    
    "DATA1":{
                "DESCRIPTION": "TODO",
                "N_MODEL": 1000,
                "N_MESH": 50,
                "FILENAME": "2dft.dat"
            },

    "DATA2":{
                "DESCRIPTION": "TODO",
                "N_MODEL": 1000,
                "N_MESH": 50,
                "FILENAME": "2dft.dat"
            },

    "DATA3":{
                "DESCRIPTION": "TODO",
                "N_MODEL": 1000,
                "N_MESH": 50,
                "FILENAME": "2dft.dat"
            },

    "DATA4":{
                "DESCRIPTION": "TODO",
                "N_MODEL": 1000,
                "N_MESH": 50,
                "FILENAME": "2dft.dat"
            },
    "NAIR_ABRAHAM_2010":{
                "DESCRIPTION": "TODO",
                "N_MODEL": 14034,
                "N_MESH": 50,
                "FILENAME": "total-list.dat"
            },
    "CG_611":{
                "DESCRIPTION": "TODO",
                "N_MODEL": 1,
                "N_MESH": 50,
                "FILENAME": "cg_611.dat"
            },
    "IC3328":{
            "DESCRIPTION": "TODO",
            "N_MODEL": 1,
            "N_MESH": 50,
            "FILENAME": "ic3328.dat"
            },
    "NGC5845":{
            "DESCRIPTION": "TODO",
            "N_MODEL": 1,
            "N_MESH": 50,
            "FILENAME": "NGC5845.dat"
            }

}

_ATTR = {
    "NAIR_ABRAHAM_2010":{
                "DESCRIPTION": "TODO",
                "COUNT": 14034,
                "FILENAME": "T_type.txt"
            }
}

_DF_METADATA = {
    "DR15_DISC": {
        "FILEPATH": r"dr15\discs\ES_SDSS_metadata.txt"
    },
    "DR15_NO_DISC": {
        "FILEPATH": r"dr15\no_discs\E_SDSS_metadata.txt"
    }
}

_IMAGE_DATA_SETS = {
    
    "DISC":
    {
        "MEDIUM_LARGE_DISCS":    {
            "DESCRIPTION": "TODO",
            "N_MESH": 50,
            "FOLDERNAME": '^rgal.*disc_1$',
            "FILENAME": "2dft.dat",
            "CLASS": 1
        },
        "SMALL_LARGE_DISCS":    {
            "DESCRIPTION": "TODO",
            "N_MESH": 50,
            "FOLDERNAME": '^rgal.*disc_1$',
            "FILENAME": "2dft.dat",
            "CLASS": 1
        }
    },

    "NO_DISC":
    {
        "MEDIUM_LARGE_DISCS":    {
            "DESCRIPTION": "TODO",
            "N_MESH": 50,
            "FOLDERNAME": '^rgal.*disc_2$',
            "FILENAME": "2dft.dat",
            "CLASS": 0
        },
        "SMALL_LARGE_DISCS":    {
            "DESCRIPTION": "TODO",
            "N_MESH": 50,
            "FOLDERNAME": '^rgal.*disc_2$',
            "FILENAME": "2dft.dat",
            "CLASS": 0
        }
    },
    
    "DR15":
    {
        "DISCS":    {
            "DESCRIPTION": "TODO",
            "N_MESH": 50,
            "FOLDERNAME": 'ES_SDSS_image_data',
            "FILENAME": "ES_SDSS_image_data.txt",
            "CLASS": 1
        },
        "NO_DISCS":    {
            "DESCRIPTION": "TODO",
            "N_MESH": 50,
            "FOLDERNAME": 'E_SDSS_image_data',
            "FILENAME": "E_SDSS_image_data.txt",
            "CLASS": 0
        }
    }

}

def load_data_set(data="DISC", data_param="SMALL_LARGE_DISCS", count=None, skip=0):
    import re
    def load_file(fp):
        nrows = None
        if count is not None:
            nrows = count*n_mesh3
        np_arr = pd.read_csv(fp, 
                             header=None, 
                             sep='\s+', 
                             skiprows=skip*n_mesh3, 
                             nrows=nrows).to_numpy()
        return np_arr.reshape(-1, n_mesh, n_mesh, 1)
    x = None
    Y = None
    n_mesh = _IMAGE_DATA_SETS[data][data_param]["N_MESH"]
    n_mesh3 = pow(n_mesh, 2)
    with os.scandir(os.path.join(_PARENT_PATH, data, data_param)) as dirs:
        for entry in dirs:
            regex = re.compile(_IMAGE_DATA_SETS[data][data_param]["FOLDERNAME"])
            if entry.is_dir() and re.match(regex, entry.name):
                loaded_file = load_file(os.path.join(entry.path, _IMAGE_DATA_SETS[data][data_param]["FILENAME"]))
                if x is None:
                    x = loaded_file
                else:
                    x = np.append(x, loaded_file, axis=0)
                loaded_file_classes = np.full(loaded_file.shape[0], _IMAGE_DATA_SETS[data][data_param]["CLASS"], dtype=np.int)
                if Y is None:
                    Y = loaded_file_classes
                else:
                    Y = np.append(Y, loaded_file_classes, axis=0)
    return x, Y

def load_total_combined_data(data_param="SMALL_LARGE_DISCS", count=None, skip=0):
    x_DISC, Y_DISC = load_data_set(data="DISC", data_param=data_param, count=count, skip=skip)
    x_NO_DISC, Y_NO_DISC = load_data_set(data="NO_DISC", data_param=data_param, count=count, skip=skip)
    x = np.append(x_DISC, x_NO_DISC, axis=0)
    Y = np.append(Y_DISC, Y_NO_DISC, axis=0)
    return x, Y

def load_total_combined_data_DR15(count=None, skip=0):
    x_DISC, Y_DISC = load_data_set(data="DR15", data_param="DISCS", count=count, skip=skip)
    x_NO_DISC, Y_NO_DISC = load_data_set(data="DR15", data_param="NO_DISCS", count=count, skip=skip)
    x = np.append(x_DISC, x_NO_DISC, axis=0)
    Y = np.append(Y_DISC, Y_NO_DISC, axis=0)
    return x, Y

def load_data(data="DATA1", count=None, skip=0):
    if (count is None):
        n_model = _IMAGE_DATA[data]["N_MODEL"]
    else:
        n_model = count
        
    
    n_mesh = _IMAGE_DATA[data]["N_MESH"]
    n_mesh3 = pow(n_mesh, 2)
    #np_arr = np.genfromtxt(os.path.join(os.path.join(_PARENT_PATH, data), _DATA[data]["FILENAME"]), autostrip=True, skip_header=skip*n_mesh3, max_rows=n_model*n_mesh3)
    np_arr = pd.read_csv(os.path.join(os.path.join(_PARENT_PATH, data), _IMAGE_DATA[data]["FILENAME"]), header=None, sep='\s+', skiprows=skip*n_mesh3, nrows=n_model*n_mesh3).to_numpy()
    return np_arr.reshape(n_model, n_mesh, n_mesh, 1)

def load_attribute(data, count=None, skip=0):
    if (count is None):
        count = _ATTR[data]["COUNT"]
    np_arr = pd.read_csv(os.path.join(os.path.join(_PARENT_PATH, data), _ATTR[data]["FILENAME"]), header=None, sep='\s+', skiprows=skip, nrows=count).to_numpy()
    return np_arr

def load_dataframe(data):
    df = pd.read_csv(os.path.join(_PARENT_PATH, _DF_METADATA[data]["FILEPATH"]), sep=',')
    return df