In [23]:
%matplotlib notebook
import vtk
from Parser import Parser
from GeneralPlots import GeneralPlot
from ipycanvas import Canvas
from typing import Dict, Callable, Tuple, List
from ipywidgets import Image, Output
import ipywidgets as widgets
import numpy as np
import math
import asyncio
from time import time
from vtk.util import numpy_support
import matplotlib.pyplot as plt
import constants

output = Output()

figureNumber = 0

MOVE_ACTION = 'move'
ZOOM_ACTION = 'zoom'
SELECT_ACTION = 'select'
ROTATE_ACTION = 'rotate'

VISIBLE_CELL = "cell"

##################################################################### UTILS ####################################

def defaultColorMap(cells, variables, selectedCell):
    """
    default 2d cell coloring function. Colors by cycle_model:
    
        live: green
        apoptosis_death: blue
        necrosis_death: red
        autophagy_death: orange
        
        selected_cell: white
    
    """
    selectedCellIndex = cells[variables['ID']] == selectedCell
    notSelectedCellIndex = np.invert(selectedCellIndex)
    cycles = cells[variables['cycle_model']]
    return [
        ('#004400', 'green', np.logical_and(cycles < constants.apoptosis_death_model, notSelectedCellIndex)),
        ('#000099', '#0000CC', np.logical_and(cycles == constants.apoptosis_death_model, notSelectedCellIndex)),
        ('#440000', 'red', np.logical_and(cycles == constants.necrosis_death_model, notSelectedCellIndex)),
        ('#CC8500', '#FFA500', np.logical_and(cycles == constants.autophagy_death_model, notSelectedCellIndex)),
        ('#AAAAAA', 'white', selectedCellIndex)
    ]

def noFilter(cells, variables):
    """
    returns all cells, causes all cells to be sent to the renderer
    
    """
    return np.full(cells.shape[1], True)

class Timer:
    """
    Utility function used to call run a function after a specified amount of time.
    From Jypyter documentation https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Events.html#Debouncing
    
    """
    
    def __init__(self, timeout, callback):
        self._timeout = timeout
        self._callback = callback

    async def _job(self):
        await asyncio.sleep(self._timeout)
        self._callback()

    def start(self):
        self._task = asyncio.ensure_future(self._job())

    def cancel(self):
        self._task.cancel()

def debounce(wait):
    """ 
    Utility function used to call run a function after a specified amount of time.
    Based on Jypyter documentation https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Events.html#Debouncing
    
    Decorator which debounces a function call for a specified amount of time or until it has been longer than 
    the specified amount of time since the function was last called.
    """
    def decorator(fn):
        timer = None
        lastCall = time()
        def debounced(*args, **kwargs):
            nonlocal timer
            nonlocal lastCall
            def call_it():
                fn(*args, **kwargs)
            if timer is not None:
                timer.cancel()
            if time() - lastCall >= wait:
                lastCall = time()
                if timer is not None:
                    timer.cancel()
                call_it()
            else:
                timer = Timer(wait, call_it)
                timer.start()
        return debounced
    return decorator

@output.capture()
def defaultColorMap3d(cells, variables, selectedCellId):
    """
    default 3d cell coloring function. Colors by cycle_model:
    
        live: green
        apoptosis_death: blue
        necrosis_death: red
        autophagy_death: orange
        
        selected_cell: white
    """
    tags = vtk.vtkFloatArray()
    cycleModels = cells[variables['cycle_model']]
    colors = np.zeros(cells.shape[1]) + (cycleModels < constants.apoptosis_death_model) + 2 * (cycleModels == constants.apoptosis_death_model) + 3 * (cycleModels == constants.necrosis_death_model) + 4 * (cycleModels == constants.autophagy_death_model)
    if selectedCellId is not None:
        colors[cells[variables['ID']] == selectedCellId] = 5
    for cell in colors:
        tags.InsertNextValue(float(cell))
    
    
    colorTransferFunction = vtk.vtkColorTransferFunction()
    colorTransferFunction.AddRGBPoint(0.0, 0.0, 0.0, 0.0)
    colorTransferFunction.AddRGBPoint(1.0, 0.0, 1.0, 0.0)
    colorTransferFunction.AddRGBPoint(2.0, 0.2, 0.2, 6.0)
    colorTransferFunction.AddRGBPoint(3.0, 1.0, 0.0, 0.0)
    colorTransferFunction.AddRGBPoint(4.0, 1.0, 6.0, 1.0)
    colorTransferFunction.AddRGBPoint(5.0, 1.0, 1.0, 1.0)
    
    return tags, colorTransferFunction

def defaultEnvironment(environment, attribute):
    """
    Default 3d environment renderer. White is the highest value in the frame and black is the lowest value.
    """
    xbounds = environment.mesh.boundsX
    ybounds = environment.mesh.boundsY
    zbounds = environment.mesh.boundsZ

    attributeIndex = environment.current.variables[attribute]

    positions = environment.mesh.voxels
    data = environment.current.data[attributeIndex]

    xCount = np.unique(environment.mesh.voxels[environment.mesh.variables['x']]).shape[0]
    yCount = np.unique(environment.mesh.voxels[environment.mesh.variables['y']]).shape[0]
    zCount = np.unique(environment.mesh.voxels[environment.mesh.variables['z']]).shape[0]

    minimum = np.min(data)
    maximum = np.max(data)
    
    if not maximum == 0:
        data = data / maximum

    data = np.reshape(data, (xCount, yCount, zCount))

    imdata = vtk.vtkImageData()
    depthArray = numpy_support.numpy_to_vtk(data.ravel(), deep=True, array_type=vtk.VTK_DOUBLE)

    imdata.SetDimensions(data.shape)
    imdata.SetSpacing([(xbounds[1] - xbounds[0]) / xCount, (ybounds[1] - ybounds[0]) / yCount, (zbounds[1] - zbounds[0]) / zCount])
    imdata.SetOrigin([xbounds[0], ybounds[0], zbounds[0]])
    imdata.GetPointData().SetScalars(depthArray)

    colorFunc = vtk.vtkColorTransferFunction()
    colorFunc.AddRGBPoint(0.0, 0.0, 0.0, 0.0)
    colorFunc.AddRGBPoint(1.0, 1.0, 1.0, 1.0)

    opacity = vtk.vtkPiecewiseFunction()
    opacity.AddPoint(0.0, 0.0)
    opacity.AddPoint(1, 0.8)

    volumeProperty = vtk.vtkVolumeProperty()
    volumeProperty.SetColor(colorFunc)
    volumeProperty.SetScalarOpacity(opacity)
    volumeProperty.SetInterpolationTypeToLinear()
    volumeProperty.SetIndependentComponents(2)

    volumeMapper = vtk.vtkFixedPointVolumeRayCastMapper()
    volumeMapper.SetInputData(imdata)
    volumeMapper.SetBlendModeToMaximumIntensity()


    volume = vtk.vtkVolume()
    volume.SetMapper(volumeMapper)
    volume.SetProperty(volumeProperty)
    
    return volume

def defaultRender2DEnvironment(attribute, environment, rRange=(0,255), gRange=(0,255), bRange=(0,255), aRange=(0, 255)):
    """
    Default 3d environment renderer. White is the highest value in the frame and black is the lowest value.
    """
    xCount = np.unique(environment.mesh.voxels[environment.mesh.variables['x']]).shape[0]
    yCount = np.unique(environment.mesh.voxels[environment.mesh.variables['y']]).shape[0]
    
    data = environment.current.data[environment.current.variables[attribute]]

    minimum = np.min(data)
    maximum = np.max(data)

    shapedData = np.reshape(data / maximum, (xCount, yCount))

    image_data = np.stack((shapedData * (rRange[1] - rRange[0]) + rRange[0], shapedData * (gRange[1] - gRange[0]) + gRange[0], shapedData * (bRange[1] - bRange[0]) + bRange[0], shapedData * (aRange[1] - aRange[0]) + aRange[0]), axis=2)
    image_data = image_data.astype(dtype=np.int32)
    
    return image_data

##################################################################### BASE INTERACTOR ####################################

class Interactor:
    """
    Base interactor class. Is an abstract class. Handles shared functionality such as canvases and graphs.
    
    """
    def __init__(self, parser: Parser, width: int, height: int, colorMap, filterFunction, availableActions):
        """
        Setup the interactor variables, should be called first in a sub class.
        
            parser - a simulation output parser
            width - the width of the canvas to draw to
            height - the height of the canvas to draw to
            colorMap - a function resulting in the cell colors (distinct results based on subclass, see subclass documentation for necessary return values)
            filterFunction - a function resulting in an array of the indicies of the cells to draw
            availableActions - a list of all the actions that can be performed on the canvas
        
        """
        self._currentFrame = parser.getFrameRange()[0]
        self._canvas = Canvas(width=width, height=height)
        self._parser = parser
        self._colorMap = colorMap
        self._height = height
        self._width = width
        self._filterFunction = filterFunction
        
        self._cellFigure = None
        self._generalGraph = GeneralPlot(parser)
        
        frame = parser.getFrame(self._currentFrame)
        mesh = frame.environment.mesh
        self._availableEnvironments = [*frame.environment.current.attributes]
        
        self._canvas.on_mouse_down(self._onMouseDown)
        self._canvas.on_mouse_move(self._onMouseMove)
        self._canvas.on_mouse_up(self._onMouseUp)
        self._canvas.on_mouse_out(self._onMouseOut)
        
        self.action = availableActions[0]
        
        self._buttons = widgets.RadioButtons(
            options=availableActions,
            value=self.action,
            description='Mouse Action:',
            disabled=False,
        )
        self._buttons.observe(self.onToolChange, names='value')
        
        self._visible = (VISIBLE_CELL,)
        self._environmentButtons = widgets.SelectMultiple(
            options=[VISIBLE_CELL, *self._availableEnvironments],
            description='Visible:',
            disabled=False,
            value=[VISIBLE_CELL],
        )
        self._environmentButtons.observe(self.onVisibilityChange, names='value')
        
        self._availableAttributes = [*frame.cells.variables.keys()]
        self._selectedAttribute = self._availableAttributes[0]
        self._previousSelectedAttribute = None
        self._previousSelectedAttributeCell = None
        self._attributes = widgets.RadioButtons(
            options=self._availableAttributes,
            value=self._selectedAttribute,
            description='Generate Attribute Graph:\n',
            disabled=False,
        )
        self._attributes.observe(self.onAttriChange, names='value')

        self._frameSelector = widgets.IntSlider(
            value=self._currentFrame,
            min=parser.getFrameRange()[0],
            max=parser.getFrameRange()[1] - 1,
            step=1,
            description='Frame:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='d'
        )
        self._frameSelector.observe(self.onFrameChange, names='value')
        
        self._clicking = False
        self._dragStartX = 0
        self._dragStartY = 0
        self._actionOriginX = 0
        self._actionOriginY = 0
        
        self._selectedCell = None
    
    def setColorMap(colorMap):
        """
        Set the cell coloring function.
        
            colorMap: the cell coloring function
        
        """
        self._colorMap = colorMap
        self.update()
    
    def setFilterFunction(filterFunction):
        """
        Set the cell filtering function
        
            filterFunction: the filtering function
        
        """
        self._filterFunction = filterFunction
        self.update()
    
    def onToolChange(self, action):
        """
        Handle a user action which changes the canvas tool
        
            action: an ipyWidget radio button selection action
        
        """
        self.action = action.new
    
    @output.capture()
    def onVisibilityChange(self, action):
        """
        Handle a user action which changes the elements which are visibile on the canvas
        
            action: an ipyWidget multi selection action
        
        """
        self._visible = action.new
        self.update()
        
    @output.capture()    
    def onAttriChange(self, action):
        """
        Handle a user action which changes the attribute to display for the selected cell
        
            action: an ipyWidget radio button selection action
        
        """
        self._selectedAttribute = action.new
        self._updateCellPlot()
        
    @output.capture()
    def _updateCellPlot(self):
        """
        Update the plot which shows how the currently selected attribute changes over time for the selected cell
        
        """
        if (self._previousSelectedAttribute == self._selectedAttribute and self._previousSelectedAttributeCell == self._selectedCell) or self._selectedCell is None:
            return
        
        self._previousSelectedAttribute = self._selectedAttribute
        self._previousSelectedAttributeCell = self._selectedCell
        
        data = []
        timestamps = []
        time = 0
        variables = self.getCellVariables()
        for frameNumber in range(*self._parser.getFrameRange()):
            cell = self.getSelectedCell(frameNumber)
            data.append(cell[variables[self._selectedAttribute]])
            timestamps.append(time)
            time += 1
            
        if self._cellFigure is not None:
            plt.figure(self._cellFigure.get_label())
            self._cellAx.clear()
            ln, = self._cellAx.plot(timestamps, data)
            self._cellAx.set_xlabel('time')
            self._cellAx.set_ylabel(self._selectedAttribute)
            self._cellAx.set_title(self._selectedAttribute + " for cell " + str(int(self._selectedCell)))
            ln.set_color('orange')
        
        
    def onFrameChange(self, action):
        """
        Handle a user action which changes the frame to view
        
            action: an ipyWidget slider action
        
        """
        self._currentFrame = action.new
        self.update()
    

    def getCellVariables(self, frame: int = None):
        """
        Get the cell varaibles for the frame
        
            frame: int - the frame number to get the variables from
        
        """
        if frame is None:
            frame = self._currentFrame
        return self._parser.getFrame(frame).cells.variables
    
    def getCell(self, cellId: int, frame: int = None):
        """
        Get the cell from the frame (used current frame if not specified)
        
            cellId: int - the id of the cell to get
            frame: int = currentFrame - the frame number to get the cell from
        
        """
        if frame is None:
            frame = self._currentFrame
        cells = self.getCells(frame)
        variables = self.getCellVariables()
        cellColumn = cells[variables['ID']] == cellId
        return cells[:,cellColumn].reshape(-1)
    
    def getCells(self, frame: int = None):
        """
        Get the cells from the frame (uses the current frame if not specified)
        
            frame: int = currentFrame - the frame number to get the cells from
        
        """
        if frame is None:
            frame = self._currentFrame
        return self._parser.getFrame(frame).cells.data
    
    def getCellsToRender(self, frame: int = None):
        """
        Get the cells to render from the given frame (uses the current frame if not specified). Applies
        the current cell filtering function on the result.
        
            frame: int = currentFrame - the frame number to get the cells from
        
        """
        cells = self.getCells(frame)
        
        if self._filterFunction:
            cellVariables = self.getCellVariables(frame)
            cells = cells[:,self._filterFunction(cells, cellVariables)]
        
        return cells
    
    def getSelectedCell(self, frame: int = None):
        """
        Get the cell data for the cell which is currently selected on the specified frame (uses the current frame if not specified).
        Will be None if no cell is selected.
        
            frame: int = currentFrame - the frame number to get the cells from
        
        """
        if self._selectedCell is None:
            return None
        
        return self.getCell(self._selectedCell, frame)
    
    @output.capture()
    def _onMouseDown(self, x: int, y: int):
        """
        Handle user mouse down action from ipyCanvas canvas
        
            x: int - the x position the mouse event occured based on the top left corner of the canvas
            y: int - the y position the mouse event occured based on the top left corner of the canvas
        
        """
        self._dragStartX = x
        self._dragStartY = y
        self._actionOriginX = x
        self._actionOriginY = y
        
        self.onMouseDown(x, y)
        
        if self.action == SELECT_ACTION:
            self.selectCell(x, y)
        else:
            self._clicking = True
    
    def selectCell(self, x: int, y: int):
        """
        Select the cell at the given x and y coordinates.
        Abstract function, should be implemented by subclass.
        
            x: int - the x position the mouse event occured based on the top left corner of the canvas
            y: int - the y position the mouse event occured based on the top left corner of the canvas
        
        """
        pass
    
    def _onMouseUp(self, x: int, y: int):
        """
        Handle user mouse up action from ipyCanvas canvas
        
            x: int - the x position the mouse event occured based on the top left corner of the canvas
            y: int - the y position the mouse event occured based on the top left corner of the canvas
        
        """
        self._clicking = False
        self.onMouseUp(x, y)
    
    def _onMouseOut(self, x: int, y: int):
        """
        Handle user mouse out (mouse moves off the canvas) action from ipyCanvas canvas
        
            x: int - the x position the mouse event occured based on the top left corner of the canvas
            y: int - the y position the mouse event occured based on the top left corner of the canvas
        
        """
        self._clicking = False
        self.onMouseOut(x, y)
    
    @output.capture()
    def _onMouseMove(self, x: int, y: int):
        """
        Handle user mouse move action from ipyCanvas canvas
        
            x: int - the x position the mouse event occured based on the top left corner of the canvas. 
                        The current X position of the mouse.
            y: int - the y position the mouse event occured based on the top left corner of the canvas.
                        The current y position of the mouse.
        
        """
        self.onMouseMove(x, y)
    
    def onMouseUp(self, x: int, y: int):
        """
        Handle a mouse up event at the given x and y coordinates.
        Abstract function, can be implemented by subclass.
        
            x: int - the x position the mouse event occured based on the top left corner of the canvas
            y: int - the y position the mouse event occured based on the top left corner of the canvas
        
        """
        pass

    def onMouseOut(self, x: int, y: int):
        """
        Handle a mouse out event at the given x and y coordinates.
        Abstract function, can be implemented by subclass.
        
            x: int - the x position the mouse event occured based on the top left corner of the canvas
            y: int - the y position the mouse event occured based on the top left corner of the canvas
        
        """
        pass
    
    def onMouseMove(self, x: int, y: int):
        """
        Handle a mouse move event at the given x and y coordinates.
        Abstract function, can be implemented by subclass.
        
            x: int - the x position the mouse event occured based on the top left corner of the canvas
            y: int - the y position the mouse event occured based on the top left corner of the canvas
        
        """
        pass
    
    def onMouseDown(self, x: int, y: int):
        """
        Handle a mouse down event at the given x and y coordinates.
        Abstract function, can be implemented by subclass.
        
            x: int - the x position the mouse event occured based on the top left corner of the canvas
            y: int - the y position the mouse event occured based on the top left corner of the canvas
        
        """
        pass
        
    def update(self):
        """
        Update UI, should be called after subclass updates. Puts shared data onto the canas.
        
        """
        canvas = self._canvas
        
        selectedCell = self.getSelectedCell()
        if selectedCell is not None:
            cellVaraibles = self.getCellVariables()
            
            canvas.fill_style = '#A0A0A0'
            canvas.font = '10px serif'
            
            
            identity = selectedCell[cellVaraibles["ID"]]
            volume = selectedCell[cellVaraibles["total_volume"]]
            phase = selectedCell[cellVaraibles["current_phase"]]
            
            canvas.fill_text(f"(ID: {identity}, Volume: {volume}, Phase: {phase})", 10, self._height - 10)
            
        self._updateCellPlot()
    
    def radiusOfCells(self, cells, variables):
        """
        Determines the radius of the cells based on the cell (3D) volume. Results in a numpy array.
        
            cells - the cells to get the radius of
            variables - the variables of the cells
        
        """
        return (cells[variables['total_volume']] * (3 / ( 4 * math.pi))) ** (1 / 3)
    
    def show(self):
        """
        Displays the widgets to the screen
        
        """
        display(self._canvas, self._buttons, self._frameSelector, self._attributes, self._environmentButtons, output)
        if self._cellFigure is None:
            global figureNumber
            fig, ax = plt.subplots()
            self._cellAx = ax
            self._cellFigure = fig
            figureNumber += 1
        plt.figure()
        self._generalGraph.plotPop()

##################################################################### 2D INTERACTOR ####################################

class Interactor2D(Interactor):
    """
    An interactive visualization of a 2D PhysiCell simulation
    
    """
    def __init__(self, parser: Parser, width: int = 500, height: int = 400, colorMap = defaultColorMap, filterFunction = None, backgroundColorRGB = "black", environmentRenderer = defaultRender2DEnvironment):
        """
        Visualize a 2D PhysiCell simulation
        
            parser - a simulation output parser
            width: int = 500 - the width of the canvas to draw to
            height: int = 400 - the height of the canvas to draw to
            colorMap: (cells: numpyArray, variables: Map<string, indicies>, selectedCellID: ?string) => Array<Tuple<fillColor: string, strokeColor: string, cellIndicies: numpyArray>> = defaultColorMap - a function resulting in the cell colors
            filterFunction: (cells: numpyArray, variables: Map<string, indicies>) => numpyArray (boolean/indicies) = None - a function resulting in an array of the indicies of the cells to draw
            backgroundColorRGB: string = "black" - the color to paint the background
            environmentRenderer: (attribute: string, environment: Parser.Environment) => numpyArray<x, y, 1/2/3/4> = defaultRender2DEnvironment - given an environment and the attribute to render, create an image (numpy array) to render (will span the bounds of the attribute)
        
        """
        super().__init__(parser, width, height, colorMap, filterFunction, (MOVE_ACTION, ZOOM_ACTION,SELECT_ACTION))

        frame = parser.getFrame(self._currentFrame)
        mesh = frame.environment.mesh
        self._zoom = max((mesh.boundsX[1] - mesh.boundsX[0]) / width, (mesh.boundsY[1] - mesh.boundsY[0]) / height)
        
        self._xOffset = mesh.boundsX[0]
        self._yOffset = mesh.boundsY[0]
        
        self._environmentRenderer = environmentRenderer
        
        self._buffer = Canvas(width=width, height=height)
        
        self._backgroundColor = backgroundColorRGB
        
        self.update()
    
    @output.capture()
    def drawEnvironment(self, attribute: str):
        """
        draw the environment's attribute to the canvas
        
            attribute: string - the attribute of the environment to draw
        
        """
        if attribute is None:
            return
        canvas = self._canvas
        
        canvas.save()
        
        environment = self._parser.getFrame(self._currentFrame).environment

        xbounds = environment.mesh.boundsX
        ybounds = environment.mesh.boundsY

        image_data = self._environmentRenderer(attribute, environment)
        
        
        xCount, yCount, *_ = image_data.shape

        canvas.save()
        
        scale = (xbounds[1] - xbounds[0]) / ( xCount * self._zoom )
        
        canvas.scale( scale )
        
        x = (xbounds[0] - self._xOffset) / (scale * self._zoom)
        y = (ybounds[0] - self._yOffset) / (scale * self._zoom)

        canvas.put_image_data(image_data, x,  y)

        canvas.restore()
    
    def drawCells(self):
        """
        draw the cells of the current frame to the canvas.
        
        """
        canvas = self._canvas
    
        cells = self.getCellsToRender()
        cellVariables = self.getCellVariables()

        x = (cells[cellVariables['position.x']] - self._xOffset) / self._zoom
        y = (cells[cellVariables['position.y']] - self._yOffset) / self._zoom
        r = self.radiusOfCells(cells, cellVariables) / self._zoom
        
        combined = np.array([x, y, r])
        
        for fill, stroke, indices in self._colorMap(cells, cellVariables, self._selectedCell):
            split = combined[:,indices]
            if split.shape[0] == 0:
                continue
            x, y, r = split
            canvas.fill_style = fill
            canvas.fill_circles(x, y, r)
            canvas.stroke_style = stroke
            canvas.stroke_circles(x, y, r)
        
    
    def update(self):
        """
        Update and draw to the canvas if necessary
        
        """
        actualCanvas = self._canvas
        
        self._canvas = self._buffer
        canvas = self._buffer
        
        canvas.fill_style = self._backgroundColor
        canvas.fill_rect(0, 0, self._width, self._height)
        
        shouldDrawCells = False
        for element in self._visible:
            if element == VISIBLE_CELL:
                shouldDrawCells = True
                continue
            self.drawEnvironment(element)
            
        
        if shouldDrawCells:
            self.drawCells()
        
        super().update()
        
        actualCanvas.draw_image(self._canvas)
        
        self._canvas = actualCanvas
    
        
    def onMouseDown(self, x: int, y: int):
        """
        Handle a mouse down event. Translates the canvas coordinates into the environment coordinate system.
        
        """
        self._envActionOriginX = x * self._zoom + self._xOffset
        self._envActionOriginY = y * self._zoom + self._yOffset
        self._actionOriginZoom = self._zoom
    
    @debounce(0.05)
    def onMouseMove(self, x: int, y: int):
        """
        Handle a mouse move event.
        
        """
        if self._clicking:
            if self.action == MOVE_ACTION:
                self._xOffset -= (x - self._dragStartX) * self._zoom
                self._yOffset -= (y - self._dragStartY) * self._zoom
            elif self.action == ZOOM_ACTION:
                self._zoom = max(self._actionOriginZoom * 2 ** ((self._actionOriginY - y) / 25), 0.0001)
                self._xOffset = self._envActionOriginX - self._actionOriginX * self._zoom
                self._yOffset = self._envActionOriginY - self._actionOriginY * self._zoom
                
            self._dragStartX = x
            self._dragStartY = y
            self.update()
    
    def selectCell(self, x: float, y: float):
        """
        Selects the cell located closest to the given x and y, if x and y fall within the cell.
        
        """
        x = self._envActionOriginX
        y = self._envActionOriginY
        
        frame = self._parser.getFrame(self._currentFrame)
    
        cells = frame.cells.data
        cellVariables = frame.cells.variables
        
        if self._filterFunction:
            cells = cells[:,self._filterFunction(cells, cellVariables)]
        
        distances = np.sqrt(np.square(x - cells[cellVariables['position.x']]) + np.square(y - cells[cellVariables['position.y']])) - self.radiusOfCells(cells, cellVariables)
        minIndex = np.argmin(distances)
        
        if distances[minIndex] <= 0:
            self._selectedCell = cells[cellVariables['ID'], minIndex]
        else:
            self._selectedCell = None
        
        self.update()


##################################################################### 3D INTERACTOR ####################################
    

class Interactor3D(Interactor):
    def __init__(self, parser: Parser, width: int = 500, height: int = 400, colorMap = defaultColorMap3d, filterFunction = None, environmentRenderer = defaultEnvironment):
        """
        Visualize a 3D PhysiCell simulation
        
            parser: Parser - a simulation output parser
            width: int = 500 - the width of the canvas to draw to
            height: int = 400 - the height of the canvas to draw to
            colorMap: (cells: numpyArray, variables: Map<string, indicies>, selectedCellID: ?string) => Tuple<vtkFloatArray, vtkColorTransferFunction> = defaultColorMap3d - a function resulting in the cell colors
            filterFunction: (cells: numpyArray, variables: Map<string, indicies>) => numpyArray (boolean/indicies) = None - a function resulting in an array of the indicies of the cells to draw
            environmentRenderer: (attribute: string, environment: Parser.Environment) => vtkVolume = defaultEnvironment - given an environment and the attribute to render, create an image (numpy array) to render (will span the bounds of the attribute)
        
        """
        super().__init__(parser, width, height, colorMap, filterFunction, (MOVE_ACTION, ZOOM_ACTION, ROTATE_ACTION, SELECT_ACTION))
        
        self._previousFrameNumber = None
        self._previouslySelectedCell = None
        self._environmentRenderer = environmentRenderer
        self._colorMap = colorMap
        
        self.renderer = vtk.vtkRenderer()
        
        self.create(self._currentFrame)
        
        renderWindow = vtk.vtkRenderWindow()
        renderWindow.SetOffScreenRendering(1)
        renderWindow.AddRenderer(self.renderer)
        renderWindow.SetSize(width, height)
        renderWindow.Render()
        
        self.renderWindow = renderWindow
        
        self.update()
    
    def drawCells(self, frame=None):
        """
        Adds the cells to the renderer to draw
        
            frame - the frame (Frame) to draw the cells for
        """
        cells = self.getCellsToRender(frame.frameNumber)
        variables = self.getCellVariables(frame.frameNumber)
        
        x = variables["position.x"]
        y = variables["position.y"]
        z = variables["position.z"]
        r = self.radiusOfCells(cells, variables)
        
        data = vtk.vtkPolyData()
        points = vtk.vtkPoints()
        radii = vtk.vtkFloatArray()
        radii.SetName("radius")
        
        colors, lookupTable = self._colorMap(cells, variables, self._selectedCell)
        colors.SetName("color")
        
        for cell in range(cells.shape[1]):
            points.InsertNextPoint(cells[x,cell], cells[y, cell], cells[z, cell])
            radii.InsertNextValue(float(r[cell]))
        
        data.SetPoints(points)
        data.GetPointData().AddArray(radii)
        data.GetPointData().AddArray(colors)
        data.GetPointData().SetActiveScalars("color")
        
        # Source - ball
        ball = vtk.vtkSphereSource()
        ball.SetRadius(1)
        ball.SetThetaResolution(8)
        ball.SetPhiResolution(8)
        
        # Glyph - ball
        ballGlyph = vtk.vtkGlyph3D()
        ballGlyph.SetInputData(data)
        ballGlyph.SetScaleFactor(1)
        ballGlyph.ClampingOff()
        ballGlyph.SetColorModeToColorByScalar()
        ballGlyph.SetSourceConnection(ball.GetOutputPort())
        ballGlyph.SetInputArrayToProcess(0,0,0,0,'radius')
        ballGlyph.SetInputArrayToProcess(3,0,0,0,'color')
        
        # Mapper - ball
        ballMapper = vtk.vtkPolyDataMapper()
        ballMapper.SetInputData(data)
        ballMapper.SetInputConnection(ballGlyph.GetOutputPort())
        ballMapper.ScalarVisibilityOn()
        ballMapper.SetScalarModeToUsePointData()
        ballMapper.SelectColorArray("color")
        ballMapper.SetLookupTable(lookupTable)
        
        # Actor - ball
        ballActor = vtk.vtkActor()
        ballActor.SetMapper(ballMapper)
        
        self.renderer.AddActor(ballActor)
        
    
    def create(self, frameNumber: int):
        """
        Create and add the actors to the renderer to render the given frame. Will do nothing if nothing has changed.
        
            frameNumber - the frame to render
        """
        if frameNumber == self._previousFrameNumber and self._selectedCell == self._previouslySelectedCell and self._previousFilterFunction == self._filterFunction and self._previouslyVisible == self._visible:
            return
        
        self._previouslySelectedCell = self._selectedCell
        self._previousFrameNumber = frameNumber
        self._previousFilterFunction = self._filterFunction
        self._previouslyVisible = self._visible
        
        frame = self._parser.getFrame(frameNumber)
        
        self.renderer.RemoveAllViewProps()
        
        shouldDrawCells = False
        for element in self._visible:
            if element == VISIBLE_CELL:
                shouldDrawCells = True
                continue
            self.renderer.AddVolume(self._environmentRenderer(frame.environment, element))
        
        if shouldDrawCells:
            self.drawCells(frame)
        
        self.renderer.ResetCameraClippingRange()
        
        
        
    def update(self):
        """
        Update the canvas with a new image.
        """
        canvas = self._canvas
        
        self.create(self._currentFrame)
        
        self.renderWindow.Render()
        
        windowToImageFilter = vtk.vtkWindowToImageFilter()
        windowToImageFilter.SetInput(self.renderWindow)
        windowToImageFilter.Update()

        writer = vtk.vtkPNGWriter()
        writer.SetWriteToMemory(1)
        writer.SetInputConnection(windowToImageFilter.GetOutputPort())
        writer.Write()
        
        data = memoryview(writer.GetResult()).tobytes()
        
        image = Image(value=data)
        
        canvas.draw_image(image)
        
        super().update()
        
        
    
    @output.capture()
    @debounce(0.1)
    def onMouseMove(self, x: int, y: int):
        """
        Handle a mouse move event.
        
        """
        if self._clicking:
            if self.action == MOVE_ACTION:
                self.pan(x, y)
            elif self.action == ZOOM_ACTION:
                self.renderer.GetActiveCamera().Zoom(1 + (self._actionOriginY - y) / 100)
            elif self.action == ROTATE_ACTION:
                self.rotate(x, y)
                
            self._dragStartX = x
            self._dragStartY = y
            self.update()
    
    def pan(self, x: float, y: float):
        """
        Handle a pan action. Pans from the mouse down event to the current x, y coordinates.
        Addapted from:
        
           https://compucell3d.org/BinDoc/cc3d_binaries/dependencies/windows/MinGW/dependencies_qt_4.8.4_pyqt_4.9.6_vtk_5.10.1_python27/Player/vtk/wx/wxVTKRenderWindow.py 
        
        """
        renderer = self.renderer
        camera = renderer.GetActiveCamera()
        (pPoint0,pPoint1,pPoint2) = camera.GetPosition()
        (fPoint0,fPoint1,fPoint2) = camera.GetFocalPoint()

        renderer.SetWorldPoint(fPoint0,fPoint1,fPoint2,1.0)
        renderer.WorldToDisplay()
        # Convert world point coordinates to display coordinates
        dPoint = renderer.GetDisplayPoint()
        focalDepth = dPoint[2]

        aPoint0 = self._width / 2 + (x - self._dragStartX)
        aPoint1 = self._height / 2 - (y - self._dragStartY)

        renderer.SetDisplayPoint(aPoint0,aPoint1,focalDepth)
        renderer.DisplayToWorld()

        (rPoint0,rPoint1,rPoint2,rPoint3) = renderer.GetWorldPoint()
        if (rPoint3 != 0.0):
            rPoint0 = rPoint0/rPoint3
            rPoint1 = rPoint1/rPoint3
            rPoint2 = rPoint2/rPoint3

        camera.SetFocalPoint((fPoint0 - rPoint0) + fPoint0,
                             (fPoint1 - rPoint1) + fPoint1,
                             (fPoint2 - rPoint2) + fPoint2)

        camera.SetPosition((fPoint0 - rPoint0) + pPoint0,
                           (fPoint1 - rPoint1) + pPoint1,
                           (fPoint2 - rPoint2) + pPoint2)
        
    def rotate(self, x: float, y: float):
        """
        Handle a rotate action. Rotates from the mouse down event to the current x, y coordinates.
        Addapted from:
        
           https://compucell3d.org/BinDoc/cc3d_binaries/dependencies/windows/MinGW/dependencies_qt_4.8.4_pyqt_4.9.6_vtk_5.10.1_python27/Player/vtk/wx/wxVTKRenderWindow.py 
        
        """
        renderer = self.renderer
        camera = renderer.GetActiveCamera()
        camera.Azimuth(self._dragStartX - x)
        camera.Elevation(y - self._dragStartY)
        camera.OrthogonalizeViewUp()

        renderer.ResetCameraClippingRange()
    
    def zoom(self, x: float, y: float):
        """
        Handle a zoom action. Zooms from the mouse down event to the current x, y coordinates.
        Addapted from:
        
           https://compucell3d.org/BinDoc/cc3d_binaries/dependencies/windows/MinGW/dependencies_qt_4.8.4_pyqt_4.9.6_vtk_5.10.1_python27/Player/vtk/wx/wxVTKRenderWindow.py 
        
        """
        
        renderer = self.renderer
        camera = renderer.GetActiveCamera()

        zoomFactor = math.pow(1.02,(0.5*(self._dragStartY - y)))
        self._CurrentZoom = self._CurrentZoom * zoomFactor

        if camera.GetParallelProjection():
            parallelScale = camera.GetParallelScale()/zoomFactor
            camera.SetParallelScale(parallelScale)
        else:
            camera.Dolly(zoomFactor)
            renderer.ResetCameraClippingRange()

        self._dragStartX = x
        self._dragStartY = y

        self.Render()
    
    @output.capture()
    def selectCell(self, x: float, y: float):
        """
        Select a cell based on the x, y position clicked on the canvas.
        
        """
        picker = vtk.vtkPropPicker()
        picker.Pick(x, self._height - y, 0, self.renderer)

        # get the new
        self.NewPickedActor = picker.GetActor()
        
        position = picker.GetPickPosition()
        
        if not picker.GetActor():
            self._selectedCell = None
            self.update()
            return
        
        x, y, z = picker.GetPickPosition()
            
        cells = self.getCellsToRender()
        cellVariables = self.getCellVariables()
        
        distances = np.sqrt(np.square(x - cells[cellVariables['position.x']]) + np.square(y - cells[cellVariables['position.y']]) + np.square(z - cells[cellVariables['position.z']])) - self.radiusOfCells(cells, cellVariables)
        minIndex = np.argmin(distances)
        
        if distances[minIndex] <= 0:
            self._selectedCell = cells[cellVariables['ID'], minIndex]
            
        self.update()
            

def viewSimulation(outputPath: str, width: int = 500, height: int = 400, force3d: bool = False, force2d: bool = False, **kwargs):
    """
    view the Physicell simulation output stored at the given path.
    
        outputPath: str - the path to the folder containing the simulation output.
        width: int - the width of the canvas to draw the visualization to
        height: int - the height of the canvas to draw the visualization to
        force3d: int - forces a 3d visualization of the simulation
        force2d: int - forces a 2d visualization of the simulation
        **kwargs - named arguments to pass to the visualizations, reference Interactor2D and Interactor3D for possible values.
    
    """
    parser = Parser(outputPath)
    if force2d or (parser.getFrame(parser.getFrameRange()[0]).environment.is2D and not force3d):
        return Interactor2D(parser, width, height, **kwargs)
    return Interactor3D(parser, width, height,**kwargs)

In [33]:
def filterFunction(cells, variables):
    return cells[variables['cycle_model']] == 5

env = viewSimulation('./sample-output', width=900, height=800, force3d=True)
env.show()

Canvas(height=800, width=900)

RadioButtons(description='Mouse Action:', options=('move', 'zoom', 'rotate', 'select'), value='move')

IntSlider(value=0, continuous_update=False, description='Frame:', max=9)

RadioButtons(description='Generate Attribute Graph:\n', options=('ID', 'position.x', 'position.y', 'position.z…

SelectMultiple(description='Visible:', index=(0,), options=('cell', 'oxygen'), value=('cell',))

Output()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [32]:
def filterFunction(cells, variables):
    return cells[variables['cycle_model']] != 5

env = viewSimulation('./sample-output-3d', width=800, height=800)
env.show()

Canvas(height=800, width=800)

RadioButtons(description='Mouse Action:', options=('move', 'zoom', 'rotate', 'select'), value='move')

IntSlider(value=0, continuous_update=False, description='Frame:', max=5)

RadioButtons(description='Generate Attribute Graph:\n', options=('ID', 'position.x', 'position.y', 'position.z…

SelectMultiple(description='Visible:', index=(0,), options=('cell', 'oxygen', 'immunostimulatory factor'), val…

Output()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>