In [1]:
"""
Contains parsers for
- Dynophore JSON file, which contains statistis on the occurrences and distances of superfeatures
  and environmental partners
- Dynophore PML file, which contains 3D coordinates for points in superfeature
  point clouds
"""

import xml.etree.ElementTree as ET
import numpy as np
import ipynb_importer
import compute
import dynophores as dyno
from matplotlib import colors
import ipynb_importer
import write


def extract_coordinates(dyno_dict, feature_key):
    """Extract xyz-trajectory for one feature into a 2D NumPy array"""
    coordinates = []
    last_frame = None
    for p in dyno_dict[feature_key]['points']:
        if p["frame_ix"] == last_frame:
            continue
        last_frame = p["frame_ix"]
        coordinates.append([p["x"], p["y"], p["z"]])
    return np.asarray(coordinates)


def extract_time(dyno_dict, feature_key):
    """Extract frame id for one feature into a 1D NumPy array"""
    time = []
    last_frame = None
    for p in dyno_dict[feature_key]['points']:
        if p["frame_ix"] == last_frame:
            continue
        last_frame = p["frame_ix"]
        time.append(p["frame_ix"])
    return np.asarray(time)


def extract_norm(dyno_dict):
    '''Normalize xyz coordinates and add frame information'''
    data = {}
    for key in dyno_dict.keys():
        points_tmp = extract_coordinates(dyno_dict, key)
        if points_tmp != []:
            points_min, points_max = np.min(points_tmp), np.max(points_tmp)
            norm_points = (points_tmp-points_min)/(points_max-points_min)
            frames = extract_time(dyno_dict, key)
            data[key] = {
                "points": norm_points,
                "frames": frames,
                "non_norm": points_tmp
            }
        else:
            frames = extract_time(dyno_dict, key)
            data[key] = {
                "points": np.array([0, 0, 0]).reshape(-1, 1),
                "frames": frames,
                "non_norm": np.array([0, 0, 0]).reshape(-1, 1)
            }
    return data


def pml_to_dict(pml_path, n_drop = 0):
    """
    Parse PML file content (selections of it).

    Parameters
    ----------
    pml_path : str or pathlib.Path
        Path to PML file.
    n_drop : int
        number of frames to drop counting from frame 0

    Returns
    -------
    dict
        Superfeature data with the following keys and nested keys (key : value data type):

        Example:
        - <superfeature id>
          - id : str
          - color : str
          - center : numpy.array
          - points : list
            - x : float
            - y : float
            - z : float
            - frame_ix : int
            - weight : float
    """

    dynophore3d_xml = ET.parse(pml_path)
    dynophore3d_dict = {}

    feature_clouds = dynophore3d_xml.findall("featureCloud")
    for feature_cloud in feature_clouds:

        # Superfeature ID
        superfeature_feature_name = feature_cloud.get("name")
        superfeature_atom_numbers = feature_cloud.get("involvedAtomSerials")
        superfeature_id = f"{superfeature_feature_name}[{superfeature_atom_numbers}]"
        # Superfeature color
        superfeature_color = feature_cloud.get("featureColor")
        # Superfeature cloud center
        center = feature_cloud.find("position")
        center_data = np.array(
            [
                float(center.get("x3")),
                float(center.get("y3")),
                float(center.get("z3")),
            ]
        )
        # Superfeature cloud points
        additional_points = feature_cloud.findall("additionalPoint")
        additional_points_data = []
        for additional_point in additional_points:
            frame_ix = int(additional_point.get("frameIndex"))
            #################
#             print("original:", frame_ix)

            if frame_ix >= n_drop:
#                 print("after:", frame_ix - n_drop)
                additional_point_data = {
                    "x": float(additional_point.get("x3")),
                    "y": float(additional_point.get("y3")),
                    "z": float(additional_point.get("z3")),
                    "frame_ix": frame_ix - n_drop,
                    "weight": float(additional_point.get("weight")),
                }
                additional_points_data.append(additional_point_data)

        dynophore3d_dict[superfeature_id] = {}
        dynophore3d_dict[superfeature_id]["id"] = superfeature_id
        dynophore3d_dict[superfeature_id]["color"] = superfeature_color
        dynophore3d_dict[superfeature_id]["center"] = center_data
        dynophore3d_dict[superfeature_id]["points"] = additional_points_data

    return dynophore3d_dict, dynophore3d_xml


def pre_process(pml_path, n_drop = 0, include_time = False):
    '''Data pre-processing work flow'''
    dynophore_dict, _ = pml_to_dict(pml_path, n_drop = n_drop)
    data = extract_norm(dynophore_dict)
    data = compute.add_distance_mat(data, dynophore_dict, include_time = include_time)
    max_frame = max([x["frames"][-1] for x in data.values()])
    print(f"Data pre-processed: {max_frame} frames in trajectory")
    
    return data, dynophore_dict


def get_wrap_data(data):
    '''
    wrap up information for all clusters of all superfeatures for visualization
    
    output:
     [x, y, z, (frame), label in each superfeature, superfeature_nr]
    '''
    wrap_data = []
    for i, key in enumerate(data.keys()):
        final_data = data[key]["non_norm"]
        cluster_temp = data[key]["clustering"]
        label = cluster_temp._labels.labels
        superfeature_nr = np.array([i] * len(final_data))
        final_data = np.column_stack((final_data, label))
        final_data = np.column_stack((final_data, superfeature_nr))
        
        wrap_data.append(final_data)
    
    return wrap_data


def get_feature_name(feature):
    for i in range(len(feature)):
        if feature[i] == '[':
            return feature[:i]
        

def get_env_partner_coord(feature_name, dynophore, pdb_path):
    env_partners = list(dynophore.superfeatures[feature_name].__dict__['envpartners'].keys())

    atom_ids = []

    for partner in env_partners:
        for i in range(len(partner)):
            if partner[i] == '[':
                left = i+1
                ids = partner[left:-1].split(',')
                for idx in ids:
                    atom_ids.append(int(idx))
    partner_coords = np.empty(shape=(0, 3))
    pdb_array = write.get_atoms_coord(pdb_path)
    for idx in atom_ids:
        coord = pdb_array[idx]
        partner_coords = np.vstack([partner_coords, coord])

    x, y, z = np.mean(partner_coords[:, 0]), np.mean(partner_coords[:, 1]), np.mean(partner_coords[:, 2])
    return (x, y, z)