# Use a pre-trained neural network to reproduce published results

Paper: *"Fine-grained TLS services classification with reject option"*

In [None]:
import sys
!{sys.executable} -m pip install "torch>=1.10" --index-url https://download.pytorch.org/whl/cu118
!{sys.executable} -m pip install cesnet_datazoo cesnet_models tqdm

Load a pre-trained neural network. We use the mm-CESNET-v1, which is the first version of the multi-modal CESNET architecture. The selected weights were trained on the 40th week of the CESNET-TLS22 dataset.

In [1]:
from cesnet_models.models import MM_CESNET_V1_Weights, mm_cesnet_v1

pretrained_weights = MM_CESNET_V1_Weights.CESNET_TLS22_WEEK40
model = mm_cesnet_v1(weights=pretrained_weights, model_dir="models/")
model.eval();

Downloading: "https://liberouter.org/datazoo/download?bucket=cesnet-models&file=mmv1_CESNET_TLS22_Week40.pth" to models/mmv1_CESNET_TLS22_Week40.pth
100%|██████████| 4.70M/4.70M [00:00<00:00, 6.86MB/s]


Download and initialize a dataset class of the CESNET-TLS22 dataset.

Prepare dataset configuration:

- Select test period. Samples from this week will be used to test the model.
- Use data transforms provided in the pre-trained model.
- Select the same application classes on which the model was trained.

In [2]:
import logging
from cesnet_datazoo.config import AppSelection, DatasetConfig
from cesnet_datazoo.datasets import CESNET_TLS22

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s")

DATASET_SIZE = "XS" # Use the "ORIG" size for exact reproduction of the results
dataset = CESNET_TLS22(data_root="data/CESNET-TLS22/", size=DATASET_SIZE)

dataset_config = DatasetConfig(
    dataset=dataset,
    test_period_name="W-2021-41",
    test_workers=2,
    test_batch_size=16384,
    ppi_transform=pretrained_weights.transforms["ppi_transform"],
    flowstats_transform=pretrained_weights.transforms["flowstats_transform"],
    flowstats_phist_transform=pretrained_weights.transforms["flowstats_phist_transform"],
    use_packet_histograms=pretrained_weights.meta["use_packet_histograms"],
    use_tcp_features=pretrained_weights.meta["use_tcp_features"],
    apps_selection=AppSelection.FIXED,
    apps_selection_fixed_known=pretrained_weights.meta["classes"],
    need_train_set=False,
    need_val_set=False,
    return_tensors=True,)

# Check that the model expects the same flow statistics features and PPI features as provided in the current configuration of the dataset
assert dataset_config.get_flowstats_feature_names_expanded() == pretrained_weights.meta["flowstats_features"]
assert len(dataset_config.get_ppi_channels()) == pretrained_weights.meta["ppi_input_channels"]

dataset.set_dataset_config_and_initialize(dataset_config)
test_dataloader = dataset.get_test_dataloader()

[2024-04-08 12:20:44,569][cesnet_datazoo.pytables_data.indices_setup][INFO] - Processing test indices
[2024-04-08 12:20:45,170][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20211011 took 0.54 seconds
[2024-04-08 12:20:45,757][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20211012 took 0.59 seconds
[2024-04-08 12:20:46,197][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20211013 took 0.44 seconds
[2024-04-08 12:20:46,738][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20211014 took 0.54 seconds
[2024-04-08 12:20:47,245][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20211015 took 0.51 seconds
[2024-04-08 12:20:47,466][cesnet_datazoo.pytables_data.pytables_dataset][INFO] - Reading app column for table /flows/D20211016 took 0.22 seconds
[2024-04-08 12:20:47,701][ce

Iterate over the test dataloader and use the model to compute predictions. Use a GPU if available.

In [3]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

def compute_model_predictions(model: nn.Module, test_dataloader: DataLoader, device) -> tuple[np. ndarray, np.ndarray]:
    model.eval()
    test_labels = []
    preds = []
    with torch.no_grad():
        for _, batch_ppi, batch_flowstats, batch_labels in tqdm(test_dataloader, total=len(test_dataloader)):
            batch_ppi, batch_flowstats, batch_labels = batch_ppi.to(device), batch_flowstats.to(device), batch_labels.to(device)
            out = model((batch_ppi, batch_flowstats))
            batch_preds = out.argmax(dim=1)
            test_labels.append(batch_labels)
            preds.append(batch_preds)
    test_labels, preds = torch.cat(test_labels).cpu().numpy(), torch.cat(preds).cpu().numpy()
    return test_labels, preds

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
test_labels, preds = compute_model_predictions(model, test_dataloader, device=device)

100%|██████████| 295/295 [00:33<00:00,  8.93it/s]


Finally, compute the classification accuracy.

In [4]:
import numpy as np
from sklearn.metrics import accuracy_score, recall_score

accuracy = accuracy_score(test_labels, preds)
recall = recall_score(test_labels, preds, average="macro", zero_division=np.nan) # type: ignore
print(f"The pre-trained model achieved an accuracy of {accuracy:.5f} and a recall of {recall:.5f} on the test period {dataset_config.test_period_name} of the {dataset.name} dataset.")

The pre-trained model achieved an accuracy of 0.97087 and a recall of 0.94707 on the test period W-2021-41 of the CESNET-TLS22-XS dataset.
