In [1]:
from graph2net.trainers import gen_and_validate,generate,model_validate,full_model_run
from graph2net.data_loaders import *
from graph2net.graph_generators import *
from graph2net.archetypes import inception
from graph2net.helpers import *
from graph2net.notifier import notify_me
import logging
import numpy as np
import pandas as pd
import pickle as pkl
import time
import os
import random
import matplotlib.pyplot as plt


from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(filename='logs/model_testbed.log', level=logging.INFO)

%load_ext autoreload
%autoreload 2

In [2]:
data = load_data(batch_size=256)

In [3]:
lr_schedule512 = {
    'type': 'cosine',
    'lr_min': .00001,
    'lr_max': .001,
    't_0': 1,
    't_mult': 2
}

In [9]:
def micro_net(cell):
    print("== micro ==")
    t_start = time.time()
    model = gen_and_validate([cell],data,scale=2,cell_types=[1,1])
    loss, correct, preds,acc_preds,confs  = full_model_run(model,
                                                  data=data,
                                                  epochs=8,
                                                  lr=.01,
                                                  momentum=.9,
                                                  weight_decay=1e-4,
                                                  lr_schedule=lr_schedule512,
                                                  drop_path=False,
                                                  log=True,
                                                  track_progress=True,
                                                  prefix="Micro",
                                                  verbose=False)
    return correct,time.time()-t_start

def macro_net(cell):
    print("\n== macro ==")
    t_start = time.time()
    if 0:#np.random.rand()>.5:
        print("Parallel")
        model = gen_and_validate([cell,cell],data,scale=5,cell_types=[1,0,0,1,0,0,1,0,0,1,0,0,1])
    else:
        print("Large-Scale")
        model = gen_and_validate([cell],data,scale=5,cell_types=[1,0,0,1,0,0,1,0,0,1,0,0,1])
    if not model:
        return np.zeros(512),0
    else:
        try:
            loss, correct, preds,acc_preds,confs = full_model_run(model,
                                                      data=data,
                                                      epochs=512,
                                                      lr=.01,
                                                      momentum=.9,
                                                      weight_decay=1e-4,
                                                      lr_schedule=lr_schedule512,
                                                      drop_path=True,
                                                      log=True,
                                                      track_progress=True,
                                                      prefix="Macro",
                                                      verbose=False)
        except MemoryError as e:
            print("\n{}".format(e))
            return np.zeros(512),0
    
        return correct,time.time()-t_start

In [16]:
load=False
if load:
    results = pkl.load(open('micro_macro_results.pkl','rb'))
else:
    results = []

try:
    i = 0
    while 1:
        print("\n===== CELL {} ===== ".format(i))
        cell = gen_cell(np.random.randint(5,15),connectivity=np.random.uniform(.1,.7),concat=.5)
        macro_scores,macro_time = macro_net(cell)
        if max(macro_scores)>1000:
            micro_scores,micro_time = micro_net(cell)        
            notify_me("Test {} successful, moving on. Micro: {}, Macro: {}".format(i,max(micro_scores),max(macro_scores)))
            i+=1
except KeyboardInterrupt as e:
    raise e
except Exception as e:
    notify_me("Micro/Macro test errored. {}".format(str(e)))
    raise e


===== CELL 0 ===== 

== macro ==
Large-Scale
539,530 params
Model outside size boundaries, terminating.

===== CELL 0 ===== 

== macro ==
Large-Scale
1,341,002 params
Model outside size boundaries, terminating.

===== CELL 0 ===== 

== macro ==
Large-Scale
2,299,178 params
Model outside size boundaries, terminating.

===== CELL 0 ===== 

== macro ==
Large-Scale
5,820,938 params
Validating model...
GPU Memory past safe thresh: 6.88 GB, 1.72 GB/layer at layer 4

===== CELL 0 ===== 

== macro ==
Large-Scale
9,975,818 params
Validating model...
GPU Memory past safe thresh: 7.25 GB, 1.45 GB/layer at layer 5

===== CELL 0 ===== 

== macro ==
Large-Scale
20,511,370 params
Validating model...
GPU Memory past safe thresh: 4.37 GB, 4.37 GB/layer at layer 1

===== CELL 0 ===== 

== macro ==
Large-Scale
28,239,178 params
Model outside size boundaries, terminating.

===== CELL 0 ===== 

== macro ==
Large-Scale
17,045,546 params
Validating model...
GPU Memory past safe thresh: 6.63 GB, 2.21 GB/laye

KeyboardInterrupt: 

In [None]:
cell = gen_cell(np.random.randint(5,15),connectivity=np.random.uniform(.2,.7),concat=.5)
micro_scores,micro_time = micro_net(cell)        

In [25]:
cell = np.array(
    [[ 0.,  1.,  1., 13., 13.,  1.,  1., 12., 15.,  0.,  0.],
       [ 1.,  0.,  0.,  0.,  4.,  0.,  0.,  0.,  0.,  0.,  5.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 11.,  0.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  7.,  0.,  0., 15.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  9.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.]])
micro_net(cell)  

== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 11.33 s
	64 epochs : 12 min, 5 s
	128 epochs: 24 min, 10 s
	512 epochs: 1 hrs, 36 mins, 40 s
Number of parameters: 11,962
Micro   8/8: 43.56%, Predicted: 4360 (±36.26))

(array([1411, 3437, 3818, 3927, 4131, 4343, 4266, 4356]), 105.285315990448)