In [None]:
import mdtraj
import numpy as np
import ipynb_importer
import visualize, compute, parsers
import xml.etree.ElementTree as et
import dynophores as dyno
from matplotlib import colors


def split_trajectory(pam, pdb_path, dcd_path, n_drop = 0, cluster_frames_map = None):
    '''write trajectory in each cluster into dcd file'''
    state_nr = np.max(pam.labels_) + 1
    if cluster_frames_map == None:
        cluster_frames_map = compute.get_frames_each_cluster(pam, shift = n_drop)
#     print("original frame idx:", np.where(pam.labels_ == 0)[0])
#     print("after: ", np.where(pam.labels_ == 0)[0] + n_drop)
    trajectory = mdtraj.load_dcd(
            dcd_path,
            top = pdb_path,
            )

    cluster_traj_map = {
        k: trajectory[cluster_frames_map[k]]
        for k in range(state_nr)
    }

    for k, v in cluster_traj_map.items():
        v.save_dcd(f"./output/cluster_{k}.dcd")
        print(f"Successfully write cluster_{k}.dcd")
        
        
def get_atoms_coord(pdb_path):
    pdb_array = np.empty(shape=(0, 3))

    with open(pdb_path) as pdb_file:
        for line in pdb_file:
            if line.startswith(("ATOM")):
                content = line.strip()
                x = float(content[30:38])
                y = float(content[39:46])
                z = float(content[47:56])
                pdb_array = np.vstack(
                    [pdb_array, 
                     np.asarray([x, y, z])]
                )
    return pdb_array


def extract_superfeature(data, dynophore_dict, pdb_path, pml_path, dyno_path, include_time = False, frequency_cutoff = None):
    '''
    input: wrap_data as data
    return the center point of each superfeature clusters and and parameters for writing into LigandScot file
    output:
    [center_x, center_y, center_z, radius, superfeatures_name, cluster_nr]
    '''
    if frequency_cutoff == None:
        frequency_cutoff = {'H': 0.2, 'AR': 0.03, 'HA2': 0.2, 'HD2': 0.2, 'HDA': 0.2,
                            'HBD': 0.2, 'HBA': 0.2, 'HA': 0.2, 'HD': 0.2, 'HI': 0.2, 'PI': 0.2, 'NI': 0.2}
    
    col_nr = 3
    if include_time:
        col_nr = 4
        
    dynophore = dyno.Dynophore.from_dir(dyno_path)
    superfeature = []
    superfeatures = list(dynophore_dict.keys())
    max_frame = max([np.max(data[i][:, col_nr-1]) for i in range(len(data))])
    number = 1

    for i in range(len(data)):
        superfeature_data = data[i]
        cluster_nr = len(np.unique(superfeature_data[:, col_nr]))  # cluster number in one superfeature
        
        order = 0
        
        feature = parsers.get_feature_name(superfeatures[i])
        env_coord = parsers.get_env_partner_coord(superfeatures[i], dynophore, pdb_path)
        
        while cluster_nr > 0:
            cluster_data = superfeature_data[superfeature_data[:, col_nr] == (cluster_nr - 1)]
            freq_cutoff = frequency_cutoff[feature]
            if len(cluster_data) >= max_frame * freq_cutoff:
                try:
                    center_coord = compute.get_geo_center(cluster_data)
                    if feature in ['H', 'NI', 'PI', 'AR']:
                        radius = 1.5
                    elif feature in ['HD', 'HD2', 'HA', 'HA2', 'HBA', 'HBD']:
                        radius = 0.2
                    else:
                        radius = 1.5
                    superfeature.append([number, feature, "M", np.array(center_coord), radius, np.array(env_coord), 0., 1.])
                except:
                    pass
            order += 1
            number += 1
            cluster_nr -= 1
    return superfeature

def pml_feature_point(pharmacophore, feature):
    """ This function generates an xml branch for positive and negative ionizable features as well as hydrophobic
    interactions. """
    point_name = feature[1]
    point_featureId = '{}_{}'.format(feature[1], feature[0])
    point_optional = 'false'
    point_disabled = 'false'
    point_weight = str(feature[7])
    position_x3, position_y3, position_z3 = str(feature[3][0]), str(feature[3][1]), str(feature[3][2])
    position_tolerance = str(feature[4])
    point_attributes = {'name': point_name, 'featureId': point_featureId, 'optional': point_optional,
                        'disabled': point_disabled, 'weight': point_weight}
    point = et.SubElement(pharmacophore, 'point', attrib=point_attributes)
    position_attributes = {'x3': position_x3, 'y3': position_y3, 'z3': position_z3, 'tolerance': position_tolerance}
    et.SubElement(point, 'position', attrib=position_attributes)
    return


def pml_feature_plane(pharmacophore, feature):
    """ This function generates an xml branch for aromatic interactions. """
    plane_name = 'AR'
    plane_featureId = '{}_{}'.format('ai', feature[0])
    plane_optional = 'false'
    plane_disabled = 'false'
    plane_weight = str(feature[7])
    position_x3, position_y3, position_z3 = str(feature[3][0]), str(feature[3][1]), str(feature[3][2])
    position_tolerance = str(feature[4])
    normal_x3, normal_y3, normal_z3 = [str(feature[3][0] - feature[5][0]),
                                       str(feature[3][1] - feature[5][1]),
                                       str(feature[3][2] - feature[5][2])]
#     print(feature)
#     print('=============')
#     normal_v1 = np.array([feature[3][0] - feature[5][0][0], 
#                 feature[3][1] - feature[5][0][1],
#                 feature[3][2] - feature[5][0][2]])
#     normal_v2 = np.array([feature[3][0] - feature[5][1][0], 
#                 feature[3][1] - feature[5][1][1],
#                 feature[3][2] - feature[5][1][2]])
#     normal_v3 = np.array([feature[3][0] - feature[5][2][0],
#                 feature[3][1] - feature[5][2][1],
#                 feature[3][2] - feature[5][2][2]])
#     normal_x3 = (normal_v2[1]-normal_v1[1])*(normal_v3[2]-normal_v1[2]) - (normal_v2[2]-normal_v1[2])*(normal_v3[1]-normal_v1[1])
#     normal_y3 = (normal_v2[2]-normal_v1[2])*(normal_v3[0]-normal_v1[0]) - (normal_v2[0]-normal_v1[0])*(normal_v3[2]-normal_v1[2])
#     normal_z3 = (normal_v2[0]-normal_v1[0])*(normal_v3[1]-normal_v1[1]) - (normal_v2[1]-normal_v1[1])*(normal_v3[0]-normal_v1[0])
#     print(normal_v1)
#     print(type(normal_v1))
#     normal_x3, normal_y3, normal_z3 = normal_v1 - normal_v2
    normal_x3, normal_y3, normal_z3 = str(normal_x3), str(normal_y3), str(normal_z3)
    
    normal_tolerance = str(feature[6])
    plane_attributes = {'name': plane_name, 'featureId': plane_featureId, 'optional': plane_optional,
                        'disabled': plane_disabled, 'weight': plane_weight}
    plane = et.SubElement(pharmacophore, 'plane', attrib=plane_attributes)
    position_attributes = {'x3': position_x3, 'y3': position_y3, 'z3': position_z3, 'tolerance': position_tolerance}
    et.SubElement(plane, 'position', attrib=position_attributes)
    normal_attributes = {'x3': normal_x3, 'y3': normal_y3, 'z3': normal_z3, 'tolerance': normal_tolerance}
    et.SubElement(plane, 'normal', attrib=normal_attributes)
    return



def pml_feature_vector(pharmacophore, feature):
    """ This function generates an xml branch for hydrogen bonds. """
#     print(len(feature[5]))
#     for index in range(len(feature[5])):  # all as donor
    vector_name = 'HBD'
    vector_featureId = '{}_{}'.format(feature[1], feature[0])
    
#     print("feature[1]:", feature[1])
    if feature[1] in ['HBD', 'ha2', 'hd2', 'hda']:
        vector_featureId = '{}_{}'.format(feature[1], feature[0])

#         print("vector_featureId:", vector_featureId)
    vector_pointsToLigand = 'false'
    vector_hasSyntheticProjectedPoint = 'false'
    vector_optional = 'false'
    vector_disabled = 'false'
    vector_weight = str(feature[7])
    origin_x3, origin_y3, origin_z3 = str(feature[3][0]), str(feature[3][1]), str(feature[3][2])
    origin_tolerance = str(feature[4])
    target_x3, target_y3, target_z3 = [str(feature[5][0]), str(feature[5][1]),
                                       str(feature[5][2])]
    target_tolerance = str(feature[6])

#     print("origin_x3:", origin_x3)

    if feature[1] in ['HBA', 'ha', 'ha2'] or (feature[1] == 'hda'):  # switch to acceptor
        vector_name = 'HBA'
        vector_pointsToLigand = 'true'
        origin_x3, origin_y3, origin_z3, target_x3, target_y3, target_z3 = [target_x3, target_y3, target_z3,
                                                                            origin_x3, origin_y3, origin_z3]
        origin_tolerance, target_tolerance = target_tolerance, origin_tolerance
    vector_attributes = {'name': vector_name, 'featureId': vector_featureId,
                         'pointsToLigand': vector_pointsToLigand,
                         'hasSyntheticProjectedPoint': vector_hasSyntheticProjectedPoint,
                         'optional': vector_optional,
                         'disabled': vector_disabled, 'weight': vector_weight}
    vector = et.SubElement(pharmacophore, 'vector', attrib=vector_attributes)
    origin_attributes = {'x3': origin_x3, 'y3': origin_y3, 'z3': origin_z3, 'tolerance': origin_tolerance}
    et.SubElement(vector, 'origin', attrib=origin_attributes)
    target_attributes = {'x3': target_x3, 'y3': target_y3, 'z3': target_z3, 'tolerance': target_tolerance}
    et.SubElement(vector, 'target', attrib=target_attributes)
    return


def pml_feature(pharmacophore, feature):
    """ This function distributes features according to their feature type to the appropriate feature function. """
    if feature[1] in ['H', 'NI', 'PI']:
        pml_feature_point(pharmacophore, feature)
    elif feature[1] in ['HD', 'HD2', 'HA', 'HA2', 'HBA', 'HBD']:
        pml_feature_vector(pharmacophore, feature)
    elif feature[1] == 'AR':
        pml_feature_plane(pharmacophore, feature)
    return


def indent_xml(element, level=0):
    """ This function adds indentation to an xml structure for pretty printing. """
    i = "\n" + level*"  "
    if len(element):
        if not element.text or not element.text.strip():
            element.text = i + "  "
        if not element.tail or not element.tail.strip():
            element.tail = i
        for element in element:
            indent_xml(element, level + 1)
        if not element.tail or not element.tail.strip():
            element.tail = i
    else:
        if level and (not element.tail or not element.tail.strip()):
            element.tail = i


def pml_pharmacophore(features, directory, name):
    """ This function generates an xml tree describing a pharmacophore that is written to a pml file. """
    pharmacophore = et.Element('pharmacophore', attrib={'name': name, 'pharmacophoreType': 'LIGAND_SCOUT'})
    for feature in features:
        pml_feature(pharmacophore, feature)
    indent_xml(pharmacophore)
    tree = et.ElementTree(pharmacophore)
    tree.write('{}/{}'.format(directory, name), encoding="UTF-8", xml_declaration=True)
    return


def pharmacophore_writer(data, dynophore_dict, pdb_path, pml_path, dyno_path, include_time = False, name = "cluster_pharmacophore", directory = "./output/"):
    """ This function writes out pharmacophores. """
    wrap_data = parsers.get_wrap_data(data)
    features = extract_superfeature(wrap_data, dynophore_dict, pdb_path, pml_path, dyno_path, include_time = include_time)
    pml_pharmacophore(features, directory, '{}.{}'.format(name, 'pml'))
    print("Pharmacophore successfully written into pml file.")
    return