## TODO:

- Falta mover el código de line detection a prod.
- Falta mover el código de OCR a prod.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys

sys.path.append("../")

from src.detection.prediction_utils import choose_model, filter_predictions, visualize_boxes
from src.slides_utils.slides_utils import predict_tiles
from src.line_detection.hough import get_pairs, apply_hough, lines_to_points, hough_detecting, nearest_tabla_from_cardinalidad, clean_cardinalidades, reverse_dict, bbox_sum_n, sep_line, any_points_inside, point_inside
from src.ocr_utils.ocr import get_ocr_model, get_lemmatizer, predict_ocr, generate_db, extract_candidate_keys, \
pairs_to_names, initial_guess_primary_keys, get_unchosen, get_foreign_keys

In [None]:
from PIL import Image
from torchvision import transforms as T

In [None]:
# img_path = '/home/nacho/TFI-Cazcarra/data/imagenes_diagramas/ERDiagramsMySQL-9.png'
img_path = '/home/nacho/TFI-Cazcarra/data/images_testing/test1.png'
img = Image.open(img_path).convert("RGB")

min_size = 600
max_size = 1333

transform = T.Compose([T.ToTensor()])
img_tensor = transform(img)
img.resize((int(s*0.75) for s in img.size))

## Predicciones sobre los objetos

In [None]:
model_tablas = choose_model(model_name="retinanet", object_to_predict="tablas")

In [None]:
model_cardinalidades = choose_model(model_name="retinanet", object_to_predict="cardinalidades")

In [None]:
tablas_pred = model_tablas([img_tensor])[1][0]
tablas_boxes, tablas_scores = filter_predictions(tablas_pred, nms_threshold=0.5)

In [None]:
cardinalidades_pred = predict_tiles(img, model=model_cardinalidades, is_yolo=False, transform=transform)
cardinalidades_boxes, cardinalidades_scores = filter_predictions(cardinalidades_pred, nms_threshold=0.5, 
                                                                 score_threshold=0.25)

In [None]:
visualize_boxes(img, cardinalidades_boxes)

## Predicciones sobre las líneas

In [None]:
import numpy as np

In [None]:
def unify_cardinalidades(img, lines, cardinalidades, plot=False):
    dict_cardinalidades = {}
    dict_lines = {f"line_{i}":l for i,l in enumerate(lines) if len(l)>1}
    matches = 0
    for c in cardinalidades:
        augment = 0
        offset = 1
        flag = False
        while not flag:
            c_offset = bbox_sum_n(c, offset*augment).tolist()
            for k, l in dict_lines.items():
                start = l[0]
                end = l[-1]
                if point_inside(c_offset, start) or point_inside(c_offset, end):
                    # Table + augment
                    match_key= (c.tolist(), augment)
                    dict_cardinalidades[str(match_key)] = k
                    matches +=1
                    flag = True
                    break
            if str(c.tolist()) not in dict_cardinalidades.keys():
                augment += 2
                print(f"Increasing offset to {augment}")
    if plot:
        display(plot_results(img, dict_cardinalidades, dict_lines))
    new_dict_cardinalidades = reverse_dict(dict_cardinalidades)
    return new_dict_cardinalidades

In [None]:
def find_lines(tablas, cardinalidades, img, offset_tablas=5, **kwargs):
    offset = np.array([-offset_tablas, -offset_tablas, offset_tablas, offset_tablas]).reshape(1,4)
    tablas = np.sum([tablas, offset])
    img, all_lines = apply_hough(img, tablas, [])
    all_points = lines_to_points(all_lines)
    lines = hough_detecting(all_points)
    cardinalidades = clean_cardinalidades(cardinalidades, tablas)
    return unify_cardinalidades(img, lines, cardinalidades, **kwargs)

In [None]:
def sep_line(line, tablas):
    tabla_a = None
    tabla_b = None
    try:
        cardinalidades = line.split("|")
        cardinalidades = [literal_eval(c) for c in cardinalidades]
        cardinalidades_dist = {str(c[0]): c[1] for c in cardinalidades}
        # TOP 2 con menos augment
        cardinalidades = sorted(cardinalidades_dist, key=cardinalidades_dist.get)[:2]
        tabla_a = nearest_tabla_from_cardinalidad(cardinalidades[0], tablas)
        tabla_b = nearest_tabla_from_cardinalidad(cardinalidades[1], tablas)
    except Exception as e:
        print(f"Error al separar tablas! {e}. Chequear las bounding boxes pasadas. Salteando..")
    finally:
        return (tabla_a, tabla_b)

In [None]:
def get_pairs(boxes_tablas, boxes_cardinalidades, img, **kwargs):
    pairs = []
    tablas = boxes_tablas.detach().numpy().astype(int)
    cardinalidades = boxes_cardinalidades.detach().numpy().astype(int)
        
    for line_name, line in find_lines(img=img, tablas=tablas, cardinalidades=cardinalidades, **kwargs).items(): 
        tabla_a, tabla_b = sep_line(line, tablas)
        if tabla_a and tabla_b:
            pairs.append((tabla_a, tabla_b))
    return pairs

In [None]:
from ast import literal_eval

In [None]:
conexiones = get_pairs(tablas_boxes, cardinalidades_boxes, img=img, plot=False)

In [None]:
conexiones

In [None]:
pairs

## OCR

In [None]:
ocr = get_ocr_model(det_algo="db", rec_algo="svtr", lang="en")
lemmatizer = get_lemmatizer(lang="en")

In [None]:
tablas_boxes_int = tablas_boxes.detach().numpy().astype(int)
all_tables, tables_names = predict_ocr(img=img, tablas=tablas_boxes_int, ocr_model=ocr, scale_percent=100)

In [None]:
# print(generate_db(pairs=conexiones, all_tables=all_tables, tables_names=tables_names))

In [None]:
# !pip install lemminflect

In [None]:
import lemminflect

In [None]:
def get_plural(word, lemmatizer):
    return lemmatizer(word)[0]._.inflect('NNS')

In [None]:
def is_many_to_many(table, table_candidates, tables_names, pairs):
    # Con esto deberian quedar solo dos pero hay que preever que pasaría si hay más.
    matches = [t for t in tables_names if ((t in table) or (get_plural(t, lemmatizer) in table)) and (t!=table)]
    
    i = 0
    confirmed_matches = []
    flag = False
    while not flag and i<len(pairs):
        pair = pairs[i]
        if table not in pair:
            i+=1
            continue

        if pair[0] in matches:
            # Confirmamos que hay una tabla que aparece en el nombre de la m2m y
            # tiene relación con ella.
            confirmed_matches.append(pair[0]) 
        elif pair[1] in matches:
            confirmed_matches.append(pair[1])
            
        if len(confirmed_matches) == 2:
            # Si hay dos tablas que aparecen en el nombre de la m2m y tienen conexión con ella, es confirmado
            flag = True
        i+=1
    return flag

In [None]:
def generate_valid_combs_fk(table, lemmatizer):
    '''
    Dado el nombre de una tabla, genera las combinaciones válidas. Osea, tabla+id, tabla+_id, 
    tabla_lematizada+id y tabla_lematizada+_id. También aplica a casos donde la tabla está en singular y 
    la PK en plural.
    '''
    table_lemmatized = lemmatizer(table)[0].lemma_
    table_unlemmatized = get_plural(table, lemmatizer)
    valid_combs = [table+"id", table+"_id", 
                   table_lemmatized+"id", table_lemmatized+"_id",
                   table_unlemmatized+"id", table_unlemmatized+"_id"]
    return list(set(valid_combs))

In [None]:
def match_fk(table_pair, table_candidates, table_pair_candidates, lemmatizer):
    '''
    Match normal entre dos variantes con el nombre de la tabla.
    '''
    valid_combs = generate_valid_combs_fk(table_pair, lemmatizer)
    possibilities = valid_combs
    pair_possibilities = valid_combs + ["id"]
    
    for possibility in possibilities:
        for pair_possibility in pair_possibilities:
            if possibility in table_candidates and pair_possibility in table_pair_candidates:
                table_candidates.remove(possibility) # Remuevo la fk de la lista de candidatos.
                return (True, possibility, pair_possibility)
    return (False, "", "")

In [None]:
def match_autofk(table, table_candidates, lemmatizer):
    valid_combs = generate_valid_combs_fk(table, lemmatizer) + ["id"]
    table_lemmatized = lemmatizer(table)[0].lemma_
    table_unlemmatized = get_plural(table, lemmatizer)
    fk = None
    pk = None
    
    i = 0
    while not fk and i<len(table_candidates):
        t = table_candidates[i]
        if ((table in t) or (table_lemmatized in t) or (table_unlemmatized in t)) and (t not in valid_combs):
            fk = t
        i+=1
    j = 0
    while not pk and j<len(valid_combs):
        v = valid_combs[j]
        if v in table_candidates:
            pk = v
        j+=1
    
    if pk:
        if fk:
            return (True, fk, pk)
        else:
            left_candidates = get_unchosen(table_candidates, valid_combs)
            if len(left_candidates) == 1:
                return (True, left_candidates[0], pk)
    return (False, "", "")

In [None]:
def match_m2m(table, table_pair, table_candidates, table_pair_candidates, lemmatizer):
    '''
    Chequea si hay un atributo en común entre la tabla normal y la m2m. Si hay uno solo, se devuelve ese.
    Si no, se sigue con la opción "normal" entre dos tablas convencionales (método 'match_fk').
    '''
    matches = [table_candidate for table_candidate in table_candidates if table_candidate in table_pair_candidates]
    if len(matches) == 1:
        # Si hubo match con solo un atributo.
        # A las m2m no se les remueve la FK porque tambien es PK. 
        return (True, matches[0], matches[0])
    else:
        # Si no hubo un solo match, hacemos el chequeo "normal" con las combinaciones válidas de la tabla.
        return match_fk(table_pair, table_candidates, table_pair_candidates, lemmatizer)

In [None]:
def is_foreign_key(table, table_pair, table_candidates, table_pair_candidates, lemmatizer, m2m_tables,
                   is_auto_fk=False):
    '''
    Se fija si hay un match entre un atributo con _id en su versión original y lematizada.
    Soporta relaciones convencionales, relaciones many to many y auto foreign keys.
    '''
    if table_pair == table and not is_auto_fk:
        return (False, "", "")

    if is_auto_fk:
        return match_autofk(table, table_candidates, lemmatizer)
    elif table in m2m_tables:
        # Si es una relación y la tabla es una "many to many"
        return match_m2m(table, table_pair, table_candidates, table_pair_candidates, lemmatizer)
    else:
        # Si es una relación entre dos tablas "estándar".
        return match_fk(table_pair, table_candidates, table_pair_candidates, lemmatizer)

In [None]:
def get_foreign_keys(table, all_candidates, pairs, m2m_tables, lemmatizer=None, check_auto_fks=False):
    """
    Ejemplo:
    table -> poems
    candidates -> ['poems_id', 'users_id', 'categories_id']
    pairs -> [('tokens', 'users'), ('poems', 'users'), ('poems', 'categories')]
    """
    if not lemmatizer:
        lemmatizer = get_lemmatizer()
    fks = {}
    completed_pairs = []
        
    table_candidates = all_candidates[table]
    for pair in pairs:
        if table not in pair:
            continue
            
        is_auto_fk = False
        if pair[0] == pair[1] and check_auto_fks:
            is_auto_fk = True
        is_fk_pair0, table_att0, pair_att0 = is_foreign_key(table=table, table_pair=pair[0], 
                                                            table_candidates=table_candidates,
                                                            table_pair_candidates=all_candidates[pair[0]], 
                                                            lemmatizer=lemmatizer, is_auto_fk=is_auto_fk,
                                                            m2m_tables=m2m_tables)
        is_fk_pair1, table_att1, pair_att1 = is_foreign_key(table=table, table_pair=pair[1], 
                                                            table_candidates=table_candidates,
                                                            table_pair_candidates=all_candidates[pair[1]], 
                                                            lemmatizer=lemmatizer, is_auto_fk=is_auto_fk,
                                                            m2m_tables=m2m_tables)
        if is_fk_pair0:
            fks[(table_att0, pair_att0)] = pair[0]
            completed_pairs.append(pair)
        elif is_fk_pair1:
            fks[(table_att1, pair_att1)] = pair[1]
            completed_pairs.append(pair)
    return fks, completed_pairs

In [None]:
pairs = pairs_to_names(conexiones, tables_names)
# Pasada inicial para extraer todos los candidatos.
all_candidates = {}
m2m_tables = []
for k, dict_attributes in all_tables.items():
    candidates = extract_candidate_keys(dict_attributes.keys(), [t for t in tables_names if t != k], lang="en")
    all_candidates[k] = candidates
    if is_many_to_many(k, candidates, tables_names, pairs):
        m2m_tables.append(k)

all_tables_pks = {}
all_tables_fks = {}

# En esta segunda pasada se resuelven todas las relaciones menos la de auto fks.
for k in all_tables.keys():    
    pks = {pk: k for pk in initial_guess_primary_keys(k, all_candidates[k], get_lemmatizer(lang="en"))}
    fks, completed_pairs = get_foreign_keys(table=k, all_candidates=all_candidates, pairs=pairs,\
                                            m2m_tables=m2m_tables, check_auto_fks=False)
    pairs = get_unchosen(pairs, completed_pairs)
    pks = {**pks, **{pk: k for pk in get_unchosen(all_candidates[k], fks.keys())}}
    all_tables_pks[k] = pks
    all_tables_fks[k] = fks
    
# En esta tercera pasada se completan los auto-fks.
for k in all_tables.keys():
    fks, completed_pairs = get_foreign_keys(table=k, all_candidates=all_candidates, pairs=pairs,\
                                            m2m_tables=m2m_tables, check_auto_fks=True)
    pairs = get_unchosen(pairs, completed_pairs)
    if fks:
        all_tables_fks[k] = {**all_tables_fks[k], **fks}

In [None]:
for p in pairs:
    print(f"Relationship between {p} could not be established.") 
    print("Please check that the attributes are in the correct format.\n")

In [None]:
def generate_pks_code(pks):
    keys = pks.keys()
    keys = ", ".join(keys)
    return f"PRIMARY KEY ({keys})"


def generate_fks_code(table, fks):
    code = ""
    for fk, table_reference in fks.items():
        code += f"ALTER TABLE {table} ADD FOREIGN KEY ({fk[0]}) REFERENCES {table_reference}({fk[1]}); \n"
    return code


def create_code(table, dict_attributes, primary_keys, foreign_keys):
    '''
    Crea una tabla de MySQL
    '''
    attributes_code = "  "
    i = 0
    for k, v in dict_attributes.items():
        attributes_code += k + " " + v           
        attributes_code += ",\n   "
        i += 1
    pks_code = generate_pks_code(primary_keys)
    fks_code = generate_fks_code(table, foreign_keys)
    if pks_code:
        attributes_code += pks_code
    code = f" CREATE TABLE {table} ( \n {attributes_code} \n ); \n"
    return code, fks_code

In [None]:
def generate_db(pairs, all_tables, tables_names, lang):
    pairs = pairs_to_names(pairs, tables_names)
    
    all_candidates = {}
    m2m_tables = []
    # Primera pasada: Extraemos los candidatos y vemos qué tabla es m2m.
    for k, dict_attributes in all_tables.items():
        candidates = extract_candidate_keys(dict_attributes.keys(), [t for t in tables_names if t != k], lang="en")
        all_candidates[k] = candidates
        if is_many_to_many(k, candidates, tables_names, pairs):
            m2m_tables.append(k)
    
    all_tables_pks = {}
    all_tables_fks = {}
    # Segunda pasada: Se resuelven todas las relaciones menos la de auto fks.
    for k in all_tables.keys():    
        pks = {pk: k for pk in initial_guess_primary_keys(k, all_candidates[k], get_lemmatizer(lang="en"))}
        fks, completed_pairs = get_foreign_keys(table=k, all_candidates=all_candidates, pairs=pairs,\
                                                m2m_tables=m2m_tables, check_auto_fks=False)
        pairs = get_unchosen(pairs, completed_pairs)
        pks = {**pks, **{pk: k for pk in get_unchosen(all_candidates[k], fks.keys())}}
        all_tables_pks[k] = pks
        all_tables_fks[k] = fks
        
    all_code = ""
    all_fks_code = ""
    # Tercera pasada: Se completan los auto-fks y se genera el código.
    for k, dict_attributes in all_tables.items():
        fks, completed_pairs = get_foreign_keys(table=k, all_candidates=all_candidates, pairs=pairs,\
                                                m2m_tables=m2m_tables, check_auto_fks=True)
        pairs = get_unchosen(pairs, completed_pairs)
        if fks:
            all_tables_fks[k] = {**all_tables_fks[k], **fks}
            
        code, fk_code = create_code(k, dict_attributes, \
                                    primary_keys=all_tables_pks[k], \
                                    foreign_keys=all_tables_fks[k])
        all_code += code
        all_fks_code += fk_code
    return all_code + "\n" + all_fks_code

In [None]:
generate_db(conexiones, all_tables, tables_names, lang="en")