In [9]:
import os
import sys
sys.path.append(os.path.join(os.getcwd().replace("notebooks/distilled_models", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from config import Config
from data.dataset import NetworkDataset, load_datasets
from model_config import MLP_Models
from compact.knowledge_distillation import KnowledgeDistillation

In [10]:
conf = Config()
load_models = MLP_Models()
teacher_model = load_models.mlp_4
teacher_model.load()
student_model = load_models.light_mlp_4
kd = KnowledgeDistillation(
    teacher=teacher_model,
    student=student_model,
    device=conf.device,
    distillation="feature"
)

Checkpoint loaded from /global/D1/homes/jorgetf/Network-Packet-ML-Model/checkpoint/large_model/mlp_4.pth!


In [11]:
X_train, y_train, X_val, y_val, X_test, y_test = load_datasets(conf.datasets, "mlp")

# create train, val and test dataloaders
train_dataset = NetworkDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, conf.batch_size, shuffle=True)

val_dataset = NetworkDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, conf.batch_size, shuffle=True)

test_dataset = NetworkDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, conf.batch_size)

In [12]:
data, labels = next(iter(train_loader))
print(data.shape, labels.shape)

torch.Size([512, 513]) torch.Size([512])


In [13]:
acc, train_loss, val_loss = kd.train_kd(train_loader, val_loader, conf.epochs)

Epoch: 1/20, Accuracy: 44.93%, Train loss: 12879899.6741, Val loss: 2.3862
Epoch: 2/20, Accuracy: 44.67%, Train loss: 705009.7769, Val loss: 1.9141
Epoch: 3/20, Accuracy: 55.62%, Train loss: 424349.9346, Val loss: 1.2561
Epoch: 4/20, Accuracy: 56.31%, Train loss: 322949.1099, Val loss: 1.0311
Epoch: 5/20, Accuracy: 61.93%, Train loss: 286476.9748, Val loss: 0.9207
Epoch: 6/20, Accuracy: 66.66%, Train loss: 255789.1848, Val loss: 0.8550
Epoch: 7/20, Accuracy: 67.23%, Train loss: 219582.9649, Val loss: 0.8192
Epoch: 8/20, Accuracy: 66.61%, Train loss: 218549.9248, Val loss: 0.8063
Epoch: 9/20, Accuracy: 68.58%, Train loss: 220204.1114, Val loss: 0.7548
Epoch: 10/20, Accuracy: 67.24%, Train loss: 205696.5419, Val loss: 0.7731
Epoch: 11/20, Accuracy: 69.09%, Train loss: 200044.4829, Val loss: 0.7285
Epoch: 12/20, Accuracy: 69.60%, Train loss: 224953.0791, Val loss: 0.7158
Epoch: 13/20, Accuracy: 69.76%, Train loss: 197587.3993, Val loss: 0.6842
Epoch: 14/20, Accuracy: 70.29%, Train loss: 2

In [14]:
# test teacher accuracy
loss, accuracy = teacher_model.evaluate(test_loader)
print(f"Teacher accuracy: {100*accuracy:.2f}%")

Teacher accuracy: 91.51%


In [15]:
# test student accuracy
loss, accuracy = student_model.evaluate(test_loader)
print(f"Student accuracy: {100*accuracy:.2f}%")

Student accuracy: 71.04%


In [17]:
student_model.save()

Checkpoint saved at /global/D1/homes/jorgetf/Network-Packet-ML-Model/checkpoint/compressed_model/light_mlp_4.pth
