In [1]:
import sys
import math
import random
import vtk
import numpy as np
import struct
from codecs import decode
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5 import Qt

from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor
class MainWindow(Qt.QMainWindow):

    def __init__(self, parent = None):
        Qt.QMainWindow.__init__(self, parent)
        
        ''' Step 1: Initialize the Qt window '''
        self.setWindowTitle("Microvascular Visualization")
        self.resize(1000,self.height())
        self.frame = Qt.QFrame() # Create a main window frame to add ui widgets
        self.mainLayout = Qt.QHBoxLayout()  # Set layout - Lines up widgets horizontally
        self.frame.setLayout(self.mainLayout)
        self.setCentralWidget(self.frame)
        
        ''' Step 2: Add a vtk widget to the central widget '''
        # As we use QHBoxLayout, the vtk widget will be automatically moved to the left
        global vtkWidget
        self.vtkWidget = QVTKRenderWindowInteractor(self.frame)
        self.mainLayout.addWidget(self.vtkWidget)
        
        #Initialize the vtk variables for the visualization tasks
        self.init_vtk_widget()
        
        # Add an object to the rendering window
        # self.add_vtk_object()
        
        ''' Step 3: Add the control panel to the right hand side of the central widget '''
        # Note: To add a widget, we first need to create a widget, then set the layout for it
        self.right_panel_widget = Qt.QWidget() # create a widget
        self.right_panel_layout = Qt.QVBoxLayout() # set layout - lines up the controls vertically
        self.right_panel_widget.setLayout(self.right_panel_layout) #assign the layout to the widget
        self.mainLayout.addWidget(self.right_panel_widget) # now, add it to the central frame
        
        # The controls will be added here
        self.add_controls()
                
        
    '''
        Initialize the vtk variables for the visualization tasks
    '''    
    def init_vtk_widget(self):
        #vtk.vtkObject.GlobalWarningDisplayOff() #Disable vtkOutputWindow - Comment out this line if you want to see the warning/error messages from vtk
        
        # Create the graphics structure. The renderer renders into the render
        # window. The render window interactor captures mouse events and will
        # perform appropriate camera or actor manipulation depending on the
        # nature of the events.
        self.ren = vtk.vtkRenderer() 
        self.vtkWidget.GetRenderWindow().AddRenderer(self.ren)
        self.iren = self.vtkWidget.GetRenderWindow().GetInteractor()
        # The following set the interactor for 2D image style (i.e., no rotation)
        self.ren.SetBackground(0.8,0.8,0.8) # you can change the background color here

        # Start the vtk screen
        self.ren.ResetCamera()
        self.show()
        self.iren.Initialize()
        self.iren.Start()

    '''
        Add QT controls to the control panel in the righ hand size
    '''
    def add_controls(self):
    
        ''' Add a sample group box '''
        groupBox = Qt.QGroupBox("3D Visualization of Microvascular structure") # Use a group box to group controls
        self.groupBox_layout = Qt.QVBoxLayout() #lines up the controls vertically
        groupBox.setLayout(self.groupBox_layout) 
        self.right_panel_layout.addWidget(groupBox)
  
        ''' Add a textfield ( QLineEdit) to show the file path and the browser button '''
        label = Qt.QLabel("Choose a file ('.nwt'):")
        self.groupBox_layout.addWidget(label)
        hbox = Qt.QHBoxLayout()
        self.qt_file_name = Qt.QLineEdit()
        hbox.addWidget(self.qt_file_name) 
        self.qt_browser_button = Qt.QPushButton('Browser')
        self.qt_browser_button.clicked.connect(self.on_file_browser_clicked)
        self.qt_browser_button.show()
        hbox.addWidget(self.qt_browser_button)
        file_widget = Qt.QWidget()
        file_widget.setLayout(hbox)
        self.groupBox_layout.addWidget(file_widget)
 
        ''' Add the Open button '''
        self.qt_open_button = Qt.QPushButton('Open')
        self.qt_open_button.clicked.connect(self.open_nwt_file)
        self.qt_open_button.show()
        self.groupBox_layout.addWidget(self.qt_open_button)
      
        '''Add Label annotating node'''
        label2 = Qt.QLabel("Radius of Node 1:")
        self.groupBox_layout.addWidget(label2)
        hbox_details = Qt.QHBoxLayout()
        self.qt_details = Qt.QLineEdit()
        hbox_details.addWidget(self.qt_details) 
        file_widget1 = Qt.QWidget()
        file_widget1.setLayout(hbox_details)
        self.groupBox_layout.addWidget(file_widget1)
        
        '''Add Label annotating node1'''
        label3 = Qt.QLabel("Number of connection to Node 1:")
        self.groupBox_layout.addWidget(label3)
        hbox_details2 = Qt.QHBoxLayout()
        self.qt_details2 = Qt.QLineEdit()
        hbox_details2.addWidget(self.qt_details2) 
        file_widget2 = Qt.QWidget()
        file_widget2.setLayout(hbox_details2)
        self.groupBox_layout.addWidget(file_widget2)
        
        '''Add Label annotating node 2'''
        label4 = Qt.QLabel("Radius of Node 2:")
        self.groupBox_layout.addWidget(label4)
        hbox_details3 = Qt.QHBoxLayout()
        self.qt_details3 = Qt.QLineEdit()
        hbox_details3.addWidget(self.qt_details3) 
        file_widget3 = Qt.QWidget()
        file_widget3.setLayout(hbox_details3)
        self.groupBox_layout.addWidget(file_widget3)
        
        '''Add Label annotating node'''
        label5 = Qt.QLabel("Number of connection to Node 2:")
        self.groupBox_layout.addWidget(label5)
        hbox_details4 = Qt.QHBoxLayout()
        self.qt_details4 = Qt.QLineEdit()
        hbox_details4.addWidget(self.qt_details4) 
        file_widget4 = Qt.QWidget()
        file_widget4.setLayout(hbox_details4)
        self.groupBox_layout.addWidget(file_widget4)
        
        '''Add Label annotating node'''
        label6 = Qt.QLabel("Number of Points in between Nodes:")
        self.groupBox_layout.addWidget(label6)
        hbox_details5 = Qt.QHBoxLayout()
        self.qt_details5 = Qt.QLineEdit()
        hbox_details5.addWidget(self.qt_details5) 
        file_widget5 = Qt.QWidget()
        file_widget5.setLayout(hbox_details5)
        self.groupBox_layout.addWidget(file_widget5)
        
        
        
        
        '''
         #Add the widgets for arrow plot 
        hbox_arrowplot = Qt.QHBoxLayout()
        self.qt_arrow_checkbox = Qt.QCheckBox("3D Plot ")
        self.qt_arrow_checkbox.setChecked(False)
        self.qt_arrow_checkbox.toggled.connect(self.on_3D_checkbox_change)  
        hbox_arrowplot.addWidget(self.qt_arrow_checkbox)
        arrow_widget = Qt.QWidget()
        arrow_widget.setLayout(hbox_arrowplot)
        self.groupBox_layout.addWidget(arrow_widget)
        '''
        
    def on_file_browser_clicked(self):
        dlg = Qt.QFileDialog()
        dlg.setFileMode(Qt.QFileDialog.AnyFile)
        dlg.setNameFilter("loadable files (*.nwt)")
        
        if dlg.exec_():
            filenames = dlg.selectedFiles()
            self.qt_file_name.setText(filenames[0])
                    
    def open_nwt_file(self):
        '''Read and verify the vtk input file '''
        input_file_name = self.qt_file_name.text()
        
        if ".nwt" in input_file_name: #The input file is MetaImageData
            with open("/Users/jordanyu/Desktop/Viz-Project/network_3 .nwt","rb") as f:
                #Header
                identifier = f.read(14)
                description = f.read(58)
                byte_vert = f.read(4)
                num_vert = int.from_bytes(byte_vert, "little")
                byte_edges = f.read(4)
                global num_edges
                num_edges = int.from_bytes(byte_edges, "little")
                # Reading Vertex
                # vert_list is a list of vertices 
                global vert_list
                vert_list = []
                global vert_list2
                vert_list2 = []
                vert_count = 0
                while vert_count < num_vert:
                    #Vertex
                    x = f.read(4)
                    y = f.read(4)
                    z = f.read(4)
                    x_float = struct.unpack('<f', x)
                    y_float = struct.unpack('<f', y)
                    z_float = struct.unpack('<f', z)
                    #print(x_float, y_float, z_float)
                    vert_list.append([x_float, y_float, z_float])
                    vert_list2.append([x_float, y_float, z_float])
                    #Edge indicies (not needed)
                    #e = f.read(32)
                    x1 = f.read(4)
                    y1 = f.read(4)
                    x1 = int.from_bytes(x1, "little")
                    y1 = int.from_bytes(y1, "little")
                    e = f.read((x1+y1)*4)
                    vert_count += 1

                # Reading Edges
                # f_list is the final edge list that contains at each index: 1) starting vertex index 
                # 2) Ending vertex index 3) x,y,z, and radius of all points included in edge. 
                edge_count = 0
                global edge_list
                edge_list = [] #v0, v1, num_pts, list of pts w/ radius
                edge_list2 = [] 
                while edge_count < num_edges:
                    i_vert = f.read(4)
                    i_vert = int.from_bytes(i_vert, "little")
                    f_vert = f.read(4)
                    f_vert = int.from_bytes(f_vert, "little")
                    num_pts = f.read(4)
                    num_pts = int.from_bytes(num_pts, "little")
                    i_list = []
                    j_list = []
                    i_list.append(i_vert)
                    i_list.append(f_vert)
                    i_list.append(num_pts)
                    j_list.append(i_vert)
                    j_list.append(f_vert)
                    j_list.append(num_pts)
                    for i in range(num_pts):
                        x_pt = f.read(4)
                        x_pt = struct.unpack('<f', x_pt)
                        y_pt = f.read(4)
                        y_pt = struct.unpack('<f', y_pt)
                        z_pt = f.read(4)
                        z_pt = struct.unpack('<f', z_pt)
                        r_pt = f.read(4)
                        r_pt = struct.unpack('<f', r_pt)
                        i_list.append([x_pt, y_pt, z_pt, r_pt])
                        if i == 0:
                            j_list.append(float('.'.join(str(value) for value in r_pt)))
                            vert_list2[i_vert].append(float('.'.join(str(value) for value in r_pt)))
                        if i == num_pts-1:
                            j_list.append(float('.'.join(str(value) for value in r_pt)))
                            vert_list2[f_vert].append(float('.'.join(str(value) for value in r_pt)))
                            
                        #i_list.append(vert_list[f_vert])
                    edge_list.insert(edge_count, i_list)
                    edge_count += 1
                
                
                #### RENDER 3D GRAPH 
                points = vtk.vtkPoints()
                lines = vtk.vtkCellArray()
                widths = vtk.vtkDoubleArray()
                global Colors
                Colors = vtk.vtkUnsignedCharArray()
                Colors.SetNumberOfComponents(3)
                Colors.SetName("Colors")
                widths.SetName("width")
                ### Average radius 
                numpts = 0
                total = 0
                for a in range(len(edge_list)-1):
                    numpts += edge_list[a][2]
                    for b in range(3,len(edge_list[a])-1):
                        r_ptr = edge_list[a][b][3]
                        if(a == 855 and b ==47):
                            total = total
                            numpts = numpts-1
                        else:
                            total += float('.'.join(str(value) for value in r_ptr))
                average_radius = total/numpts
                for i in range(len(edge_list)-1):
                    for j in range(3,len(edge_list[i])-1):
                        
                        floatVal = float('.'.join(str(value) for value in edge_list[i][j][0]))
                        floatVal2 = float('.'.join(str(value) for value in edge_list[i][j][1]))
                        floatVal3 = float('.'.join(str(value) for value in edge_list[i][j][2]))
                        pt1 = points.InsertNextPoint(floatVal, floatVal2 , floatVal3)
                        floatVal = float('.'.join(str(value) for value in edge_list[i][j+1][0]))
                        floatVal2 = float('.'.join(str(value) for value in edge_list[i][j+1][1]))
                        floatVal3 = float('.'.join(str(value) for value in edge_list[i][j+1][2]))
                        pt2 = points.InsertNextPoint(floatVal, floatVal2 , floatVal3)
                        '''
                        # if statement checking if its the first vertix
                            if j==0:
                                w = float('.'.join(str(value) for value in self.edge_list[i][j+1][3]))
                                widths.InsertNextValue(w)
                                widths.InsertNextValue(w)
                            # if statement checking if its the last vertex
                            elif j==len(self.edge_list[i])-2:
                                w = float('.'.join(str(value) for value in self.edge_list[i][j][3]))
                                widths.InsertNextValue(w)
                                widths.InsertNextValue(w)
                            else:
                                w1 = float('.'.join(str(value) for value in self.edge_list[i][j][3]))
                                w2 = float('.'.join(str(value) for value in self.edge_list[i][j+1][3]))
                                widths.InsertNextValue(w1)
                                widths.InsertNextValue(w2)
                        '''
                        w1 = float('.'.join(str(value) for value in edge_list[i][j][3]))
                        w2 = float('.'.join(str(value) for value in edge_list[i][j+1][3]))
                        if w1 > 90:
                            w1 = float(average_radius)
                        if w2 > 90:
                            w2 = float(average_radius)
                        
                        widths.InsertNextValue(w1/10)
                        widths.InsertNextValue(w2/10)
                        Colors.InsertNextTuple3(255,1,1)
                        Colors.InsertNextTuple3(255,1,1)
                        lines.InsertNextCell(2, [pt1, pt2])
                #set these global
                global pointPolyData
                pointPolyData  = vtk.vtkPolyData()
                pointPolyData.SetPoints(points)
                #pointPolyData.GetPointData().AddArray(Colors)
                pointPolyData.GetCellData().SetScalars(Colors);
                pointPolyData.GetPointData().AddArray(widths)
                pointPolyData.GetPointData().SetActiveScalars("width")
                pointPolyData.SetLines(lines)
                streamTube = vtk.vtkTubeFilter()
                streamTube.SetInputData(pointPolyData)
                streamTube.SetNumberOfSides(10)
                streamTube.CappingOff()
                #streamTube.SetInputArrayToProcess(1, 0, 0, vtk.vtkDataObject.FIELD_ASSOCIATION_POINTS, "Colors")
                streamTube.SetVaryRadiusToVaryRadiusByAbsoluteScalar()
                streamTube.SetRadiusFactor(0.5)
                streamTube.Update()
                micro_mapper = vtk.vtkPolyDataMapper()
                micro_mapper.SetInputConnection(streamTube.GetOutputPort())
                micro_mapper.ScalarVisibilityOn()
                micro_mapper.SetScalarModeToUseCellFieldData()
                micro_mapper.SelectColorArray(Colors.GetName())
                micro_mapper.Update()
                self.micro_actor = vtk.vtkActor()
                self.micro_actor.SetMapper(micro_mapper)
                #self.micro_actor.GetProperty().SetColor(1,0,0)
                self.ren.AddActor(self.micro_actor)
                self.vtkWidget.GetRenderWindow().Render()
                
                
                #### RENDER 2D graph
                %matplotlib qt
                import networkx as nx
                import matplotlib.pyplot as plt
                from grave import plot_network
                from grave.style import use_attributes
                G = nx.DiGraph()
                edges = []
                for i in range(len(edge_list)-1):
                    edges.append((edge_list[i][0],edge_list[i][1],edge_list[i][2]))

                for (u, v, w) in edges:
                    G.add_edge(u, v, penwidth=w)
                self.inter_list = list()  
                def hilighter(event):
                    # if we did not hit a node, bail
                    if not hasattr(event, 'nodes') or not event.nodes:
                        return
                    # pull out the graph,
                    graph = event.artist.graph
                    
                    # clear any non-default color on nodes
                    for node, attributes in graph.nodes.data():
                        attributes.pop('color', None)

                    for u, v, attributes in graph.edges.data():
                        attributes.pop('width', None)
                    for node in event.nodes:
                        
                        graph.nodes[node]['color'] = 'C1'
                        self.inter_list.insert(0,node)
                        #print(node)
                        for edge_attribute in graph[node].values():
                            edge_attribute['width'] = 3
                            
                    # update the screen
                    #print(self.inter_list)
                    if len(self.inter_list) == 2:
                        ### raduis of node 1
                        radius1 = str(vert_list2[self.inter_list[0]][3])
                        self.qt_details.setText(radius1)
                        
                        ### connection to node 1
                        count = -1
                        for i in range(len(edge_list)):
                            if self.inter_list[0] == edge_list[i][0] or self.inter_list[0] == edge_list[i][1] :
                                count += 1
                        if count == 0:
                            count = 1
                        count = str(count)
                        self.qt_details2.setText(count)
                        
                        ### connection to node 2
                        count2 = -1
                        for i in range(len(edge_list)):
                            if self.inter_list[1] == edge_list[i][0] or self.inter_list[1] == edge_list[i][1]:
                                count2 += 1
                        if count2 == 0:
                            count2 = 1
                        count2 = str(count2)
                        self.qt_details4.setText(count2)
                        
                        ### radius of node2
                        radius2 = str(vert_list2[self.inter_list[1]][3])
                        self.qt_details3.setText(radius2)
                        
                        ### points in between 
                        for i in range(len(edge_list)):
                            if (self.inter_list[0] == edge_list[i][0] or self.inter_list[0] == edge_list[i][1]) and (self.inter_list[1] == edge_list[i][0] or self.inter_list[0] == edge_list[i][1]):
                                self.qt_details5.setText(str(edge_list[i][2]))
                        
                        
                        Colors = vtk.vtkUnsignedCharArray()
                        Colors.SetNumberOfComponents(3)
                        Colors.SetName("Colors")
                        for i in range(len(edge_list)-1):
                            for j in range(3, len(edge_list[i])-1):
                                if edge_list[i][0] in self.inter_list and edge_list[i][1] in self.inter_list:
                                    Colors.InsertNextTuple3(1,1,255)
                                else:
                                    Colors.InsertNextTuple3(255,1,1)
                        
                        pointPolyData.GetCellData().SetScalars(Colors);
                        self.ren.GetActiveCamera()
                        x_cam1 = float('.'.join(str(value) for value in vert_list[self.inter_list[0]][0]))
                        y_cam1 = float('.'.join(str(value) for value in vert_list[self.inter_list[0]][1]))
                        z_cam1 = float('.'.join(str(value) for value in vert_list[self.inter_list[0]][2]))
                        x_cam2 = float('.'.join(str(value) for value in vert_list[self.inter_list[1]][0]))
                        y_cam2 = float('.'.join(str(value) for value in vert_list[self.inter_list[1]][1]))
                        z_cam2 = float('.'.join(str(value) for value in vert_list[self.inter_list[1]][2]))
                        xf_cam = (x_cam1+x_cam2)/2
                        yf_cam = (y_cam1+y_cam2)/2
                        zf_cam = (z_cam1+z_cam2)/2
                        self.ren.GetActiveCamera().SetPosition(xf_cam-30, yf_cam-30, zf_cam-30)
                        self.ren.GetActiveCamera().SetFocalPoint(xf_cam, yf_cam, zf_cam)
                        self.vtkWidget.GetRenderWindow().Render()
                        self.inter_list = list()
                    elif len(self.inter_list)>2:
                        self.inter_list = list()
                            #if edge_list[i][1] in node_list:
                                #Colors.InsertNextTuple3(1,1,255)
                            #else:
                                #Colors.InsertNextTuple3(255,1,1)
                                
                    
                    
                   
                    
                    event.artist.stale = True
                    event.artist.figure.canvas.draw_idle()
                    
                G2 = nx.spring_layout(G)
                graph = G
                fig, ax = plt.subplots()
                #art = plot_network(graph,layout='kamada_kawai', ax=ax, node_style=use_attributes(),
                                   #edge_style=use_attributes())
                art = plot_network(graph,layout='spring', ax=ax, node_style=use_attributes(),
                                   edge_style=use_attributes())
                        
                #art = plot_network(graph, ax=ax, node_style=use_attributes(),
                                   #edge_style=use_attributes())

                art.set_picker(10)
                #ax.set_title('Note: Selection of overlapping node does not result in 3D highlight')
                #ax.set_suptitle('Note: if there is a selection of an overlapping node this will count as 3 nodes and the 3D highlight will not show')
                fig.canvas.mpl_connect('pick_event', hilighter)
                fig.suptitle('Select two nodes for a corresponding 3D view')
                plt.show()
                
                
           
     

        
if __name__ == "__main__":
    app = Qt.QApplication(sys.argv)
    
    window = MainWindow()
    sys.exit(app.exec_())

SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
