# Enhanced Distillation: Simplified and Tuned Student Model
This notebook applies the following improvements:
1. Teacher-student distillation using only final reconstructions
2. Reduced distillation weight γ to 0.1
3. Increased student capacity to [48,24]
4. EarlyStopping and ReduceLROnPlateau callbacks
5. Threshold re-tuning via F1-score

In [9]:
# Step 1: Imports
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix, precision_recall_curve
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

In [10]:
# Step 2: Load and split data
df = pd.read_csv("UNSW-NB15P-MM-SAMPLE.csv")
Dn = df[df['Class'] == 0].drop(columns=['Class'])
Da = df[df['Class'] == 1].drop(columns=['Class'])
Dntr, Dnts = train_test_split(Dn, test_size=0.2, random_state=42)
Dts = pd.concat([Dnts, Da], ignore_index=True)
y_test = np.array([0]*len(Dnts) + [1]*len(Da))

In [11]:
# Step 3: Normalize
scaler = StandardScaler()
X_train = scaler.fit_transform(Dntr)
X_test = scaler.transform(Dts)

In [12]:
# Step 4: Train Teacher Model
input_dim = X_train.shape[1]
inp = Input(shape=(input_dim,))
x = Dense(64, activation='relu')(inp)
x = Dropout(0.2)(x)
x = Dense(32, activation='relu')(x)
x = Dropout(0.2)(x)
encoded = Dense(16, activation='relu')(x)
x = Dense(32, activation='relu')(encoded)
x = Dropout(0.2)(x)
x = Dense(64, activation='relu')(x)
teacher_out = Dense(input_dim, activation='linear')(x)
teacher = Model(inp, teacher_out)
teacher.compile(optimizer=Adam(0.001), loss='mse')
es = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
rlr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
teacher.fit(X_train, X_train, epochs=30, batch_size=256, validation_split=0.1, callbacks=[es, rlr], verbose=1)

Epoch 1/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 911us/step - loss: 0.6125 - val_loss: 0.2442 - learning_rate: 0.0010
Epoch 2/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 841us/step - loss: 0.3098 - val_loss: 0.2134 - learning_rate: 0.0010
Epoch 3/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 839us/step - loss: 0.2599 - val_loss: 0.1686 - learning_rate: 0.0010
Epoch 4/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 811us/step - loss: 0.2381 - val_loss: 0.1552 - learning_rate: 0.0010
Epoch 5/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 801us/step - loss: 0.2258 - val_loss: 0.1517 - learning_rate: 0.0010
Epoch 6/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 798us/step - loss: 0.2176 - val_loss: 0.1443 - learning_rate: 0.0010
Epoch 7/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 814us/step - loss: 0.2124 - val_loss: 0.1438 - 

<keras.src.callbacks.history.History at 0x3731c9a30>

In [13]:
# Step 5: Generate Teacher Reconstructions
T_train = teacher.predict(X_train)
T_test = teacher.predict(X_test)

[1m8570/8570[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 197us/step
[1m2837/2837[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 200us/step


In [14]:
# Step 6: Build Larger Student Model
def build_student():
    inp_s = Input(shape=(input_dim,))
    x = Dense(48, activation='relu')(inp_s)
    x = Dense(24, activation='relu')(x)
    x = Dense(48, activation='relu')(x)
    out_s = Dense(input_dim, activation='linear')(x)
    return Model(inp_s, out_s)
student = build_student()

In [15]:
# Step 7: Distillation with Simplified Targets
gamma = 0.1  # lower distillation weight
distill = Model(student.input, [student.output, student.output])
distill.compile(
    optimizer=Adam(5e-4),
    loss=['mse', 'mse'],
    loss_weights=[1.0, gamma]
)
es2 = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
rlr2 = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
distill.fit(
    X_train, [X_train, T_train],
    epochs=30, batch_size=256, validation_split=0.1,
    callbacks=[es2, rlr2], verbose=1
)

Epoch 1/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 696us/step - dense_19_loss: 0.3757 - loss: 0.5638 - val_dense_19_loss: 0.0954 - val_loss: 0.1258 - learning_rate: 5.0000e-04
Epoch 2/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 594us/step - dense_19_loss: 0.0918 - loss: 0.1060 - val_dense_19_loss: 0.0837 - val_loss: 0.0658 - learning_rate: 5.0000e-04
Epoch 3/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 610us/step - dense_19_loss: 0.0855 - loss: 0.0595 - val_dense_19_loss: 0.0867 - val_loss: 0.0438 - learning_rate: 5.0000e-04
Epoch 4/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 637us/step - dense_19_loss: 0.0874 - loss: 0.0411 - val_dense_19_loss: 0.0902 - val_loss: 0.0320 - learning_rate: 5.0000e-04
Epoch 5/30
[1m965/965[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 582us/step - dense_19_loss: 0.0902 - loss: 0.0310 - val_dense_19_loss: 0.0927 - val_loss: 0.0266 - learning_rate: 5

<keras.src.callbacks.history.History at 0x383f34bc0>

In [16]:
# Step 8: Evaluate Student Model
S_pred = student.predict(X_test)
errors_s = np.mean((X_test - S_pred)**2, axis=1)
prec, rec, thr = precision_recall_curve(y_test, errors_s)
f1_scores = 2*(prec*rec)/(prec+rec+1e-8)
best_thr = thr[np.argmax(f1_scores)]
y_pred = (errors_s > best_thr).astype(int)
cm = confusion_matrix(y_test, y_pred)
print('Confusion Matrix:', cm)
print(classification_report(y_test, y_pred, target_names=['Normal','Attack']))
tn, fp, fn, tp = cm.ravel()
print(f"FPR: {fp/(fp+tn):.4f}, FNR: {fn/(fn+tp):.4f}")
print(f"ROC-AUC: {roc_auc_score(y_test, errors_s):.4f}")

[1m2837/2837[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 204us/step
Confusion Matrix: [[65424  3133]
 [  301 21914]]
              precision    recall  f1-score   support

      Normal       1.00      0.95      0.97     68557
      Attack       0.87      0.99      0.93     22215

    accuracy                           0.96     90772
   macro avg       0.94      0.97      0.95     90772
weighted avg       0.97      0.96      0.96     90772

FPR: 0.0457, FNR: 0.0135
ROC-AUC: 0.9923
