# Imports

In [6]:
import pickle as pkl
import time
import sys
import numpy as np

from bonsai.data_loaders import load_data
from bonsai.net import Net
from bonsai.trainers import *
from bonsai.helpers import *
from bonsai.ops import commons, Zero

%load_ext autoreload
%autoreload 2

mem_stats()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


'0.00B'

In [18]:
nas_schedule = {'learn_phase':16,
                'prune_phase':16,
                'prune_interval':4}
hypers = {
    'gpu_space':8.25,
    'dataset':'CIFAR10',
    'classes':10,
    'batch_size':64,
    'scale':5,
    'nodes':4,
    'patterns':[['n','n','na'],['r','n','n','na'],['r','n','na'],['r','n','na'],['r','n','na'],['r','na'],['n','na']],
    'half':False,
    'multiplier':1,
    'lr_schedule':
        {'lr_max': .01,
         'T': 600},
    'drop_prob':.25,
    'prune_rate':{'edge':.5,'input':.5}
}
data, dim = load_data(hypers['batch_size'], hypers['dataset'])
hypers['num_patterns']=get_n_patterns(hypers['patterns'], dim, target=5)
print(hypers['num_patterns'])

6


# Determine Height/Size Ratios
Check how a test model scales under the search params to ensure we don't overfill GPU

### check out a sample model 

In [19]:
print(sp_size_test(hypers['num_patterns']-1,e_c=.25,add_pattern=True,remove_prune=False,print_model=True,**hypers))

                     :     Dim      :    Params    :   Comp   
Initializer          :              :     160      :          
Cell 0  (Normal)     :   32 x 32    :    9,946     :   24.1%  
Cell 1  (Normal)     :   32 x 32    :    9,947     :   24.1%  
Cell 2  (Normal)     :   32 x 32    :    9,948     :   24.1%  
 ↳ Aux Tower         :              :   327,690    :          
Cell 3  (Reduction)  :   64 x 16    :    32,157    :   24.7%  
Cell 4  (Normal)     :   64 x 16    :    33,133    :   25.0%  
Cell 5  (Normal)     :   64 x 16    :    33,134    :   25.0%  
Cell 6  (Normal)     :   64 x 16    :    33,135    :   25.0%  
 ↳ Aux Tower         :              :   163,850    :          
Cell 7  (Reduction)  :  128 x 8     :   113,441    :   24.1%  
Cell 8  (Normal)     :  128 x 8     :   113,452    :   25.0%  
Cell 9  (Normal)     :  128 x 8     :   113,453    :   25.0%  
 ↳ Aux Tower         :              :    81,930    :          
Cell 10 (Reduction)  :  256 x 4     :   427,315    :   

### Get Sizing Ratios

In [20]:
sizes = {}
for n in range(1,hypers['num_patterns']):
    sizes[n]=[]
    remove_prune = False#(n==hypers['num_patterns']['final']-1)
    bst=BST(.2,1.)
    while bst.answer is None:
        print("{}: {:.3f}\r".format(n,bst.pos),end="")
        size = sp_size_test(n,e_c=bst.pos,add_pattern=True,remove_prune=remove_prune,**hypers)
        query = not (not size[1] and (size[0])<hypers['gpu_space'])
        bst.query(query)
    if bst.passes:
        sizes[n]=max(bst.passes)

if any([v for (k,v) in sizes.items() if v==1]):
    start_size = [k for (k,v) in sizes.items() if v==1][-1]+1
else:
    start_size = 1
print("Comp Ratios:",*["\n{}{}->{}: {:.3f}".format(" " if k!=start_size else "*",k,k+1,v) for (k,v) in sizes.items()])

Comp Ratios: 
 1->2: 1.000 
*2->3: 0.850 
 3->4: 0.700 
 4->5: 0.500 
 5->6: 0.375


# Search

## Model Setup

In [21]:
def jn_print(x,end="\n"):
    print(x,end=end)
    with open("logs/jn_out.log","a") as f:
        f.write(x+end)
          
# init model
model = Net(dim=dim, 
            classes=hypers['classes'], 
            scale=hypers['scale'],
            patterns=hypers['patterns'], 
            num_patterns=start_size,
            nodes=hypers['nodes'],
            drop_prob=hypers['drop_prob'],
            lr_schedule=hypers['lr_schedule'])
model.data = data
size, overflow = size_test(model, data)
print(model)
print("Est Size: {}{:.2f}GiB {}".format(">" if overflow else "", size, "(overflow)" if overflow else "")) 
if overflow:
    del model
    clean('Search init')    

Init: 26.00MiB
0: 1.90GiB
1: 3.62GiB
2: 5.34GiB
Tower 2: 5.34GiB
3: 5.46GiB
4: 5.97GiB
5: 6.78GiB
6: 7.60GiB
GP: 7.60GiB
Classifier: 7.60GiB
                     :     Dim      :    Params    :   Comp   
Initializer          :              :     160      :          
Cell 0  (Normal)     :   32 x 32    :    41,209    :  100.0%  
Cell 1  (Normal)     :   32 x 32    :    41,210    :  100.0%  
Cell 2  (Normal)     :   32 x 32    :    41,211    :  100.0%  
 ↳ Aux Tower         :              :   327,690    :          
Cell 3  (Reduction)  :   64 x 16    :   139,708    :  100.0%  
Cell 4  (Normal)     :   64 x 16    :   139,709    :  100.0%  
Cell 5  (Normal)     :   64 x 16    :   139,710    :  100.0%  
Cell 6  (Normal)     :   64 x 16    :   139,711    :  100.0%  
 ↳ Classifier        :              :   163,850    :          
Total                :              :  1,174,168   :  100.0%  

Est Size: 7.60GiB 


## Model Search

In [None]:
wipe_output()
search_start = time.time()

# search loop
for n in range(start_size,hypers['num_patterns']):
    print("===",n,"===")
    print(model)
    finish = False
    comp_ratio = sizes.get(n,0)
    aim = comp_ratio*.9 if comp_ratio>.35 else comp_ratio*.66
    jn_print("=== {} Patterns. Target Comp: {:.2f}, Aim: {:.2f}".format(n, comp_ratio,aim))

    for tries in range(1,10):
        # try initialization
        epochs = (nas_schedule['learn_phase']*(tries==1))+nas_schedule['prune_phase']
        comp_lambdas = {'transition': model.lr_scheduler.t+(nas_schedule['prune_phase']*(tries==1)),
                        'lambdas': {k:v*tries for k,v in hypers['prune_rate'].items()}}
        
        #learn+prune
        full_train(model, epochs, comp_lambdas=comp_lambdas, comp_ratio=aim, prune_interval=nas_schedule['prune_interval'])
        clean(verbose=False)
        hard_comp = model.genotype_compression()[1]
        if hard_comp and hard_comp > sizes[n]:
            jn_print("Try {}. Restarting pruning at pattern {}. Target comp: {:.2f}/{:.2f}, Actual: {:.3f}".format(tries,n,comp_ratio,aim,hard_comp))
        else:
            finish = True
            break

    if finish:
        if n != hypers['num_patterns']:
            print("Adding next pattern:",n+1)
            model.add_pattern()
    else: 
        print("No progress after 10 tries, aborting.")
        break

clean("Search End")
print("Search Time:",show_time(time.time()-search_start))
print(model)

=== 2 ===
                     :     Dim      :    Params    :   Comp   
Initializer          :              :     160      :          
Cell 0  (Normal)     :   32 x 32    :    41,209    :  100.0%  
Cell 1  (Normal)     :   32 x 32    :    41,210    :  100.0%  
Cell 2  (Normal)     :   32 x 32    :    41,211    :  100.0%  
 ↳ Aux Tower         :              :   327,690    :          
Cell 3  (Reduction)  :   64 x 16    :   139,708    :  100.0%  
Cell 4  (Normal)     :   64 x 16    :   139,709    :  100.0%  
Cell 5  (Normal)     :   64 x 16    :   139,710    :  100.0%  
Cell 6  (Normal)     :   64 x 16    :   139,711    :  100.0%  
 ↳ Classifier        :              :   163,850    :          
Total                :              :  1,174,168   :  100.0%  

=== 2 Patterns. Target Comp: 0.85, Aim: 0.77
=== Training Oppenheimer Essex Convair ===
Starting at 2019-12-05 13:30:49.860562
12/05/2019 01:30 PM
Init: 50.00MiB
0: 1.91GiB
1: 3.63GiB
2: 5.35GiB
Tower 2: 5.35GiB
3: 5.46GiB
4: 5.97GiB

# Train

In [None]:
full_train(model, epochs=model.lr_scheduler.remaining);
clean()

# Random Search

In [None]:
e_c, i_c = .25, 1

In [None]:
data, dim =load_data(hypers['batch_size'], hypers['dataset'])
model = Net(dim=dim, 
            classes=hypers['classes'], 
            scale=hypers['scale'],
            num_patterns=hypers['num_patterns'],
            patterns=hypers['patterns'],
            nodes=hypers['nodes'],
            random_ops={'e_c':e_c,'i_c':i_c}, 
            drop_prob=hypers['drop_prob'],
            lr_schedule=hypers['lr_schedule'],
            prune=False)
model.data = data
model.save_genotype()
print(model)
print(size_test(model, data))

In [None]:
full_train(model, hypers['lr_schedule']['T'])

# Scratch