# Enhanced Distillation: Simplified and Tuned Student Model
This notebook applies the following improvements:
1. Teacher-student distillation using only final reconstructions
2. Reduced distillation weight γ to 0.1
3. Increased student capacity to [48,24]
4. EarlyStopping and ReduceLROnPlateau callbacks
5. Threshold re-tuning via F1-score

In [1]:
# Step 1: Imports
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix, precision_recall_curve
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

In [2]:
# Step 2: Load and split data
df = pd.read_csv("UNSW-NB15P-MM-SAMPLE.csv")
Dn = df[df['Class'] == 0].drop(columns=['Class'])
Da = df[df['Class'] == 1].drop(columns=['Class'])
Dntr, Dnts = train_test_split(Dn, test_size=0.2, random_state=42)
Dts = pd.concat([Dnts, Da], ignore_index=True)
y_test = np.array([0]*len(Dnts) + [1]*len(Da))

In [3]:
# Step 3: Normalize
scaler = StandardScaler()
X_train = scaler.fit_transform(Dntr)
X_test = scaler.transform(Dts)

In [4]:
# Step 4: Train Teacher Model
input_dim = X_train.shape[1]
inp = Input(shape=(input_dim,))
x = Dense(64, activation='relu')(inp)
x = Dropout(0.2)(x)
x = Dense(32, activation='relu')(x)
x = Dropout(0.2)(x)
encoded = Dense(16, activation='relu')(x)
x = Dense(32, activation='relu')(encoded)
x = Dropout(0.2)(x)
x = Dense(64, activation='relu')(x)
teacher_out = Dense(input_dim, activation='linear')(x)
teacher = Model(inp, teacher_out)
teacher.compile(optimizer=Adam(0.001), loss='mse')
es = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
rlr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
teacher.fit(X_train, X_train, epochs=50, batch_size=256, validation_split=0.1, callbacks=[es, rlr], verbose=1)

Epoch 1/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - loss: 0.6446 - val_loss: 0.2782 - learning_rate: 0.0010
Epoch 2/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 839us/step - loss: 0.3296 - val_loss: 0.1917 - learning_rate: 0.0010
Epoch 3/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 710us/step - loss: 0.2839 - val_loss: 0.1687 - learning_rate: 0.0010
Epoch 4/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 860us/step - loss: 0.2451 - val_loss: 0.1627 - learning_rate: 0.0010
Epoch 5/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 820us/step - loss: 0.2308 - val_loss: 0.1581 - learning_rate: 0.0010
Epoch 6/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 927us/step - loss: 0.2320 - val_loss: 0.1437 - learning_rate: 0.0010
Epoch 7/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 904us/step - loss: 0.2192 - val_loss: 0.1400 - le

<keras.src.callbacks.history.History at 0x317705700>

In [5]:
# Step 5: Generate Teacher Reconstructions
T_train = teacher.predict(X_train)
T_test = teacher.predict(X_test)

[1m8570/8570[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 166us/step
[1m2837/2837[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 180us/step


In [6]:
# Step 6: Build Larger Student Model
def build_student():
    inp_s = Input(shape=(input_dim,))
    x = Dense(48, activation='relu')(inp_s)
    x = Dense(24, activation='relu')(x)
    x = Dense(48, activation='relu')(x)
    out_s = Dense(input_dim, activation='linear')(x)
    return Model(inp_s, out_s)
student = build_student()

In [7]:
# Step 7: Distillation with Simplified Targets
gamma = 0.1  # lower distillation weight
distill = Model(student.input, [student.output, student.output])
distill.compile(
    optimizer=Adam(5e-4),
    loss=['mse', 'mse'],
    loss_weights=[1.0, gamma]
)
es2 = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
rlr2 = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
distill.fit(
    X_train, [X_train, T_train],
    epochs=50, batch_size=256, validation_split=0.1,
    callbacks=[es2, rlr2], verbose=1
)

Epoch 1/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 602us/step - dense_9_loss: 0.3780 - loss: 0.5740 - val_dense_9_loss: 0.0911 - val_loss: 0.1421 - learning_rate: 5.0000e-04
Epoch 2/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 495us/step - dense_9_loss: 0.0873 - loss: 0.1213 - val_dense_9_loss: 0.0798 - val_loss: 0.0761 - learning_rate: 5.0000e-04
Epoch 3/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 528us/step - dense_9_loss: 0.0817 - loss: 0.0700 - val_dense_9_loss: 0.0809 - val_loss: 0.0550 - learning_rate: 5.0000e-04
Epoch 4/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 496us/step - dense_9_loss: 0.0829 - loss: 0.0507 - val_dense_9_loss: 0.0843 - val_loss: 0.0427 - learning_rate: 5.0000e-04
Epoch 5/50
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 505us/step - dense_9_loss: 0.0864 - loss: 0.0404 - val_dense_9_loss: 0.0881 - val_loss: 0.0346 - learning_rate: 5.0000e-04


<keras.src.callbacks.history.History at 0x3279f6720>

In [8]:
# Step 8: Evaluate Student Model
S_pred = student.predict(X_test)
errors_s = np.mean((X_test - S_pred)**2, axis=1)
prec, rec, thr = precision_recall_curve(y_test, errors_s)
f1_scores = 2*(prec*rec)/(prec+rec+1e-8)
best_thr = thr[np.argmax(f1_scores)]
y_pred = (errors_s > best_thr).astype(int)
cm = confusion_matrix(y_test, y_pred)
print('Confusion Matrix:', cm)
print(classification_report(y_test, y_pred, target_names=['Normal','Attack']))
tn, fp, fn, tp = cm.ravel()
print(f"FPR: {fp/(fp+tn):.4f}, FNR: {fn/(fn+tp):.4f}")
print(f"ROC-AUC: {roc_auc_score(y_test, errors_s):.4f}")

[1m2837/2837[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 167us/step
Confusion Matrix: [[64606  3951]
 [  445 21770]]
              precision    recall  f1-score   support

      Normal       0.99      0.94      0.97     68557
      Attack       0.85      0.98      0.91     22215

    accuracy                           0.95     90772
   macro avg       0.92      0.96      0.94     90772
weighted avg       0.96      0.95      0.95     90772

FPR: 0.0576, FNR: 0.0200
ROC-AUC: 0.9868
