In [None]:
import sys
import pickle

sys.path.append("..")
from helpers import init_dir, \
                    run_analysis, unpack_stats

import torch.nn as nn

sys.path.append("../../../src")
from models import conv_bn, GAPool, FC
from data import gen_cifar10_ds

## skip connection model

In [None]:
class skip_block(nn.Module):
    def __init__(self, hid, nskip, stride=1, bias=False, use_bn=False, mode="linear"):
        super().__init__()
        self.layers = nn.Sequential(*[conv_bn(hid, hid, stride=stride, bias=bias, use_bn=use_bn, mode=mode) \
                                      for i in range(nskip)])

    def forward(self, x):
        z = self.layers(x)
        z = z + x
        return z

class skip_conv_net(nn.Module):
    def __init__(self, inp, hid, out, nlayer, nskip,
                 bias=False, use_bn=False, mode="linear"):
        """
        """
        super().__init__()
        self.l1 = conv_bn(inp, hid, stride=2, bias=bias, use_bn=use_bn, mode=mode)
        
        self.layers = nn.Sequential(*[conv_bn(hid, hid, stride=1, bias=bias, use_bn=use_bn, mode=mode) \
                                      for i in range(max(0,nlayer-2-nskip))])
        self.skip = skip_block(hid, nskip, stride=1, bias=bias, use_bn=use_bn, mode=mode)
        
        self.GAPool = GAPool()
        self.out = FC(hid, out, bias=False, mode="linear")
        
    def forward(self, x):
        """
        """
        return self.out(self.GAPool(self.skip(self.layers(self.l1(x)))))
    
    def get_mode(self):
        """
        """
        return next(self._activations()).mode
    
    def set_mode(self, mode):
        """
        """
        for activation in self._activations():
            activation.set_mode(mode)
    
    def _activations(self):
        """
        """
        return filter(lambda x:isinstance(x, Activation), self.modules())

def get_model_ds_loss(inp_dim, hid_dim, out_dim,
                      nlayer, nskip, bias, use_bn, mode,
                      nsamp, device, loss_mode='CrossEntropy'):
    
    model =  skip_conv_net(inp_dim, hid_dim, out_dim, nlayer, nskip, bias, use_bn, mode).cuda(device)
    ds = gen_cifar10_ds(nsamp, device, download=False)
    
    assert loss_mode in ["CrossEntropy", "Linear"]
    if loss_mode=='CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()
    elif loss_mode=='Linear':
        loss_fn = LinearClassification(out_dim)
        
    return model, ds, loss_fn


In [None]:
class skip_conv_net2(nn.Module):
    def __init__(self, inp, hid, out, nlayer, nskip,
                 bias=False, use_bn=False, mode="linear"):
        """
        """
        super().__init__()
        self.l1 = conv_bn(inp, hid, stride=2, bias=bias, use_bn=use_bn, mode=mode)
        
        self.layers = nn.Sequential(*[conv_bn(hid, hid, stride=1, bias=bias, use_bn=use_bn, mode=mode) \
                                      for i in range(max(0,nlayer-2-nskip))])
        self.skip = skip_block(hid, nskip, stride=1, bias=bias, use_bn=use_bn, mode=mode)
        
        self.GAPool = GAPool()
        self.out = FC(hid, out, bias=False, mode="linear")
        
    def forward(self, x):
        """
        """
        return self.out(self.GAPool(self.layers(self.skip(self.l1(x)))))
    
    def get_mode(self):
        """
        """
        return next(self._activations()).mode
    
    def set_mode(self, mode):
        """
        """
        for activation in self._activations():
            activation.set_mode(mode)
    
    def _activations(self):
        """
        """
        return filter(lambda x:isinstance(x, Activation), self.modules())

def get_model_ds_loss2(inp_dim, hid_dim, out_dim,
                      nlayer, nskip, bias, use_bn, mode,
                      nsamp, device, loss_mode='CrossEntropy'):
    
    model =  skip_conv_net2(inp_dim, hid_dim, out_dim, nlayer, nskip, bias, use_bn, mode).cuda(device)
    ds = gen_cifar10_ds(nsamp, device, download=False)
    
    assert loss_mode in ["CrossEntropy", "Linear"]
    if loss_mode=='CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()
    elif loss_mode=='Linear':
        loss_fn = LinearClassification(out_dim)
        
    return model, ds, loss_fn


## params

In [None]:
inp_dim = 3
hid_dim = 64
out_dim = 10
nlayer = 5
bias = False 
use_bn = (True,True,True,True)
bn_code = f"{use_bn[0]:d}{use_bn[1]:d}{use_bn[2]:d}{use_bn[3]:d}"
mode = "relu"
loss_mode = 'CrossEntropy'
device = 0

nsamp = 500

save_model_dir = "models"
datafolder = f"data{bn_code}_pre-skip"

init_dir(datafolder)

### Analysis

In [None]:
epochs = 5000
lr = 0.01
valfreq = 1
nskips = [1,2]

for nskip in nskips:
    model, ds, loss_fn = get_model_ds_loss(inp_dim, hid_dim, out_dim,
                                           nlayer, nskip, bias, use_bn, mode,
                                           nsamp, device, loss_mode)

    val_stats, tr_stats = run_analysis(model, ds, loss_fn, lr, epochs, valfreq)
    #val_stats, tr_stats = run_analysis(model, ds, loss_fn, lr, epochs, valfreq, save_model_dir)
    
    H, delta, fo, ho, error, fostat = unpack_stats(val_stats)
    loss, acc = unpack_stats(tr_stats)
    
    filename = f"stat_lr{lr:.0e}_nl{nlayer}_hid{hid_dim:03d}_skip{nskip}_bn{bn_code}"

    stat = {"H": H,
            "delta": delta,
            "fo": fo,
            "ho": ho,
            "error": error,
            "fostat": fostat,
            "loss": loss,
            "acc": acc,
            "lr": lr}
    
    with open(f'{datafolder}/{filename}.pkl','wb') as f:
        pickle.dump(stat,f)


In [None]:
import os
import glob

In [None]:
in_data = "data/*"
out_dir = "loss_acc"
init_dir(out_dir)

for fname in sorted(glob.iglob(in_data)):
    with open(fname,'rb') as f:
        stat = pickle.load(f)
    stat = {"loss": stat["loss"],
            "acc": stat["acc"]}
    with open(f'{out_dir}/{os.path.basename(fname)}','wb') as f:
        pickle.dump(stat,f)