# Mosquito Sinks and Sources Detector

# Setup

In [1]:
import numpy as np
import pandas as pd
import networkx as nx
from collections import defaultdict
from sklearn.cluster import KMeans, AgglomerativeClustering

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cmx

# The Detector Class

In [2]:
class Detector:
    def __init__(self, transitions, locations, b_tol=1e-08, s_tol=1e-08, ss_vals=100000, method='kmeans', as_df=False, *args, **kwargs):
        """
        Inputs:
            transitions : A dataframe with transition probabilities
            locations   : A dataframe with an initial population column, 'pop'
            b_tol       : The tolerance for steady state detection.
            s_tol       : The tolerance for steady state detection
            ss_vals     : The number of possible steady-state steps to consider
            method      : The clustering method, which is either 'kmeans' | 'agglomerative' 
            as_df       : Whether or not to return the output as a dataframe
            *args       : Extra arguments for sklearn's kmeans function
            **kwargs    : Extra key word arguments for the clustering method
        """
        self.tmtx     = transitions
        self.locs     = locations
        self.b_tol    = b_tol 
        self.s_tol    = s_tol
        self.ss_vals  = ss_vals
        self.method   = method
        self.as_df    = as_df
        self.args     = args
        self.kwargs   = kwargs
        self.ss_step  = self.compute_sss(self.ss_vals)

    ####################
    # Private Methods: #
    ####################
    def compute_sss(self, vals):
        """
        Computes the steady state step using binary search from the 
        list of possible steps passed in.
        """
        lo = 0
        hi = vals - 1
        
        while lo <= hi:
            m = (hi + lo) // 2
            prev = self.migrate(m)
            curr = self.migrate(m + 1)
            
            if np.all(np.abs(curr - prev) <= self.s_tol):
                hi = m - 1
            else:
                lo = m + 1
                
        return lo
    
    def communities(self):
        """
        Returns a dictionary where the key is the community id and
        the value is a list of locations grouped together using the
        kmeans algorithm. More clustering methods may be added using
        the `self.method` class attribute.
        """
        if self.method == 'kmeans':
            cluster = KMeans(*self.args, **self.kwargs).fit(self.locs[['lon', 'lat']])
        elif self.method == 'agglomerative':
            cluster = AgglomerativeClustering(*self.args, **self.kwargs).fit(self.locs[['lon', 'lat']])
        else:
            raise ValueError("Method '{}' is not supported".format(self.method))
            
        communities = defaultdict(list)
        for loc, cid in enumerate(cluster.labels_):
                communities[cid].append(loc)
        return communities
    
    def proportions(self, communities, start_step, final_step):
        """
        Returns a dictionary such that each key is a community ID 
        and each value is a dictionary. In the second-level 
        dictionary, there are 5 keys which are explained in depth 
        below:
            1. num_in  -> number of mosquitos that migrated into 
                          this community since time step `start_step`
            2. num_out -> number of mosquitos that migrated out of 
                          this community since time step `start_step`
            3. prp_in  -> the proportion of mosquitos that entered 
                          the community since time step `start_step`
            4. prp_out -> the proportion of mosquitos that left the
                          community since time step `start_step`
            5. com     -> a list of the locations within this 
                          community
            6. type    -> either source, sink, or bridge
        """
        data = defaultdict(dict)
        for cid, community in communities.items():
            start = self.migrate(start_step)[community]
            final = self.migrate(final_step)[community]
            data[cid]['num_in']  = np.sum(final[final > start] - start[final > start])
            data[cid]['num_out'] = np.sum(start[final < start] - final[final < start])
            data[cid]['prp_in']  = data[cid]['num_in']  / (data[cid]['num_in'] + data[cid]['num_out'])
            data[cid]['prp_out'] = data[cid]['num_out'] / (data[cid]['num_in'] + data[cid]['num_out'])
            data[cid]['type']    = self.classify(data[cid]['prp_in'], data[cid]['prp_out'])
            data[cid]['com']     = community
        return data
    
    def classify(self, prop_in, prop_out):
        """
        Returns whether a community is a sink, source, or bridge.
        """
        diff = prop_in - prop_out
        if np.abs(diff) <= self.b_tol: return 'bridge'
        if diff < 0:                   return 'source'
        if diff > 0:                   return 'sink'
    
    def pipeline(self, start, final):
        """
        Performs all steps of sink/source detection.
        """
        # Computes the result
        if final == float('inf'):
            final = self.ss_step
        result = self.proportions(self.communities(), start, final)
        
        # Formats result
        if self.as_df:
            return pd.DataFrame(result).transpose().sort_index()
        else:
            return result
        
    def get_cids(self, name='cid'):
        """
        Maps each location to its community ID.
        """
        coms = self.communities() 
        if self.as_df:
            self.cids = self.locs.copy()
            self.cids[name] = self.locs.index
            for cid in coms:
                self.cids.loc[coms[cid], name] = cid
            return self.cids.copy()
        else: 
            return { loc : k for k, v in coms.items() for loc in v  }
            
    ###################
    # Public Methods: #
    ###################
    
    def set_params(self, b_tol=None, s_tol=None, ss_vals=None, method=None, as_df=None, *args, **kwargs):
        """
        Resets the parameters of this instance.
        """
        if b_tol is not None: self.b_tol = b_tol
        if s_tol is not None: self.s_tol = s_tol        
        if ss_vals is not None: self.ss_vals = ss_vals
        if method: self.method = method
        if  as_df  is not None: self.as_df = as_df
        if   args: self.args   = args
        if kwargs: self.kwargs = kwargs
            
    def run(self, start=0, final=float('inf')):
        """
        Runs sink source detection from time step `start` to time step
        `final`. By defualt, this runs from step 0 to steady-state.
        """
        self.data = self.pipeline(start, final)
        self.cids = self.get_cids()
        return self
    
    def clabels(self):
        """
        Returns each location and its corresponding community id/label.
        """
        return self.cids.copy()
    
    def results(self):
        """
        Returns the in/out proportions of each community along with
        its class (sink/source/bridge).
        """
        return self.data.copy()
    
    def migrate(self, k):
        """
        Returns the populations at each location after `k` time steps.
        """
        return np.linalg.matrix_power(self.tmtx, k) @ self.locs['pop']

# Visualization

In [3]:
def plot_data(tmtx_pdf,
              locs_pdf,
              coms_pdf,
              
              nodes_cm=cmx.cool,
              nodes_fn=lambda n: n,
              min_popl=0,
              
              bordr_cm=cmx.RdGy,
              bordr_mu=2,
              bordr_op=1,
              
              edges_co=None,
              edges_cm=cmx.magma,
              edges_fn=lambda w: w,
              min_prob=0,
              edges_mu=1,
              edges_op=0.75,
              
              solid_co=None,
              solid_cm=None,
              solid_fn=None,
              solid_mu=None,
              solid_op=None,
              
              dottd_co=None,
              dottd_cm=None,
              dottd_fn=None,
              dottd_mu=None,
              dottd_op=None,
              
              bgrd_crd=None,
              bgrd_clr='#6699cc',
              bgrd_opc=0.5,
              
              bgbd_lwd=1,
              bgbd_clr='#000000',
              bgbd_opc=0.5,
              
              fig_size=(10,10),
              axis_arg='auto',
              plt_bbar=False,
              plt_pbar=False,
              plt_dbar=False,
              plt_sbar=False,
              save_fig=False
):
    """
    Inputs:
    
        Required Parameters:
            tmtx_pdf : the probability transitions as a Pandas dataframe
            locs_pdf : the locations and cids as a Pandas dataframe
            coms_pdf : the community info as a Pandas dataframe
            
        Node Settings:
            nodes_cm : the colormap for the nodes
            nodes_fn : a function to apply to the node sizes
            min_popl : if a location's population is very small use this default value
            
        Node Border Settings:
            bordr_cm : the colormap for the node borders
            bordr_mu : multiplies this value by the node sizes to come up with border sizes
            bordr_op : opacity of edges (applied to every node border)

        Edge Settings:
            edges_co : an RGB string to set the color of all edges (this will override the colormap)
            edges_cm : the colormap for the edges
            edges_fn : a function to apply to the edge weights
            min_prob : draw an edge if it is at least this value
            edges_mu : multiplies this value by the edge weights (applied after edges_fn is called)
            edges_op : opacity of edges (applied to every edge)
            
        Solid Edge Settings (if none of these are specified, the edge settings will be used):
            solid_co : an RGB string to set the color of all solid edges (this will override the colormap)
            solid_cm : the colormap for the solid edges
            solid_fn : a function to apply to the solid edge weights
            solid_mu : multiplies this value by the solid edge weights
            solid_op : sets the opacity of all the solid edges
        
        Dotted Edge Settings (if none of these are specified, the edge settings will be used):
            dottd_co : an RGB string to set the color of all dotted edges (this will override the colormap)
            dottd_cm : the colormap for the dotted edges
            dottd_fn : a function to apply to the dotted edge weights
            dottd_mu : multiplies this value by the dotted edge weights
            dottd_op : sets the opacity of all the dotted edges
        
        Background Settings:
            bgrd_crd : a dictionary compatible with descartes' PolygonPatch object (see cell above for format)
            bgrd_clr : the background color as an RGB string
            bgrd_opc : the opacity of the background
            
        Background Border Settings:
            bgbd_lwd : the line width of the background border 
            bgbd_clr : the background border color as an RGB string
            bgbd_opc : the opacity of the background border
        
        Miscellaneous Settings:        
            fig_size : tuple of figure dimensions
            axis_arg : argument for matplotlib.axes.Axes.axis
            plt_bbar : plots the color bar for the borders (border values are arranged such that sink < bridge < source)
            plt_pbar : plots the color bar for the population sizes
            plt_dbar : plots the color bar for the dotted edges
            plt_sbar : plots the color bar for the solid edges
            save_fig : if true, saves the figure as a PNG file
    
    Notes:
        Populations are scaled up to the default value then nodes_fn is called
        If an edge color is specified, it overrides the color maps
    """
    # Adjust parameters
    if solid_co is None: solid_co = edges_co
    if dottd_co is None: dottd_co = edges_co
    if solid_cm is None: solid_cm = edges_cm
    if dottd_cm is None: dottd_cm = edges_cm
    if solid_fn is None: solid_fn = edges_fn
    if dottd_fn is None: dottd_fn = edges_fn
    if solid_mu is None: solid_mu = edges_mu
    if dottd_mu is None: dottd_mu = edges_mu
    if solid_op is None: solid_op = edges_op
    if dottd_op is None: dottd_op = edges_op
    
    # Ensures node numberings are consistent with transition matrix numberings
    locs_pdf = locs_pdf.reset_index()
    
    # Ensures all population sizes are greater than or equal to min_popl then apply function
    population_sizes = locs_pdf['pop'].apply(lambda p: max(p, min_popl)).apply(nodes_fn)                  
    
    # Removes self transitions and filter out any edges below the probability threshold
    tmtx = tmtx_pdf.values
    np.fill_diagonal(tmtx, 0)
    G = nx.from_numpy_matrix(tmtx)
    G.remove_edges_from([(n1, n2) for n1, n2, w in G.edges.data('weight') if w < min_prob])
    
    # Constructs a dictionary of nodes for plotting
    nodes = { i : (r['lon'], r['lat']) for i, r in locs_pdf.iterrows() }
    
    # Separates the dotted and solid edges
    d_edges, d_wghts = [], np.array([])
    s_edges, s_wghts = [], np.array([])
    for n1, n2, w in G.edges.data('weight'):
        if locs_pdf.loc[n1, 'cid'] != locs_pdf.loc[n2, 'cid']:
            d_edges.append((n1, n2))
            d_wghts = np.append(d_wghts, dottd_fn(w) * dottd_mu)
        else:
            s_edges.append((n1, n2))
            s_wghts = np.append(s_wghts, solid_fn(w) * solid_mu)
        
    # Sets up the figure
    fig = plt.figure(figsize=fig_size)
    
    # Plots the background
    ax = fig.gca()
    if bgrd_crd is not None:
        ax.add_patch(PolygonPatch(bgrd_crd, fc=bgrd_clr, ec=bgbd_clr, alpha=bgrd_opc, zorder=0))
        ax.add_patch(PolygonPatch(bgrd_crd, fill=False, ec=bgbd_clr, alpha=bgbd_opc, linewidth=bgbd_lwd, zorder=0))
    
    # Plots dotted edges
    dotted_edges = nx.draw_networkx_edges(G, nodes,
                           edgelist=d_edges,
                           style='dashed',
                           edge_color=dottd_co if dottd_co is not None else d_wghts,
                           edge_cmap=dottd_cm,
                           alpha=dottd_op,
                           width=abs(d_wghts))
    
    # Plots solid edges
    solid_edges = nx.draw_networkx_edges(G, nodes,
                           edgelist=s_edges,
                           style='solid',
                           edge_color=solid_co if solid_co is not None else s_wghts,
                           edge_cmap=solid_cm,
                           alpha=solid_op,
                           width=abs(s_wghts))
    
    # Maps each node to its type: sink, source, bridge
    types = locs_pdf.merge(coms_pdf[['type']], left_on='cid', right_index=True)\
                    .replace({'sink' : 1, 'bridge' : 2, 'source' : 3})['type']
    
    # Plots borders
    borderdata_nodes = nx.draw_networkx_nodes(G, nodes,
                                     node_color=types,
                                     cmap=bordr_cm,
                                     alpha=bordr_op,
                                     node_size=abs(population_sizes*bordr_mu))
        
    # Plots nodes
    population_nodes = nx.draw_networkx_nodes(G, nodes,
                                     node_color=locs_pdf['cid'].values,
                                     cmap=nodes_cm,
                                     node_size=abs(population_sizes))
    
    # Plots a colorbar for the border
    if plt_bbar:
        brd_cbar = plt.colorbar(borderdata_nodes)
        brd_cbar.ax.set_ylabel('source / bridge / sink',labelpad=15,rotation=270)
    
    # Plots a colorbar for the population
    if plt_pbar:
        pop_cbar = plt.colorbar(population_nodes)
        pop_cbar.ax.set_ylabel('population',labelpad=15,rotation=270)
        
    # Plots a colorbar for the dotted edges
    if dottd_co is None and plt_dbar:
        s = 'edges' if dottd_cm == solid_cm else 'dotted edges'
        dot_cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=dottd_cm))
        dot_cbar.ax.set_ylabel(s,labelpad=15,rotation=270)
        
    # Plots a colorbar for the solid edges
    if solid_co is None and plt_sbar:
        s = 'edges' if dottd_cm == solid_cm else 'solid edges'
        sol_cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=solid_cm))
        sol_cbar.ax.set_ylabel(s,labelpad=15,rotation=270)
    
    # Saves the figure if desired
    if save_fig: 
        plt.savefig("figure.png")
    
    ax.axis(axis_arg)
    plt.show()