In [None]:
import sys
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm

sys.path.append("../src")
from losses import LinearClassification
from block_analysis import block_hessian, curvature_effects, eval_loss, update_params
from lr_tools import lr_calibrate
from models import conv_net
from data import gen_cifar10_ds

## helpers

### Investigation:

- Different LR
- More granular observations
- GD vs SGD

## params

In [None]:
inp_dim = 3
hid_dim = 64
out_dim = 10
nlayer = 5
bias = False 
use_bn = (False,False,False,False)
mode = "relu"
loss_mode = 'CrossEntropy'
device = 0

nsamp = 500

save_model_dir = "models"

### Analysis

In [None]:
from utils import zero_grad
from activation_stats import first_order_analysis, plot_param, plot_activ
from block_analysis import clone_model, eval_loss

In [None]:
import math

class Activation(nn.Module):
    def __init__(self, mode="linear"):
        """
        """
        super().__init__()
        assert mode in ["relu", "linear"]
        self.mode = mode
        self.last_msk = None
        
    def set_mode(self, mode):
        """
        """
        assert mode in ["relu", "replay", "linear"]
        self.mode=mode
        
    def forward(self, x):
        """
        """
        if self.mode=="relu":
            msk = (x>0).detach().to(x.dtype)
            self.last_msk=msk
            return x*msk
        
        elif self.mode=="replay":
            assert self.mode=="relu" and self.last_msk is not None
            return x*self.last_msk
        else:
            return x

class FC(nn.Module):
    def __init__(self, inp, out, bias=False, mode="linear"):
        """
        """
        super().__init__()
        self.fc = nn.Linear(inp, out, bias=bias)
        self.act = Activation(mode)
        self.init_weights()
        
    def forward(self, x):
        """
        """
        return self.act(self.fc(x))
    
    def init_weights(self, init_type="variance"):
        if init_type=="variance":
            var = {"relu":2, "linear":1}[self.act.mode]
            with torch.no_grad():
                self.fc.weight.normal_(std=math.sqrt(var/self.fc.weight.shape[0]))
        else:
            raise NotImplementedError

In [None]:
class var_center(nn.Module):
    def forward(self, x):
        return x - x.mean(dim=1, keepdim=True)

class BN2d_ctrl(nn.Module):
    def __init__(self, num_features, use_bn=(False, False, False, False), eps=1e-5):
        super().__init__()
        self.ctr = use_bn[0]
        self.std = use_bn[1]
        self.scl = use_bn[2]
        self.bias = use_bn[3]
        self.mean = None
        self.bn_std = None
        self.eps = eps
        if self.scl:
            self.bn_weight = nn.Parameter(torch.ones(num_features))
        if self.bias:
            self.bn_bias = nn.Parameter(torch.zeros(num_features))
        
    def forward(self, x):
        if self.ctr:
            #self.mean = x.mean(dim=(0,2,3), keepdim=True)
            x = x - x.mean(dim=(0,2,3), keepdim=True)#self.mean
        if self.std:
            #self.bn_std = torch.sqrt(x.var(dim=(0,2,3), keepdim=True) + self.eps)
            x = x / x.std(dim=(0,2,3), keepdim=True)#self.bn_std
        if self.scl:
            x = torch.mul(x, self.bn_weight.unsqueeze(0).unsqueeze(-1).unsqueeze(-1))#[None,:,None,None]
        if self.bias:
            x = torch.add(x, self.bn_bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1))#[None,:,None,None]
        
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return torch.flatten(x)

class GAPool(nn.Module):
    def forward(self, x):
        return torch.mean(x,(2,3))

class conv_bn(nn.Module):
    def __init__(self, inp, out, stride=1, bias=False, use_bn=(False, False, False, False), mode="linear"):
        """
        """
        super().__init__()
        self.conv = nn.Conv2d(inp, out, kernel_size=3, stride=stride, padding=1, bias=bias)
        #self.var_c = var_center()
        #self.bn = BN2d_ctrl(out, use_bn)
        self.bn = nn.BatchNorm2d(out)
        with torch.no_grad():
            self.bn.weight.fill_(1)
        self.act = Activation(mode)
        torch.nn.init.kaiming_normal_(self.conv.weight, a={"relu":0, "linear":1}[mode])
        
    def forward(self, x):
        """
        """
        return self.act(self.bn(self.conv(x)))
    
class conv_net(nn.Module):
    def __init__(self, inp, hid, out, nlayer, bias=False, use_bn=(False, False, False, False), mode="linear"):
        """
        """
        super().__init__()
        self.l1 = conv_bn(inp, hid, stride=2, bias=bias, use_bn=use_bn, mode=mode)
        self.layers = nn.Sequential(*[conv_bn(hid, hid, stride=2, bias=bias, use_bn=use_bn, mode=mode) \
                                      for i in range(max(0,nlayer-2))])
        self.GAPool = GAPool()
        self.out = FC(hid, out, bias=False, mode="linear")
        
    def forward(self, x):
        """
        """
        return self.out(self.GAPool(self.layers(self.l1(x))))
    
    def get_mode(self):
        """
        """
        return next(self._activations()).mode
    
    def set_mode(self, mode):
        """
        """
        for activation in self._activations():
            activation.set_mode(mode)
    
    def _activations(self):
        """
        """
        return filter(lambda x:isinstance(x, Activation), self.modules())

In [None]:
def get_model_ds_loss(loss_mode='CrossEntropy'):    
    model =  conv_net(inp_dim, hid_dim, out_dim, nlayer, bias, use_bn, mode).cuda(device)
    ds = gen_cifar10_ds(nsamp, device)
    
    assert loss_mode in ["CrossEntropy", "Linear"]
    if loss_mode=='CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()
    elif loss_mode=='Linear':
        loss_fn = LinearClassification(out_dim)
        
    return model, ds, loss_fn

def correct(classifier, target):
    return classifier.max(dim = 1)[1] == target

def train_epoch(model, ds, loss_fn, lr):
    zero_grad(model)
    loss = eval_loss(model, ds, loss_fn)

    grads = [x.grad for x in model.parameters()]
    delta = grads#get_delta_params(model, grads)
    params = list(model.parameters())
    
    update_params(params, delta, lr)

    #delta_norm = list(map(lambda x: x.clone().norm().item(), delta))
    model.zero_grad()
    
    acc = correct(model(ds[0][0]), ds[0][1])

    return loss, acc.type(torch.FloatTensor).mean().item() #delta_norm, 

In [None]:
import pandas as pd

def relative_error(a, b, eps=1e-6):
    """
    """
    return abs((a - b) / min(abs(a), abs(b)))

def init_dir(path, name="default"):
    path = f"{save_model_dir}/{name}/"
    if not os.path.exists(path):
        os.makedirs(path)
        
def save_model(model, epoch, name="default"):
    torch.save(model.state_dict(), f"{save_model_dir}/{name}/{epoch}.pth".format(save_model_dir, epoch))

def select(stats, column):
    """
    """
    stats = [stat[column]for stat in stats]
    return pd.concat(stats, axis=1).T.set_index(np.arange(len(stats)))

def plot_stats(stats, column, *args, ax=None, **kwargs):
    """
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(12,8))
    df = select(stats, column)
    df.plot(*args, ax=ax, **kwargs)
    
def summarize_stats(stats):
    """
    """
    fig, ax = plt.subplots(2,2, figsize=(20, 20))
    plot_stats(stats, 'a_l_std', ax=ax[0,0])
    plot_stats(stats, 'a_l_m', ax=ax[0,1])
    plot_stats(stats, 'W_g_std', ax=ax[1,0])
    plot_stats(stats, 'W_std', ax=ax[1,1])
    
def run_analysis(model, ds, loss_fn, lr, epochs, valfreq=40):
    """
    """
    init_dir(save_model_dir)
    val_stats, tr_stats = [],[]

    for epoch in tqdm(range(epochs)):
        if (epoch+1) % valfreq==0:
            H = block_hessian(model, ds, loss_fn, lr)
            delta, fo, ho, fostat = first_order_analysis(model, ds, loss_fn, lr)
            error = relative_error(H.sum().item(), ho)
            val_stats.append((H, delta, fo, ho, error, fostat))
            save_model(model, epoch)

        loss, acc = train_epoch(model, ds, loss_fn, lr)
        tr_stats.append((loss, acc))

    #H, delta, fo, ho, error, fostat = list(zip(*val_stats))
    #loss, acc = list(zip(*tr_stats))
    return val_stats, tr_stats
    
def unpack_stats(stats):
    return list(zip(*stats))
    
def plot_acc_loss(acc, loss, ax=None):
    """
    """
    if ax is None:
        fig, ax = plt.subplots(1,2, figsize=(12,5))
    ax[1].plot(loss)
    ax[0].plot(acc)

In [None]:
epochs = 2000
lr = 0.01
valfreq = 1

model, ds, loss_fn = get_model_ds_loss(loss_mode)

In [None]:
val_stats, tr_stats = run_analysis(model, ds, loss_fn, lr, epochs, valfreq)

In [None]:
H, delta, fo, ho, error, fostat = unpack_stats(val_stats)
loss, acc = unpack_stats(tr_stats)

In [None]:
plot_acc_loss(acc, loss)

In [None]:
bs = 16
ch = 20
hw = 64

bn1 = nn.BatchNorm2d(ch).cuda()
with torch.no_grad():
    bn1.weight.fill_(1)

bn2 = BN2d_ctrl(ch, use_bn=(True,True,True,True)).cuda()
x = torch.randn(bs, ch, hw, hw).cuda()

In [None]:
o1 = bn1(x)
o2 = bn2(x)

In [None]:
def diff(a,b):
    m = (a-b).abs().max().item()
    rel = (a-b).abs().sum() / a.abs().sum()
    return m,rel

diff(o1,o2)

In [None]:
def pp(o2):
    print(o2.mean((0,2,3)), o2.std((0,2,3)))

In [None]:
o1.mean((0,2,3)), o1.std((0,2,3))

In [None]:
summarize_stats(fostat)

### Visualizing Tools

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
%matplotlib inline

def vals2anime(vals):
    fig, ax = plt.subplots()
    xdata, ydata = range(len(vals)), vals
    ln, = plt.plot([], [])
    ln2, = plt.plot([], [], 'ro')

    def init():
        ax.set_xlim(0, len(vals))
        ax.set_ylim(min(vals), max(vals))
        ax.set_title("i=0")
        ln.set_data(xdata, ydata)
        return ln, ln2

    def update(i):
        ln2.set_data(i, vals[i])
        ax.set_title(f"i={i}")
        return ln, ln2

    ani = FuncAnimation(fig, update, frames=len(vals),
                        init_func=init, blit=True, interval=30)

    ani = ani.to_jshtml()
    #ani = ani.to_html5_video()
    plt.close()
    return HTML(ani)

def normalize_imgs(imgs, time_norm=False):

    if time_norm:
        imgs = [img.abs() for img in imgs]
        vmin = min([img.min() for img in imgs])
        vmax = max([img.max() for img in imgs])

    else:
        imgs = [img.abs()/img.abs().max() for img in imgs]
        vmin = 0
        vmax = 1
    
    return imgs, vmin, vmax

def imgs2anime(imgs, time_norm=False):
    imgs, vmin, vmax = normalize_imgs(imgs, time_norm)
    
    fig, ax = plt.subplots()
    im = ax.imshow(imgs[0], vmin=vmin, vmax=vmax)

    def init():
        im.set_data(imgs[0])
        #ax.imshow(imgs[0])
        ax.set_title("i=0")

    def update(i):
        im.set_data(imgs[i])
        #ax.imshow(imgs[i])
        ax.set_title(f"i={i}")

    ani = FuncAnimation(fig, update, init_func=init, frames=len(imgs),
                        interval=30)
    ani = ani.to_jshtml()
    #ani = ani.to_html5_video()
    plt.close()
    return HTML(ani)

def bars2anime(data):
    
    fig, ax = plt.subplots()
    
    x = np.arange(1, data.shape[1]+1)
    rectss = []
    for i in range(data.shape[2]):
        if i==0:
            rectss.append(ax.bar(x, data[0,:,0]))
        else:
            rectss.append(ax.bar(x, data[0,:,i], bottom=data[0,:,:i].sum(axis=1)))

    def set_rectss(rectss, d):
        for i, rects in enumerate(rectss):
            for j, rect in enumerate(rects):
                rect.set_y(d[j,:i].sum())
                rect.set_height(d[j,i])
        
        #return rectss
    
    def init():
        set_rectss(rectss, data[0,:,:])
        ax.set_title("i=0")
        return rectss[0]

    def update(i):
        set_rectss(rectss, data[i,:,:])
        ax.set_title(f"i={i}")
        return rectss[0]
            
    ani = FuncAnimation(fig, update, frames=data.shape[0],
                        init_func=init, interval=50)

    ani = ani.to_jshtml()
    #ani = ani.to_html5_video()
    plt.close()
    return HTML(ani)

In [None]:
def visualize_data(H, data, loss, acc):    
    fig, ax = plt.subplots(2, 2, figsize=(12,10))

    # BH
    H, vmin, vmax = normalize_imgs(H)
    im = ax[0,0].imshow(H[0], vmin=vmin, vmax=vmax)
    ax[0,0].set_title("BH")

    # layerwise delta_loss, fo, ho
    ax[0,1].set_xlim(0, data.shape[1]+1)
    margin = (data.max()-data.min())*1.08
    ax[0,1].set_ylim(data.min()-margin, data.max()+margin)
    width = 0.3
    x = np.arange(1, data.shape[1]+1)
    rectss = []
    for i in range(data.shape[2]):
        if i==0:
            rectss.append(ax[0,1].bar(x, data[0,:,0], width=width))
        else:
            #rectss.append(ax[1].bar(x, data[0,:,i], bottom=data[0,:,:i].sum(axis=1)))
            rectss.append(ax[0,1].bar(x+width*i, data[0,:,i], width=width))
    ax[0,1].legend(["delta","fo","ho"])
    ax[0,1].set_title("layerwise contrib")
    
    # loss
    ax[1,0].set_xlim(0, len(loss))
    ax[1,0].set_ylim(min(loss), max(loss))
    xdata = range(len(loss))
    ln_l, = ax[1,0].plot([], [])
    ln_l2, = ax[1,0].plot([], [], 'ro')
    ax[1,0].set_title("loss")

    # acc
    ax[1,1].set_xlim(0, len(acc))
    ax[1,1].set_ylim(min(acc), max(acc))
    ln_a, = ax[1,1].plot([], [])
    ln_a2, = ax[1,1].plot([], [], 'ro')
    ax[1,1].set_title("acc")

    def set_rectss(rectss, d):
        for i, rects in enumerate(rectss):
            for j, rect in enumerate(rects):
                #rect.set_y(d[j,:i].sum())
                rect.set_height(d[j,i])
        
    def init():
        im.set_data(H[0])
        
        set_rectss(rectss, data[0,:,:])
        
        ln_l.set_data(xdata, loss)
        ln_a.set_data(xdata, acc)
        
        fig.suptitle("i=0")
        return ln_l, ln_l2, ln_a, ln_a2

    def update(i):
        im.set_data(H[i])
        
        set_rectss(rectss, data[i,:,:])
        margin = (data[i,:,:].max()-data[i,:,:].min())*0.1
        ax[0,1].set_ylim(data[i,:,:].min()-margin, data[i,:,:].max()+margin)
        
        t = i*valfreq
        ln_l2.set_data(t, loss[t])
        ln_a2.set_data(t, acc[t])
        
        fig.suptitle(f"i={i}")
        return ln_l, ln_l2, ln_a, ln_a2

    ani = FuncAnimation(fig, update, init_func=init, frames=len(H), #frames=data.shape[0],
                        interval=100, blit=True)
    ani = ani.to_jshtml()
    #ani = ani.to_html5_video()
    plt.close()
    return ani#HTML(ani)


In [None]:
#BH = [h.cpu() for h in H]
#BH = [h.cpu()[0:9:2,0:9:2] for h in H]
BH = [h.cpu()[0:13:3,0:13:3] for h in H]

lw_fo = [np.array(fo.W_g_sqr)*lr for fo in fostat]
lw_ho = [h.numpy().sum(axis=0) for h in BH]
lw_delta = [fo + ho for fo, ho in zip(lw_fo, lw_ho)]

In [None]:
delta_fo_ho = np.array((lw_delta, lw_fo, lw_ho), dtype=np.float).transpose(1,2,0)

In [None]:
import matplotlib
matplotlib.rcParams['animation.embed_limit']=100

In [None]:
ani = visualize_data(BH, delta_fo_ho, loss, acc)

In [None]:
with open("stat_torchbn_init1.html", "w") as f:
    print(ani, file=f)

In [None]:
HTML(ani)

### Visualizing Tools (Refactored)

In [None]:
class vals_anime():
    def __init__(self, vals, ax, title):
        self.ax = ax
        self.vals = vals
        ax.set_title(title)
        ax.set_xlim(0, len(vals))
        ax.set_ylim(min(vals), max(vals))
        ax.plot(range(len(vals)), vals)
        self.ln, = ax.plot([], [], 'ro')
        
    def init(self):
        return self.ln
        
    def update(self, i):
        self.ln.set_data(i, self.vals[i])
        return self.ln

class imgs_anime():
    def __init__(self, imgs, ax, title, time_norm=False):
        self.ax = ax
        self.imgs, vmin, vmax = self.normalize_imgs(imgs, time_norm)
        self.im = ax.imshow(self.imgs[0], vmin=vmin, vmax=vmax)
        ax.set_title(title)
        
    def init(self):
        return
    
    def update(self, i):
        self.im.set_data(self.imgs[i])
        return
    
    def normalize_imgs(self, imgs, time_norm=False):
        if time_norm:
            imgs = [img.abs() for img in imgs]
            vmin = min([img.min() for img in imgs])
            vmax = max([img.max() for img in imgs])
        else:
            imgs = [img.abs()/img.abs().max() for img in imgs]
            vmin = 0
            vmax = 1

        return imgs, vmin, vmax
    
class bars_anime():
    def __init__(self, data, ax, title):
        self.ax = ax
        ax.set_title(title)
        self.data = data
        ax.set_xlim(0, data.shape[1]+1)
        ax.set_ylim(data.min(), data.max())
        width = 0.1

        x = np.arange(1, data.shape[1]+1)
        self.rectss = []
        for i in range(data.shape[2]):
            if i==0:
                self.rectss.append(ax.bar(x, data[0,:,0], width=width))
            else:
                self.rectss.append(ax.bar(x+width*i, data[0,:,i], width=width))

    def init(self):
        return
        
    def update(self, i):
        self.set_rectss(self.rectss, self.data[i,:,:])
        return
        
    def set_rectss(self, rectss, d):
        for i, rects in enumerate(rectss):
            for j, rect in enumerate(rects):
                #rect.set_y(d[j,:i].sum())
                rect.set_height(d[j,i])


In [None]:
def vis_data(H, data, loss, acc):    
    fig, ax = plt.subplots(2, 2, figsize=(12,10))
    animes = []

    animes.append(imgs_anime(H, ax[0,0], "BH"))
    animes.append(bars_anime(data, ax[0,1], "LW contrib"))
    ax[0,1].legend(["delta","fo","ho"])
    animes.append(vals_anime(loss, ax[1,0], "loss"))
    animes.append(vals_anime(acc, ax[1,1], "acc"))
        
    def init():
        for anime in animes:
            anime.init()
        fig.suptitle("i=0")
        return

    def update(i):
        #t = i*valfreq
        for anime in animes:
            anime.update(i)        
        fig.suptitle(f"i={i}")
        return

    ani = FuncAnimation(fig, update, init_func=init, frames=len(H),
                        interval=100)
    ani = ani.to_jshtml()
    #ani = ani.to_html5_video()
    plt.close()
    return ani#HTML(ani)


In [None]:
HTML(vis_data(BH, delta_fo_ho, loss, acc))