In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd().replace("distilled_models", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from utils.config import Config
from utils.parse_dataset import NetworkDataset, load_datasets
from utils.load_models import GRU_Models
from utils.knowledge_distillation import KnowledgeDistillation

In [2]:
conf = Config()
load_models = GRU_Models()
teacher_model = load_models.gru_4
teacher_model.load()
student_model = load_models.light_gru_1
kd = KnowledgeDistillation(
    teacher=teacher_model,
    student=student_model,
    device=conf.device,
    distillation="relation",
    rnn=True
)

Checkpoint loaded from /global/D1/homes/jorgetf/Network-Packet-ML-Model/large_models/checkpoint/gru_4.pth!


In [3]:
X_train, y_train, X_val, y_val, X_test, y_test = load_datasets(conf.datasets, "lstm")

# create train, val and test dataloaders
train_dataset = NetworkDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, conf.batch_size, shuffle=True)

val_dataset = NetworkDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, conf.batch_size, shuffle=True)

test_dataset = NetworkDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, conf.batch_size)

In [4]:
data, labels = next(iter(train_loader))
print(data.shape, labels.shape)

torch.Size([512, 513, 1]) torch.Size([512])


In [5]:
acc, train_loss, val_loss = kd.train_kd(train_loader, val_loader, conf.epochs)

Epoch: 1/10, Accuracy: 76.49%, Train loss: 0.0077, Val loss: 0.5137
Epoch: 2/10, Accuracy: 85.43%, Train loss: 0.0036, Val loss: 0.2496
Epoch: 3/10, Accuracy: 89.88%, Train loss: 0.0024, Val loss: 0.1718
Epoch: 4/10, Accuracy: 90.99%, Train loss: 0.0019, Val loss: 0.1450
Epoch: 5/10, Accuracy: 91.73%, Train loss: 0.0016, Val loss: 0.1311
Epoch: 6/10, Accuracy: 92.40%, Train loss: 0.0015, Val loss: 0.1206
Epoch: 7/10, Accuracy: 92.57%, Train loss: 0.0014, Val loss: 0.1149
Epoch: 8/10, Accuracy: 92.82%, Train loss: 0.0013, Val loss: 0.1119
Epoch: 9/10, Accuracy: 93.09%, Train loss: 0.0012, Val loss: 0.1076
Epoch: 10/10, Accuracy: 93.39%, Train loss: 0.0012, Val loss: 0.1050


In [6]:
# test teacher accuracy
loss, accuracy = teacher_model.evaluate(test_loader)
print(f"Teacher accuracy: {100*accuracy:.2f}%")

Teacher accuracy: 96.37%


In [7]:
# test student accuracy
loss, accuracy = student_model.evaluate(test_loader)
print(f"Student accuracy: {100*accuracy:.2f}%")

Student accuracy: 93.34%


In [8]:
student_model.save()

Checkpoint saved at /global/D1/homes/jorgetf/Network-Packet-ML-Model/distilled_models/checkpoint/gru_1.pth
