Test Models

In [48]:
from model.eval import confusion_matrix, get_classification_report, load_model
from preprocessing.ENV import ID_TO_APP, ID_TO_TRAFFIC
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
application_model_path = 'model/application_classification.cnn.model'
traffic_model_path = 'model/traffic_classification.cnn.model'

# test data path
application_test_data= 'train_test_data/application_classification/test.parquet'
traffic_test_data= 'train_test_data/traffic_classification/test.parquet'

In [None]:
application_model = load_model(model_path=application_model_path, gpu=True)
traffic_model = load_model(model_path=traffic_model_path, gpu=True)

In [None]:
app_conf_matrix = confusion_matrix(
    data_path=application_test_data,
    model=application_model,
    num_class=len(ID_TO_APP)
)

app_labels = []
for i in sorted(list(ID_TO_APP.keys())):
    app_labels.append(ID_TO_APP[i])

get_classification_report(app_conf_matrix, app_labels)

In [None]:
with np.errstate(all='ignore'):
    app_conf_matrix = app_conf_matrix / app_conf_matrix.sum(axis=1, keepdims=True)
    app_conf_matrix = np.nan_to_num(app_conf_matrix)

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
sns.heatmap(
    data=app_conf_matrix, cmap='YlGnBu',
    xticklabels=app_labels, yticklabels=app_labels,
    annot=True, ax=ax, fmt='.2f'
)
ax.set_xlabel('Predict labels')
ax.set_ylabel('True labels')
fig.show()

In [None]:
traffic_conf_matrix = confusion_matrix(
    data_path=traffic_test_data,
    model=traffic_model,
    class_num=len(ID_TO_TRAFFIC)
)

traffic_labels = []
for i in sorted(list(ID_TO_TRAFFIC.keys())):
    traffic_labels.append(ID_TO_TRAFFIC[i])

get_classification_report(traffic_conf_matrix, traffic_labels)

In [None]:
with np.errstate(all='ignore'):
    traffic_conf_matrix = traffic_conf_matrix / traffic_conf_matrix.sum(axis=1, keepdims=True)
    traffic_conf_matrix = np.nan_to_num(traffic_conf_matrix)

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
sns.heatmap(
    data=traffic_conf_matrix, cmap='YlGnBu',
    xticklabels=traffic_labels, yticklabels=traffic_labels,
    annot=True, ax=ax, fmt='.2f'
)
ax.set_xlabel('Predict labels')
ax.set_ylabel('True labels')
fig.show()