In [None]:
#!/usr/bin/env python3
"""
Knowledge Distillation for IoT Intrusion Detection - PyTorch Implementation
Teacher: Large LSTM Model
Student: Lightweight LSTM Model (10x smaller)
60-20-20 Train-Val-Test Split
Max 5 files loaded at once
GPU Accelerated with PyTorch
"""

import os
import sys
import gc
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import IncrementalPCA
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import kagglehub
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend for servers
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import warnings
warnings.filterwarnings('ignore')

# ==========================================================
# üéÆ GPU CONFIGURATION
# ==========================================================

def setup_gpu():
    """Configure PyTorch to use GPU efficiently"""
    print("=" * 80)
    print("üéÆ GPU Configuration")
    print("=" * 80)

    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"‚úÖ GPU detected: {torch.cuda.get_device_name(0)}")
        print(f"‚úÖ CUDA Version: {torch.version.cuda}")
        print(f"‚úÖ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        print(f"‚úÖ Number of GPUs: {torch.cuda.device_count()}")

        # Enable cudnn autotuner for optimal performance
        torch.backends.cudnn.benchmark = True
        print("‚úÖ cuDNN autotuner enabled")

    else:
        device = torch.device('cpu')
        print("‚ö†Ô∏è  No GPU detected, running on CPU")

    print("=" * 80 + "\n")
    return device

device = setup_gpu()

# ==========================================================
# üßπ HELPER FUNCTIONS
# ==========================================================

def load_and_clean(path, label_col=None):
    """Load CSV and separate features from labels"""
    df = pd.read_csv(path)
    df = df.dropna()
    df = df.drop_duplicates()

    if label_col is None:
        label_col = "Label" if "Label" in df.columns else df.columns[-1]

    X = df.drop(columns=[label_col])
    y = df[label_col]
    return X, y


def encode_objects(X):
    """Encode categorical columns and convert to numpy array"""
    for col in X.select_dtypes(include=["object"]).columns:
        X[col] = LabelEncoder().fit_transform(X[col])
    return X.values


def process_files_generator(file_list, scaler, pca, label_encoder, batch_size=5):
    """Generator that yields batches of processed data without storing all in memory"""
    for i in range(0, len(file_list), batch_size):
        batch_files = file_list[i:i+batch_size]

        X_batch = []
        y_batch = []

        for f in batch_files:
            try:
                X, y = load_and_clean(f)
                X = encode_objects(X)

                X_scaled = scaler.transform(X)
                X_reduced = pca.transform(X_scaled)

                X_batch.append(X_reduced)
                y_batch.append(label_encoder.transform(y.astype(str)))

            except Exception as e:
                print(f"Error processing {f}: {e}")
                continue

        if X_batch:
            X_combined = np.vstack(X_batch)
            y_combined = np.hstack(y_batch)

            del X_batch, y_batch
            gc.collect()

            yield X_combined, y_combined


# ==========================================================
# üéì PYTORCH MODELS
# ==========================================================

class TeacherLSTM(nn.Module):
    """Large Teacher Model - High Capacity (~200K parameters)"""

    def __init__(self, input_size, hidden_sizes, num_classes, dropout=0.3):
        super(TeacherLSTM, self).__init__()

        self.lstm1 = nn.LSTM(input_size, hidden_sizes[0], batch_first=True)
        self.dropout1 = nn.Dropout(dropout)

        self.lstm2 = nn.LSTM(hidden_sizes[0], hidden_sizes[1], batch_first=True)
        self.dropout2 = nn.Dropout(dropout)

        self.lstm3 = nn.LSTM(hidden_sizes[1], hidden_sizes[2], batch_first=True)
        self.dropout3 = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_sizes[2], 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, num_classes)

    def forward(self, x):
        # x shape: (batch, seq_len, features)
        out, _ = self.lstm1(x)
        out = self.dropout1(out)

        out, _ = self.lstm2(out)
        out = self.dropout2(out)

        out, _ = self.lstm3(out)
        out = self.dropout3(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu1(self.fc1(out))
        out = self.relu2(self.fc2(out))
        out = self.fc3(out)

        return out


class StudentLSTM(nn.Module):
    """Lightweight Student Model - 10x smaller (~20K parameters)"""

    def __init__(self, input_size, hidden_sizes, num_classes, dropout=0.2):
        super(StudentLSTM, self).__init__()

        self.lstm1 = nn.LSTM(input_size, hidden_sizes[0], batch_first=True)
        self.dropout1 = nn.Dropout(dropout)

        self.lstm2 = nn.LSTM(hidden_sizes[0], hidden_sizes[1], batch_first=True)
        self.dropout2 = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_sizes[1], 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        out, _ = self.lstm1(x)
        out = self.dropout1(out)

        out, _ = self.lstm2(out)
        out = self.dropout2(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu(self.fc1(out))
        out = self.fc2(out)

        return out


# ==========================================================
# üéì KNOWLEDGE DISTILLATION LOSS
# ==========================================================

class DistillationLoss(nn.Module):
    """Combined loss for knowledge distillation"""

    def __init__(self, temperature=4.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # Hard target loss (cross-entropy with true labels)
        hard_loss = self.ce_loss(student_logits, labels)

        # Soft target loss (KL divergence with teacher)
        soft_student = torch.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = torch.softmax(teacher_logits / self.temperature, dim=1)

        soft_loss = self.kl_loss(soft_student, soft_teacher) * (self.temperature ** 2)

        # Combined loss
        total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss

        return total_loss, hard_loss, soft_loss


# ==========================================================
# üèãÔ∏è TRAINING FUNCTIONS
# ==========================================================

def train_epoch(model, data_generator, optimizer, criterion, device, is_distillation=False, teacher_model=None):
    """Train for one epoch"""
    model.train()
    if teacher_model is not None:
        teacher_model.eval()

    total_loss = 0
    total_samples = 0

    for X_batch, y_batch in data_generator:
        # Convert to tensors
        X_tensor = torch.FloatTensor(X_batch).unsqueeze(1).to(device)  # Add sequence dimension
        y_tensor = torch.LongTensor(y_batch).to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(X_tensor)

        if is_distillation and teacher_model is not None:
            # Get teacher predictions
            with torch.no_grad():
                teacher_outputs = teacher_model(X_tensor)

            loss, hard_loss, soft_loss = criterion(outputs, teacher_outputs, y_tensor)
        else:
            loss = criterion(outputs, y_tensor)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(y_batch)
        total_samples += len(y_batch)

        # Free memory
        del X_tensor, y_tensor, outputs
        torch.cuda.empty_cache()

    return total_loss / total_samples


def evaluate(model, data_generator, criterion, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0

    with torch.no_grad():
        for X_batch, y_batch in data_generator:
            X_tensor = torch.FloatTensor(X_batch).unsqueeze(1).to(device)
            y_tensor = torch.LongTensor(y_batch).to(device)

            outputs = model(X_tensor)
            loss = criterion(outputs, y_tensor)

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == y_tensor).sum().item()
            total_loss += loss.item() * len(y_batch)
            total_samples += len(y_batch)

            del X_tensor, y_tensor, outputs
            torch.cuda.empty_cache()

    accuracy = correct / total_samples
    avg_loss = total_loss / total_samples

    return avg_loss, accuracy


# ==========================================================
# üìÇ DOWNLOAD & SPLIT DATASET
# ==========================================================

print("=" * 80)
print("üì• Downloading CIC-IoT-2023 Dataset from Kaggle...")
print("=" * 80)

dataset_dir = kagglehub.dataset_download("akashdogra/cic-iot-2023")
print(f"‚úÖ Dataset downloaded to: {dataset_dir}")

csv_files = sorted([
    os.path.join(dataset_dir, f)
    for f in os.listdir(dataset_dir)
    if f.endswith(".csv")
])

print(f"üìÇ Found {len(csv_files)} CSV files.")

# 60-20-20 split
n_files = len(csv_files)
train_idx = int(n_files * 0.60)
val_idx = int(n_files * 0.80)

train_files = csv_files[:train_idx]
val_files = csv_files[train_idx:val_idx]
test_files = csv_files[val_idx:]

print(f"\nüìä Dataset Split:")
print(f"   Training:   {len(train_files)} files")
print(f"   Validation: {len(val_files)} files")
print(f"   Testing:    {len(test_files)} files")

# ==========================================================
# üè∑Ô∏è FIT LABEL ENCODER
# ==========================================================

print("\n" + "=" * 80)
print("üè∑Ô∏è  Fitting Label Encoder...")
print("=" * 80)

all_labels = []
max_batch = 5

for i in range(0, len(train_files), max_batch):
    batch_files = train_files[i:i+max_batch]
    print(f"Processing batch {i//max_batch + 1}/{(len(train_files)-1)//max_batch + 1}")

    for f in batch_files:
        _, y = load_and_clean(f)
        all_labels.extend(list(y.astype(str)))

    if i % (max_batch * 4) == 0:
        gc.collect()

label_encoder = LabelEncoder()
label_encoder.fit(all_labels)
del all_labels
gc.collect()

print(f"‚úÖ LabelEncoder fitted with {len(label_encoder.classes_)} classes")

# ==========================================================
# üèóÔ∏è FIT SCALER & PCA
# ==========================================================

print("\n" + "=" * 80)
print("üèóÔ∏è  Fitting Scaler & PCA...")
print("=" * 80)

scaler = StandardScaler()

sample_X, _ = load_and_clean(train_files[0])
sample_X = encode_objects(sample_X)
n_features = sample_X.shape[1]
n_components = min(30, n_features)
del sample_X
gc.collect()

print(f"PCA will use {n_components} components (dataset has {n_features} features)")

pca = IncrementalPCA(n_components=n_components)

# Pass 1: Fit Scaler
print("Pass 1: Fitting Scaler...")
for i in range(0, len(train_files), max_batch):
    batch_files = train_files[i:i+max_batch]
    print(f"  Scaler batch {i//max_batch + 1}/{(len(train_files)-1)//max_batch + 1}")

    for f in batch_files:
        X, _ = load_and_clean(f)
        X = encode_objects(X)
        scaler.partial_fit(X)
        del X
        gc.collect()

print("‚úÖ Scaler fitted")

# Pass 2: Fit PCA
print("\nPass 2: Fitting PCA...")
for i in range(0, len(train_files), max_batch):
    batch_files = train_files[i:i+max_batch]
    print(f"  PCA batch {i//max_batch + 1}/{(len(train_files)-1)//max_batch + 1}")

    for f in batch_files:
        X, _ = load_and_clean(f)
        X = encode_objects(X)
        X_scaled = scaler.transform(X)
        pca.partial_fit(X_scaled)
        del X, X_scaled
        gc.collect()

print(f"‚úÖ PCA fitted with {pca.n_components_} components")
gc.collect()

# ==========================================================
# üéì STAGE 1: TRAIN TEACHER MODEL
# ==========================================================

print("\n" + "=" * 80)
print("üéì STAGE 1: Training Teacher Model (Large)")
print("=" * 80)

n_classes = len(label_encoder.classes_)

# Initialize teacher model
teacher_model = TeacherLSTM(
    input_size=n_components,
    hidden_sizes=[256, 128, 64],
    num_classes=n_classes,
    dropout=0.3
).to(device)

# Count parameters
teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"\nüèóÔ∏è  Teacher Model: {teacher_params:,} parameters")

# Optimizer and criterion
teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
teacher_criterion = nn.CrossEntropyLoss()

# Training settings
epochs_teacher = 20
files_per_epoch = 3

best_teacher_acc = 0
patience_counter = 0
patience = 3

print("\nüöÄ Training Teacher Model...")

for epoch in range(epochs_teacher):
    print(f"\n{'='*80}")
    print(f"TEACHER EPOCH {epoch+1}/{epochs_teacher}")
    print(f"{'='*80}")

    # Select training files
    start = (epoch * files_per_epoch) % len(train_files)
    selected_files = train_files[start:start + files_per_epoch]

    if len(selected_files) < files_per_epoch:
        selected_files += train_files[:files_per_epoch - len(selected_files)]

    print(f"Training on {len(selected_files)} files")

    # Train
    train_gen = process_files_generator(selected_files, scaler, pca, label_encoder, batch_size=files_per_epoch)
    train_loss = train_epoch(teacher_model, train_gen, teacher_optimizer, teacher_criterion, device)

    # Validate
    val_gen = process_files_generator(val_files[:5], scaler, pca, label_encoder, batch_size=5)
    val_loss, val_acc = evaluate(teacher_model, val_gen, teacher_criterion, device)

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    # Save best model
    if val_acc > best_teacher_acc:
        best_teacher_acc = val_acc
        torch.save(teacher_model.state_dict(), 'teacher_model.pth')
        print(f"‚úÖ Best teacher model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    gc.collect()
    torch.cuda.empty_cache()

print("\n‚úÖ Teacher Model Training Complete!")
teacher_model.load_state_dict(torch.load('teacher_model.pth'))

# ==========================================================
# üéí STAGE 2: KNOWLEDGE DISTILLATION - TRAIN STUDENT
# ==========================================================

print("\n" + "=" * 80)
print("üéí STAGE 2: Knowledge Distillation - Training Student Model")
print("=" * 80)

# Initialize student model
student_model = StudentLSTM(
    input_size=n_components,
    hidden_sizes=[32, 16],
    num_classes=n_classes,
    dropout=0.2
).to(device)

student_params = sum(p.numel() for p in student_model.parameters())
reduction_ratio = teacher_params / student_params

print(f"\nüèóÔ∏è  Student Model: {student_params:,} parameters")
print(f"\nüìä Model Comparison:")
print(f"   Teacher Parameters: {teacher_params:,}")
print(f"   Student Parameters: {student_params:,}")
print(f"   Size Reduction:     {reduction_ratio:.1f}x smaller")

# Optimizer and distillation loss
student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)
distillation_criterion = DistillationLoss(temperature=4.0, alpha=0.7)

epochs_student = 25
best_student_acc = 0
patience_counter = 0

print(f"\nüöÄ Training Student with Knowledge Distillation...")
print(f"   Temperature: {distillation_criterion.temperature}")
print(f"   Alpha (soft target weight): {distillation_criterion.alpha}")

for epoch in range(epochs_student):
    print(f"\n{'='*80}")
    print(f"STUDENT EPOCH {epoch+1}/{epochs_student}")
    print(f"{'='*80}")

    # Select training files
    start = (epoch * files_per_epoch) % len(train_files)
    selected_files = train_files[start:start + files_per_epoch]

    if len(selected_files) < files_per_epoch:
        selected_files += train_files[:files_per_epoch - len(selected_files)]

    print(f"Training on {len(selected_files)} files")

    # Train with distillation
    train_gen = process_files_generator(selected_files, scaler, pca, label_encoder, batch_size=files_per_epoch)
    train_loss = train_epoch(student_model, train_gen, student_optimizer, distillation_criterion,
                            device, is_distillation=True, teacher_model=teacher_model)

    # Validate
    val_gen = process_files_generator(val_files[:5], scaler, pca, label_encoder, batch_size=5)
    val_criterion = nn.CrossEntropyLoss()
    val_loss, val_acc = evaluate(student_model, val_gen, val_criterion, device)

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    # Save best model
    if val_acc > best_student_acc:
        best_student_acc = val_acc
        torch.save(student_model.state_dict(), 'student_model.pth')
        print(f"‚úÖ Best student model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    gc.collect()
    torch.cuda.empty_cache()

print("\n‚úÖ Student Model Training Complete!")
student_model.load_state_dict(torch.load('student_model.pth'))

# ==========================================================
# üìà STAGE 3: EVALUATION - COMPARE TEACHER VS STUDENT
# ==========================================================

print("\n" + "=" * 80)
print("üìà STAGE 3: Final Evaluation - Teacher vs Student")
print("=" * 80)

def evaluate_model_detailed(model, model_name):
    """Evaluate model on test set with detailed metrics"""
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name}...")
    print(f"{'='*60}")

    model.eval()
    y_true_all = []
    y_pred_all = []

    test_gen = process_files_generator(test_files, scaler, pca, label_encoder, batch_size=5)

    with torch.no_grad():
        for batch_num, (X_test, y_test) in enumerate(test_gen):
            print(f"Test batch {batch_num + 1}/{(len(test_files)-1)//5 + 1}")

            X_tensor = torch.FloatTensor(X_test).unsqueeze(1).to(device)
            outputs = model(X_tensor)
            _, predicted = torch.max(outputs, 1)

            y_true_all.extend(y_test)
            y_pred_all.extend(predicted.cpu().numpy())

            del X_tensor, outputs
            torch.cuda.empty_cache()

    y_true_all = np.array(y_true_all)
    y_pred_all = np.array(y_pred_all)

    accuracy = accuracy_score(y_true_all, y_pred_all)
    precision = precision_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    recall = recall_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    f1 = f1_score(y_true_all, y_pred_all, average='weighted', zero_division=0)

    print(f"\nüìä {model_name} Performance:")
    print(f"   Accuracy:  {accuracy:.4f}")
    print(f"   Precision: {precision:.4f}")
    print(f"   Recall:    {recall:.4f}")
    print(f"   F1-Score:  {f1:.4f}")

    return y_true_all, y_pred_all, accuracy, precision, recall, f1

# Evaluate both models
teacher_results = evaluate_model_detailed(teacher_model, "TEACHER MODEL")
student_results = evaluate_model_detailed(student_model, "STUDENT MODEL (Distilled)")

# ==========================================================
# üìä CONFUSION MATRIX & COMPARISON
# ==========================================================

print("\n" + "=" * 80)
print("üìä Generating Comparison Report...")
print("=" * 80)

y_true, y_pred, s_acc, s_prec, s_rec, s_f1 = student_results
_, _, t_acc, t_prec, t_rec, t_f1 = teacher_results

cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(20, 16))
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_,
            cbar_kws={'label': 'Count'})
plt.title('Student Model Confusion Matrix (Knowledge Distillation - PyTorch)', fontsize=16, pad=20)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right', fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.savefig('student_confusion_matrix.png', dpi=300, bbox_inches='tight')
print("‚úÖ Student confusion matrix saved as 'student_confusion_matrix.png'")

# Comparison Summary
performance_retention = (s_acc / t_acc) * 100

print("\n" + "=" * 80)
print("üìä FINAL COMPARISON: TEACHER vs STUDENT")
print("=" * 80)
print(f"\n{'Metric':<15} {'Teacher':<15} {'Student':<15} {'Difference':<15}")
print("=" * 80)
print(f"{'Accuracy':<15} {t_acc:<15.4f} {s_acc:<15.4f} {(s_acc-t_acc):<15.4f}")
print(f"{'Precision':<15} {t_prec:<15.4f} {s_prec:<15.4f} {(s_prec-t_prec):<15.4f}")
print(f"{'Recall':<15} {t_rec:<15.4f} {s_rec:<15.4f} {(s_rec-t_rec):<15.4f}")
print(f"{'F1-Score':<15} {t_f1:<15.4f} {s_f1:<15.4f} {(s_f1-t_f1):<15.4f}")
print(f"{'Parameters':<15} {teacher_params:<15,} {student_params:<15,} {'-':<15}")
print(f"{'Model Size':<15} {'1.0x':<15} {f'{1/reduction_ratio:.2f}x':<15} {f'{reduction_ratio:.1f}x smaller':<15}")
print("=" * 80)

print(f"\nüéØ Performance Retention: {performance_retention:.2f}%")
print(f"üéØ Model Size Reduction: {reduction_ratio:.1f}x smaller")
print(f"üéØ Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}% fewer parameters")

# ==========================================================
# üíæ SAVE MODELS AND PREPROCESSING OBJECTS
# ==========================================================

print("\n" + "=" * 80)
print("üíæ Saving Models and Preprocessing Objects")
print("=" * 80)

# Save PyTorch models
torch.save({
    'model_state_dict': teacher_model.state_dict(),
    'input_size': n_components,
    'hidden_sizes': [256, 128, 64],
    'num_classes': n_classes,
    'accuracy': t_acc,
    'params': teacher_params
}, 'teacher_model_complete.pth')
print("‚úÖ Saved: teacher_model_complete.pth")

torch.save({
    'model_state_dict': student_model.state_dict(),
    'input_size': n_components,
    'hidden_sizes': [32, 16],
    'num_classes': n_classes,
    'accuracy': s_acc,
    'params': student_params
}, 'student_model_complete.pth')
print("‚úÖ Saved: student_model_complete.pth")

# Save preprocessing objects
preprocessing_objects = {
    'scaler': scaler,
    'pca': pca,
    'label_encoder': label_encoder
}

with open('preprocessing.pkl', 'wb') as f:
    pickle.dump(preprocessing_objects, f)
print("‚úÖ Saved: preprocessing.pkl")

# Save metadata
metadata = {
    'n_classes': int(n_classes),
    'n_features': int(n_features),
    'n_components': int(n_components),
    'teacher_params': int(teacher_params),
    'student_params': int(student_params),
    'teacher_accuracy': float(t_acc),
    'student_accuracy': float(s_acc),
    'size_reduction': float(reduction_ratio),
    'performance_retention': float(performance_retention),
    'classes': label_encoder.classes_.tolist()
}

with open('model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=4)
print("‚úÖ Saved: model_metadata.json")

# Create summary
with open('model_summary.txt', 'w') as f:
    f.write("=" * 80 + "\n")
    f.write("KNOWLEDGE DISTILLATION - PYTORCH MODEL SUMMARY\n")
    f.write("=" * 80 + "\n\n")

    f.write("TEACHER MODEL:\n")
    f.write(f"  Parameters: {teacher_params:,}\n")
    f.write(f"  Accuracy: {t_acc:.4f}\n")
    f.write(f"  Precision: {t_prec:.4f}\n")
    f.write(f"  Recall: {t_rec:.4f}\n")
    f.write(f"  F1-Score: {t_f1:.4f}\n\n")

    f.write("STUDENT MODEL (DISTILLED):\n")
    f.write(f"  Parameters: {student_params:,}\n")
    f.write(f"  Accuracy: {s_acc:.4f}\n")
    f.write(f"  Precision: {s_prec:.4f}\n")
    f.write(f"  Recall: {s_rec:.4f}\n")
    f.write(f"  F1-Score: {s_f1:.4f}\n\n")

    f.write("COMPRESSION METRICS:\n")
    f.write(f"  Size Reduction: {reduction_ratio:.1f}x smaller\n")
    f.write(f"  Performance Retention: {performance_retention:.2f}%\n")
    f.write(f"  Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}%\n\n")

    f.write("FILES GENERATED:\n")
    f.write("  - teacher_model_complete.pth (Teacher model with metadata)\n")
    f.write("  - student_model_complete.pth (Student model with metadata)\n")
    f.write("  - preprocessing.pkl (Scaler, PCA, Label Encoder)\n")
    f.write("  - model_metadata.json (Model specifications)\n")
    f.write("  - student_confusion_matrix.png (Confusion matrix visualization)\n")

print("‚úÖ Saved: model_summary.txt")

# List all files
print("\n" + "=" * 80)
print("üìÅ Generated Files:")
print("=" * 80)

files_to_check = [
    'teacher_model_complete.pth',
    'student_model_complete.pth',
    'preprocessing.pkl',
    'model_metadata.json',
    'model_summary.txt',
    'student_confusion_matrix.png'
]

total_size = 0
for filename in files_to_check:
    if os.path.exists(filename):
        size = os.path.getsize(filename)
        total_size += size
        size_mb = size / (1024 * 1024)
        print(f"‚úÖ {filename:<40} {size_mb:>10.2f} MB")

print("=" * 80)
print(f"üìä Total Size: {total_size / (1024 * 1024):.2f} MB")

print("\n" + "=" * 80)
print("üéâ KNOWLEDGE DISTILLATION COMPLETE!")
print("=" * 80)
print("\nYour lightweight student model is ready for deployment!")
print(f"Model size reduced by {reduction_ratio:.1f}x with {performance_retention:.1f}% performance retention")
print("\nAll model files have been saved and are ready for download.")
print("=" * 80)

üéÆ GPU Configuration
‚ö†Ô∏è  No GPU detected, running on CPU

üì• Downloading CIC-IoT-2023 Dataset from Kaggle...
Downloading from https://www.kaggle.com/api/v1/datasets/download/akashdogra/cic-iot-2023?dataset_version_number=1...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2.77G/2.77G [00:25<00:00, 117MB/s]

Extracting files...





‚úÖ Dataset downloaded to: /root/.cache/kagglehub/datasets/akashdogra/cic-iot-2023/versions/1
üìÇ Found 169 CSV files.

üìä Dataset Split:
   Training:   101 files
   Validation: 34 files
   Testing:    34 files

üè∑Ô∏è  Fitting Label Encoder...
Processing batch 1/21
Processing batch 2/21
Processing batch 3/21
Processing batch 4/21
Processing batch 5/21
Processing batch 6/21
Processing batch 7/21
Processing batch 8/21
Processing batch 9/21
Processing batch 10/21
Processing batch 11/21
Processing batch 12/21
Processing batch 13/21
Processing batch 14/21
Processing batch 15/21
Processing batch 16/21
Processing batch 17/21
Processing batch 18/21
Processing batch 19/21
Processing batch 20/21
Processing batch 21/21
‚úÖ LabelEncoder fitted with 34 classes

üèóÔ∏è  Fitting Scaler & PCA...
PCA will use 30 components (dataset has 46 features)
Pass 1: Fitting Scaler...
  Scaler batch 1/21
  Scaler batch 2/21
  Scaler batch 3/21
  Scaler batch 4/21
  Scaler batch 5/21
  Scaler batch 6/21
  Sc

In [None]:
#!/usr/bin/env python3
"""
RAM-EFFICIENT Knowledge Distillation for IoT Intrusion Detection
Key Fixes:
- NEVER loads all data into RAM
- Streaming data processing
- Train on one file at a time
- Incremental learning approach
"""

import os
import gc
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import IncrementalPCA
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import kagglehub
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pickle
import json
import warnings
warnings.filterwarnings('ignore')

# ==========================================================
# üéÆ GPU CONFIGURATION
# ==========================================================

def setup_gpu():
    """Configure PyTorch to use GPU efficiently"""
    print("=" * 80)
    print("üéÆ GPU Configuration")
    print("=" * 80)

    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')
        print("‚ö†Ô∏è  No GPU detected")

    print("=" * 80 + "\n")
    return device

device = setup_gpu()

# ==========================================================
# üßπ AGGRESSIVE RAM MANAGEMENT
# ==========================================================

def clear_memory():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def print_ram_usage():
    """Print RAM usage (Linux only)"""
    try:
        import psutil
        process = psutil.Process()
        ram_mb = process.memory_info().rss / 1024 / 1024
        print(f"üíæ RAM Usage: {ram_mb:.0f} MB")
    except:
        pass

# ==========================================================
# üì¶ STREAMING DATASET (NO RAM LOADING)
# ==========================================================

class StreamingIoTDataset(Dataset):
    """
    Dataset that NEVER loads data into RAM.
    Reads from disk on-the-fly for each batch.
    """

    def __init__(self, csv_path, scaler, pca, label_encoder, chunk_size=5000):
        self.csv_path = csv_path
        self.scaler = scaler
        self.pca = pca
        self.label_encoder = label_encoder
        self.chunk_size = chunk_size

        # Only read the file length, not the data
        self.length = sum(1 for _ in open(csv_path)) - 1  # Exclude header

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        """
        This is inefficient for random access but saves RAM.
        For production, use chunked iteration instead.
        """
        # Read only the specific row (very slow, see note below)
        df = pd.read_csv(self.csv_path, skiprows=range(1, idx+1), nrows=1)

        label_col = "Label" if "Label" in df.columns else df.columns[-1]
        X = df.drop(columns=[label_col])
        y = df[label_col].values[0]

        # Encode
        for col in X.select_dtypes(include=["object"]).columns:
            try:
                X[col] = LabelEncoder().fit_transform(X[col])
            except:
                X[col] = 0

        X = X.values.astype(np.float32)

        # Transform
        X_scaled = self.scaler.transform(X)
        X_reduced = self.pca.transform(X_scaled)
        y_encoded = self.label_encoder.transform([str(y)])[0]

        return torch.FloatTensor(X_reduced[0]), torch.LongTensor([y_encoded])[0]


class ChunkedFileLoader:
    """
    Better approach: Load file in chunks, iterate through chunks.
    This is what we'll actually use.
    """

    def __init__(self, csv_path, scaler, pca, label_encoder, chunk_size=5000):
        self.csv_path = csv_path
        self.scaler = scaler
        self.pca = pca
        self.label_encoder = label_encoder
        self.chunk_size = chunk_size

    def get_chunks(self):
        """Yield chunks of data without loading entire file"""
        chunks = pd.read_csv(self.csv_path, chunksize=self.chunk_size, low_memory=False)

        for chunk_df in chunks:
            # Clean
            chunk_df = chunk_df.dropna()
            chunk_df = chunk_df.drop_duplicates()

            if len(chunk_df) == 0:
                continue

            # Separate X and y
            label_col = "Label" if "Label" in chunk_df.columns else chunk_df.columns[-1]
            X = chunk_df.drop(columns=[label_col])
            y = chunk_df[label_col]

            # Encode objects
            for col in X.select_dtypes(include=["object"]).columns:
                try:
                    X[col] = LabelEncoder().fit_transform(X[col].astype(str))
                except:
                    X[col] = 0

            X = X.values.astype(np.float32)

            # Transform
            try:
                X_scaled = self.scaler.transform(X)
                X_reduced = self.pca.transform(X_scaled)
                y_encoded = self.label_encoder.transform(y.astype(str))

                yield X_reduced, y_encoded
            except Exception as e:
                print(f"Error processing chunk: {e}")
                continue

            # Clean up immediately
            del chunk_df, X, y
            gc.collect()

# ==========================================================
# üßπ HELPER FUNCTIONS
# ==========================================================

def load_and_clean_sample(path, max_rows=1000):
    """Load ONLY a small sample for fitting preprocessing"""
    df = pd.read_csv(path, nrows=max_rows, low_memory=False)
    df = df.dropna()

    label_col = "Label" if "Label" in df.columns else df.columns[-1]
    X = df.drop(columns=[label_col])
    y = df[label_col]

    # Encode objects
    for col in X.select_dtypes(include=["object"]).columns:
        try:
            X[col] = LabelEncoder().fit_transform(X[col].astype(str))
        except:
            X[col] = 0

    return X.values.astype(np.float32), y.astype(str).values

# ==========================================================
# üéì SMALLER PYTORCH MODELS
# ==========================================================

class TeacherLSTM(nn.Module):
    """Compact Teacher Model"""

    def __init__(self, input_size, hidden_size, num_classes, dropout=0.3):
        super(TeacherLSTM, self).__init__()

        self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.dropout1 = nn.Dropout(dropout)
        self.lstm2 = nn.LSTM(hidden_size, hidden_size//2, batch_first=True)
        self.dropout2 = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hidden_size//2, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        out, _ = self.lstm1(x)
        out = self.dropout1(out)
        out, _ = self.lstm2(out)
        out = self.dropout2(out)
        out = out[:, -1, :]
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        return out


class StudentLSTM(nn.Module):
    """Tiny Student Model"""

    def __init__(self, input_size, hidden_size, num_classes, dropout=0.2):
        super(StudentLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hidden_size, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.dropout(out)
        out = out[:, -1, :]
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        return out

# ==========================================================
# üéì DISTILLATION LOSS
# ==========================================================

class DistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        hard_loss = self.ce_loss(student_logits, labels)

        soft_student = torch.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = torch.softmax(teacher_logits / self.temperature, dim=1)
        soft_loss = self.kl_loss(soft_student, soft_teacher) * (self.temperature ** 2)

        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

# ==========================================================
# üèãÔ∏è TRAINING WITH STREAMING DATA
# ==========================================================

def train_on_file(model, csv_path, scaler, pca, label_encoder, optimizer,
                  criterion, device, batch_size=32, is_distillation=False,
                  teacher_model=None):
    """Train on a single file using streaming chunks"""
    model.train()
    if teacher_model is not None:
        teacher_model.eval()

    loader = ChunkedFileLoader(csv_path, scaler, pca, label_encoder, chunk_size=5000)

    total_loss = 0
    total_samples = 0

    for X_chunk, y_chunk in loader.get_chunks():
        # Create mini-batches from chunk
        n_samples = len(X_chunk)

        for i in range(0, n_samples, batch_size):
            batch_X = X_chunk[i:i+batch_size]
            batch_y = y_chunk[i:i+batch_size]

            X_tensor = torch.FloatTensor(batch_X).unsqueeze(1).to(device)
            y_tensor = torch.LongTensor(batch_y).to(device)

            optimizer.zero_grad()

            outputs = model(X_tensor)

            if is_distillation and teacher_model is not None:
                with torch.no_grad():
                    teacher_outputs = teacher_model(X_tensor)
                loss = criterion(outputs, teacher_outputs, y_tensor)
            else:
                loss = criterion(outputs, y_tensor)

            loss.backward()
            optimizer.step()

            total_loss += loss.item() * len(batch_y)
            total_samples += len(batch_y)

            del X_tensor, y_tensor, outputs
            clear_memory()

        # Clean up chunk
        del X_chunk, y_chunk
        clear_memory()

    return total_loss / total_samples if total_samples > 0 else 0


def evaluate_on_file(model, csv_path, scaler, pca, label_encoder,
                     criterion, device, batch_size=32):
    """Evaluate on a single file using streaming"""
    model.eval()

    loader = ChunkedFileLoader(csv_path, scaler, pca, label_encoder, chunk_size=5000)

    total_loss = 0
    correct = 0
    total_samples = 0

    with torch.no_grad():
        for X_chunk, y_chunk in loader.get_chunks():
            n_samples = len(X_chunk)

            for i in range(0, n_samples, batch_size):
                batch_X = X_chunk[i:i+batch_size]
                batch_y = y_chunk[i:i+batch_size]

                X_tensor = torch.FloatTensor(batch_X).unsqueeze(1).to(device)
                y_tensor = torch.LongTensor(batch_y).to(device)

                outputs = model(X_tensor)
                loss = criterion(outputs, y_tensor)

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == y_tensor).sum().item()
                total_loss += loss.item() * len(batch_y)
                total_samples += len(batch_y)

                del X_tensor, y_tensor, outputs

            del X_chunk, y_chunk
            clear_memory()

    avg_loss = total_loss / total_samples if total_samples > 0 else 0
    accuracy = correct / total_samples if total_samples > 0 else 0

    return avg_loss, accuracy

# ==========================================================
# üìÇ DATASET DOWNLOAD & SPLIT
# ==========================================================

print("=" * 80)
print("üì• Downloading Dataset...")
print("=" * 80)

dataset_dir = kagglehub.dataset_download("akashdogra/cic-iot-2023")
csv_files = sorted([os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir) if f.endswith(".csv")])

print(f"üìÇ Found {len(csv_files)} CSV files")

# Split files (not data!)
n_files = len(csv_files)
train_idx = int(n_files * 0.60)
val_idx = int(n_files * 0.80)

train_files = csv_files[:train_idx]
val_files = csv_files[train_idx:val_idx]
test_files = csv_files[val_idx:]

print(f"üìä Split: {len(train_files)} train, {len(val_files)} val, {len(test_files)} test files")

# ==========================================================
# üè∑Ô∏è FIT PREPROCESSING (SAMPLING ONLY)
# ==========================================================

print("\n" + "=" * 80)
print("üè∑Ô∏è  Fitting Preprocessing (Memory-Safe)...")
print("=" * 80)

# Collect labels from SAMPLES only
all_labels = []
sample_X_list = []

for f in train_files[:3]:  # Only 3 files
    X_sample, y_sample = load_and_clean_sample(f, max_rows=1000)
    all_labels.extend(y_sample)
    sample_X_list.append(X_sample)
    print(f"Sampled {len(y_sample)} rows from {os.path.basename(f)}")

# Fit label encoder
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)
n_classes = len(label_encoder.classes_)

print(f"‚úÖ Found {n_classes} classes")

# Fit scaler
scaler = StandardScaler()
for X in sample_X_list:
    scaler.partial_fit(X)

print("‚úÖ Scaler fitted")

# Fit PCA
n_components = 20
pca = IncrementalPCA(n_components=n_components)

for X in sample_X_list:
    X_scaled = scaler.transform(X)
    pca.partial_fit(X_scaled)

print(f"‚úÖ PCA fitted with {n_components} components")

del all_labels, sample_X_list
clear_memory()
print_ram_usage()

# ==========================================================
# üéì TRAIN TEACHER MODEL (FILE-BY-FILE)
# ==========================================================

print("\n" + "=" * 80)
print("üéì Training Teacher Model (Streaming)...")
print("=" * 80)

teacher_model = TeacherLSTM(
    input_size=n_components,
    hidden_size=64,  # Small
    num_classes=n_classes,
    dropout=0.3
).to(device)

teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"Teacher: {teacher_params:,} parameters")

teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
teacher_criterion = nn.CrossEntropyLoss()

epochs = 5  # Process files multiple times
best_teacher_acc = 0

for epoch in range(epochs):
    print(f"\n{'='*80}")
    print(f"TEACHER EPOCH {epoch+1}/{epochs}")
    print(f"{'='*80}")

    # Train on 2 files per epoch (cycling through)
    files_to_use = train_files[(epoch*2) % len(train_files):(epoch*2+2) % len(train_files)]
    if len(files_to_use) < 2:
        files_to_use = train_files[:2]

    for i, train_file in enumerate(files_to_use):
        print(f"\nüìÇ Training on file {i+1}/{len(files_to_use)}: {os.path.basename(train_file)}")

        train_loss = train_on_file(
            teacher_model, train_file, scaler, pca, label_encoder,
            teacher_optimizer, teacher_criterion, device, batch_size=32
        )

        print(f"   Loss: {train_loss:.4f}")
        print_ram_usage()

    # Validate on first val file only
    print(f"\nüìä Validating...")
    val_loss, val_acc = evaluate_on_file(
        teacher_model, val_files[0], scaler, pca, label_encoder,
        teacher_criterion, device, batch_size=32
    )

    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_teacher_acc:
        best_teacher_acc = val_acc
        torch.save(teacher_model.state_dict(), 'teacher_model.pth')
        print(f"‚úÖ Saved best teacher: {val_acc:.4f}")

print(f"\n‚úÖ Teacher complete! Best: {best_teacher_acc:.4f}")
teacher_model.load_state_dict(torch.load('teacher_model.pth'))

# ==========================================================
# üéí TRAIN STUDENT WITH DISTILLATION (FILE-BY-FILE)
# ==========================================================

print("\n" + "=" * 80)
print("üéí Training Student with Knowledge Distillation...")
print("=" * 80)

student_model = StudentLSTM(
    input_size=n_components,
    hidden_size=24,  # Tiny
    num_classes=n_classes,
    dropout=0.2
).to(device)

student_params = sum(p.numel() for p in student_model.parameters())
print(f"Student: {student_params:,} parameters")
print(f"Compression: {teacher_params/student_params:.1f}x")

student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)
distillation_criterion = DistillationLoss(temperature=3.0, alpha=0.7)

epochs = 8
best_student_acc = 0

for epoch in range(epochs):
    print(f"\n{'='*80}")
    print(f"STUDENT EPOCH {epoch+1}/{epochs}")
    print(f"{'='*80}")

    files_to_use = train_files[(epoch*2) % len(train_files):(epoch*2+2) % len(train_files)]
    if len(files_to_use) < 2:
        files_to_use = train_files[:2]

    for i, train_file in enumerate(files_to_use):
        print(f"\nüìÇ Training on file {i+1}/{len(files_to_use)}: {os.path.basename(train_file)}")

        train_loss = train_on_file(
            student_model, train_file, scaler, pca, label_encoder,
            student_optimizer, distillation_criterion, device, batch_size=32,
            is_distillation=True, teacher_model=teacher_model
        )

        print(f"   Loss: {train_loss:.4f}")
        print_ram_usage()

    print(f"\nüìä Validating...")
    val_criterion = nn.CrossEntropyLoss()
    val_loss, val_acc = evaluate_on_file(
        student_model, val_files[0], scaler, pca, label_encoder,
        val_criterion, device, batch_size=32
    )

    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_student_acc:
        best_student_acc = val_acc
        torch.save(student_model.state_dict(), 'student_model.pth')
        print(f"‚úÖ Saved best student: {val_acc:.4f}")

print(f"\n‚úÖ Student complete! Best: {best_student_acc:.4f}")

# ==========================================================
# üìà FINAL EVALUATION
# ==========================================================

print("\n" + "=" * 80)
print("üìà Final Evaluation on Test Set...")
print("=" * 80)

student_model.load_state_dict(torch.load('student_model.pth'))

# Evaluate on first test file
test_criterion = nn.CrossEntropyLoss()
test_loss, test_acc = evaluate_on_file(
    student_model, test_files[0], scaler, pca, label_encoder,
    test_criterion, device, batch_size=32
)

print(f"\nüìä FINAL RESULTS:")
print(f"   Teacher Accuracy: {best_teacher_acc:.4f}")
print(f"   Student Accuracy: {best_student_acc:.4f}")
print(f"   Test Accuracy:    {test_acc:.4f}")
print(f"   Compression:      {teacher_params/student_params:.1f}x")

# ==========================================================
# üíæ SAVE EVERYTHING
# ==========================================================

print("\n" + "=" * 80)
print("üíæ Saving Models...")
print("=" * 80)

torch.save(teacher_model.state_dict(), 'teacher_final.pth')
torch.save(student_model.state_dict(), 'student_final.pth')

with open('preprocessing.pkl', 'wb') as f:
    pickle.dump({
        'scaler': scaler,
        'pca': pca,
        'label_encoder': label_encoder
    }, f)

metadata = {
    'teacher_params': int(teacher_params),
    'student_params': int(student_params),
    'compression_ratio': float(teacher_params / student_params),
    'n_classes': int(n_classes),
    'teacher_accuracy': float(best_teacher_acc),
    'student_accuracy': float(best_student_acc),
    'test_accuracy': float(test_acc)
}

with open('metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print("‚úÖ All models saved!")
print(f"\nüéâ SUCCESS!")
print(f"   Compression: {teacher_params/student_params:.1f}x smaller")
print(f"   Performance: {test_acc:.4f}")
print("=" * 80)

üéÆ GPU Configuration
‚ö†Ô∏è  No GPU detected

üì• Downloading Dataset...
Using Colab cache for faster access to the 'cic-iot-2023' dataset.
üìÇ Found 169 CSV files
üìä Split: 101 train, 34 val, 34 test files

üè∑Ô∏è  Fitting Preprocessing (Memory-Safe)...
Sampled 1000 rows from part-00000-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv
Sampled 1000 rows from part-00001-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv
Sampled 1000 rows from part-00002-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv
‚úÖ Found 28 classes
‚úÖ Scaler fitted
‚úÖ PCA fitted with 20 components
üíæ RAM Usage: 652 MB

üéì Training Teacher Model (Streaming)...
Teacher: 38,492 parameters

TEACHER EPOCH 1/5

üìÇ Training on file 1/2: part-00000-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv
Error processing chunk: y contains previously unseen labels: np.str_('CommandInjection')
Error processing chunk: y contains previously unseen labels: np.str_('Backdoor_Malware')
Error processing chunk: y contains previously uns

In [None]:
#!/usr/bin/env python3
"""
Knowledge Distillation for IoT Intrusion Detection - Full File Processing
- Reduced model sizes for memory efficiency
- Process entire files at once (no chunking within files)
- Handle all 169 files properly
- Stream one file at a time to avoid RAM overflow
"""

import os
import gc
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import IncrementalPCA
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import kagglehub
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import warnings
warnings.filterwarnings('ignore')

# ==========================================================
# üéÆ GPU CONFIGURATION
# ==========================================================

def setup_gpu():
    """Configure PyTorch to use GPU efficiently"""
    print("=" * 80)
    print("üéÆ GPU Configuration")
    print("=" * 80)

    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"‚úÖ GPU detected: {torch.cuda.get_device_name(0)}")
        print(f"‚úÖ CUDA Version: {torch.version.cuda}")
        print(f"‚úÖ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        torch.backends.cudnn.benchmark = True
        print("‚úÖ cuDNN autotuner enabled")
    else:
        device = torch.device('cpu')
        print("‚ö†Ô∏è  No GPU detected, running on CPU")

    print("=" * 80 + "\n")
    return device

device = setup_gpu()

# ==========================================================
# üßπ MEMORY MANAGEMENT
# ==========================================================

def clear_memory():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def print_memory_stats():
    """Print RAM and GPU usage"""
    try:
        import psutil
        process = psutil.Process()
        ram_gb = process.memory_info().rss / 1e9
        print(f"üíæ RAM Usage: {ram_gb:.2f} GB", end="")
    except:
        pass

    if torch.cuda.is_available():
        gpu_gb = torch.cuda.memory_allocated() / 1e9
        print(f" | GPU: {gpu_gb:.2f} GB")
    else:
        print()

# ==========================================================
# üßπ HELPER FUNCTIONS
# ==========================================================

def load_and_clean(path, label_col=None):
    """Load CSV and separate features from labels"""
    df = pd.read_csv(path, low_memory=False)
    df = df.dropna()
    df = df.drop_duplicates()

    if label_col is None:
        label_col = "Label" if "Label" in df.columns else df.columns[-1]

    X = df.drop(columns=[label_col])
    y = df[label_col]

    del df
    gc.collect()

    return X, y

def encode_objects(X):
    """Encode categorical columns and convert to numpy array"""
    for col in X.select_dtypes(include=["object"]).columns:
        try:
            X[col] = LabelEncoder().fit_transform(X[col].astype(str))
        except:
            X[col] = 0
    return X.values.astype(np.float32)

def load_and_process_file(filepath, scaler, pca, label_encoder):
    """Load and process a single file completely"""
    try:
        X, y = load_and_clean(filepath)
        X = encode_objects(X)

        X_scaled = scaler.transform(X)
        X_reduced = pca.transform(X_scaled)
        y_encoded = label_encoder.transform(y.astype(str))

        del X, y, X_scaled
        gc.collect()

        return X_reduced, y_encoded
    except Exception as e:
        print(f"‚ùå Error processing {os.path.basename(filepath)}: {e}")
        return None, None

# ==========================================================
# üì¶ FULL FILE DATASET
# ==========================================================

class FullFileDataset(Dataset):
    """Dataset that holds entire file in memory"""

    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ==========================================================
# üéì REDUCED PYTORCH MODELS
# ==========================================================

class TeacherLSTM(nn.Module):
    """Teacher Model - [128, 64] (Reduced from [256,128,64])"""

    def __init__(self, input_size, hidden_sizes, num_classes, dropout=0.3):
        super(TeacherLSTM, self).__init__()

        self.lstm1 = nn.LSTM(input_size, hidden_sizes[0], batch_first=True)
        self.dropout1 = nn.Dropout(dropout)

        self.lstm2 = nn.LSTM(hidden_sizes[0], hidden_sizes[1], batch_first=True)
        self.dropout2 = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_sizes[1], 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        # x shape: (batch, seq_len, features)
        out, _ = self.lstm1(x)
        out = self.dropout1(out)

        out, _ = self.lstm2(out)
        out = self.dropout2(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu1(self.fc1(out))
        out = self.relu2(self.fc2(out))
        out = self.fc3(out)

        return out


class StudentLSTM(nn.Module):
    """Student Model - [32] (Reduced from [32,16])"""

    def __init__(self, input_size, hidden_size, num_classes, dropout=0.2):
        super(StudentLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_size, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.dropout(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu(self.fc1(out))
        out = self.fc2(out)

        return out

# ==========================================================
# üéì KNOWLEDGE DISTILLATION LOSS
# ==========================================================

class DistillationLoss(nn.Module):
    """Combined loss for knowledge distillation"""

    def __init__(self, temperature=4.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # Hard target loss
        hard_loss = self.ce_loss(student_logits, labels)

        # Soft target loss
        soft_student = torch.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = torch.softmax(teacher_logits / self.temperature, dim=1)

        soft_loss = self.kl_loss(soft_student, soft_teacher) * (self.temperature ** 2)

        # Combined loss
        total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss

        return total_loss

# ==========================================================
# üèãÔ∏è TRAINING FUNCTIONS (FULL FILE AT ONCE)
# ==========================================================

def train_on_file(model, filepath, scaler, pca, label_encoder, optimizer,
                  criterion, device, batch_size=512, is_distillation=False,
                  teacher_model=None):
    """Train on entire file at once"""

    # Load and process entire file
    X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

    if X_file is None:
        return 0

    # Create dataset and dataloader
    dataset = FullFileDataset(X_file, y_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    model.train()
    if teacher_model is not None:
        teacher_model.eval()

    total_loss = 0
    total_samples = 0

    for X_batch, y_batch in dataloader:
        X_batch = X_batch.unsqueeze(1).to(device)  # Add sequence dimension
        y_batch = y_batch.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(X_batch)

        if is_distillation and teacher_model is not None:
            with torch.no_grad():
                teacher_outputs = teacher_model(X_batch)
            loss = criterion(outputs, teacher_outputs, y_batch)
        else:
            loss = criterion(outputs, y_batch)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(y_batch)
        total_samples += len(y_batch)

        del X_batch, y_batch, outputs
        clear_memory()

    # Clean up file data
    del X_file, y_file, dataset, dataloader
    clear_memory()

    return total_loss / total_samples if total_samples > 0 else 0


def evaluate_on_files(model, file_list, scaler, pca, label_encoder,
                      criterion, device, batch_size=512):
    """Evaluate on multiple files"""

    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0

    for filepath in file_list:
        X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

        if X_file is None:
            continue

        dataset = FullFileDataset(X_file, y_file)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

        with torch.no_grad():
            for X_batch, y_batch in dataloader:
                X_batch = X_batch.unsqueeze(1).to(device)
                y_batch = y_batch.to(device)

                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == y_batch).sum().item()
                total_loss += loss.item() * len(y_batch)
                total_samples += len(y_batch)

                del X_batch, y_batch, outputs
                clear_memory()

        del X_file, y_file, dataset, dataloader
        clear_memory()

    accuracy = correct / total_samples if total_samples > 0 else 0
    avg_loss = total_loss / total_samples if total_samples > 0 else 0

    return avg_loss, accuracy

# ==========================================================
# üìÇ DOWNLOAD & SPLIT DATASET (169 FILES)
# ==========================================================

print("=" * 80)
print("üì• Downloading CIC-IoT-2023 Dataset from Kaggle...")
print("=" * 80)

dataset_dir = kagglehub.dataset_download("akashdogra/cic-iot-2023")
print(f"‚úÖ Dataset downloaded to: {dataset_dir}")

csv_files = sorted([
    os.path.join(dataset_dir, f)
    for f in os.listdir(dataset_dir)
    if f.endswith(".csv")
])

print(f"üìÇ Found {len(csv_files)} CSV files.")

# 60-20-20 split
n_files = len(csv_files)
train_idx = int(n_files * 0.60)
val_idx = int(n_files * 0.80)

train_files = csv_files[:train_idx]
val_files = csv_files[train_idx:val_idx]
test_files = csv_files[val_idx:]

print(f"\nüìä Dataset Split (from {n_files} files):")
print(f"   Training:   {len(train_files)} files")
print(f"   Validation: {len(val_files)} files")
print(f"   Testing:    {len(test_files)} files")

# ==========================================================
# üè∑Ô∏è FIT PREPROCESSING (SAMPLING FROM MULTIPLE FILES)
# ==========================================================

print("\n" + "=" * 80)
print("üè∑Ô∏è  Fitting Preprocessing...")
print("=" * 80)

# Collect labels and samples from first 5 files
all_labels = []
sample_data = []

for i, filepath in enumerate(train_files[:5]):
    print(f"Sampling file {i+1}/5: {os.path.basename(filepath)}")

    # Load small sample
    df = pd.read_csv(filepath, nrows=1000, low_memory=False)
    df = df.dropna()

    label_col = "Label" if "Label" in df.columns else df.columns[-1]
    X = df.drop(columns=[label_col])
    y = df[label_col]

    all_labels.extend(list(y.astype(str)))

    # Encode objects
    for col in X.select_dtypes(include=["object"]).columns:
        try:
            X[col] = LabelEncoder().fit_transform(X[col].astype(str))
        except:
            X[col] = 0

    sample_data.append(X.values.astype(np.float32))

    del df, X, y
    gc.collect()

# Fit label encoder
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)
n_classes = len(label_encoder.classes_)

print(f"‚úÖ LabelEncoder fitted with {n_classes} classes")

# Fit scaler
scaler = StandardScaler()
for data in sample_data:
    scaler.partial_fit(data)

print(f"‚úÖ Scaler fitted")

# Fit PCA
n_features = sample_data[0].shape[1]
n_components = min(30, n_features)

pca = IncrementalPCA(n_components=n_components)
for data in sample_data:
    X_scaled = scaler.transform(data)
    pca.partial_fit(X_scaled)

print(f"‚úÖ PCA fitted with {n_components} components (from {n_features} features)")

del all_labels, sample_data
clear_memory()
print_memory_stats()

# ==========================================================
# üéì STAGE 1: TRAIN TEACHER MODEL
# ==========================================================

print("\n" + "=" * 80)
print("üéì STAGE 1: Training Teacher Model")
print("=" * 80)

# Initialize teacher model with REDUCED sizes
teacher_model = TeacherLSTM(
    input_size=n_components,
    hidden_sizes=[128, 64],  # Reduced from [256, 128, 64]
    num_classes=n_classes,
    dropout=0.3
).to(device)

teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"\nüèóÔ∏è  Teacher Model: {teacher_params:,} parameters")
print(f"   Architecture: Input({n_components}) ‚Üí LSTM(128) ‚Üí LSTM(64) ‚Üí FC(64) ‚Üí FC(32) ‚Üí Output({n_classes})")

# Optimizer and criterion
teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
teacher_criterion = nn.CrossEntropyLoss()

# Training settings
epochs_teacher = 3  # Train over all files 3 times
batch_size = 512  # Large batch size allowed
files_per_epoch = 5  # Process 5 files per epoch cycle

best_teacher_acc = 0
patience_counter = 0
patience = 3

print("\nüöÄ Training Teacher Model...")
print(f"   Batch Size: {batch_size}")
print(f"   Files per Epoch Cycle: {files_per_epoch}")

for epoch in range(epochs_teacher):
    print(f"\n{'='*80}")
    print(f"TEACHER EPOCH {epoch+1}/{epochs_teacher}")
    print(f"{'='*80}")

    # Select rotating files
    start_idx = (epoch * files_per_epoch) % len(train_files)
    end_idx = min(start_idx + files_per_epoch, len(train_files))
    selected_files = train_files[start_idx:end_idx]

    if len(selected_files) < files_per_epoch and len(train_files) > files_per_epoch:
        remaining = files_per_epoch - len(selected_files)
        selected_files += train_files[:remaining]

    print(f"Training on {len(selected_files)} files (indices {start_idx} to {end_idx})")

    # Train on each file
    epoch_losses = []
    for i, filepath in enumerate(selected_files):
        print(f"\n  üìÇ File {i+1}/{len(selected_files)}: {os.path.basename(filepath)}")

        train_loss = train_on_file(
            teacher_model, filepath, scaler, pca, label_encoder,
            teacher_optimizer, teacher_criterion, device, batch_size=batch_size
        )

        epoch_losses.append(train_loss)
        print(f"     Loss: {train_loss:.4f}")
        print_memory_stats()

    avg_train_loss = np.mean(epoch_losses)

    # Validate on subset of validation files
    print(f"\n  üìä Validating...")
    val_loss, val_acc = evaluate_on_files(
        teacher_model, val_files[:3], scaler, pca, label_encoder,
        teacher_criterion, device, batch_size=batch_size
    )

    print(f"\n  üìà Epoch Summary:")
    print(f"     Avg Train Loss: {avg_train_loss:.4f}")
    print(f"     Val Loss: {val_loss:.4f}")
    print(f"     Val Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_teacher_acc:
        best_teacher_acc = val_acc
        torch.save(teacher_model.state_dict(), 'teacher_model.pth')
        print(f"  ‚úÖ Best teacher model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    clear_memory()

print("\n‚úÖ Teacher Model Training Complete!")
print(f"   Best Validation Accuracy: {best_teacher_acc:.4f}")
teacher_model.load_state_dict(torch.load('teacher_model.pth'))

# ==========================================================
# üéí STAGE 2: KNOWLEDGE DISTILLATION - TRAIN STUDENT
# ==========================================================

print("\n" + "=" * 80)
print("üéí STAGE 2: Knowledge Distillation - Training Student Model")
print("=" * 80)

# Initialize student model with REDUCED size
student_model = StudentLSTM(
    input_size=n_components,
    hidden_size=32,  # Single layer, reduced from [32, 16]
    num_classes=n_classes,
    dropout=0.2
).to(device)

student_params = sum(p.numel() for p in student_model.parameters())
reduction_ratio = teacher_params / student_params

print(f"\nüèóÔ∏è  Student Model: {student_params:,} parameters")
print(f"   Architecture: Input({n_components}) ‚Üí LSTM(32) ‚Üí FC(32) ‚Üí Output({n_classes})")
print(f"\nüìä Model Comparison:")
print(f"   Teacher Parameters: {teacher_params:,}")
print(f"   Student Parameters: {student_params:,}")
print(f"   Size Reduction:     {reduction_ratio:.1f}x smaller")

# Optimizer and distillation loss
student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)
distillation_criterion = DistillationLoss(temperature=4.0, alpha=0.7)

epochs_student = 4
best_student_acc = 0
patience_counter = 0

print(f"\nüöÄ Training Student with Knowledge Distillation...")
print(f"   Temperature: {distillation_criterion.temperature}")
print(f"   Alpha (soft target weight): {distillation_criterion.alpha}")
print(f"   Batch Size: {batch_size}")

for epoch in range(epochs_student):
    print(f"\n{'='*80}")
    print(f"STUDENT EPOCH {epoch+1}/{epochs_student}")
    print(f"{'='*80}")

    # Select rotating files
    start_idx = (epoch * files_per_epoch) % len(train_files)
    end_idx = min(start_idx + files_per_epoch, len(train_files))
    selected_files = train_files[start_idx:end_idx]

    if len(selected_files) < files_per_epoch and len(train_files) > files_per_epoch:
        remaining = files_per_epoch - len(selected_files)
        selected_files += train_files[:remaining]

    print(f"Training on {len(selected_files)} files (indices {start_idx} to {end_idx})")

    # Train with distillation
    epoch_losses = []
    for i, filepath in enumerate(selected_files):
        print(f"\n  üìÇ File {i+1}/{len(selected_files)}: {os.path.basename(filepath)}")

        train_loss = train_on_file(
            student_model, filepath, scaler, pca, label_encoder,
            student_optimizer, distillation_criterion, device,
            batch_size=batch_size, is_distillation=True, teacher_model=teacher_model
        )

        epoch_losses.append(train_loss)
        print(f"     Loss: {train_loss:.4f}")
        print_memory_stats()

    avg_train_loss = np.mean(epoch_losses)

    # Validate
    print(f"\n  üìä Validating...")
    val_criterion = nn.CrossEntropyLoss()
    val_loss, val_acc = evaluate_on_files(
        student_model, val_files[:3], scaler, pca, label_encoder,
        val_criterion, device, batch_size=batch_size
    )

    print(f"\n  üìà Epoch Summary:")
    print(f"     Avg Train Loss: {avg_train_loss:.4f}")
    print(f"     Val Loss: {val_loss:.4f}")
    print(f"     Val Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_student_acc:
        best_student_acc = val_acc
        torch.save(student_model.state_dict(), 'student_model.pth')
        print(f"  ‚úÖ Best student model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    clear_memory()

print("\n‚úÖ Student Model Training Complete!")
print(f"   Best Validation Accuracy: {best_student_acc:.4f}")
student_model.load_state_dict(torch.load('student_model.pth'))

# ==========================================================
# üìà STAGE 3: FINAL EVALUATION
# ==========================================================

print("\n" + "=" * 80)
print("üìà STAGE 3: Final Evaluation on Test Set")
print("=" * 80)

def evaluate_model_detailed(model, model_name, file_list):
    """Evaluate model on test set with detailed metrics"""
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name}...")
    print(f"{'='*60}")

    model.eval()
    y_true_all = []
    y_pred_all = []

    for i, filepath in enumerate(file_list):
        print(f"Processing file {i+1}/{len(file_list)}: {os.path.basename(filepath)}")

        X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

        if X_file is None:
            continue

        dataset = FullFileDataset(X_file, y_file)
        dataloader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0)

        with torch.no_grad():
            for X_batch, y_batch in dataloader:
                X_batch = X_batch.unsqueeze(1).to(device)

                outputs = model(X_batch)
                _, predicted = torch.max(outputs, 1)

                y_true_all.extend(y_batch.numpy())
                y_pred_all.extend(predicted.cpu().numpy())

                del X_batch, outputs
                clear_memory()

        del X_file, y_file, dataset, dataloader
        clear_memory()

    y_true_all = np.array(y_true_all)
    y_pred_all = np.array(y_pred_all)

    accuracy = accuracy_score(y_true_all, y_pred_all)
    precision = precision_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    recall = recall_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    f1 = f1_score(y_true_all, y_pred_all, average='weighted', zero_division=0)

    print(f"\nüìä {model_name} Performance:")
    print(f"   Accuracy:  {accuracy:.4f}")
    print(f"   Precision: {precision:.4f}")
    print(f"   Recall:    {recall:.4f}")
    print(f"   F1-Score:  {f1:.4f}")

    return y_true_all, y_pred_all, accuracy, precision, recall, f1

# Evaluate both models
teacher_results = evaluate_model_detailed(teacher_model, "TEACHER MODEL", test_files)
student_results = evaluate_model_detailed(student_model, "STUDENT MODEL (Distilled)", test_files)

# ==========================================================
# üìä GENERATE REPORTS
# ==========================================================

print("\n" + "=" * 80)
print("üìä Generating Final Report...")
print("=" * 80)

y_true, y_pred, s_acc, s_prec, s_rec, s_f1 = student_results
_, _, t_acc, t_prec, t_rec, t_f1 = teacher_results

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(20, 16))
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_,
            cbar_kws={'label': 'Count'})
plt.title('Student Model Confusion Matrix (Knowledge Distillation)', fontsize=16, pad=20)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right', fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.savefig('student_confusion_matrix.png', dpi=300, bbox_inches='tight')
print("‚úÖ Confusion matrix saved as 'student_confusion_matrix.png'")

# Performance comparison
performance_retention = (s_acc / t_acc) * 100 if t_acc > 0 else 0

print("\n" + "=" * 80)
print("üìä FINAL COMPARISON: TEACHER vs STUDENT")
print("=" * 80)
print(f"\n{'Metric':<15} {'Teacher':<15} {'Student':<15} {'Difference':<15}")
print("=" * 80)
print(f"{'Accuracy':<15} {t_acc:<15.4f} {s_acc:<15.4f} {(s_acc-t_acc):<15.4f}")
print(f"{'Precision':<15} {t_prec:<15.4f} {s_prec:<15.4f} {(s_prec-t_prec):<15.4f}")
print(f"{'Recall':<15} {t_rec:<15.4f} {s_rec:<15.4f} {(s_rec-t_rec):<15.4f}")
print(f"{'F1-Score':<15} {t_f1:<15.4f} {s_f1:<15.4f} {(s_f1-t_f1):<15.4f}")
print(f"{'Parameters':<15} {teacher_params:<15,} {student_params:<15,} {'-':<15}")
print(f"{'Model Size':<15} {'1.0x':<15} {f'{1/reduction_ratio:.2f}x':<15} {f'{reduction_ratio:.1f}x smaller':<15}")
print("=" * 80)

print(f"\nüéØ Performance Retention: {performance_retention:.2f}%")
print(f"üéØ Model Size Reduction: {reduction_ratio:.1f}x smaller")
print(f"üéØ Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}% fewer parameters")

# ==========================================================
# üíæ SAVE MODELS AND PREPROCESSING OBJECTS
# ==========================================================

print("\n" + "=" * 80)
print("üíæ Saving Models and Preprocessing Objects")
print("=" * 80)

# Save PyTorch models
torch.save({
    'model_state_dict': teacher_model.state_dict(),
    'input_size': n_components,
    'hidden_sizes': [128, 64],
    'num_classes': n_classes,
    'accuracy': t_acc,
    'params': teacher_params
}, 'teacher_model_complete.pth')
print("‚úÖ Saved: teacher_model_complete.pth")

torch.save({
    'model_state_dict': student_model.state_dict(),
    'input_size': n_components,
    'hidden_size': 32,
    'num_classes': n_classes,
    'accuracy': s_acc,
    'params': student_params
}, 'student_model_complete.pth')
print("‚úÖ Saved: student_model_complete.pth")

# Save preprocessing objects
preprocessing_objects = {
    'scaler': scaler,
    'pca': pca,
    'label_encoder': label_encoder
}

with open('preprocessing.pkl', 'wb') as f:
    pickle.dump(preprocessing_objects, f)
print("‚úÖ Saved: preprocessing.pkl")

# Save metadata
metadata = {
    'n_classes': int(n_classes),
    'n_features': int(n_features),
    'n_components': int(n_components),
    'teacher_params': int(teacher_params),
    'student_params': int(student_params),
    'teacher_accuracy': float(t_acc),
    'student_accuracy': float(s_acc),
    'size_reduction': float(reduction_ratio),
    'performance_retention': float(performance_retention),
    'total_files': len(csv_files),
    'train_files': len(train_files),
    'val_files': len(val_files),
    'test_files': len(test_files),
    'classes': label_encoder.classes_.tolist()
}

with open('model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=4)
print("‚úÖ Saved: model_metadata.json")

# Create summary
with open('model_summary.txt', 'w') as f:
    f.write("=" * 80 + "\n")
    f.write("KNOWLEDGE DISTILLATION - PYTORCH MODEL SUMMARY\n")
    f.write("=" * 80 + "\n\n")

    f.write("DATASET INFORMATION:\n")
    f.write(f"  Total Files: {len(csv_files)}\n")
    f.write(f"  Training Files: {len(train_files)}\n")
    f.write(f"  Validation Files: {len(val_files)}\n")
    f.write(f"  Test Files: {len(test_files)}\n\n")

    f.write("TEACHER MODEL:\n")
    f.write(f"  Architecture: LSTM [128, 64]\n")
    f.write(f"  Parameters: {teacher_params:,}\n")
    f.write(f"  Accuracy: {t_acc:.4f}\n")
    f.write(f"  Precision: {t_prec:.4f}\n")
    f.write(f"  Recall: {t_rec:.4f}\n")
    f.write(f"  F1-Score: {t_f1:.4f}\n\n")

    f.write("STUDENT MODEL (DISTILLED):\n")
    f.write(f"  Architecture: LSTM [32]\n")
    f.write(f"  Parameters: {student_params:,}\n")
    f.write(f"  Accuracy: {s_acc:.4f}\n")
    f.write(f"  Precision: {s_prec:.4f}\n")
    f.write(f"  Recall: {s_rec:.4f}\n")
    f.write(f"  F1-Score: {s_f1:.4f}\n\n")

    f.write("COMPRESSION METRICS:\n")
    f.write(f"  Size Reduction: {reduction_ratio:.1f}x smaller\n")
    f.write(f"  Performance Retention: {performance_retention:.2f}%\n")
    f.write(f"  Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}%\n\n")

    f.write("FILES GENERATED:\n")
    f.write("  - teacher_model_complete.pth (Teacher model with metadata)\n")
    f.write("  - student_model_complete.pth (Student model with metadata)\n")
    f.write("  - preprocessing.pkl (Scaler, PCA, Label Encoder)\n")
    f.write("  - model_metadata.json (Model specifications)\n")
    f.write("  - student_confusion_matrix.png (Confusion matrix visualization)\n")

print("‚úÖ Saved: model_summary.txt")

print("\n" + "=" * 80)
print("üéâ KNOWLEDGE DISTILLATION COMPLETE!")
print("=" * 80)
print(f"\n‚ú® Successfully processed all {len(csv_files)} files!")
print(f"‚ú® Teacher Model: {teacher_params:,} parameters ‚Üí Accuracy: {t_acc:.4f}")
print(f"‚ú® Student Model: {student_params:,} parameters ‚Üí Accuracy: {s_acc:.4f}")
print(f"‚ú® Compression: {reduction_ratio:.1f}x smaller with {performance_retention:.1f}% performance retention")
print("\nüì¶ All models saved and ready for deployment!")
print("=" * 80)

üéÆ GPU Configuration
‚úÖ GPU detected: Tesla T4
‚úÖ CUDA Version: 12.6
‚úÖ GPU Memory: 15.83 GB
‚úÖ cuDNN autotuner enabled

üì• Downloading CIC-IoT-2023 Dataset from Kaggle...
Using Colab cache for faster access to the 'cic-iot-2023' dataset.
‚úÖ Dataset downloaded to: /kaggle/input/cic-iot-2023
üìÇ Found 169 CSV files.

üìä Dataset Split (from 169 files):
   Training:   101 files
   Validation: 34 files
   Testing:    34 files

üè∑Ô∏è  Fitting Preprocessing...
Sampling file 1/5: part-00000-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv
Sampling file 2/5: part-00001-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv
Sampling file 3/5: part-00002-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv
Sampling file 4/5: part-00003-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv
Sampling file 5/5: part-00004-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv
‚úÖ LabelEncoder fitted with 29 classes
‚úÖ Scaler fitted
‚úÖ PCA fitted with 30 components (from 46 features)
üíæ RAM Usage: 0.72 GB | GPU: 0.00 

FileNotFoundError: [Errno 2] No such file or directory: 'teacher_model.pth'

In [None]:
#!/usr/bin/env python3
"""
Knowledge Distillation for IoT Intrusion Detection - Full File Processing
- Reduced model sizes for memory efficiency
- Process entire files at once (no chunking within files)
- Handle all 169 files properly
- Stream one file at a time to avoid RAM overflow
"""

import os
import gc
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import IncrementalPCA
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import kagglehub
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import warnings
warnings.filterwarnings('ignore')

# ==========================================================
# üéÆ GPU CONFIGURATION
# ==========================================================

def setup_gpu():
    """Configure PyTorch to use GPU efficiently"""
    print("=" * 80)
    print("üéÆ GPU Configuration")
    print("=" * 80)

    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"‚úÖ GPU detected: {torch.cuda.get_device_name(0)}")
        print(f"‚úÖ CUDA Version: {torch.version.cuda}")
        print(f"‚úÖ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        torch.backends.cudnn.benchmark = True
        print("‚úÖ cuDNN autotuner enabled")
    else:
        device = torch.device('cpu')
        print("‚ö†Ô∏è  No GPU detected, running on CPU")

    print("=" * 80 + "\n")
    return device

device = setup_gpu()

# ==========================================================
# üßπ MEMORY MANAGEMENT
# ==========================================================

def clear_memory():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def print_memory_stats():
    """Print RAM and GPU usage"""
    try:
        import psutil
        process = psutil.Process()
        ram_gb = process.memory_info().rss / 1e9
        print(f"üíæ RAM Usage: {ram_gb:.2f} GB", end="")
    except:
        pass

    if torch.cuda.is_available():
        gpu_gb = torch.cuda.memory_allocated() / 1e9
        print(f" | GPU: {gpu_gb:.2f} GB")
    else:
        print()

# ==========================================================
# üßπ HELPER FUNCTIONS
# ==========================================================

def load_and_clean(path, label_col=None):
    """Load CSV and separate features from labels"""
    df = pd.read_csv(path, low_memory=False)
    df = df.dropna()
    df = df.drop_duplicates()

    if label_col is None:
        label_col = "Label" if "Label" in df.columns else df.columns[-1]

    X = df.drop(columns=[label_col])
    y = df[label_col]

    del df
    gc.collect()

    return X, y

def encode_objects(X):
    """Encode categorical columns and convert to numpy array"""
    for col in X.select_dtypes(include=["object"]).columns:
        try:
            X[col] = LabelEncoder().fit_transform(X[col].astype(str))
        except:
            X[col] = 0
    return X.values.astype(np.float32)

def load_and_process_file(filepath, scaler, pca, label_encoder):
    """Load and process a single file completely"""
    try:
        X, y = load_and_clean(filepath)
        X = encode_objects(X)

        X_scaled = scaler.transform(X)
        X_reduced = pca.transform(X_scaled)
        y_encoded = label_encoder.transform(y.astype(str))

        del X, y, X_scaled
        gc.collect()

        return X_reduced, y_encoded
    except Exception as e:
        print(f"‚ùå Error processing {os.path.basename(filepath)}: {e}")
        return None, None

# ==========================================================
# üì¶ FULL FILE DATASET
# ==========================================================

class FullFileDataset(Dataset):
    """Dataset that holds entire file in memory"""

    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ==========================================================
# üéì REDUCED PYTORCH MODELS
# ==========================================================

class TeacherLSTM(nn.Module):
    """Teacher Model - [128, 64] (Reduced from [256,128,64])"""

    def __init__(self, input_size, hidden_sizes, num_classes, dropout=0.3):
        super(TeacherLSTM, self).__init__()

        self.lstm1 = nn.LSTM(input_size, hidden_sizes[0], batch_first=True)
        self.dropout1 = nn.Dropout(dropout)

        self.lstm2 = nn.LSTM(hidden_sizes[0], hidden_sizes[1], batch_first=True)
        self.dropout2 = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_sizes[1], 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        # x shape: (batch, seq_len, features)
        out, _ = self.lstm1(x)
        out = self.dropout1(out)

        out, _ = self.lstm2(out)
        out = self.dropout2(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu1(self.fc1(out))
        out = self.relu2(self.fc2(out))
        out = self.fc3(out)

        return out


class StudentLSTM(nn.Module):
    """Student Model - [32] (Reduced from [32,16])"""

    def __init__(self, input_size, hidden_size, num_classes, dropout=0.2):
        super(StudentLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_size, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.dropout(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu(self.fc1(out))
        out = self.fc2(out)

        return out

# ==========================================================
# üéì KNOWLEDGE DISTILLATION LOSS
# ==========================================================

class DistillationLoss(nn.Module):
    """Combined loss for knowledge distillation"""

    def __init__(self, temperature=4.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # Hard target loss
        hard_loss = self.ce_loss(student_logits, labels)

        # Soft target loss
        soft_student = torch.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = torch.softmax(teacher_logits / self.temperature, dim=1)

        soft_loss = self.kl_loss(soft_student, soft_teacher) * (self.temperature ** 2)

        # Combined loss
        total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss

        return total_loss

# ==========================================================
# üèãÔ∏è TRAINING FUNCTIONS (FULL FILE AT ONCE)
# ==========================================================

def train_on_file(model, filepath, scaler, pca, label_encoder, optimizer,
                  criterion, device, batch_size=512, is_distillation=False,
                  teacher_model=None):
    """Train on entire file at once"""

    # Load and process entire file
    X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

    if X_file is None:
        return 0

    # Create dataset and dataloader
    dataset = FullFileDataset(X_file, y_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    model.train()
    if teacher_model is not None:
        teacher_model.eval()

    total_loss = 0
    total_samples = 0

    for X_batch, y_batch in dataloader:
        X_batch = X_batch.unsqueeze(1).to(device)  # Add sequence dimension
        y_batch = y_batch.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(X_batch)

        if is_distillation and teacher_model is not None:
            with torch.no_grad():
                teacher_outputs = teacher_model(X_batch)
            loss = criterion(outputs, teacher_outputs, y_batch)
        else:
            loss = criterion(outputs, y_batch)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(y_batch)
        total_samples += len(y_batch)

        del X_batch, y_batch, outputs
        clear_memory()

    # Clean up file data
    del X_file, y_file, dataset, dataloader
    clear_memory()

    return total_loss / total_samples if total_samples > 0 else 0


def evaluate_on_files(model, file_list, scaler, pca, label_encoder,
                      criterion, device, batch_size=512):
    """Evaluate on multiple files"""

    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0

    for filepath in file_list:
        X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

        if X_file is None:
            continue

        dataset = FullFileDataset(X_file, y_file)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

        with torch.no_grad():
            for X_batch, y_batch in dataloader:
                X_batch = X_batch.unsqueeze(1).to(device)
                y_batch = y_batch.to(device)

                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == y_batch).sum().item()
                total_loss += loss.item() * len(y_batch)
                total_samples += len(y_batch)

                del X_batch, y_batch, outputs
                clear_memory()

        del X_file, y_file, dataset, dataloader
        clear_memory()

    accuracy = correct / total_samples if total_samples > 0 else 0
    avg_loss = total_loss / total_samples if total_samples > 0 else 0

    return avg_loss, accuracy

# ==========================================================
# üìÇ DOWNLOAD & SPLIT DATASET (169 FILES)
# ==========================================================

print("=" * 80)
print("üì• Downloading CIC-IoT-2023 Dataset from Kaggle...")
print("=" * 80)

dataset_dir = kagglehub.dataset_download("akashdogra/cic-iot-2023")
print(f"‚úÖ Dataset downloaded to: {dataset_dir}")

csv_files = sorted([
    os.path.join(dataset_dir, f)
    for f in os.listdir(dataset_dir)
    if f.endswith(".csv")
])

print(f"üìÇ Found {len(csv_files)} CSV files.")

# 60-20-20 split
n_files = len(csv_files)
train_idx = int(n_files * 0.60)
val_idx = int(n_files * 0.80)

train_files = csv_files[:train_idx]
val_files = csv_files[train_idx:val_idx]
test_files = csv_files[val_idx:]

print(f"\nüìä Dataset Split (from {n_files} files):")
print(f"   Training:   {len(train_files)} files")
print(f"   Validation: {len(val_files)} files")
print(f"   Testing:    {len(test_files)} files")

# ==========================================================
# üè∑Ô∏è FIT PREPROCESSING (SCAN ALL TRAINING FILES FOR LABELS)
# ==========================================================

print("\n" + "=" * 80)
print("üè∑Ô∏è  Fitting Preprocessing - Scanning ALL Training Files...")
print("=" * 80)

# CRITICAL FIX: Scan ALL training files to collect ALL unique labels
all_labels = set()
sample_data = []

print(f"Scanning {len(train_files)} training files for all unique labels...")
for i, filepath in enumerate(train_files):
    try:
        # Read only the label column to save memory
        df = pd.read_csv(filepath, low_memory=False)
        label_col = "Label" if "Label" in df.columns else df.columns[-1]

        # Collect all unique labels from this file
        unique_labels = df[label_col].dropna().astype(str).unique()
        all_labels.update(unique_labels)

        print(f"  File {i+1}/{len(train_files)}: {os.path.basename(filepath)} - Found {len(unique_labels)} unique labels (Total: {len(all_labels)})")

        # Sample features from first 10 files only
        if i < 10:
            df_sample = df.head(1000).dropna()
            X = df_sample.drop(columns=[label_col])

            # Encode objects
            for col in X.select_dtypes(include=["object"]).columns:
                try:
                    X[col] = LabelEncoder().fit_transform(X[col].astype(str))
                except:
                    X[col] = 0

            sample_data.append(X.values.astype(np.float32))

        del df
        gc.collect()

    except Exception as e:
        print(f"  ‚ö†Ô∏è  Error reading {os.path.basename(filepath)}: {e}")
        continue

# Convert set to sorted list for consistent encoding
all_labels = sorted(list(all_labels))

# Fit label encoder with ALL labels
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)
n_classes = len(label_encoder.classes_)

print(f"\n‚úÖ LabelEncoder fitted with {n_classes} classes")
print(f"   Classes found: {', '.join(label_encoder.classes_[:10])}{'...' if n_classes > 10 else ''}")

# Fit scaler
scaler = StandardScaler()
for data in sample_data:
    scaler.partial_fit(data)

print(f"‚úÖ Scaler fitted on {len(sample_data)} file samples")

# Fit PCA
n_features = sample_data[0].shape[1]
n_components = min(30, n_features)

pca = IncrementalPCA(n_components=n_components)
for data in sample_data:
    X_scaled = scaler.transform(data)
    pca.partial_fit(X_scaled)

print(f"‚úÖ PCA fitted with {n_components} components (from {n_features} features)")

del all_labels, sample_data
clear_memory()
print_memory_stats()

# ==========================================================
# üéì STAGE 1: TRAIN TEACHER MODEL
# ==========================================================

print("\n" + "=" * 80)
print("üéì STAGE 1: Training Teacher Model")
print("=" * 80)

# Initialize teacher model with REDUCED sizes
teacher_model = TeacherLSTM(
    input_size=n_components,
    hidden_sizes=[128, 64],  # Reduced from [256, 128, 64]
    num_classes=n_classes,
    dropout=0.3
).to(device)

teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"\nüèóÔ∏è  Teacher Model: {teacher_params:,} parameters")
print(f"   Architecture: Input({n_components}) ‚Üí LSTM(128) ‚Üí LSTM(64) ‚Üí FC(64) ‚Üí FC(32) ‚Üí Output({n_classes})")

# Optimizer and criterion
teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
teacher_criterion = nn.CrossEntropyLoss()

# Training settings
epochs_teacher = 3  # Train over all files 3 times
batch_size = 512  # Large batch size allowed
files_per_epoch = 20  # Process 20 files per epoch (will cycle through all 101 training files)

best_teacher_acc = 0
patience_counter = 0
patience = 5  # Increased patience

print("\nüöÄ Training Teacher Model...")
print(f"   Batch Size: {batch_size}")
print(f"   Files per Epoch Cycle: {files_per_epoch}")
print(f"   Total Training Files: {len(train_files)}")
print(f"   Epochs: {epochs_teacher}")

for epoch in range(epochs_teacher):
    print(f"\n{'='*80}")
    print(f"TEACHER EPOCH {epoch+1}/{epochs_teacher}")
    print(f"{'='*80}")

    # Select rotating files
    start_idx = (epoch * files_per_epoch) % len(train_files)
    end_idx = min(start_idx + files_per_epoch, len(train_files))
    selected_files = train_files[start_idx:end_idx]

    if len(selected_files) < files_per_epoch and len(train_files) > files_per_epoch:
        remaining = files_per_epoch - len(selected_files)
        selected_files += train_files[:remaining]

    print(f"Training on {len(selected_files)} files (indices {start_idx} to {end_idx})")

    # Train on each file
    epoch_losses = []
    for i, filepath in enumerate(selected_files):
        print(f"\n  üìÇ File {i+1}/{len(selected_files)}: {os.path.basename(filepath)}")

        train_loss = train_on_file(
            teacher_model, filepath, scaler, pca, label_encoder,
            teacher_optimizer, teacher_criterion, device, batch_size=batch_size
        )

        epoch_losses.append(train_loss)
        print(f"     Loss: {train_loss:.4f}")
        print_memory_stats()

    avg_train_loss = np.mean(epoch_losses)

    # Validate on subset of validation files
    print(f"\n  üìä Validating...")
    val_loss, val_acc = evaluate_on_files(
        teacher_model, val_files[:5], scaler, pca, label_encoder,
        teacher_criterion, device, batch_size=batch_size
    )

    print(f"\n  üìà Epoch Summary:")
    print(f"     Avg Train Loss: {avg_train_loss:.4f}")
    print(f"     Val Loss: {val_loss:.4f}")
    print(f"     Val Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_teacher_acc:
        best_teacher_acc = val_acc
        torch.save(teacher_model.state_dict(), 'teacher_model.pth')
        print(f"  ‚úÖ Best teacher model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    clear_memory()

print("\n‚úÖ Teacher Model Training Complete!")
print(f"   Best Validation Accuracy: {best_teacher_acc:.4f}")

# Load best model if it was saved, otherwise keep current
if os.path.exists('teacher_model.pth'):
    teacher_model.load_state_dict(torch.load('teacher_model.pth'))
    print("   Loaded best teacher model from disk")
else:
    print("   ‚ö†Ô∏è  No saved model found, using final epoch weights")

# ==========================================================
# üéí STAGE 2: KNOWLEDGE DISTILLATION - TRAIN STUDENT
# ==========================================================

print("\n" + "=" * 80)
print("üéí STAGE 2: Knowledge Distillation - Training Student Model")
print("=" * 80)

# Initialize student model with REDUCED size
student_model = StudentLSTM(
    input_size=n_components,
    hidden_size=32,  # Single layer, reduced from [32, 16]
    num_classes=n_classes,
    dropout=0.2
).to(device)

student_params = sum(p.numel() for p in student_model.parameters())
reduction_ratio = teacher_params / student_params

print(f"\nüèóÔ∏è  Student Model: {student_params:,} parameters")
print(f"   Architecture: Input({n_components}) ‚Üí LSTM(32) ‚Üí FC(32) ‚Üí Output({n_classes})")
print(f"\nüìä Model Comparison:")
print(f"   Teacher Parameters: {teacher_params:,}")
print(f"   Student Parameters: {student_params:,}")
print(f"   Size Reduction:     {reduction_ratio:.1f}x smaller")

# Optimizer and distillation loss
student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)
distillation_criterion = DistillationLoss(temperature=4.0, alpha=0.7)

epochs_student = 4
best_student_acc = 0
patience_counter = 0

print(f"\nüöÄ Training Student with Knowledge Distillation...")
print(f"   Temperature: {distillation_criterion.temperature}")
print(f"   Alpha (soft target weight): {distillation_criterion.alpha}")
print(f"   Batch Size: {batch_size}")
print(f"   Files per Epoch: {files_per_epoch}")

for epoch in range(epochs_student):
    print(f"\n{'='*80}")
    print(f"STUDENT EPOCH {epoch+1}/{epochs_student}")
    print(f"{'='*80}")

    # Select rotating files
    start_idx = (epoch * files_per_epoch) % len(train_files)
    end_idx = min(start_idx + files_per_epoch, len(train_files))
    selected_files = train_files[start_idx:end_idx]

    if len(selected_files) < files_per_epoch and len(train_files) > files_per_epoch:
        remaining = files_per_epoch - len(selected_files)
        selected_files += train_files[:remaining]

    print(f"Training on {len(selected_files)} files (indices {start_idx} to {end_idx})")

    # Train with distillation
    epoch_losses = []
    for i, filepath in enumerate(selected_files):
        print(f"\n  üìÇ File {i+1}/{len(selected_files)}: {os.path.basename(filepath)}")

        train_loss = train_on_file(
            student_model, filepath, scaler, pca, label_encoder,
            student_optimizer, distillation_criterion, device,
            batch_size=batch_size, is_distillation=True, teacher_model=teacher_model
        )

        epoch_losses.append(train_loss)
        print(f"     Loss: {train_loss:.4f}")
        print_memory_stats()

    avg_train_loss = np.mean(epoch_losses)

    # Validate
    print(f"\n  üìä Validating...")
    val_criterion = nn.CrossEntropyLoss()
    val_loss, val_acc = evaluate_on_files(
        student_model, val_files[:5], scaler, pca, label_encoder,
        val_criterion, device, batch_size=batch_size
    )

    print(f"\n  üìà Epoch Summary:")
    print(f"     Avg Train Loss: {avg_train_loss:.4f}")
    print(f"     Val Loss: {val_loss:.4f}")
    print(f"     Val Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_student_acc:
        best_student_acc = val_acc
        torch.save(student_model.state_dict(), 'student_model.pth')
        print(f"  ‚úÖ Best student model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    clear_memory()

print("\n‚úÖ Student Model Training Complete!")
print(f"   Best Validation Accuracy: {best_student_acc:.4f}")

# Load best model if it exists
if os.path.exists('student_model.pth'):
    student_model.load_state_dict(torch.load('student_model.pth'))
    print("   Loaded best student model from disk")
else:
    print("   ‚ö†Ô∏è  No saved model found, using final epoch weights")

# ==========================================================
# üìà STAGE 3: FINAL EVALUATION
# ==========================================================

print("\n" + "=" * 80)
print("üìà STAGE 3: Final Evaluation on Test Set")
print("=" * 80)

def evaluate_model_detailed(model, model_name, file_list):
    """Evaluate model on test set with detailed metrics"""
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name}...")
    print(f"{'='*60}")

    model.eval()
    y_true_all = []
    y_pred_all = []

    for i, filepath in enumerate(file_list):
        print(f"Processing file {i+1}/{len(file_list)}: {os.path.basename(filepath)}")

        X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

        if X_file is None:
            continue

        dataset = FullFileDataset(X_file, y_file)
        dataloader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0)

        with torch.no_grad():
            for X_batch, y_batch in dataloader:
                X_batch = X_batch.unsqueeze(1).to(device)

                outputs = model(X_batch)
                _, predicted = torch.max(outputs, 1)

                y_true_all.extend(y_batch.numpy())
                y_pred_all.extend(predicted.cpu().numpy())

                del X_batch, outputs
                clear_memory()

        del X_file, y_file, dataset, dataloader
        clear_memory()

    y_true_all = np.array(y_true_all)
    y_pred_all = np.array(y_pred_all)

    accuracy = accuracy_score(y_true_all, y_pred_all)
    precision = precision_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    recall = recall_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    f1 = f1_score(y_true_all, y_pred_all, average='weighted', zero_division=0)

    print(f"\nüìä {model_name} Performance:")
    print(f"   Accuracy:  {accuracy:.4f}")
    print(f"   Precision: {precision:.4f}")
    print(f"   Recall:    {recall:.4f}")
    print(f"   F1-Score:  {f1:.4f}")

    return y_true_all, y_pred_all, accuracy, precision, recall, f1

# Evaluate both models
teacher_results = evaluate_model_detailed(teacher_model, "TEACHER MODEL", test_files)
student_results = evaluate_model_detailed(student_model, "STUDENT MODEL (Distilled)", test_files)

# ==========================================================
# üìä GENERATE REPORTS
# ==========================================================

print("\n" + "=" * 80)
print("üìä Generating Final Report...")
print("=" * 80)

y_true, y_pred, s_acc, s_prec, s_rec, s_f1 = student_results
_, _, t_acc, t_prec, t_rec, t_f1 = teacher_results

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(20, 16))
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_,
            cbar_kws={'label': 'Count'})
plt.title('Student Model Confusion Matrix (Knowledge Distillation)', fontsize=16, pad=20)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right', fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.savefig('student_confusion_matrix.png', dpi=300, bbox_inches='tight')
print("‚úÖ Confusion matrix saved as 'student_confusion_matrix.png'")

# Performance comparison
performance_retention = (s_acc / t_acc) * 100 if t_acc > 0 else 0

print("\n" + "=" * 80)
print("üìä FINAL COMPARISON: TEACHER vs STUDENT")
print("=" * 80)
print(f"\n{'Metric':<15} {'Teacher':<15} {'Student':<15} {'Difference':<15}")
print("=" * 80)
print(f"{'Accuracy':<15} {t_acc:<15.4f} {s_acc:<15.4f} {(s_acc-t_acc):<15.4f}")
print(f"{'Precision':<15} {t_prec:<15.4f} {s_prec:<15.4f} {(s_prec-t_prec):<15.4f}")
print(f"{'Recall':<15} {t_rec:<15.4f} {s_rec:<15.4f} {(s_rec-t_rec):<15.4f}")
print(f"{'F1-Score':<15} {t_f1:<15.4f} {s_f1:<15.4f} {(s_f1-t_f1):<15.4f}")
print(f"{'Parameters':<15} {teacher_params:<15,} {student_params:<15,} {'-':<15}")
print(f"{'Model Size':<15} {'1.0x':<15} {f'{1/reduction_ratio:.2f}x':<15} {f'{reduction_ratio:.1f}x smaller':<15}")
print("=" * 80)

print(f"\nüéØ Performance Retention: {performance_retention:.2f}%")
print(f"üéØ Model Size Reduction: {reduction_ratio:.1f}x smaller")
print(f"üéØ Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}% fewer parameters")

# ==========================================================
# üíæ SAVE MODELS AND PREPROCESSING OBJECTS
# ==========================================================

print("\n" + "=" * 80)
print("üíæ Saving Models and Preprocessing Objects")
print("=" * 80)

# Save PyTorch models
torch.save({
    'model_state_dict': teacher_model.state_dict(),
    'input_size': n_components,
    'hidden_sizes': [128, 64],
    'num_classes': n_classes,
    'accuracy': t_acc,
    'params': teacher_params
}, 'teacher_model_complete.pth')
print("‚úÖ Saved: teacher_model_complete.pth")

torch.save({
    'model_state_dict': student_model.state_dict(),
    'input_size': n_components,
    'hidden_size': 32,
    'num_classes': n_classes,
    'accuracy': s_acc,
    'params': student_params
}, 'student_model_complete.pth')
print("‚úÖ Saved: student_model_complete.pth")

# Save preprocessing objects
preprocessing_objects = {
    'scaler': scaler,
    'pca': pca,
    'label_encoder': label_encoder
}

with open('preprocessing.pkl', 'wb') as f:
    pickle.dump(preprocessing_objects, f)
print("‚úÖ Saved: preprocessing.pkl")

# Save metadata
metadata = {
    'n_classes': int(n_classes),
    'n_features': int(n_features),
    'n_components': int(n_components),
    'teacher_params': int(teacher_params),
    'student_params': int(student_params),
    'teacher_accuracy': float(t_acc),
    'student_accuracy': float(s_acc),
    'size_reduction': float(reduction_ratio),
    'performance_retention': float(performance_retention),
    'total_files': len(csv_files),
    'train_files': len(train_files),
    'val_files': len(val_files),
    'test_files': len(test_files),
    'classes': label_encoder.classes_.tolist()
}

with open('model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=4)
print("‚úÖ Saved: model_metadata.json")

# Create summary
with open('model_summary.txt', 'w') as f:
    f.write("=" * 80 + "\n")
    f.write("KNOWLEDGE DISTILLATION - PYTORCH MODEL SUMMARY\n")
    f.write("=" * 80 + "\n\n")

    f.write("DATASET INFORMATION:\n")
    f.write(f"  Total Files: {len(csv_files)}\n")
    f.write(f"  Training Files: {len(train_files)}\n")
    f.write(f"  Validation Files: {len(val_files)}\n")
    f.write(f"  Test Files: {len(test_files)}\n\n")

    f.write("TEACHER MODEL:\n")
    f.write(f"  Architecture: LSTM [128, 64]\n")
    f.write(f"  Parameters: {teacher_params:,}\n")
    f.write(f"  Accuracy: {t_acc:.4f}\n")
    f.write(f"  Precision: {t_prec:.4f}\n")
    f.write(f"  Recall: {t_rec:.4f}\n")
    f.write(f"  F1-Score: {t_f1:.4f}\n\n")

    f.write("STUDENT MODEL (DISTILLED):\n")
    f.write(f"  Architecture: LSTM [32]\n")
    f.write(f"  Parameters: {student_params:,}\n")
    f.write(f"  Accuracy: {s_acc:.4f}\n")
    f.write(f"  Precision: {s_prec:.4f}\n")
    f.write(f"  Recall: {s_rec:.4f}\n")
    f.write(f"  F1-Score: {s_f1:.4f}\n\n")

    f.write("COMPRESSION METRICS:\n")
    f.write(f"  Size Reduction: {reduction_ratio:.1f}x smaller\n")
    f.write(f"  Performance Retention: {performance_retention:.2f}%\n")
    f.write(f"  Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}%\n\n")

    f.write("FILES GENERATED:\n")
    f.write("  - teacher_model_complete.pth (Teacher model with metadata)\n")
    f.write("  - student_model_complete.pth (Student model with metadata)\n")
    f.write("  - preprocessing.pkl (Scaler, PCA, Label Encoder)\n")
    f.write("  - model_metadata.json (Model specifications)\n")
    f.write("  - student_confusion_matrix.png (Confusion matrix visualization)\n")

print("‚úÖ Saved: model_summary.txt")

print("\n" + "=" * 80)
print("üéâ KNOWLEDGE DISTILLATION COMPLETE!")
print("=" * 80)
print(f"\n‚ú® Successfully processed all {len(csv_files)} files!")
print(f"‚ú® Teacher Model: {teacher_params:,} parameters ‚Üí Accuracy: {t_acc:.4f}")
print(f"‚ú® Student Model: {student_params:,} parameters ‚Üí Accuracy: {s_acc:.4f}")
print(f"‚ú® Compression: {reduction_ratio:.1f}x smaller with {performance_retention:.1f}% performance retention")
print("\nüì¶ All models saved and ready for deployment!")
print("=" * 80)

üéÆ GPU Configuration
‚úÖ GPU detected: Tesla T4
‚úÖ CUDA Version: 12.6
‚úÖ GPU Memory: 15.83 GB
‚úÖ cuDNN autotuner enabled

üì• Downloading CIC-IoT-2023 Dataset from Kaggle...
Downloading from https://www.kaggle.com/api/v1/datasets/download/akashdogra/cic-iot-2023?dataset_version_number=1...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2.77G/2.77G [00:38<00:00, 77.7MB/s]

Extracting files...





‚úÖ Dataset downloaded to: /root/.cache/kagglehub/datasets/akashdogra/cic-iot-2023/versions/1
üìÇ Found 169 CSV files.

üìä Dataset Split (from 169 files):
   Training:   101 files
   Validation: 34 files
   Testing:    34 files

üè∑Ô∏è  Fitting Preprocessing - Scanning ALL Training Files...
Scanning 101 training files for all unique labels...
  File 1/101: part-00000-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 2/101: part-00001-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 3/101: part-00002-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 4/101: part-00003-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 5/101: part-00004-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  File 6/101: part-00005-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv - Found 34 unique labels (Total: 34)
  F

In [None]:
#!/usr/bin/env python3
"""
Knowledge Distillation for IoT Intrusion Detection - Full File Processing
- Reduced model sizes for memory efficiency
- Process entire files at once (no chunking within files)
- Handle all 169 files properly
- Stream one file at a time to avoid RAM overflow
"""

import os
import gc
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import IncrementalPCA
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import kagglehub
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import warnings
warnings.filterwarnings('ignore')

# ==========================================================
# üéÆ GPU CONFIGURATION
# ==========================================================

def setup_gpu():
    """Configure PyTorch to use GPU efficiently"""
    print("=" * 80)
    print("üéÆ GPU Configuration")
    print("=" * 80)

    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"‚úÖ GPU detected: {torch.cuda.get_device_name(0)}")
        print(f"‚úÖ CUDA Version: {torch.version.cuda}")
        print(f"‚úÖ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        torch.backends.cudnn.benchmark = True
        print("‚úÖ cuDNN autotuner enabled")
    else:
        device = torch.device('cpu')
        print("‚ö†Ô∏è  No GPU detected, running on CPU")

    print("=" * 80 + "\n")
    return device

device = setup_gpu()

# ==========================================================
# üßπ MEMORY MANAGEMENT
# ==========================================================

def clear_memory():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def print_memory_stats():
    """Print RAM and GPU usage"""
    try:
        import psutil
        process = psutil.Process()
        ram_gb = process.memory_info().rss / 1e9
        print(f"üíæ RAM Usage: {ram_gb:.2f} GB", end="")
    except:
        pass

    if torch.cuda.is_available():
        gpu_gb = torch.cuda.memory_allocated() / 1e9
        print(f" | GPU: {gpu_gb:.2f} GB")
    else:
        print()

# ==========================================================
# üßπ HELPER FUNCTIONS
# ==========================================================

def load_and_clean(path, label_col=None):
    """Load CSV and separate features from labels"""
    df = pd.read_csv(path, low_memory=False)
    df = df.dropna()
    df = df.drop_duplicates()

    if label_col is None:
        label_col = "Label" if "Label" in df.columns else df.columns[-1]

    X = df.drop(columns=[label_col])
    y = df[label_col]

    del df
    gc.collect()

    return X, y

def encode_objects(X):
    """Encode categorical columns and convert to numpy array"""
    for col in X.select_dtypes(include=["object"]).columns:
        try:
            X[col] = LabelEncoder().fit_transform(X[col].astype(str))
        except:
            X[col] = 0
    return X.values.astype(np.float32)

def load_and_process_file(filepath, scaler, pca, label_encoder):
    """Load and process a single file completely"""
    try:
        X, y = load_and_clean(filepath)
        X = encode_objects(X)

        X_scaled = scaler.transform(X)
        X_reduced = pca.transform(X_scaled)
        y_encoded = label_encoder.transform(y.astype(str))

        del X, y, X_scaled
        gc.collect()

        return X_reduced, y_encoded
    except Exception as e:
        print(f"‚ùå Error processing {os.path.basename(filepath)}: {e}")
        return None, None

# ==========================================================
# üì¶ FULL FILE DATASET
# ==========================================================

class FullFileDataset(Dataset):
    """Dataset that holds entire file in memory"""

    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ==========================================================
# üéì REDUCED PYTORCH MODELS
# ==========================================================

class TeacherLSTM(nn.Module):
    """Teacher Model - [128, 64] (Reduced from [256,128,64])"""

    def __init__(self, input_size, hidden_sizes, num_classes, dropout=0.3):
        super(TeacherLSTM, self).__init__()

        self.lstm1 = nn.LSTM(input_size, hidden_sizes[0], batch_first=True)
        self.dropout1 = nn.Dropout(dropout)

        self.lstm2 = nn.LSTM(hidden_sizes[0], hidden_sizes[1], batch_first=True)
        self.dropout2 = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_sizes[1], 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        # x shape: (batch, seq_len, features)
        out, _ = self.lstm1(x)
        out = self.dropout1(out)

        out, _ = self.lstm2(out)
        out = self.dropout2(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu1(self.fc1(out))
        out = self.relu2(self.fc2(out))
        out = self.fc3(out)

        return out


class StudentLSTM(nn.Module):
    """Student Model - [32] (Reduced from [32,16])"""

    def __init__(self, input_size, hidden_size, num_classes, dropout=0.2):
        super(StudentLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_size, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.dropout(out)

        # Take last timestep
        out = out[:, -1, :]

        out = self.relu(self.fc1(out))
        out = self.fc2(out)

        return out

# ==========================================================
# üéì KNOWLEDGE DISTILLATION LOSS
# ==========================================================

class DistillationLoss(nn.Module):
    """Combined loss for knowledge distillation"""

    def __init__(self, temperature=4.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # Hard target loss
        hard_loss = self.ce_loss(student_logits, labels)

        # Soft target loss
        soft_student = torch.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = torch.softmax(teacher_logits / self.temperature, dim=1)

        soft_loss = self.kl_loss(soft_student, soft_teacher) * (self.temperature ** 2)

        # Combined loss
        total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss

        return total_loss

# ==========================================================
# üèãÔ∏è TRAINING FUNCTIONS (FULL FILE AT ONCE)
# ==========================================================

def train_on_file(model, filepath, scaler, pca, label_encoder, optimizer,
                  criterion, device, batch_size=512, is_distillation=False,
                  teacher_model=None):
    """Train on entire file at once"""

    # Load and process entire file
    X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

    if X_file is None:
        return 0

    # Create dataset and dataloader
    dataset = FullFileDataset(X_file, y_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    model.train()
    if teacher_model is not None:
        teacher_model.eval()

    total_loss = 0
    total_samples = 0

    for X_batch, y_batch in dataloader:
        X_batch = X_batch.unsqueeze(1).to(device)  # Add sequence dimension
        y_batch = y_batch.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(X_batch)

        if is_distillation and teacher_model is not None:
            with torch.no_grad():
                teacher_outputs = teacher_model(X_batch)
            loss = criterion(outputs, teacher_outputs, y_batch)
        else:
            loss = criterion(outputs, y_batch)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(y_batch)
        total_samples += len(y_batch)

        del X_batch, y_batch, outputs
        clear_memory()

    # Clean up file data
    del X_file, y_file, dataset, dataloader
    clear_memory()

    return total_loss / total_samples if total_samples > 0 else 0


def evaluate_on_files(model, file_list, scaler, pca, label_encoder,
                      criterion, device, batch_size=512):
    """Evaluate on multiple files"""

    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0

    for filepath in file_list:
        X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

        if X_file is None:
            continue

        dataset = FullFileDataset(X_file, y_file)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

        with torch.no_grad():
            for X_batch, y_batch in dataloader:
                X_batch = X_batch.unsqueeze(1).to(device)
                y_batch = y_batch.to(device)

                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == y_batch).sum().item()
                total_loss += loss.item() * len(y_batch)
                total_samples += len(y_batch)

                del X_batch, y_batch, outputs
                clear_memory()

        del X_file, y_file, dataset, dataloader
        clear_memory()

    accuracy = correct / total_samples if total_samples > 0 else 0
    avg_loss = total_loss / total_samples if total_samples > 0 else 0

    return avg_loss, accuracy

# ==========================================================
# üìÇ DOWNLOAD & SPLIT DATASET (169 FILES)
# ==========================================================

print("=" * 80)
print("üì• Downloading CIC-IoT-2023 Dataset from Kaggle...")
print("=" * 80)

dataset_dir = kagglehub.dataset_download("akashdogra/cic-iot-2023")
print(f"‚úÖ Dataset downloaded to: {dataset_dir}")

csv_files = sorted([
    os.path.join(dataset_dir, f)
    for f in os.listdir(dataset_dir)
    if f.endswith(".csv")
])

print(f"üìÇ Found {len(csv_files)} CSV files.")

# 60-20-20 split
n_files = len(csv_files)
train_idx = int(n_files * 0.60)
val_idx = int(n_files * 0.80)

train_files = csv_files[:train_idx]
val_files = csv_files[train_idx:val_idx]
test_files = csv_files[val_idx:]

print(f"\nüìä Dataset Split (from {n_files} files):")
print(f"   Training:   {len(train_files)} files")
print(f"   Validation: {len(val_files)} files")
print(f"   Testing:    {len(test_files)} files")

# ==========================================================
# üè∑Ô∏è FIT PREPROCESSING (SCAN ALL TRAINING FILES FOR LABELS)
# ==========================================================

print("\n" + "=" * 80)
print("üè∑Ô∏è  Fitting Preprocessing - Scanning ALL Training Files...")
print("=" * 80)

# CRITICAL FIX: Scan ALL training files to collect ALL unique labels
all_labels = set()
sample_data = []

print(f"Scanning {len(train_files)} training files for all unique labels...")
for i, filepath in enumerate(train_files):
    try:
        # Read only the label column to save memory
        df = pd.read_csv(filepath, low_memory=False)
        label_col = "Label" if "Label" in df.columns else df.columns[-1]

        # Collect all unique labels from this file
        unique_labels = df[label_col].dropna().astype(str).unique()
        all_labels.update(unique_labels)

        print(f"  File {i+1}/{len(train_files)}: {os.path.basename(filepath)} - Found {len(unique_labels)} unique labels (Total: {len(all_labels)})")

        # Sample features from first 10 files only
        if i < 10:
            df_sample = df.head(1000).dropna()
            X = df_sample.drop(columns=[label_col])

            # Encode objects
            for col in X.select_dtypes(include=["object"]).columns:
                try:
                    X[col] = LabelEncoder().fit_transform(X[col].astype(str))
                except:
                    X[col] = 0

            sample_data.append(X.values.astype(np.float32))

        del df
        gc.collect()

    except Exception as e:
        print(f"  ‚ö†Ô∏è  Error reading {os.path.basename(filepath)}: {e}")
        continue

# Convert set to sorted list for consistent encoding
all_labels = sorted(list(all_labels))

# Fit label encoder with ALL labels
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)
n_classes = len(label_encoder.classes_)

print(f"\n‚úÖ LabelEncoder fitted with {n_classes} classes")
print(f"   Classes found: {', '.join(label_encoder.classes_[:10])}{'...' if n_classes > 10 else ''}")

# Fit scaler
scaler = StandardScaler()
for data in sample_data:
    scaler.partial_fit(data)

print(f"‚úÖ Scaler fitted on {len(sample_data)} file samples")

# Fit PCA
n_features = sample_data[0].shape[1]
n_components = min(30, n_features)

pca = IncrementalPCA(n_components=n_components)
for data in sample_data:
    X_scaled = scaler.transform(data)
    pca.partial_fit(X_scaled)

print(f"‚úÖ PCA fitted with {n_components} components (from {n_features} features)")

del all_labels, sample_data
clear_memory()
print_memory_stats()

# ==========================================================
# üéì STAGE 1: TRAIN TEACHER MODEL
# ==========================================================

print("\n" + "=" * 80)
print("üéì STAGE 1: Training Teacher Model")
print("=" * 80)

# Initialize teacher model with REDUCED sizes
teacher_model = TeacherLSTM(
    input_size=n_components,
    hidden_sizes=[128, 64],  # Reduced from [256, 128, 64]
    num_classes=n_classes,
    dropout=0.3
).to(device)

teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"\nüèóÔ∏è  Teacher Model: {teacher_params:,} parameters")
print(f"   Architecture: Input({n_components}) ‚Üí LSTM(128) ‚Üí LSTM(64) ‚Üí FC(64) ‚Üí FC(32) ‚Üí Output({n_classes})")

# Optimizer and criterion
teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
teacher_criterion = nn.CrossEntropyLoss()

# Training settings
epochs_teacher = 3  # Train over all files 3 times
batch_size = 512  # Large batch size allowed
files_per_epoch = 20  # Process 20 files per epoch (will cycle through all 101 training files)

best_teacher_acc = 0
patience_counter = 0
patience = 5  # Increased patience

print("\nüöÄ Training Teacher Model...")
print(f"   Batch Size: {batch_size}")
print(f"   Files per Epoch Cycle: {files_per_epoch}")
print(f"   Total Training Files: {len(train_files)}")
print(f"   Epochs: {epochs_teacher}")

for epoch in range(epochs_teacher):
    print(f"\n{'='*80}")
    print(f"TEACHER EPOCH {epoch+1}/{epochs_teacher}")
    print(f"{'='*80}")

    # Select rotating files
    start_idx = (epoch * files_per_epoch) % len(train_files)
    end_idx = min(start_idx + files_per_epoch, len(train_files))
    selected_files = train_files[start_idx:end_idx]

    if len(selected_files) < files_per_epoch and len(train_files) > files_per_epoch:
        remaining = files_per_epoch - len(selected_files)
        selected_files += train_files[:remaining]

    print(f"Training on {len(selected_files)} files (indices {start_idx} to {end_idx})")

    # Train on each file
    epoch_losses = []
    for i, filepath in enumerate(selected_files):
        print(f"\n  üìÇ File {i+1}/{len(selected_files)}: {os.path.basename(filepath)}")

        train_loss = train_on_file(
            teacher_model, filepath, scaler, pca, label_encoder,
            teacher_optimizer, teacher_criterion, device, batch_size=batch_size
        )

        epoch_losses.append(train_loss)
        print(f"     Loss: {train_loss:.4f}")
        print_memory_stats()

    avg_train_loss = np.mean(epoch_losses)

    # Validate on subset of validation files
    print(f"\n  üìä Validating...")
    val_loss, val_acc = evaluate_on_files(
        teacher_model, val_files[:5], scaler, pca, label_encoder,
        teacher_criterion, device, batch_size=batch_size
    )

    print(f"\n  üìà Epoch Summary:")
    print(f"     Avg Train Loss: {avg_train_loss:.4f}")
    print(f"     Val Loss: {val_loss:.4f}")
    print(f"     Val Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_teacher_acc:
        best_teacher_acc = val_acc
        torch.save(teacher_model.state_dict(), 'teacher_model.pth')
        print(f"  ‚úÖ Best teacher model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    clear_memory()

print("\n‚úÖ Teacher Model Training Complete!")
print(f"   Best Validation Accuracy: {best_teacher_acc:.4f}")

# Load best model if it was saved, otherwise keep current
if os.path.exists('teacher_model.pth'):
    teacher_model.load_state_dict(torch.load('teacher_model.pth'))
    print("   Loaded best teacher model from disk")
else:
    print("   ‚ö†Ô∏è  No saved model found, using final epoch weights")

# ==========================================================
# üéí STAGE 2: KNOWLEDGE DISTILLATION - TRAIN STUDENT
# ==========================================================

print("\n" + "=" * 80)
print("üéí STAGE 2: Knowledge Distillation - Training Student Model")
print("=" * 80)

# Initialize student model with REDUCED size
student_model = StudentLSTM(
    input_size=n_components,
    hidden_size=32,  # Single layer, reduced from [32, 16]
    num_classes=n_classes,
    dropout=0.2
).to(device)

student_params = sum(p.numel() for p in student_model.parameters())
reduction_ratio = teacher_params / student_params

print(f"\nüèóÔ∏è  Student Model: {student_params:,} parameters")
print(f"   Architecture: Input({n_components}) ‚Üí LSTM(32) ‚Üí FC(32) ‚Üí Output({n_classes})")
print(f"\nüìä Model Comparison:")
print(f"   Teacher Parameters: {teacher_params:,}")
print(f"   Student Parameters: {student_params:,}")
print(f"   Size Reduction:     {reduction_ratio:.1f}x smaller")

# Optimizer and distillation loss
student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)
distillation_criterion = DistillationLoss(temperature=4.0, alpha=0.7)

epochs_student = 4
best_student_acc = 0
patience_counter = 0

print(f"\nüöÄ Training Student with Knowledge Distillation...")
print(f"   Temperature: {distillation_criterion.temperature}")
print(f"   Alpha (soft target weight): {distillation_criterion.alpha}")
print(f"   Batch Size: {batch_size}")
print(f"   Files per Epoch: {files_per_epoch}")

for epoch in range(epochs_student):
    print(f"\n{'='*80}")
    print(f"STUDENT EPOCH {epoch+1}/{epochs_student}")
    print(f"{'='*80}")

    # Select rotating files
    start_idx = (epoch * files_per_epoch) % len(train_files)
    end_idx = min(start_idx + files_per_epoch, len(train_files))
    selected_files = train_files[start_idx:end_idx]

    if len(selected_files) < files_per_epoch and len(train_files) > files_per_epoch:
        remaining = files_per_epoch - len(selected_files)
        selected_files += train_files[:remaining]

    print(f"Training on {len(selected_files)} files (indices {start_idx} to {end_idx})")

    # Train with distillation
    epoch_losses = []
    for i, filepath in enumerate(selected_files):
        print(f"\n  üìÇ File {i+1}/{len(selected_files)}: {os.path.basename(filepath)}")

        train_loss = train_on_file(
            student_model, filepath, scaler, pca, label_encoder,
            student_optimizer, distillation_criterion, device,
            batch_size=batch_size, is_distillation=True, teacher_model=teacher_model
        )

        epoch_losses.append(train_loss)
        print(f"     Loss: {train_loss:.4f}")
        print_memory_stats()

    avg_train_loss = np.mean(epoch_losses)

    # Validate
    print(f"\n  üìä Validating...")
    val_criterion = nn.CrossEntropyLoss()
    val_loss, val_acc = evaluate_on_files(
        student_model, val_files[:5], scaler, pca, label_encoder,
        val_criterion, device, batch_size=batch_size
    )

    print(f"\n  üìà Epoch Summary:")
    print(f"     Avg Train Loss: {avg_train_loss:.4f}")
    print(f"     Val Loss: {val_loss:.4f}")
    print(f"     Val Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_student_acc:
        best_student_acc = val_acc
        torch.save(student_model.state_dict(), 'student_model.pth')
        print(f"  ‚úÖ Best student model saved! Val Acc: {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
        break

    clear_memory()

print("\n‚úÖ Student Model Training Complete!")
print(f"   Best Validation Accuracy: {best_student_acc:.4f}")

# Load best model if it exists
if os.path.exists('student_model.pth'):
    student_model.load_state_dict(torch.load('student_model.pth'))
    print("   Loaded best student model from disk")
else:
    print("   ‚ö†Ô∏è  No saved model found, using final epoch weights")

# ==========================================================
# üìà STAGE 3: FINAL EVALUATION
# ==========================================================

print("\n" + "=" * 80)
print("üìà STAGE 3: Final Evaluation on Test Set")
print("=" * 80)

def evaluate_model_detailed(model, model_name, file_list):
    """Evaluate model on test set with detailed metrics"""
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name}...")
    print(f"{'='*60}")

    model.eval()
    y_true_all = []
    y_pred_all = []

    for i, filepath in enumerate(file_list):
        print(f"Processing file {i+1}/{len(file_list)}: {os.path.basename(filepath)}")

        X_file, y_file = load_and_process_file(filepath, scaler, pca, label_encoder)

        if X_file is None:
            continue

        dataset = FullFileDataset(X_file, y_file)
        dataloader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0)

        with torch.no_grad():
            for X_batch, y_batch in dataloader:
                X_batch = X_batch.unsqueeze(1).to(device)

                outputs = model(X_batch)
                _, predicted = torch.max(outputs, 1)

                y_true_all.extend(y_batch.numpy())
                y_pred_all.extend(predicted.cpu().numpy())

                del X_batch, outputs
                clear_memory()

        del X_file, y_file, dataset, dataloader
        clear_memory()

    y_true_all = np.array(y_true_all)
    y_pred_all = np.array(y_pred_all)

    accuracy = accuracy_score(y_true_all, y_pred_all)
    precision = precision_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    recall = recall_score(y_true_all, y_pred_all, average='weighted', zero_division=0)
    f1 = f1_score(y_true_all, y_pred_all, average='weighted', zero_division=0)

    print(f"\nüìä {model_name} Performance:")
    print(f"   Accuracy:  {accuracy:.4f}")
    print(f"   Precision: {precision:.4f}")
    print(f"   Recall:    {recall:.4f}")
    print(f"   F1-Score:  {f1:.4f}")

    return y_true_all, y_pred_all, accuracy, precision, recall, f1

# Evaluate both models
teacher_results = evaluate_model_detailed(teacher_model, "TEACHER MODEL", test_files)
student_results = evaluate_model_detailed(student_model, "STUDENT MODEL (Distilled)", test_files)

# ==========================================================
# üìä GENERATE REPORTS
# ==========================================================

print("\n" + "=" * 80)
print("üìä Generating Final Report...")
print("=" * 80)

y_true, y_pred, s_acc, s_prec, s_rec, s_f1 = student_results
_, _, t_acc, t_prec, t_rec, t_f1 = teacher_results

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(20, 16))
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_,
            cbar_kws={'label': 'Count'})
plt.title('Student Model Confusion Matrix (Knowledge Distillation)', fontsize=16, pad=20)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right', fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.savefig('student_confusion_matrix.png', dpi=300, bbox_inches='tight')
print("‚úÖ Confusion matrix saved as 'student_confusion_matrix.png'")

# Performance comparison
performance_retention = (s_acc / t_acc) * 100 if t_acc > 0 else 0

print("\n" + "=" * 80)
print("üìä FINAL COMPARISON: TEACHER vs STUDENT")
print("=" * 80)
print(f"\n{'Metric':<15} {'Teacher':<15} {'Student':<15} {'Difference':<15}")
print("=" * 80)
print(f"{'Accuracy':<15} {t_acc:<15.4f} {s_acc:<15.4f} {(s_acc-t_acc):<15.4f}")
print(f"{'Precision':<15} {t_prec:<15.4f} {s_prec:<15.4f} {(s_prec-t_prec):<15.4f}")
print(f"{'Recall':<15} {t_rec:<15.4f} {s_rec:<15.4f} {(s_rec-t_rec):<15.4f}")
print(f"{'F1-Score':<15} {t_f1:<15.4f} {s_f1:<15.4f} {(s_f1-t_f1):<15.4f}")
print(f"{'Parameters':<15} {teacher_params:<15,} {student_params:<15,} {'-':<15}")
print(f"{'Model Size':<15} {'1.0x':<15} {f'{1/reduction_ratio:.2f}x':<15} {f'{reduction_ratio:.1f}x smaller':<15}")
print("=" * 80)

print(f"\nüéØ Performance Retention: {performance_retention:.2f}%")
print(f"üéØ Model Size Reduction: {reduction_ratio:.1f}x smaller")
print(f"üéØ Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}% fewer parameters")

# ==========================================================
# üíæ SAVE MODELS AND PREPROCESSING OBJECTS
# ==========================================================

print("\n" + "=" * 80)
print("üíæ Saving Models and Preprocessing Objects")
print("=" * 80)

# Save PyTorch models
torch.save({
    'model_state_dict': teacher_model.state_dict(),
    'input_size': n_components,
    'hidden_sizes': [128, 64],
    'num_classes': n_classes,
    'accuracy': t_acc,
    'params': teacher_params
}, 'teacher_model_complete.pth')
print("‚úÖ Saved: teacher_model_complete.pth")

torch.save({
    'model_state_dict': student_model.state_dict(),
    'input_size': n_components,
    'hidden_size': 32,
    'num_classes': n_classes,
    'accuracy': s_acc,
    'params': student_params
}, 'student_model_complete.pth')
print("‚úÖ Saved: student_model_complete.pth")

# Save preprocessing objects
preprocessing_objects = {
    'scaler': scaler,
    'pca': pca,
    'label_encoder': label_encoder
}

with open('preprocessing.pkl', 'wb') as f:
    pickle.dump(preprocessing_objects, f)
print("‚úÖ Saved: preprocessing.pkl")

# Save metadata
metadata = {
    'n_classes': int(n_classes),
    'n_features': int(n_features),
    'n_components': int(n_components),
    'teacher_params': int(teacher_params),
    'student_params': int(student_params),
    'teacher_accuracy': float(t_acc),
    'student_accuracy': float(s_acc),
    'size_reduction': float(reduction_ratio),
    'performance_retention': float(performance_retention),
    'total_files': len(csv_files),
    'train_files': len(train_files),
    'val_files': len(val_files),
    'test_files': len(test_files),
    'classes': label_encoder.classes_.tolist()
}

with open('model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=4)
print("‚úÖ Saved: model_metadata.json")

# Create summary
with open('model_summary.txt', 'w') as f:
    f.write("=" * 80 + "\n")
    f.write("KNOWLEDGE DISTILLATION - PYTORCH MODEL SUMMARY\n")
    f.write("=" * 80 + "\n\n")

    f.write("DATASET INFORMATION:\n")
    f.write(f"  Total Files: {len(csv_files)}\n")
    f.write(f"  Training Files: {len(train_files)}\n")
    f.write(f"  Validation Files: {len(val_files)}\n")
    f.write(f"  Test Files: {len(test_files)}\n\n")

    f.write("TEACHER MODEL:\n")
    f.write(f"  Architecture: LSTM [128, 64]\n")
    f.write(f"  Parameters: {teacher_params:,}\n")
    f.write(f"  Accuracy: {t_acc:.4f}\n")
    f.write(f"  Precision: {t_prec:.4f}\n")
    f.write(f"  Recall: {t_rec:.4f}\n")
    f.write(f"  F1-Score: {t_f1:.4f}\n\n")

    f.write("STUDENT MODEL (DISTILLED):\n")
    f.write(f"  Architecture: LSTM [32]\n")
    f.write(f"  Parameters: {student_params:,}\n")
    f.write(f"  Accuracy: {s_acc:.4f}\n")
    f.write(f"  Precision: {s_prec:.4f}\n")
    f.write(f"  Recall: {s_rec:.4f}\n")
    f.write(f"  F1-Score: {s_f1:.4f}\n\n")

    f.write("COMPRESSION METRICS:\n")
    f.write(f"  Size Reduction: {reduction_ratio:.1f}x smaller\n")
    f.write(f"  Performance Retention: {performance_retention:.2f}%\n")
    f.write(f"  Parameter Reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}%\n\n")

    f.write("FILES GENERATED:\n")
    f.write("  - teacher_model_complete.pth (Teacher model with metadata)\n")
    f.write("  - student_model_complete.pth (Student model with metadata)\n")
    f.write("  - preprocessing.pkl (Scaler, PCA, Label Encoder)\n")
    f.write("  - model_metadata.json (Model specifications)\n")
    f.write("  - student_confusion_matrix.png (Confusion matrix visualization)\n")

print("‚úÖ Saved: model_summary.txt")

print("\n" + "=" * 80)
print("üéâ KNOWLEDGE DISTILLATION COMPLETE!")
print("=" * 80)
print(f"\n‚ú® Successfully processed all {len(csv_files)} files!")
print(f"‚ú® Teacher Model: {teacher_params:,} parameters ‚Üí Accuracy: {t_acc:.4f}")
print(f"‚ú® Student Model: {student_params:,} parameters ‚Üí Accuracy: {s_acc:.4f}")
print(f"‚ú® Compression: {reduction_ratio:.1f}x smaller with {performance_retention:.1f}% performance retention")
print("\nüì¶ All models saved and ready for deployment!")
print("=" * 80)