In [None]:
from tqdm import tqdm
import numpy as np
from cnnclustering import cluster
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn_extra import cluster as skecluster
import ipynb_importer
import visualize, compute

def do_cluster(data, include_time = False, redo = None, plot = True, v = True):
    '''cluster data within each superfeature sequently
       output: dict, list
               data with fitted cluster, superfeatures needed redo'''
    # setup visualization params
    mpl.rcParams["figure.dpi"] = 120

    fig, Ax = plt.subplots(len(data), 4,
                           figsize=(mpl.rcParams["figure.figsize"][0] * 3, mpl.rcParams["figure.figsize"][1] * len(data) * 1.2))
    default_cluster_params = {
        "radius_cutoff": 0.2,
        "cnn_cutoff": 5,
        "member_cutoff": 10
    }
    need_redo = {}
    
    if not redo:
        redo = data.keys()
        
    for i, (fkey, data_) in enumerate(data.items()):
        if "clustering" not in data[fkey]:
            if include_time:
                data[fkey]["clustering"] = cluster.Clustering(data_["distances"], registered_recipe_key="distances")
            else:
                data[fkey]["clustering"] = cluster.Clustering(data_["points"])
        if not include_time:
            for axi, d in enumerate([0, 1]):
                data[fkey]["clustering"].evaluate(
                    ax=Ax[i, axi], dim=(d, d+1),
                    original=True
                )
        if fkey in redo:
            print(i, fkey)
            data[fkey]["clustering"].fit(**data_.get("params", default_cluster_params), v=v)
        
        # plot
            if include_time and plot:
                visualize.visualize_clustering(data_, i, fkey)
            n_cluster_real = list(data[fkey]["clustering"].summary.to_DataFrame()["n_clusters"])[-1]
            ratio_noise = list(data[fkey]["clustering"].summary.to_DataFrame()["ratio_noise"])[-1]
                
            reasons = []
            if n_cluster_real < data[fkey]['min_cluster_n']:
                print(f"Fewer clusters than expected! Now: {n_cluster_real}  Expected: {data[fkey]['min_cluster_n']}")
                reasons.append("n_clusters")
            if ratio_noise > 0.05 and ratio_noise < 1:
                print("Noise over 5%!")
                reasons.append("noise")
            if len(reasons):
                need_redo[fkey] = {"reasons": reasons, "idx": i}
        
        if not include_time:
            for axi, d in enumerate([0, 1], 2):
                data[fkey]["clustering"].evaluate(
                    ax=Ax[i, axi], dim=(d, d+1),
                )
            Ax[i, 0].annotate(f"{i}: {fkey}", (0.05, 0.95), xycoords="axes fraction", fontsize=10)
            
    if need_redo:
        print("Following feature will be reclustered")
        
        # formating info table
        print("{:<4} {:<35} {:<15}".format('idx', 'key', 'reason'))
        max_len_key = 0
        for key_ in need_redo.keys():
            max_len_key = max(max_len_key, len(key_))
        
        for key_, value_ in need_redo.items():
            idx, reason = value_["idx"], value_["reasons"]
            print("{:<4} {:<35} {:<15}".format(idx, key_, str(reason)))
    else:
        print("From computer's view, no cluster need manual parameter adjustment")
            
    return data, need_redo

# def do_cluster(data, include_time = False, redo = None, plot = True, v = True):
#     '''cluster data within each superfeature sequently
#        output: dict, list
#                data with fitted cluster, superfeatures needed redo'''
    
#     default_cluster_params = {
#         "radius_cutoff": 0.2,
#         "cnn_cutoff": 5,
#         "member_cutoff": 10
#     }
#     need_redo = {}
    
#     if not redo:
#         redo = data.keys()
        
#     for i, (fkey, data_) in enumerate(data.items()):
#         if "clustering" not in data[fkey]:
#             if include_time:
#                 data[fkey]["clustering"] = cluster.Clustering(data_["distances"], registered_recipe_key="distances")
#             else:
#                 data[fkey]["clustering"] = cluster.Clustering(data_["points"])
        
#         if fkey in redo:
#             print(i, fkey)
#             data[fkey]["clustering"].fit(**data_.get("params", default_cluster_params), v=v)
        
#         # plot
#             if plot:
#                 visualize.visualize_clustering(data_, i, fkey)
#             n_cluster_real = list(data[fkey]["clustering"].summary.to_DataFrame()["n_clusters"])[-1]
#             ratio_noise = list(data[fkey]["clustering"].summary.to_DataFrame()["ratio_noise"])[-1]
        
#             reasons = []
#             if n_cluster_real < data[fkey]['min_cluster_n']:
#                 print(f"Fewer clusters than expected! Now: {n_cluster_real}  Expected: {data[fkey]['min_cluster_n']}")
#                 reasons.append("n_clusters")
#             if ratio_noise > 0.05 and ratio_noise < 1:
#                 print("Noise over 5%!")
#                 reasons.append("noise")
#             if len(reasons):
#                 need_redo[fkey] = {"reasons": reasons, "idx": i}
        
#     return data, need_redo


def parameter_scan(data, key, r_start = 0.05, r_end = 0.3, r_step = 0.05, c_statr = 0, c_end = 50, c_step = 2, include_time = False):
    '''parameter scan for data of specific superfeature
       output: tuple
           new parameters if found
           else old parameters
           
        during the process, parameters in input data are changed if solution found
    '''
    try:
        cluster_ = data[key]["clustering"]
    except:
        if include_time:
            data[key]["clustering"] = cluster.Clustering(data[key]["distances"], registered_recipe_key="distances")
        else:
            data[key]["clustering"] = cluster.Clustering(data[key]["points"])
    min_cluster_n = data[key]["min_cluster_n"]
    orig_radius_cutoff, orig_cnn_cutoff, orig_member_cutoff = data[key]["params"]["radius_cutoff"], data[key]["params"]["cnn_cutoff"], data[key]["params"]["member_cutoff"]
    for r in tqdm(np.arange(r_start, r_end, r_step)):
        for c in np.arange(c_statr, c_end, c_step):
            # fit from pre-calculated distances
            cluster_.fit(r, c, member_cutoff=10, v=False)
    
    # Get summary sorted by number of identified clusters
    df = cluster_.summary.to_DataFrame()
    df = df[(df.n_clusters == min_cluster_n)][(df.ratio_noise > 0)].sort_values(["ratio_noise", "radius_cutoff"])
    radius_cutoff, cnn_cutoff, ratio_noise = df.iloc[0][["radius_cutoff", "cnn_cutoff", "ratio_noise"]]
    
    # save the corrected parameters and cluster object
    if ratio_noise < 0.05:
        data[key]["params"]["radius_cutoff"], data[key]["params"]["radius_cutoff"] = radius_cutoff, cnn_cutoff
        cluster_.fit(radius_cutoff, cnn_cutoff, member_cutoff=10, v=True)
        print("Solution found with parameter scan")
        return (radius_cutoff, cnn_cutoff, orig_member_cutoff)
    else:
        print(f"Cannot find solution for {key}")
        return (orig_radius_cutoff, orig_cnn_cutoff, orig_member_cutoff)
    

def params_adjust(data, need_redo, include_time = False, repeat = 10, plot = True, v = True):
    '''Adjust parameter for features in need_redo
       output: list
           a list of feature names which need manual parameters'''
    please_manual = []
    need_redo_copy = need_redo.copy()

    for key, info in need_redo_copy.items():
        print("*"*100)
        print(f"Start working on {key}")

        fix_status = False
        reasons = info["reasons"]
        radius_cutoff, cnn_cutoff, member_cutoff = data[key]["params"].values()

        # only the noise problem
        if reasons == ["noise"]:
            time = 0
            print("Start adjusting parameters")
            while time < repeat and cnn_cutoff >= 2 and fix_status == False:  # if number of clusters wrong or noise is decreased, no need to continue running
                time += 1
                cnn_cutoff -= 2
                data[key]["params"]['cnn_cutoff'] = cnn_cutoff
                data_temp, need_redo_temp = do_cluster(data, include_time = include_time, plot = plot, redo = [key], v = v)
                if key not in need_redo_temp.keys():
                    fix_status = True
                    new_radius_cutoff, new_cnn_cutoff = (radius_cutoff, cnn_cutoff)
            if not fix_status:
                print("Parameter adjustment failed. Start parameter scan")
                (new_radius_cutoff, new_cnn_cutoff, new_member_cutoff) = parameter_scan(data, key)

        # if involve the problem of unsatisfied cluster number
        else:
            (new_radius_cutoff, new_cnn_cutoff, new_member_cutoff) = parameter_scan(data, key)
            if (new_radius_cutoff, new_cnn_cutoff, new_member_cutoff) != (radius_cutoff, cnn_cutoff, member_cutoff):
                fix_status = True
            
        if fix_status:
            del need_redo[key]
            data[key]["params"]["radius_cutoff"], data[key]["params"]["cnn_cutoff"] = new_radius_cutoff, new_cnn_cutoff
            print(f"Solution found for {key}")
        else:
            please_manual.append(key)
            print(f"Failed for {key}")
    return please_manual, data


def auto_cluster(data, dynophore3d_dict, include_time = False, only_result = True, info_table = True, 
                 frequency_cutoff = 0.06, redo = None, plot_search_parameter = True, 
                 plot_clustering = True, plot_params_adjust = False, repeat = 10, v = True):
    if only_result:
        info_table = False
        plot_search_parameter = False
        plot_clustering = False
        plot_params_adjust = False
        v = False
    
    data = compute.add_auto_param(data, dynophore3d_dict, include_time = include_time, 
                                  info_table = info_table, frequency_cutoff = frequency_cutoff, plot = plot_search_parameter)
    
    data, need_redo = do_cluster(data, include_time = include_time, redo = redo, plot = plot_clustering, v = v)
    print("1. Clustering attemp finished")
    if need_redo:
        please_manual, data = params_adjust(data, need_redo = need_redo, include_time = include_time, 
                                            repeat = repeat, plot = plot_params_adjust, v=v)

    # use adjusted parameter to do the clustering -> show result
    print("Result")
    data, need_redo = do_cluster(data, include_time = False, redo = None, plot = True, v = True)
    
    if please_manual:
        print("Please sepecify parameters for following features. No parameter found automatically.")
        print(please_manual)
    else:
        print("Clustering finished! You may adjust parameters for better result manually.")
        
    return please_manual


def get_binding_pose_cluster_inertia(one_hot_matrix, min_cluster = 2, max_cluster = 7):
    '''Scan cluster number for input one_hot_matrix
       plot inertia, i.e. the difference within each cluster given cluster number
       input: ndarray
           frame | existance of state 0 in superfeature 1 | existance of state 1 in superfeature 1...
           1     | 1 (means exist)                        | 0 (means absence)
           2     ...
           3     ...
        '''
    inertia = []
    for n_clusters in tqdm(range(min_cluster, max_cluster)):
        pam = skecluster.KMedoids(n_clusters = n_clusters, metric = "manhattan", method = "pam")
        pam.fit(one_hot_matrix)
        inertia.append(pam.inertia_)
    # plot  
    fig, ax = plt.subplots()
    ax.plot(range(min_cluster, max_cluster), inertia)
    ax.set_xlabel("#clusters")
    ax.set_ylabel("inertia: difference within each cluster")
    plt.show()
    
    
def get_feature_state_from_onehot_position(pos, states_per_interaction):
    '''Retrace which features and state is represented by a onehot matrix position'''
    cumsum = 0
    newcumsum = 0
    for i, state_count in enumerate(states_per_interaction):
        newcumsum += state_count
        if pos < newcumsum:
            return (i, pos - cumsum)
        cumsum = newcumsum
    return (i, pos)


def get_states_per_interaction(state_matrix):
    return [max(x) + 1 for x in state_matrix.T]


def get_center_features(data, state_matrix, pam):
    '''Get prominent features for cluster centers
       input: dict
           data
        output: dict
            {binding_state_nr: {feature_name: {'state': state_within_feature, 'idx': superfaeture_idx}
            e.g.
            {0: {'H[3187,3181,3178,3179,3185,3183]': {'state': 0, 'idx': 0}}
    '''
    center_features = {}
    features = list(data.keys())
    for i in range(pam.cluster_centers_.shape[0]):
        print(i, ":")
        features_tmp = {}
        present_feature_pos = np.where(pam.cluster_centers_[i] == 1)[0]
        states_per_interaction = get_states_per_interaction(state_matrix)

        for pos in present_feature_pos:
            feature_index, state = get_feature_state_from_onehot_position(pos, states_per_interaction)
            features_tmp[features[feature_index]] = {"state": state, "idx": feature_index}
            print(f"    {features[feature_index]:>40} state {state:<10}")
        
        center_features[i] = features_tmp
    return center_features


def binding_state_cluster(data, n_clusters = None):
    '''Interactive funtion to do one hot key clustering for recognizing binding poses
       input: dict, int
       output: KMedoids object'''
    state_matrix = compute.get_state_matrix(data)
    one_hot_matrix = compute.get_one_hot_encoding(state_matrix)
    if n_clusters == None:
        get_binding_pose_cluster_inertia(one_hot_matrix)
        n_clusters = int(input("Please give cluster number: "))
    pam = skecluster.KMedoids(n_clusters = n_clusters, metric = "manhattan", method = "pam")
    pam.fit(one_hot_matrix)
    pam.cluster_centers_
    
    get_center_features(data, state_matrix, pam)
    return pam