In [None]:
from sklearn.metrics import pairwise_distances
from scipy.signal import argrelmax, argrelmin
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import MDAnalysis as mda

import ipynb_importer
import visualize, parsers


def add_distance_mat(data, dyno_dict, include_time = False):
    max_frame = max([x["frames"][-1] for x in data.values()])
    
    for key in dyno_dict.keys():
        data_tmp = data[key]["points"]
        if len(data_tmp.shape) == 1:
            data_tmp.reshape(-1, 1)
        distances = pairwise_distances(data_tmp)
        
        # add time component in distance matrix
        if include_time:
            frames_norm = (np.sqrt(data[key]["frames"])/max_frame).reshape(-1,1)  # why sqrt: flexible conformation shows displacement within local time, not through the whole trajectory
            frames_distance = pairwise_distances(frames_norm)
            distances = distances + frames_distance
            distances = distances/np.max(distances)

        data[key]["distances"] = distances
    return data


def add_auto_param(data, dyno_dict, include_time = False, info_table = True, 
                   frequency_cutoff = 0.06, plot = True):
    '''Preliminary prediction on radius_cutoff and cnn_cutoff
       based on data distribution
       
       input: dict
       output: dict with new key
           ["params"] = {
                        "radius_cutoff": radius,
                        "cnn_cutoff": cnn_cutoff,
                        "member_cutoff": member_cutoff
                    }
       '''
    default_cluster_params = {
        "radius_cutoff": 0.2,
        "cnn_cutoff": 5,
        "member_cutoff": 10
    }
    
    for i, key in enumerate(dyno_dict.keys()):
        if not include_time:
            data_tmp = data[key]["points"]
            distances = pairwise_distances(data_tmp)
            radius_multi = 0.8
        else:
            '''include time'''
            distances = data[key]["distances"]
            radius_multi = 0.6

        min_n_cluster = 0
        min_freq = 0
        frame_nr = len(distances[0])
        max_frame = max([x["frames"][-1] for x in data.values()])
        radius = 0
        
        weights = np.zeros_like(distances.flatten()) + 1. / distances.flatten().size
        y, x, _ = plt.hist(distances.flatten(), bins=100, color='y', weights=weights)
        plt.close()

        if frame_nr >= frequency_cutoff*max_frame:  # TODO: depends on frequency
            n, bins = np.histogram(distances, 20)

            # get the peaks and valleys' position
            x_left = np.min(distances)
            x_right = np.max(distances)
            interval = (x_right - x_left)/20

            # predict min cluster number based on distance distribution
            n_cluster = len(argrelmax(n)[0])
            
            if n_cluster <= 4: # cluster > 4 could resulted from noice data, so not predicted
                # get minimal frequency at valleys
                valley_pos = [i*5 for i in argrelmin(n)]

                if len(y[tuple(valley_pos)]):
                    min_freq = round(min(y[tuple(valley_pos)]), 3)
                    max_freq = round(np.max(y), 3)

                min_n_cluster = n_cluster

                max_ = (np.array(argrelmax(n))*interval)[0]
                min_ = (np.array(argrelmin(n))*interval)[0]
                min_ = np.insert(min_, 0, 0)

                # predict radius based on average peak width/2
                min_len = min(len(max_), len(min_))
                radius = round(np.mean(max_[:min_len] - min_[:min_len])*radius_multi, 3)

                # predict cnn_cutoff based on points number: 5%*points number
                if min_freq != 0:
                    cnn_cutoff_ = round(min(frame_nr*min_freq, frame_nr*0.05, 15))
                else:
                    cnn_cutoff_ = round(min(frame_nr*0.05, 15))

                data[key]["params"] = {
                    "radius_cutoff": radius,
                    "cnn_cutoff": cnn_cutoff_,
                    "member_cutoff": 10
                }
        else:
            data[key]["params"] = default_cluster_params

        data[key]["min_cluster_n"] = min_n_cluster
        
        if plot:
            value = data[key]["params"]["radius_cutoff"]
            plt.figure(figsize=(4,2))
            plt.hist(distances.flatten(), bins=100, color='y')
            plt.axvline(value)

            plt.title(f"{i}: {key}", fontsize = 6)
            plt.annotate(f"r = {round(value, 2)}", (0.05, 0.95), xycoords="axes fraction", fontsize=6)
            plt.show()
        # print infor table
        if info_table:
            info_table = [min_freq, min_n_cluster, radius, cnn_cutoff_]
            print(i, key)
            print ("{:<15} {:<15} {:<8} {:<15}".format('min_frequency','min_cluster','radius','cnn_cutoff')) # ti
            print ("{:<15} {:<15} {:<8} {:<15}".format(min_freq, min_n_cluster, radius, cnn_cutoff_))
            print("-"*50)
        
    return data

def get_state_matrix(data):
    '''wrap up state information of all superfeatures into one ndarray
       output: ndarray of shape (max_frame_nr, superfature_count)
       '''
    state_matrix = []
    max_frame = max([x["frames"][-1] for x in data.values()])

    for fkey in data.keys():
        current_frame = 0
        state_matrix.append([])
        padded = state_matrix[-1]
        for frame_id, clabel in zip(data[fkey]["frames"], data[fkey]["clustering"].labels):

            while frame_id > current_frame:
                padded.append(0)
                current_frame += 1
            padded.append(clabel)
            current_frame += 1
        while current_frame <= max_frame:
            padded.append(0)
            current_frame += 1
            
    state_matrix = np.asarray(state_matrix).T
        
    return state_matrix


def get_one_hot_encoding(state_matrix):
    '''Transform state_matrix into one hot key matrix
       output: ndarray
           frame | existance of state 0 in superfeature 1 | existance of state 1 in superfeature 1...
           1     | 1 (means exist)                        | 0 (means absence)
           2     ...
           3     ...
    '''
    encoder = OneHotEncoder(sparse = False)
    one_hot_matrix = encoder.fit_transform(state_matrix)
    return one_hot_matrix


def get_frames_each_cluster(pam, shift = 0):
    '''input: KMedoids object
       output: dict
           {binding state 1: ndarray of frames,
            binding state 1: ...}'''
    state_nr = np.max(pam.labels_) + 1
    cluster_frames_map = {
        k: np.where(pam.labels_ == k)[0]
        for k in range(state_nr)
    }
    return cluster_frames_map


def get_state_statistis(pam, data, dynophore_dict):
    '''Get frame count of each superfeature state for each binding poses
       output: dict
           e.g. {binding pose 1: {superfeature 1:{feature state 0: count of frames, feature state 1: count of frames},
                                 {superfeature 2:{feature state 0: count of frames},
                                 {superfeature 3...}}}'''
    state_matrix = get_state_matrix(data)
    cluster_frames_map = get_frames_each_cluster(pam)
    state_statistis = {}

    for state_idx, frames in cluster_frames_map.items():
        state_data = {}
        data_per_state = state_matrix[frames]
        for feature_idx, feature in enumerate(dynophore_dict.keys()):
            data_per_feature_state = data_per_state[:, feature_idx]
            stata_feature_data = {}
            for cluster in np.unique(data_per_feature_state):
                cluster_count = len(data_per_feature_state[data_per_feature_state == cluster])
                stata_feature_data[cluster] = cluster_count
            state_data[feature_idx] = stata_feature_data

        state_statistis[state_idx] = state_data
        
    return state_statistis
    
    
def get_feature_freq_per_state(state_statistis, data):
    '''Get frequency of each superfeature in each binding pose
       output: dict
           e.g. {binding pose 1: {superfeature 1:frequency,
                                 {superfeature 2:frequency,
                                 {superfeature 3...}}}'''
    max_frame = max([x["frames"][-1] for x in data.values()])
    
    feature_per_state = {}
    for state_idx, data_ in state_statistis.items():
        cache = {}
        for feature, _data_ in data_.items():
            number = 0
            for state, count in _data_.items():
                if state > 0:
                    number += count
            cache[feature] = number/max_frame
        feature_per_state[state_idx] = cache
    return feature_per_state


def get_interact_summary(data, pam, dynophore_dict):
    '''output: DataFrame
       e.g.
                0	1	2	3	feature
    0	0.077521	0.130609	0.067970	0.290982	H[3187,3181,3178,3179,3185,3183]
    1	0.000000	0.000000	0.000000	0.000000	H[3146]
    2	0.249000	0.117503	0.166148	0.432474	H[3150,3152,3158,3156,3154,3149]'''
    state_matrix = get_state_matrix(data)
    state_statistis = get_state_statistis(pam, data, dynophore_dict)
    feature_per_state = get_feature_freq_per_state(state_statistis, data)
    
    feature_per_state_df = pd.DataFrame(feature_per_state)
    interact_summary = feature_per_state_df.copy()
    interact_summary["feature"] = dynophore_dict.keys()
    return interact_summary


def reduce_frames(pdb_path, dcd_path, select = "protein or chainID X", out_path = f"output/reduced.dcd", final_n_frame = 500):
    '''write reduce trajectory to desired length in output folder'''
    u = mda.Universe(pdb_path, dcd_path)
    frames = [i for i in range(0, len(u.trajectory), len(u.trajectory)//final_n_frame)]
    n_atoms = int(str(u)[-11:-7])
    
    with mda.Writer(out_path, n_atoms = n_atoms) as W:
        for ts in u.trajectory[frames]:
            W.write(u.select_atoms(select))
            
            
def get_geo_center(data):
    x, y, z = (max(data[:, 0]) + min(data[:, 0]))/2, (max(data[:, 1]) + min(data[:, 1]))/2, (max(data[:, 2]) + min(data[:, 2]))/2
    return (x, y, z)