In [1]:
import ipyvolume as ipv
import numpy as np
import bqplot as bq
import ipywidgets as ipw
import traitlets as tr
import time
import copy

In [2]:
class TransferFunctionEditor(ipw.VBox):
    def __init__(self, fig):
        super().__init__()
        
        # Copy figure since we're going to shrink it
        self.ipv_fig = fig
        
        self.init_ipv_fig()
        self.init_vals()
        self.init_elements()
        self.init_logic()
        self.init_style()
        self.init_layout()

    def init_vals(self):
        self.x = np.linspace(0, 1, 256)
        self.original_rgba = np.copy(self.tf.rgba)
        self.alpha_max = .2

    def init_elements(self):
        self.xscale = bq.LinearScale(min=0,max=1)
        self.yscale = bq.LinearScale(min=0,max=1)
        self.alpha_scale = bq.LinearScale(min=0, max=self.alpha_max)
        
        self.xax = bq.Axis(
            scale=self.xscale,
            label='Data Value (%)',
            grid_lines='none'
        )
        self.yax = bq.Axis(
            scale=self.yscale,
            label='RGB Values',
            orientation='vertical',
            grid_lines='none'
        )
        self.alpha_ax = bq.Axis(
            scale=self.alpha_scale,
            label='Alpha Values',
            orientation='vertical',
            grid_lines='none',
            side='right'
        )
        
        # Attributes to change
        self.attribute_list = ['red', 'green', 'blue', 'alpha']
        
        self.lines = {
            attribute: self.gen_rgba_line(index, attribute)
            for index, attribute in enumerate(self.attribute_list)
        }
        
        self.handdraws = {
            attribute: bq.interacts.HandDraw(
                lines=self.lines[attribute]
            )
            for attribute in self.attribute_list
        }

        self.attribute_selector = ipw.Select(
            options=self.attribute_list,
        )
        
        self.reset_button = ipw.Button(
            description='Reset'
        )
        
        self.reload_button = ipw.Button(
            description='Reload TF'
        )
        
        self.bq_fig = bq.Figure(
            marks=list(self.lines.values()),
            axes=[self.xax, self.yax, self.alpha_ax],
            interaction=self.handdraws[self.attribute_selector.value]
        )
        

    def init_logic(self):
        # Update which line is being edited
        self.attribute_selector.observe(self.select_attribute, names='value')
        
        # Reset TF to original state
        self.reset_button.on_click(self.reset_tf)
        
        # Update lines to match RGBA array in case it's changed via another means.
        self.reload_button.on_click(self.reload_tf)
        
        # Automatically update RGBA array on line changes
        for line in self.lines.values():
            line.observe(self.reload_tf, names='y')
    
    def init_style(self):
        self.bq_fig.layout = ipw.Layout(
            width=u'500px',
            height=u'400px'
        )
    
    def init_layout(self):
        self.children = [
            ipw.HBox([
                self.bq_fig,
                self.ipv_fig
            ]),
            ipw.HBox([
                self.attribute_selector,
                ipw.VBox([
                    self.reset_button,
                    self.reload_button,
                ])
            ])
        ]
        
    def init_ipv_fig(self):
        self.ipv_fig.width = 350
        self.ipv_fig.height = 350
        #self.ds = ipv.datasets.aquariusA2.fetch()
        #self.ipv_fig = ipv.Figure(width=350, height=350)
        #self.ipv_fig = ipv.volshow(self.ds.data)
        #ipv.show()
        
        # It seems that it takes a moment for the tf to be initialized.
        # Otherwise, tf is None and we get errors.
        # Try a few times.
        for i in range(5):
            if self.ipv_fig.tf.rgba is None:
                time.sleep(1)
        if self.ipv_fig.tf.rgba is None:
            raise ValueError('TransferFunction not loaded :(')
            
        self.tf = self.ipv_fig.tf
        
    def select_attribute(self, *args):
        "Choose to edit red, green, blue, or alpha"
        attribute = self.attribute_selector.value
        self.bq_fig.interaction = self.handdraws[attribute]
        
    def reload_tf(self, *args):
        "Update TransferFunction if graph is changed"
        # The column in the rgba array that we're setting
        column = self.attribute_selector.index
        # The values from the appropriate line
        attribute = self.attribute_selector.value
        values = self.lines[attribute].y
        
        # tf.rgba is "read only", so make a copy and assign that
        rgba_copy = np.copy(self.tf.rgba)
        rgba_copy[:,column] = values
        
        self.tf.rgba = rgba_copy
        self.update_lines()

    def reset_tf(self, *args):
        self.tf.rgba = self.original_rgba
        self.update_lines()
    
    def update_lines(self):
        """Update lines in  case TF has changed.
        It would be nice to do this automatically (via observe),
        but then we would get infinite recursion every time either changes.
        """
        
        for index, attribute in enumerate(self.attribute_list):
            self.lines[attribute].y = self.tf.rgba[:,index]
    
    def gen_rgba_line(self, column, attribute):
        "Create bqplot line from column of tf.rgba with specified attribute"
        if attribute == 'alpha':
            color = 'black'
            yscale = self.alpha_scale
        else:
            color = attribute
            yscale = self.yscale
        return bq.Lines(
            x=self.x,
            y=self.tf.rgba[:,column],
            #y=self.x*column/100,
            scales={'x': self.xscale, 'y': yscale},
            colors=[color],
            interpolation='cardinal',
            display_legend=True,
            labels=[attribute]
        )

In [6]:
ds = ipv.datasets.aquariusA2.fetch()
fig = ipv.figure()
ipv.volshow(ds.data)
fig

  gradient = gradient / np.sqrt(gradient[0]**2 + gradient[1]**2 + gradient[2]**2)


Figure(camera_center=[0.0, 0.0, 0.0], data_max=255.0, height=500, matrix_projection=[0.0, 0.0, 0.0, 0.0, 0.0, …

In [7]:
tfe = TransferFunctionEditor(fig)
tfe

TransferFunctionEditor(children=(HBox(children=(Figure(axes=[Axis(grid_lines='none', label='Data Value (%)', s…

In [None]:
ipv.show()