In [1]:
#export
from IPython import display
from matplotlib import pyplot as plt
import numpy as np
import torch

def use_svg_display(): 
    """Use the svg format to display a plot in Jupyter.""" 
    display.set_matplotlib_formats('svg')
    
def set_figsize(figsize=(3.5, 2.5)): 
    """Set the figure size for matplotlib.""" 
    use_svg_display() 
    plt.rcParams['figure.figsize'] = figsize
    
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend): 
    """Set the axes for matplotlib.""" 
    axes.set_xlabel(xlabel) 
    axes.set_ylabel(ylabel) 
    axes.set_xscale(xscale) 
    axes.set_yscale(yscale) 
    axes.set_xlim(xlim) 
    axes.set_ylim(ylim) 
    if legend: 
        axes.legend(legend) 
        axes.grid()

def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,
         ylim=None, xscale='linear', yscale='linear',
         fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):
    """Plot data points."""
    if legend is None:
        legend = []

    set_figsize(figsize)
    axes = axes if axes else plt.gca()

    # Return True if X (ndarray or list) has 1 axis
    def has_one_axis(X):
        return (hasattr(X, "ndim") and X.ndim == 1 or isinstance(X, list)
                and not hasattr(X[0], "__len__"))

    if has_one_axis(X):
        X = [X]
    if Y is None:
        X, Y = [[]] * len(X), X
    elif has_one_axis(Y):
        Y = [Y]
    if len(X) != len(Y):
        X = X * len(Y)
    axes.cla()
    for x, y, fmt in zip(X, Y, fmts):
        if len(x):
            axes.plot(x, y, fmt)
        else:
            axes.plot(y, fmt)
    set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)


In [2]:
#export
def synthetic_data(w, b, num_examples): 
    """Generate y = X w + b + noise.""" 
    X = np.random.normal(0, 1, (num_examples, len(w))) 
    y = np.dot(X, w) + b 
    y += np.random.normal(0, 0.01, y.shape) 
    return X, y

def linreg(inputs, W, b):
    return torch.mm(inputs, W) + b

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

def sgd(params, lr, batch_size):
    for param in params:
        param[:] = param - lr * param.grad / batch_size

In [3]:
#export
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        ax.imshow(img.numpy())
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

In [4]:
#export
def remove_sequential(network, all_layers=None):
    if all_layers is None:
        all_layers = []
    for layer in network.children():
        if type(layer) == nn.Sequential: # if sequential layer, apply recursively to layers in sequential layer
            remove_sequential(layer, all_layers)
        if list(layer.children()) == []: # if leaf node, add it to list
            all_layers.append(layer)
    return all_layers

def layer_description(model, x):
    for layer in remove_sequential(model):
        x = layer(x)
        print(layer.__class__.__name__,'Output shape:\t',x.shape)

In [5]:
#export
def find_modules(m, cond):
    if cond(m): return [m]
    return sum([find_modules(o,cond) for o in m.children()], [])

In [6]:
#export
def accuracy(out, yb):
    return (torch.argmax(out, dim=1) == yb).float().mean()


In [7]:
#export
import sys; sys.path.insert(0, '../')
from exp.hook import *

def hook_lsuv_stats(h, module, input, output):
    h.mean = output.data.mean().item()
    h.std = output.data.std().item()
    
def lsuv(model, module, xb, tol=1e-3, max_attempts=10):
    h = ForwardHook(module, hook_lsuv_stats)
    model(xb)
    attemp_cnt = 0
    while model(xb) is not None and abs(h.mean) >= tol and attemp_cnt < max_attempts:
        module.bias.data -= h.mean
        attemp_cnt += 1
    attemp_cnt = 0
    while model(xb) is not None and abs(h.std) >= 1+tol and attemp_cnt < max_attempts:
        module.weight.data /= h.std
        attemp_cnt += 1
    print(h.mean, h.std)
    h.remove()

def lsuv_init(model, xb, cond=lambda:True):
    for m in find_modules(model, cond):
        lsuv(model, m, xb)

In [8]:
!python notebook2script.py d2l_utils.ipynb
import sys; sys.path.insert(0, '../')
from exp import nb_d2l_utils

dir(nb_d2l_utils)

Converted d2l_utils.ipynb to exp/nb_d2l_utils.py


['ForwardHook',
 'Hooks',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'accuracy',
 'display',
 'find_modules',
 'get_hist',
 'get_min',
 'hook_lsuv_stats',
 'hook_stats',
 'layer_description',
 'linreg',
 'lsuv',
 'lsuv_init',
 'nn',
 'np',
 'partial',
 'plot',
 'plt',
 'remove_sequential',
 'set_axes',
 'set_figsize',
 'sgd',
 'show_images',
 'squared_loss',
 'synthetic_data',
 'sys',
 'torch',
 'use_svg_display']