In [1]:
from graph2net.trainers import gen_and_validate,generate,model_validate,full_model_run
from graph2net.data_loaders import *
from graph2net.graph_generators import *
from graph2net.archetypes import resNeXt
from graph2net.helpers import *
from graph2net.notifier import notify_me
import logging
import numpy as np
import pandas as pd
import pickle as pkl
import time
import os

import random
import matplotlib.pyplot as plt


from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(filename='logs/model_testbed.log', level=logging.INFO)

%load_ext autoreload
%autoreload 2

In [2]:
data = load_data(batch_size=256)

In [3]:
lr_schedule512 = {
    'type': 'cosine',
    'lr_min': 1e-9,
    'lr_max': 1e-2,
    't_0': 1,
    't_mult': 2
}

In [25]:
def micro_net5(cell):
    print("== micro ==")
    t_start = time.time()
    model, valid, reason = gen_and_validate([cell],data,scale=2,cell_types=[1,1,1,1,1])
    if valid:
        loss, correct, preds,acc_preds,confs  = full_model_run(model,
                                                      data=data,
                                                      epochs=8,
                                                      lr=.01,
                                                      momentum=.9,
                                                      weight_decay=1e-4,
                                                      lr_schedule=lr_schedule512,
                                                      drop_path=False,
                                                      log=True,
                                                      track_progress=True,
                                                      prefix="Micro5",
                                                      verbose=False)
        del model
        torch.cuda.empty_cache()
        return correct,time.time()-t_start
    else:
        del model
        torch.cuda.empty_cache()
        return np.zeros(8),time.time()-t_start

def micro_net(cell):
    print("== micro ==")
    t_start = time.time()
    model, valid, reason = gen_and_validate([cell],data,scale=2,cell_types=[1,1])
    if valid:
        loss, correct, preds,acc_preds,confs  = full_model_run(model,
                                                      data=data,
                                                      epochs=32,
                                                      lr=.01,
                                                      momentum=.9,
                                                      weight_decay=1e-4,
                                                      lr_schedule=lr_schedule512,
                                                      drop_path=False,
                                                      log=True,
                                                      track_progress=True,
                                                      prefix="Micro",
                                                      verbose=False)
        del model
        torch.cuda.empty_cache()
        return correct,time.time()-t_start
    else:
        del model
        torch.cuda.empty_cache()
        return np.zeros(8),time.time()-t_start
    
def macro_net(cell,epochs,verbose=False):
    print("\n== macro ==")
    t_start = time.time()
    
    print("Getting max model size: ",end="")
    model_hyper = max_model_size(cell,data)
    print("Scale: {}, Spacing: {}, Parallel: {}, Params: {:,}".format(*model_hyper))
    scale,spacing,parallel,params = model_hyper
    cell_list = [cell,cell] if parallel else [cell]
    model,valid,reason = gen_and_validate(cell_list,data,scale=scale,cell_types=cell_space(5,spacing),auxillaries=[2])
    
    if not valid:
        del model
        torch.cuda.empty_cache()
        return np.zeros(512),0
    else:
        try:
            loss, correct, preds,acc_preds,confs = full_model_run(model,
                                                      data=data,
                                                      epochs=epochs,
                                                      lr=.01,
                                                      momentum=.9,
                                                      weight_decay=1e-4,
                                                      lr_schedule=lr_schedule512,
                                                      drop_path=True,
                                                      log=True,
                                                      track_progress=True,
                                                      prefix="Macro",
                                                      verbose=verbose)
        except Exception as e:
            print("\n{}".format(e))
            del model
            torch.cuda.empty_cache()
            return np.zeros(512),0
        
        del model
        torch.cuda.empty_cache()
    
        return correct,time.time()-t_start

In [26]:
cells = pkl.load(open('pickle_jar/all_macro_cells.pkl','rb'))
cells = [x[0] if len(x)<3 else np.array(x) for x in cells]

micro_times,micro5_times = [],[]
for i,cell in enumerate(cells):
    print("{:>3} of {:<3}".format(i,len(cells)))
    try:
        _,micro_time = micro_net(cell)
        _,micro5_time = micro_net5(cell)
        micro_times.append(micro_time)
        micro5_times.append(micro5_time)
        print("Mean Micro Time:",np.mean(micro_times))
        print("Mean Micro5 Time:",np.mean(micro5_times))
    except KeyboardInterrupt as e:
        raise e
    except Exception as e:
        if "Input Mismatch for Summation Node" in str(e):
            print("Skipping broken cell")
        else:
            raise e
        
        

  0 of 56 
== micro ==
Skipping broken cell
  1 of 56 
== micro ==
Skipping broken cell
  2 of 56 
== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 11.68 s
	64 epochs : 12 min, 27 s
	128 epochs: 24 min, 54 s
	512 epochs: 1 hrs, 39 mins, 38 s
Number of parameters: 6,602
== micro ==2: 57.75%, Predicted: 5775 (±   0)
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 15.88 s
	64 epochs : 16 min, 56 s
	128 epochs: 33 min, 52 s
	512 epochs: 2 hrs, 15 mins, 31 s
Number of parameters: 39,082
Mean Micro Time: 317.3352653980255096 (± 478))
Mean Micro5 Time: 114.67548251152039
  3 of 56 
== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 10.50 s
	64 epochs : 11 min, 12 s
	128 epochs: 22 min, 24 s
	512 epochs: 1 hrs, 29 mins, 37 s
Number of parameters: 5,510
== micro ==2: 52.75%, Predicted: 5294 (±   0)
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 15.32 s
	64 epochs : 16 min, 20 s
	128 epochs: 32 min, 40 s
	512 epochs

Mean Micro Time: 285.3238915354013466 (± 478))
Mean Micro5 Time: 90.59956376254559
 18 of 56 
== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 9.68 s
	64 epochs : 10 min, 19 s
	128 epochs: 20 min, 38 s
	512 epochs: 1 hrs, 22 mins, 35 s
Number of parameters: 5,866
== micro ==2: 56.23%, Predicted: 5623 (±   0)
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 16.17 s
	64 epochs : 17 min, 15 s
	128 epochs: 34 min, 30 s
	512 epochs: 2 hrs, 18 mins, 0 s
Number of parameters: 18,746
Mean Micro Time: 288.0440858251908507 (± 478))
Mean Micro5 Time: 91.71087566544028
 19 of 56 
== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 12.19 s
	64 epochs : 13 min, 0 s
	128 epochs: 26 min, 0 s
	512 epochs: 1 hrs, 44 mins, 1 s
Number of parameters: 6,282
== micro ==2: 55.23%, Predicted: 5523 (±   0)
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 17.91 s
	64 epochs : 19 min, 6 s
	128 epochs: 38 min, 12 s
	512 epochs: 2 hrs, 32 

== micro ==2: 69.44%, Predicted: 6947 (±   0)
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 1 min, 13 s
	64 epochs : 1 hrs, 18 mins, 52 s
	128 epochs: 2 hrs, 37 mins, 44 s
	512 epochs: 10 hrs, 30 mins, 56 s
Number of parameters: 606,382
Mean Micro Time: 383.57322810590270434 (± 478)
Mean Micro5 Time: 129.95137609541416
 34 of 56 
== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 8.10 s
	64 epochs : 8 min, 38 s
	128 epochs: 17 min, 16 s
	512 epochs: 1 hrs, 9 mins, 5 s
Number of parameters: 5,518
== micro ==2: 55.77%, Predicted: 5587 (±   0)
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 10.36 s
	64 epochs : 11 min, 3 s
	128 epochs: 22 min, 6 s
	512 epochs: 1 hrs, 28 mins, 25 s
Number of parameters: 25,566
Mean Micro Time: 379.8944504333265835 (± 478))
Mean Micro5 Time: 128.45925256700227
 35 of 56 
== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 8.15 s
	64 epochs : 8 min, 41 s
	128 epochs: 17 min, 22 s


== micro ==2: 63.60%, Predicted: 6360 (±   0)
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 19.61 s
	64 epochs : 20 min, 55 s
	128 epochs: 41 min, 50 s
	512 epochs: 2 hrs, 47 mins, 21 s
Number of parameters: 137,290
Mean Micro Time: 365.6322070111831262 (± 478))
Mean Micro5 Time: 124.83369986712933
 50 of 56 
== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 9.05 s
	64 epochs : 9 min, 39 s
	128 epochs: 19 min, 18 s
	512 epochs: 1 hrs, 17 mins, 12 s
Number of parameters: 6,242
== micro ==2: 54.53%, Predicted: 5473 (±   0)
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 13.25 s
	64 epochs : 14 min, 7 s
	128 epochs: 28 min, 15 s
	512 epochs: 1 hrs, 53 mins, 2 s
Number of parameters: 21,474
Mean Micro Time: 364.3864685710595466 (± 478))
Mean Micro5 Time: 124.4069944595804
 51 of 56 
== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 11.09 s
	64 epochs : 11 min, 49 s
	128 epochs: 23 min, 39 s
	512 epochs: 1 hrs

In [29]:
try:
    i = 0
    while 1:
        print("\n===== CELL {} ===== ".format(i))
        cell = gen_cell(np.random.randint(3,10),connectivity=np.random.uniform(.3,.7),concat=.5)        
        micro_scores,micro_time = micro_net(cell)
        micro5_scores,micro5_time = micro_net5(cell)     
        macro_scores,macro_time = macro_net(cell,epochs=128)
        notify_me("Test {} successful, moving on. Micro: {}, Macro: {}".format(i,max(micro_scores),max(macro_scores)))
        i+=1
except KeyboardInterrupt as e:
    raise e
except Exception as e:
    notify_me("Micro/Macro test errored. {}".format(str(e)))
    raise e


===== CELL 0 ===== 
== micro ==
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 8.43 s
	64 epochs : 8 min, 59 s
	128 epochs: 17 min, 59 s
	512 epochs: 1 hrs, 11 mins, 56 s
Number of parameters: 8,322
== micro ==2: 63.79%, Predicted: 6379 (±   0)
Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 12.27 s
	64 epochs : 13 min, 5 s
	128 epochs: 26 min, 10 s
	512 epochs: 1 hrs, 44 mins, 42 s
Number of parameters: 177,218
Micro5   8/8: 59.74%, Predicted: 9446 (± 478))
== macro ==
Getting max model size: 

AttributeError: 'bool' object has no attribute 'get_num_params'

In [None]:
cell= np.array([[ 0.,  1.,  1.,  0.],
       [ 1.,  0.,  0., 12.],
       [ 1.,  0.,  0., 12.],
       [ 0.,  0.,  0.,  1.]])
model_hyper = max_model_size(cell,data)
model_hyper