# AutoML for Model Compression

This notebook will help us visualize and review the results of the DDPG agent's sub-space search.
It contains two visualizations of the process of discovering networks during the exploration and exploitation phases.  Each discovered network is projected on a 2D subspace that maps the network's compute complexity (normalized to a percentage of the dense-network's compute budget) against its Top1 accuracy.

The Top1 value is either the Test dataset Top1 measured without any fine-tuning, or after one epoch of fine-tuning (this depends on how the AMC algorithm is configured).

In [None]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import matplotlib 
import csv
from matplotlib.ticker import FuncFormatter
#from matplotlib.animation import FuncAnimation


import matplotlib.pylab as pylab
params = {'legend.fontsize': 'x-large',
          'figure.figsize': (15, 7),
          'axes.labelsize': 'x-large',
          'axes.titlesize':'xx-large',
          'xtick.labelsize':'x-large',
          'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)


def to_percent(y, position):
    # Ignore the passed in position. This has the effect of scaling the default
    # tick locations.
    if y < 1:
        y = str(100 * y)
    s = str(y)

    # The percent symbol needs escaping in latex
    if matplotlib.rcParams['text.usetex'] is True:
        return s + r'$\%$'
    else:
        return s + '%'

## Static diagram

In [None]:
import pandas as pd 

fname = '/home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/'
#fname += 'AMC-Plain20_T4___2018.12.23-014309/amc.csv'
#fname += 'AMC-Plain20_T5___2018.12.23-030950/amc.csv'
#fname += 'AMC-Plain20_T6___2018.12.23-043924/amc.csv'
#fname += 'AMC-Plain20_T3___2018.12.23-000709/amc.csv'
fname += 'AMC-Plain20_T1___2018.12.22-213316/amc.csv'
df = pd.read_csv(fname)
df.loc[df['reward'].idxmax()]

In [None]:
#fname = '/home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/'
#fname += 'AMC-Plain20_T4___2018.12.23-014309/amc.csv'

top1 = df['top1']
normalized_macs = df['normalized_macs']

plt.figure(figsize=(15,7))        
plt.title('Projection of Discovered Networks ({})'.format(len(top1)))     
plt.xlabel('Normalized MACs')
plt.ylabel('Top1 Accuracy')

# Create the formatter using the function to_percent. This multiplies all the
# default labels by 100, making them all percentages
formatter = FuncFormatter(to_percent)

# Set the formatter
plt.gca().yaxis.set_major_formatter(formatter)
plt.gca().xaxis.set_major_formatter(formatter)

# Use color gradients to show the "age" of the network:
# Lighter networks were discovered earlier than darker ones.
color_grad = [str(1-i/len(top1)) for i in range(len(top1))]
plt.scatter(normalized_macs, top1, color=color_grad, s=80, edgecolors='gray');

#plt.hlines(90, 1.5*10**8, 2.5*10**8, color='b')


In [None]:
BIN_SIZE = 2# 0.5
NUM_BINS = int(100 / BIN_SIZE)
compute_bins = [None] * NUM_BINS
color_grad_ = ["1" for _ in top1]

idx_bins = [-1] * NUM_BINS

draw_what = "accuracy contour"
draw_what = "mac contour"
if draw_what == "accuracy contour":
    for i,accuracy in enumerate(top1):
        bin_id = int(accuracy // BIN_SIZE)
        try:
            if compute_bins[bin_id] is None or compute_bins[bin_id] > normalized_macs[i]:
                compute_bins[bin_id] = normalized_macs[i]
                idx_bins[bin_id] = i
        except TypeError:
            pass
else:
    for i,compute in enumerate(normalized_macs):
        bin_id = int(compute // BIN_SIZE)
        try:
            #print(bin_id)
            if compute_bins[bin_id] is None or compute_bins[bin_id] < top1[i]:
                compute_bins[bin_id] = top1[i]
                idx_bins[bin_id] = i
        except TypeError:
            pass
    
for i in idx_bins:
    if i != -1:
        color_grad_[i] = "red"
plt.scatter(normalized_macs, top1, color=color_grad_, s=80, edgecolors='gray');

## Video animation

In [None]:
# Based on these two helpful example code: 
# https://stackoverflow.com/questions/9401658/how-to-animate-a-scatter-plot
# http://louistiao.me/posts/notebooks/embedding-matplotlib-animations-in-jupyter-notebooks/.
# Specifically, the use of IPython.display is missing from the first example, but most of the animation code
# leverages code from there.
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

from matplotlib import animation, rc
from IPython.display import HTML

INTERVAL = 100 # Animation speed
WINDOW = 20

font = {'family': 'serif',
        'color':  'darkred',
        'weight': 'normal',
        'alpha': 0.50,
        'size': 32,
        }

class AnimatedScatter(object):
    """An animated scatter plot using matplotlib.animations.FuncAnimation."""
    def __init__(self, xdata, ydata):
        assert len(xdata) == len(ydata)
        self.numpoints = len(xdata)
        self.xdata = xdata
        self.ydata = ydata
        self.stream = self.data_stream()

        # Setup the figure and axes...
        self.fig, self.ax = plt.subplots(figsize=(15,7))
        # Then setup FuncAnimation.
        self.ani = animation.FuncAnimation(self.fig, self.update, interval=INTERVAL,
                                           frames=self.numpoints-2, 
                                           init_func=self.setup_plot, blit=True)

    def setup_plot(self):
        """Initialize drawing of the scatter plot."""
        x, y, s, c = next(self.stream)
        #self.annot = self.ax.annotate("txt", (10, 10))
        self.scat = self.ax.scatter(x, y, c=c, s=s, animated=False)
        self.scat.set_edgecolors('gray')
        self.scat.set_cmap('gray')
        self.width = max(self.xdata) - min(self.xdata) + 4
        self.height = max(self.ydata) - min(self.ydata) + 4
        self.ax.axis([min(self.xdata)-2, max(self.xdata)+2, 
                      min(self.ydata)-2, max(self.ydata)+2])
        
        self.annot = self.ax.text(min(self.xdata) + self.width/2, 
                     min(self.xdata) + self.height/2, 
                     "", fontdict=font)
        # For FuncAnimation's sake, we need to return the artist we'll be using
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.scat, 

    def data_stream(self):
        numpoints = 0#len(self.xdata)
        colors = []
        xxx = 0
        while True:
            numpoints += 1
            win_len = min(WINDOW, numpoints)
            data = np.ndarray((4, win_len))
            start = max(0,numpoints-WINDOW-1)
            data[0, :] = self.xdata[start:start+win_len]
            data[1, :] = self.ydata[start:start+win_len]
            data[2, :] = [70] * win_len  # point size
            #data[3, :] = [np.random.random() for p in range(numpoints)]  # color
            # The color of the points is a gradient with larger values for "younger" points.
            # At each new frame we show one more point, and "age" each existing point by incrementaly  
            # reducing its color gradient.
            data[3, :] = [(1-i/(win_len+1)) for i in range(win_len)] 
            yield data

    def update(self, i):      
        """Update the scatter plot."""
        data = next(self.stream)
        self.annot.set_text(i)
        i = i % len(data)
            
        # Set x and y data
        xy = [(data[0,i], data[1,i]) for i in range(len(data[0,:]))]
        self.scat.set_offsets(xy)
        
        # Set colors
        self.scat.set_array(data[3])
        
        # We need to return the updated artist for FuncAnimation to draw..
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.scat, self.annot

    def show(self):
        plt.show()

a = AnimatedScatter(normalized_macs, top1)
plt.title('Projection of Discovered Networks ({})'.format(len(top1)))  
plt.xlabel('Normalized MACs')
plt.ylabel('Top1 Accuracy')
#a.ani.save('amc_vgg16.mp4', fps=10, dpi=80) #Frame per second controls speed, dpi controls the quality 
rc('animation', html='html5')
a.ani