In [None]:
import sys
import pandas as pd
sys.path.append("../src")
import temp
import utils
from networks import KnowledgeNet
from training import train_network, prune_network, check_network

In [None]:
from config import *
TRAIN_EPOCHS = 20

In [None]:
# Get the data.
if BUILD_MODE:
    X, y, classes = temp.get_dummy_data(FUNC, NOISE_SD)
else:
    X, y, classes = temp.get_data(DATA_FILE, OUTPUT_ACT)

In [None]:
X_train, X_test, y_train, y_test, train_dataset = temp.create_dataset(X, y, TEST_SIZE, BATCH_SIZE)

In [None]:
# Load the ontology.                                                               
input_id_map = utils.load_mapping(f"{EXP_DIR}/data/features.tsv")                        
dG, root, term_size_map, term_direct_input_map = utils.load_ontology(                    
        f"{EXP_DIR}/Data/ontology.txt",                                            
        input_id_map)

In [None]:
# Set the optimizer and load the model.                                                             
optimizer = tf.keras.optimizers.Adam(                                           
    learning_rate=0.001,                                                        
    beta_1=0.9,                                                                 
    beta_2=0.999,                                                               
    epsilon=1e-07,                                                              
    amsgrad=False,                                                              
    name='Adam')

model = KnowledgeNet(                                                              
        output_dim = OUTPUT_DIM,                                                   
        output_act = OUTPUT_ACT,                                                   
        module_act = MODULE_ACT,                                                
        input_act = INPUT_ACT,                                                  
        root=root,                                                              
        dG=dG,                                                                  
        module_neurons_func=MODULE_NEURONS_FUNC,                                
        input_dim=INPUT_DIM,                                                    
        term_direct_input_map=term_direct_input_map,                            
        mod_size_map=term_size_map,                                             
        initializer=WEIGHTS_INIT,                                               
        input_regularizer=INPUT_REG,                                            
        module_regularizer=MODULE_REG,                                          
        loss_fn=LOSS_FN,                                                        
        aux=AUX,                                                                
        batchnorm=BATCHNORM)

model.compile(optimizer=optimizer, loss=LOSS_FN)                                
model.build(input_shape = (BATCH_SIZE, INPUT_DIM))                              
model.summary()

In [None]:
# Train the model
model = train_network(                                                      
        model,                                                              
        train_dataset,                                                      
        train_epochs=TRAIN_EPOCHS,                                          
        optimizer=optimizer,                                                
        classification=CLASSIFICATION)

In [None]:
# Retrieve performance on train and test datasets after training model. 
train_metric, test_metric = temp.report_metrics(model, X_train, X_test, y_train, y_test, CLASSIFICATION)
print(f"Train metric: {train_metric}  |  Test metric: {test_metric}")

In [None]:
dG_current = model.dG
drop_cols = dict([(mod, []) for mod in model.mod_size_map.keys()]) 

In [None]:
# Prune the network weights.                                                   
prune_network(
    model, X_train, y_train, train_dataset, optimizer=optimizer,                              
    gl_pen1=GL_PEN1, l0_pen1=L0_PEN1,                                
    gl_pen2=GL_PEN2, l0_pen2=L0_PEN2,                                
    prune_epochs=1)

In [None]:
# Check the network structure.                                                 
model, dG_prune, drop_cols, sparsity = check_network(                          
        model, dG_current, drop_cols)
print(sparsity)

In [None]:
retrain = False
# Update the graphs and retrain (if RETRAIN) if the ontology has changed.   
if (dG_current.number_of_nodes() != dG_prune.number_of_nodes()              
    or dG_current.number_of_edges() != dG_prune.number_of_edges()):         
        update = True                                                       
        retrain = True                                                      
                                                                                
# Retrain the model (if UPDATE is True and RETRAIN is enabled).             
if retrain:                                                                 
    model = train_network(                                                  
            model,                                                          
            train_dataset,                                                  
            train_epochs=RETRAIN_EPOCHS,                                    
            optimizer=optimizer,                                            
            classification=CLASSIFICATION)