In [None]:
# Import modules.
import sys
import pandas as pd
sys.path.append("../knowledge-net")
import utils
from networks import KnowledgeNet
from training import train_network, prune_network, check_network, report_metrics

In [None]:
# Import environment variables from configuration file.
from config import *

In [None]:
# Get the data.
if BUILD_MODE:
    X, y = utils.generate_data(
            FUNC, NOISE_SD, DATA_SIZE, INPUT_DIM, LOWER, UPPER)
else:
    X, y, classes = temp.get_data(DATA_FILE, OUTPUT_ACT)

In [None]:
# Load the data into Tensorflow datasets.
train_dataset, test_dataset = utils.create_dataset(X, y, TEST_SIZE, BATCH_SIZE)

In [None]:
# Load the ontology.                                                               
feature_id_map = utils.load_mapping(f"{EXP_DIR}/data/features.tsv")                        

dG, root, term_size_map, term_direct_feature_map = utils.load_ontology(                    
        f"{EXP_DIR}/Data/ontology.tsv",                                            
        feature_id_map)

In [None]:
# Set the optimizer and load the model.                                                             
optimizer = tf.keras.optimizers.Adam(                                           
    learning_rate=0.001,                                                        
    beta_1=0.9,                                                                 
    beta_2=0.999,                                                               
    epsilon=1e-07,                                                              
    amsgrad=False,                                                              
    name='Adam')

model = KnowledgeNet(                                                              
        output_dim = OUTPUT_DIM,                                                   
        output_act = OUTPUT_ACT,                                                   
        module_act = MODULE_ACT,                                                
        input_act = INPUT_ACT,                                                  
        root=root,                                                              
        dG=dG,                                                                  
        module_neurons_func=MODULE_NEURONS_FUNC,                                
        input_dim=INPUT_DIM,                                                    
        term_direct_input_map=term_direct_feature_map,                            
        mod_size_map=term_size_map,                                             
        initializer=WEIGHTS_INIT,                                               
        input_regularizer=INPUT_REG,                                            
        module_regularizer=MODULE_REG,                                          
        loss_fn=LOSS_FN,                                                        
        aux=AUX,                                                                
        batchnorm=BATCHNORM)

model.compile(optimizer=optimizer, loss=LOSS_FN)                                
model.build(input_shape = (BATCH_SIZE, INPUT_DIM))                              
model.summary()

In [None]:
# Train the model for a specified number of training epochs.
train_network(                                                      
        model, train_dataset, epochs=TRAIN_EPOCHS, optimizer=optimizer, classification=CLASSIFICATION)

In [None]:
# Retrieve performance on train and test datasets after training model. 
train_loss, train_acc, test_loss, test_acc, sparsity, drop_cols, dG_prune = report_metrics(
        model, train_dataset, test_dataset, optimizer, CLASSIFICATION)
print(f"Train loss: {train_loss:.3f}\tTrain accuracy: {train_acc}")
print(f"Test loss: {test_loss:.3f}\tTest accuracy: {test_acc}")
print(f"Sparsity: {sparsity:.3f}")

In [None]:
# Prune the network weights.                                                   
prune_network(
    model, train_dataset, prune_epochs=1,                            
    gl_pen1=GL_PEN1, l0_pen1=L0_PEN1,                                
    gl_pen2=GL_PEN2, l0_pen2=L0_PEN2)

In [None]:
# Update the graphs and retrain if the ontology has changed.   
if (dG_current.number_of_nodes() != dG_prune.number_of_nodes()              
    or dG_current.number_of_edges() != dG_prune.number_of_edges()):         
        update = True                                                       
        retrain = True                                                                                                                               