In [98]:
import numpy as np
import networkx as nx
import itertools as it

def clusterDetection(G, threshold):
    """
    clusters the variants based on a given threshold; to do so, edges with a weight above the threshold are deleted from the given graph respresenting the optimal mappings

    :param threshold: the variant threshold the algorithm should use
    :return: list of subgraphs where each subgraph represents a cluster of variants

    """

    edges = list(G.edges(data=True))
    
    # remove the edges above the threshold (and below 0)
    for node1, node2, weight in edges:
        if weight['weight'] > threshold or weight['weight'] < 0:
            G.remove_edge(node1, node2)

    # get the subgraphs of the graph created this way
    subgraphNodes= nx.k_edge_subgraphs(G, k=1)
    subgraphs=[G.subgraph(nodes) for nodes in subgraphNodes]
    return subgraphs



def horizontalRefinement(candidateLabels, graphList):
    """
    Performs horizontal relabelling of event labels within a cluster; each event that belongs to the candidate labels will get a unique new label per cluster

    :param candidateLabels: s list of lsbels that should be refined
    :param graphList: a list of subgraphs where each subgraph represents a cluster of variants
    :return: s list of refined subgraphs, where the attribute 'newLabel' is changed for each candidate label, such that the event labels are unique per cluster
    """

    counter=1
    for subgraph in graphList:
        for label in candidateLabels:
            for node, dict in list(subgraph.nodes(data=True)):
                if dict['curLabel'] == label:
                    dict['newLabel'] += str(counter)
        counter += 1

    return graphList






In [99]:
G = nx.Graph()
G.add_nodes_from([(1, {'curLabel':'a', 'newLabel':'a'}),(2,{'curLabel':'b', 'newLabel':'b'}), (3,{'curLabel':'c', 'newLabel':'c'}),(4,{'curLabel':'b', 'newLabel':'b'}), (5,{'curLabel':'a', 'newLabel':'a'})])
G.add_edges_from([(1, 3, {'weight': 0}), (4, 5, {'weight': 0}), (2, 4, {'weight': 0.05}), (1,5,{'weight':0.7})])

subgraphs=clusterDetection(G,0.5)
print(subgraphs[0].nodes())
print(subgraphs[1].nodes(), "\n")

graphs=horizontalRefinement(['a','b'], subgraphs)
print(graphs[0].nodes(data=True))
print(graphs[1].nodes(data=True))


[1, 3]
[2, 4, 5] 

[(1, {'curLabel': 'a', 'newLabel': 'a1'}), (3, {'curLabel': 'c', 'newLabel': 'c'})]
[(2, {'curLabel': 'b', 'newLabel': 'b2'}), (4, {'curLabel': 'b', 'newLabel': 'b2'}), (5, {'curLabel': 'a', 'newLabel': 'a2'})]
