In [71]:
import numpy as np
import pandas as pd
import networkx as nx
import plotly.io as pio
import plotly.graph_objects as go
from IPython.display import HTML, display
import IPython 
from plotly.offline import init_notebook_mode
import sys, os, argparse

CONFIG_FILE = '.ipynb.config'
if os.path.isfile(CONFIG_FILE):
    with open(CONFIG_FILE) as f:
        sys.argv = f.read().split()
else:
	sys.argv = ['run_notebook.py', '/home/john.palmer/work/vibrio/results-may5-1/mlst/master.tsv']

parser = argparse.ArgumentParser()
parser.add_argument("mlst",help="Input master MLST TSV file containing all samples in a run.")
args = parser.parse_args()

mlst_path = args.mlst

init_notebook_mode(connected=True)
pio.renderers.default = "notebook"

display(HTML("<style>.container { width:100% !important; }</style>"))
def compare(mlst1, mlst2):
	diffs = 0 
	for a, b in zip(mlst1, mlst2):
		if a != b:
			diffs += 1
	return diffs

In [72]:
cols = 'sample organism mlst 1 2 3 4 5 6 7'.split()
df = pd.read_csv(mlst_path, sep='\t', names=cols)


In [73]:

counts = df['mlst'].value_counts().to_dict()

uniq = df.groupby('mlst').first()
uniq = uniq.drop(uniq.index[uniq.index=='-']).reset_index()
# uniq.loc[:,'1':'7'] = uniq.loc[:,'1':'7'].apply(lambda x : x.str.split("(").str[0])

diff_df = []
for n in range(uniq.shape[0]):
	row = []
	for m in range(n+1, uniq.shape[0]):
		mlst1 = uniq.loc[n, '1':'7'].tolist()
		mlst2 = uniq.loc[m, '1':'7'].tolist()
		row.append((uniq.loc[n,'mlst'], uniq.loc[m,'mlst'], compare(mlst1, mlst2)))
	
	diff_df += row


{'417': 24, '36': 19, '631': 5, '-': 4, '43': 2, '68': 1, '3': 1, '65': 1}
[('3', '36', 7), ('3', '417', 6), ('3', '43', 7), ('3', '631', 7), ('3', '65', 7), ('3', '68', 7), ('36', '417', 7), ('36', '43', 7), ('36', '631', 7), ('36', '65', 7), ('36', '68', 7), ('417', '43', 7), ('417', '631', 7), ('417', '65', 7), ('417', '68', 7), ('43', '631', 7), ('43', '65', 7), ('43', '68', 7), ('631', '65', 7), ('631', '68', 7), ('65', '68', 7)]


In [74]:
G = nx.Graph()

for node1, node2, weight in diff_df:
	if weight < 7:
		G.add_edge(node1, node2, weight=weight )
	else:
		G.add_node(node1)
		G.add_node(node2)

mst = nx.minimum_spanning_tree(G)

In [75]:
def get_edge_trace(graph, positions):
	edge_x = []
	edge_y = []
	edge_text = []

	for n1, n2 in graph.edges():
		x0, y0 = positions[n1]
		x1, y1 = positions[n2]
		edge_x += [x0, (x0+x1)/2, x1, None]
		edge_y += [y0, (y0+y1)/2, y1, None]
		edge_text += [None, str(graph.get_edge_data(n1,n2)['weight']), None, None]
		#node_text += [None, str(mst.get_edge_data(n1,n2)['weight']), None, None]
	edge_trace = go.Scatter(
		x=edge_x, y=edge_y,
		line=dict(width=0.8, color='#888'),
		hoverinfo='none',
		text=edge_text,
		textposition='top center',
		mode='lines+text')
	return edge_trace

def get_node_trace(graph, positions):
	node_x = []
	node_y = []

	for node in graph.nodes():
		x, y = positions[node]
		node_x.append(x)
		node_y.append(y)

	node_sizes = []
	node_text = []

	for node in graph.nodes:
		adj = dict(graph.adjacency())
		# node_sizes.append(len(adj[node]))
		node_sizes.append(np.sqrt(counts[node]*100))
		node_text.append(f'MLST: {node}<br># samples: {counts[node]}<br># connections: {str(len(adj[node]))} \n')

	node_trace = go.Scatter(
		x=node_x, y=node_y,
		mode='markers+text',
		hoverinfo='text',
		hovertext=node_text,
		text=[str(x) for x in graph.nodes],
		textposition='bottom center',
		marker=dict(
			showscale=True,
			# colorscale options
			#'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
			#'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
			#'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
			colorscale='YlGnBu',
			reversescale=True,
			color=[],
			size=node_sizes,
			colorbar=dict(
				thickness=15,
				title='Number of Samples',
				xanchor='left',
				titleside='right'
			),
			line_width=2, line_color='black'))
	
	node_trace.marker.color = node_sizes
	return node_trace	


In [124]:

def plot_graph(graph):
	pos_dict = nx.planar_layout(graph)#, weight='weight')

	edge_trace = get_edge_trace(graph, pos_dict)
	node_trace = get_node_trace(graph, pos_dict)

	fig = go.Figure(data=[edge_trace, node_trace],
			layout=go.Layout(
				title='MLST Minimum Spanning Tree',
				titlefont_size=16,
				height=700, width=1300,
				showlegend=False,
				hovermode='closest',
				#margin=dict(b=20,l=5,r=5,t=40),
				annotations=[ dict(
					text="By: John P.",
					showarrow=False,
					xref="paper", yref="paper",
					x=-0.005, y=-0.002 ) ],
				xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
				yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
			)
	fig.update_layout(
		font=dict(
			#family="Courier New, monospace",
			size=15,  # Set the font size here
		)
	)
	return fig

In [125]:
fig = plot_graph(mst)
fig.show()