In [1]:
from graph2net.trainers import gen_and_validate,generate,model_validate,full_model_run
from graph2net.data_loaders import *
from graph2net.graph_generators import *
from graph2net.archetypes import resNeXt
from graph2net.helpers import *
from graph2net.notifier import notify_me
import logging
import numpy as np
import pandas as pd
import pickle as pkl
import time
import os

import random
import matplotlib.pyplot as plt


from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(filename='logs/model_testbed.log', level=logging.INFO)

%load_ext autoreload
%autoreload 2

In [2]:
data = load_data(batch_size=256)

In [3]:
lr_schedule512 = {
    'type': 'cosine',
    'lr_min': 1e-9,
    'lr_max': 1e-2,
    't_0': 1,
    't_mult': 2
}

In [21]:
def micro_net(cell):
    print("== micro ==")
    t_start = time.time()
    model, valid, reason = gen_and_validate([cell],data,scale=2,cell_types=[1,1])
    if valid:
        loss, correct, preds,acc_preds,confs  = full_model_run(model,
                                                      data=data,
                                                      epochs=8,
                                                      lr=.01,
                                                      momentum=.9,
                                                      weight_decay=1e-4,
                                                      lr_schedule=lr_schedule512,
                                                      drop_path=False,
                                                      log=True,
                                                      track_progress=True,
                                                      prefix="Micro",
                                                      verbose=False)
        del model
        torch.cuda.empty_cache()
        return correct,time.time()-t_start
    else:
        del model
        torch.cuda.empty_cache()
        return np.zeros(8),time.time()-t_start

def macro_net(cell,verbose=False):
    print("\n== macro ==")
    t_start = time.time()
    
    print("Getting max model size: ",end="")
    model_hyper = max_model_size(cell,data)
    print("Scale: {}, Spacing: {}, Parallel: {}, Params: {:,}".format(*model_hyper))
    scale,spacing,parallel,params = model_hyper
    cell_list = [cell,cell] if parallel else [cell]
    model,valid,reason = gen_and_validate(cell_list,data,scale=scale,cell_types=cell_space(5,spacing),auxillaries=[2])
    
    if not valid:
        del model
        torch.cuda.empty_cache()
        return np.zeros(512),0
    else:
        try:
            loss, correct, preds,acc_preds,confs = full_model_run(model,
                                                      data=data,
                                                      epochs=512,
                                                      lr=.01,
                                                      momentum=.9,
                                                      weight_decay=1e-4,
                                                      lr_schedule=lr_schedule512,
                                                      drop_path=True,
                                                      log=True,
                                                      track_progress=True,
                                                      prefix="Macro",
                                                      verbose=verbose)
        except Exception as e:
            print("\n{}".format(e))
            del model
            torch.cuda.empty_cache()
            return np.zeros(512),0
        
        del model
        torch.cuda.empty_cache()
    
        return correct,time.time()-t_start

In [22]:
try:
    i = 0
    while 1:
        print("\n===== CELL {} ===== ".format(i))
        #cell = gen_cell(np.random.randint(3,10),connectivity=np.random.uniform(.3,.7),concat=.5)        
        #micro_scores,micro_time = micro_net(cell)     
        macro_scores,macro_time = macro_net(cell)
        notify_me("Test {} successful, moving on. Micro: {}, Macro: {}".format(i,max(micro_scores),max(macro_scores)))
        i+=1
except KeyboardInterrupt as e:
    raise e
except Exception as e:
    notify_me("Micro/Macro test errored. {}".format(str(e)))
    raise e


===== CELL 0 ===== 

== macro ==
Getting max model size: [[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.

[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1. 

[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1. 

[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1.  1.  0.  0.  0.]
 [ 1.  0.  0. 12. 14.  0.]
 [ 1.  0.  0.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  1.]] 0
[[ 0.  1. 

Validating model...[SUCCESS]
Estimated Timelines:
	1 epoch   : 1 min, 32 s
	64 epochs : 1 hrs, 38 mins, 21 s
	128 epochs: 3 hrs, 16 mins, 43 s
	512 epochs: 13 hrs, 6 mins, 53 s
Number of parameters: 3,233,546


KeyboardInterrupt: 