In [None]:
__author__ = 'kgeorge2@gmail.com'

In [1]:

import ipywidgets as widgets
from traitlets import Unicode, validate
from IPython.display import display


class ProgressImageWidget(widgets.DOMWidget):
    """
      ipywidget class to display incremental progress of training as an image
    """
    _view_name = Unicode('ProgressImageView').tag(sync=True)
    _view_module = Unicode('progress_image').tag(sync=True)
    value = Unicode().tag(sync=True)

In [2]:
%%javascript
require.undef('progress_image');

define('progress_image', ["jupyter-js-widgets"], function(widgets) {

    // Define the HelloView
    var ProgressImageView = widgets.DOMWidgetView.extend({
        // Render the view.
        render: function() {
            this.$img = $('<img />')
                .appendTo(this.$el);
        },
        
        update: function() {
            this.$img.attr('src', this.model.get('value'));
            return ProgressImageView.__super__.update.apply(this);
        },
        events: {"change": "handle_value_change"},
        
        handle_value_change: function(event) {
            this.model.set('value', this.$img.src);
            this.touch();
        },
        
    });

    return {
        ProgressImageView : ProgressImageView 
    }
});


'\n%%javascript\nrequire.undef(\'progress_image\');\n\ndefine(\'progress_image\', ["jupyter-js-widgets"], function(widgets) {\n\n    // Define the HelloView\n    var ProgressImageView = widgets.DOMWidgetView.extend({\n        // Render the view.\n        render: function() {\n            this.$img = $(\'<img />\')\n                .appendTo(this.$el);\n        },\n        \n        update: function() {\n            this.$img.attr(\'src\', this.model.get(\'value\'));\n            return ProgressImageView.__super__.update.apply(this);\n        },\n        events: {"change": "handle_value_change"},\n        \n        handle_value_change: function(event) {\n            this.model.set(\'value\', this.$img.src);\n            this.touch();\n        },\n        \n    });\n\n    return {\n        ProgressImageView : ProgressImageView \n    }\n});\n'

In [None]:
import numpy as np
import io, base64
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties



class Plotter(object):
    """
      A utility class to plot training/test data
      add_channel: Add as many channels as you want
      add_sample: Add as many samples to any channel
      plot: will return a dataurl containing a single plot 
    """
    def __init__(self,  **kwds):
        #need to have these keywords for initialization
        assert(kwds.get('xlabel'))
        assert(kwds.get('ylabel'))
        assert(kwds.get('title'))
        self.__dict__.update(kwds)
        #initialize empty extents
        self.extents=[np.inf, -np.inf, np.inf, -np.inf]
        self.channels={}
        pass
    
    #num_samples == upper bound on the number of samples that can be added for this channel    
    def add_channel(self, num_samples=-1, **kwds):
        assert(kwds.get('channel_name'))
        assert(kwds.get('legend'))
        channel = self.channels.setdefault(kwds['channel_name'], {})
        channel['plot_x'] = np.zeros(num_samples, dtype=np.float32)
        channel['plot_y'] = np.zeros_like( channel['plot_x']  )
        channel['legend'] = kwds['legend']
        channel['next_sample_index'] = 0
    
    #add a sample to a channel
    def add_sample(self, x, y, channel_name=''):
        assert(channel_name)
        assert(self.channels.get(channel_name))
        channel = self.channels[channel_name]
        next_index = channel['next_sample_index']
        channel['plot_x'][next_index] = x
        channel['plot_y'][next_index] = y
        channel['next_sample_index'] += 1
        self.update_extents_(x, y)

        
    #internal routine to keep track of extents
    def update_extents_(self, x, y):
        self.extents[0 ] = 0 # min(x, self.extents[0])  
        self.extents[1 ] = max(x, self.extents[1])  
        self.extents[2 ] = 0 # min(y, self.extents[2])  
        self.extents[3 ] = 1 # max(y, self.extents[3])  

    #plot routune
    def plot(self):
        fontP = FontProperties()
        fontP.set_size('small')
        format='PNG'
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        #plot each channel
        for k,v in self.channels.iteritems():
            next_sample_index = v['next_sample_index']
            ax.plot(v['plot_x'][0:next_sample_index], v['plot_y'][0:next_sample_index], label=k)
        plt.legend( loc='lower left', prop=fontP)
        ax.set_title(self.title)
        ax.set_xlabel(self.ylabel)
        ax.set_xlabel(self.xlabel)
        #return the plot as a dataurl
        buf = io.BytesIO()    
        fig.savefig(buf, format=format)
        buf.seek(0)
        dataurl = "data:image/" + format + ";base64," + base64.b64encode(buf.read())
        fig.clear()
        plt.close(fig)
        return dataurl
        