# 1. Data Collection

### Import Necessary Library

In [1]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import roc_auc_score, roc_curve, precision_score, recall_score, f1_score, accuracy_score
from os import cpu_count
from math import floor
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
# from catboost import CatBoostClassifier
from sklearn.naive_bayes import GaussianNB
# import shap
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)
# shap.initjs()

### Load Dataset

In [1]:
#data=pd.read_csv("../dataset/NF_TON_IoT_V2/NF-ToN-IoT-v2.csv")

import pandas as pd
import random

# File path
csv_path = "../dataset/NF_TON_IoT_V2/NF-ToN-IoT-v2.csv"

# STEP 1: Count total rows (without loading the file into memory)
with open(csv_path, 'r', encoding='utf-8') as f:
    total_lines = sum(1 for _ in f)

# STEP 2: Calculate how many lines to skip
sample_size = 1_000_000
lines_to_skip = sorted(random.sample(range(1, total_lines), total_lines - sample_size))

# STEP 3: Read just 1 million rows (skip most lines except header)
data = pd.read_csv(csv_path, skiprows=lines_to_skip)

# Confirm result
print(f"Loaded sample shape: {data.shape}")


Loaded sample shape: (999999, 45)


# 2. Data Preprocessing

### Summary of Stastics

In [2]:
data.head()

Unnamed: 0,IPV4_SRC_ADDR,L4_SRC_PORT,IPV4_DST_ADDR,L4_DST_PORT,PROTOCOL,L7_PROTO,IN_BYTES,IN_PKTS,OUT_BYTES,OUT_PKTS,...,TCP_WIN_MAX_IN,TCP_WIN_MAX_OUT,ICMP_TYPE,ICMP_IPV4_TYPE,DNS_QUERY_ID,DNS_QUERY_TYPE,DNS_TTL_ANSWER,FTP_COMMAND_RET_CODE,Label,Attack
0,192.168.1.79,51466,239.255.255.250,15600,17,0.0,63,1,0,0,...,0,0,0,0,0,0,0,0,0,Benign
1,192.168.1.193,64035,8.8.8.8,53,17,5.126,58,1,90,1,...,0,0,0,0,3579,1,299,0,0,Benign
2,192.168.1.79,53929,192.168.1.255,15600,17,0.0,63,1,0,0,...,0,0,0,0,0,0,0,0,0,Benign
3,192.168.1.193,49220,192.168.1.33,4444,6,0.0,151312,200,87548,167,...,16425,2941,0,0,0,0,0,0,1,ransomware
4,192.168.1.193,49236,192.168.1.37,4444,6,0.0,168880,214,38472,151,...,16425,2898,0,0,0,0,0,0,1,ransomware


In [3]:
data.dtypes

IPV4_SRC_ADDR                   object
L4_SRC_PORT                      int64
IPV4_DST_ADDR                   object
L4_DST_PORT                      int64
PROTOCOL                         int64
L7_PROTO                       float64
IN_BYTES                         int64
IN_PKTS                          int64
OUT_BYTES                        int64
OUT_PKTS                         int64
TCP_FLAGS                        int64
CLIENT_TCP_FLAGS                 int64
SERVER_TCP_FLAGS                 int64
FLOW_DURATION_MILLISECONDS       int64
DURATION_IN                      int64
DURATION_OUT                     int64
MIN_TTL                          int64
MAX_TTL                          int64
LONGEST_FLOW_PKT                 int64
SHORTEST_FLOW_PKT                int64
MIN_IP_PKT_LEN                   int64
MAX_IP_PKT_LEN                   int64
SRC_TO_DST_SECOND_BYTES        float64
DST_TO_SRC_SECOND_BYTES        float64
RETRANSMITTED_IN_BYTES           int64
RETRANSMITTED_IN_PKTS    

In [4]:
data.Label.value_counts()

Label
1    639952
0    360047
Name: count, dtype: int64

In [5]:
data.Attack.value_counts()

Attack
Benign        360047
scanning      222925
xss           144988
ddos          119472
password       68375
dos            42024
injection      40518
backdoor         977
mitm             468
ransomware       205
Name: count, dtype: int64

In [6]:
data=data.drop(columns=['L4_SRC_PORT', 'L4_DST_PORT']) #dropping metadata

In [7]:
training_set = data.sample(frac=0.05, replace=False,random_state=42)
# 1%train, 99% test
testing_set = data.drop(index=training_set.index)

In [8]:
training_set.Attack.value_counts()

Attack
Benign        18040
scanning      11293
xss            7194
ddos           5902
password       3403
dos            2063
injection      2024
backdoor         56
mitm             13
ransomware       12
Name: count, dtype: int64

In [9]:
attacks=training_set.Attack.unique()
attacks=['Benign','Reconnaissance', 'DDoS', 'DoS', 'Theft']

In [11]:
# --- Compute correlation only for numeric features ---
corr = training_set.select_dtypes(include='number').corr()

# Find highly correlated features (> 0.9)
corr_features = {
    corr.columns[i]: corr.columns[(corr > 0.9).iloc[i]].values.tolist()
    for i in range(corr.shape[0])
}

# Group correlated features into sets (no duplicates)
corr_list = []
for key, value in corr_features.items():
    have_set = any(key in s for s in corr_list)
    if not have_set and len(value) > 1:
        corr_list.append(value)

# Output the groups of highly correlated features
for i, group in enumerate(corr_list, 1):
    print(f"Group {i}: {group}")

Group 1: ['IN_BYTES', 'IN_PKTS', 'OUT_BYTES', 'RETRANSMITTED_OUT_BYTES', 'NUM_PKTS_1024_TO_1514_BYTES']
Group 2: ['TCP_FLAGS', 'SERVER_TCP_FLAGS']
Group 3: ['MIN_TTL', 'MAX_TTL']
Group 4: ['LONGEST_FLOW_PKT', 'MAX_IP_PKT_LEN']
Group 5: ['RETRANSMITTED_OUT_BYTES', 'RETRANSMITTED_OUT_PKTS']
Group 6: ['ICMP_TYPE', 'ICMP_IPV4_TYPE']


In [12]:
corr_list

[['IN_BYTES',
  'IN_PKTS',
  'OUT_BYTES',
  'RETRANSMITTED_OUT_BYTES',
  'NUM_PKTS_1024_TO_1514_BYTES'],
 ['TCP_FLAGS', 'SERVER_TCP_FLAGS'],
 ['MIN_TTL', 'MAX_TTL'],
 ['LONGEST_FLOW_PKT', 'MAX_IP_PKT_LEN'],
 ['RETRANSMITTED_OUT_BYTES', 'RETRANSMITTED_OUT_PKTS'],
 ['ICMP_TYPE', 'ICMP_IPV4_TYPE']]

In [13]:
#correction because NUM_PKTS_1024_TO_1514_BYTES appears twice
corr_list[2]=corr_list[2][:-1]
corr_list

[['IN_BYTES',
  'IN_PKTS',
  'OUT_BYTES',
  'RETRANSMITTED_OUT_BYTES',
  'NUM_PKTS_1024_TO_1514_BYTES'],
 ['TCP_FLAGS', 'SERVER_TCP_FLAGS'],
 ['MIN_TTL'],
 ['LONGEST_FLOW_PKT', 'MAX_IP_PKT_LEN'],
 ['RETRANSMITTED_OUT_BYTES', 'RETRANSMITTED_OUT_PKTS'],
 ['ICMP_TYPE', 'ICMP_IPV4_TYPE']]

In [17]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc, classification_report
from sklearn.preprocessing import label_binarize
import random

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Load dataset with sampling
csv_path = "../dataset/NF_TON_IoT_V2/NF-ToN-IoT-v2.csv"
with open(csv_path, 'r', encoding='utf-8') as f:
    total_lines = sum(1 for _ in f) - 1  # Exclude header
sample_size = 1_000_000
lines_to_skip = sorted(random.sample(range(1, total_lines + 1), total_lines - sample_size))
data = pd.read_csv(csv_path, skiprows=lines_to_skip)
print(f"Loaded sample shape: {data.shape}")

# Handle missing values
for col in data.columns:
    if data[col].dtype in ['int64', 'float64']:
        data[col] = data[col].fillna(data[col].mean())
    else:
        data[col] = data[col].fillna(0)

# Encode categorical features
le_protocol = LabelEncoder()
data['PROTOCOL'] = le_protocol.fit_transform(data['PROTOCOL'].astype(str))
le_l7_proto = LabelEncoder()
data['L7_PROTO'] = le_l7_proto.fit_transform(data['L7_PROTO'].astype(str))

# Function to prepare data
def prepare_data(df, mode='binary'):
    if mode == 'binary':
        y = df['Label'].values
        num_classes = 2
        le_attack = None
        class_names = ['Benign', 'Attack']
    elif mode == 'multiclass':
        le_attack = LabelEncoder()
        df['Attack'] = le_attack.fit_transform(df['Attack'].astype(str))
        y = df['Attack'].values
        unique_y = np.unique(y)
        num_classes = len(unique_y)
        label_map = {old: new for new, old in enumerate(unique_y)}
        y = np.array([label_map[yi] for yi in y])
        class_names = [le_attack.classes_[unique_y[i]] for i in range(num_classes)]
    else:
        raise ValueError("Mode must be 'binary' or 'multiclass'")

    # Select numerical features (exclude Label, Attack, IP addresses, and ports)
    feature_cols = [col for col in df.columns if col not in ['Label', 'Attack', 'IPV4_SRC_ADDR', 'IPV4_DST_ADDR', 'L4_SRC_PORT', 'L4_DST_PORT']]
    X = df[feature_cols].values

    # Normalize numerical features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    # Create preprocessed DataFrame
    df_processed = pd.DataFrame(X, columns=feature_cols)
    df_processed[mode.capitalize()] = y
    df_processed['L4_SRC_PORT'] = df['L4_SRC_PORT'].values
    df_processed['L4_DST_PORT'] = df['L4_DST_PORT'].values

    # Edge features (e.g., IN_BYTES, OUT_BYTES, IN_PKTS, OUT_PKTS)
    edge_feature_cols = ['IN_BYTES', 'OUT_BYTES', 'IN_PKTS', 'OUT_PKTS']
    edge_features = df[edge_feature_cols].values
    edge_scaler = StandardScaler()
    edge_features = edge_scaler.fit_transform(edge_features)

    return df_processed, y, num_classes, le_attack, class_names, edge_features, edge_feature_cols

# Function for graph construction with edge features
def build_graph(df_processed, y, edge_features, mode):
    src_ports = df_processed['L4_SRC_PORT'].values
    dst_ports = df_processed['L4_DST_PORT'].values
    unique_ports = np.unique(np.concatenate([src_ports, dst_ports]))
    port_to_idx = {port: idx for idx, port in enumerate(unique_ports)}

    node_features = defaultdict(list)
    node_labels = defaultdict(list)
    edge_index = []
    edge_attr = []

    for idx in range(len(df_processed)):
        row = df_processed.iloc[idx]
        src_port = row['L4_SRC_PORT']
        dst_port = row['L4_DST_PORT']
        features = row.drop([mode.capitalize(), 'L4_SRC_PORT', 'L4_DST_PORT']).values
        label = row[mode.capitalize()]
        src_idx = port_to_idx[src_port]
        dst_idx = port_to_idx[dst_port]
        node_features[src_idx].append(features)
        node_features[dst_idx].append(features)
        node_labels[src_idx].append(label)
        node_labels[dst_idx].append(label)
        edge_index.append([src_idx, dst_idx])
        edge_index.append([dst_idx, src_idx])
        edge_attr.append(edge_features[idx])
        edge_attr.append(edge_features[idx])  # Bidirectional

    x = []
    y_graph = []
    for port_idx in range(len(unique_ports)):
        if port_idx in node_features:
            x.append(np.mean(node_features[port_idx], axis=0))
            labels = node_labels[port_idx]
            y_graph.append(np.bincount(labels).argmax())
        else:
            x.append(np.zeros(df_processed.shape[1] - 3))
            y_graph.append(0)

    x = torch.tensor(x, dtype=torch.float)
    y_graph = torch.tensor(y_graph, dtype=torch.long)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    data_graph = Data(x=x, edge_index=edge_index, y=y_graph, edge_attr=edge_attr)
    print(f"{mode.capitalize()} Graph constructed: {data_graph.num_nodes} nodes, {data_graph.num_edges} edges")

    return data_graph

# E-GraphSAGE model
class EGraphSAGEModel(torch.nn.Module):
    def __init__(self, num_features, num_edge_features, num_classes):
        super(EGraphSAGEModel, self).__init__()
        self.conv1 = SAGEConv(num_features, 64, edge_dim=num_edge_features)
        self.conv2 = SAGEConv(64, 64, edge_dim=num_edge_features)
        self.fc = torch.nn.Linear(64, num_classes)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

# Training and evaluation function
def train_evaluate(model, data_graph, mode, num_classes, class_names, max_epochs=100, patience=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    # Masks
    num_nodes = data_graph.num_nodes
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    indices = np.random.permutation(num_nodes)
    train_size = int(0.6 * num_nodes)
    val_size = int(0.2 * num_nodes)
    train_mask[indices[:train_size]] = True
    val_mask[indices[train_size:train_size + val_size]] = True
    test_mask[indices[train_size + val_size:]] = True
    data_graph.train_mask = train_mask
    data_graph.val_mask = val_mask
    data_graph.test_mask = test_mask

    for epoch in range(max_epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data_graph)
        loss = F.nll_loss(out[data_graph.train_mask], data_graph.y[data_graph.train_mask])
        loss.backward()
        optimizer.step()

        pred = out.argmax(dim=1)
        train_acc = accuracy_score(data_graph.y[data_graph.train_mask].numpy(), pred[data_graph.train_mask].numpy())
        train_losses.append(loss.item())
        train_accuracies.append(train_acc)

        model.eval()
        with torch.no_grad():
            out = model(data_graph)
            val_loss = F.nll_loss(out[data_graph.val_mask], data_graph.y[data_graph.val_mask])
            val_pred = out.argmax(dim=1)
            val_acc = accuracy_score(data_graph.y[data_graph.val_mask].numpy(), val_pred[data_graph.val_mask].numpy())
            val_losses.append(val_loss.item())
            val_accuracies.append(val_acc)

        print(f'{mode.capitalize()} Epoch {epoch + 1}, Train Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'{mode.capitalize()} Early stopping at epoch {epoch + 1}')
                break

    model.load_state_dict(best_model_state)

    # Plots for accuracy and loss
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_accuracies, label='Train Acc')
    plt.plot(val_accuracies, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(f'{mode.capitalize()} Training vs Validation Accuracy')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{mode.capitalize()} Training vs Validation Loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{mode}_acc_loss.png')
    plt.show()
    plt.close()

    # Evaluation
    model.eval()
    with torch.no_grad():
        out = model(data_graph)
        pred = out.argmax(dim=1)
        y_true = data_graph.y[data_graph.test_mask].numpy()
        y_pred = pred[data_graph.test_mask].numpy()
        y_score = out[data_graph.test_mask].numpy()

        test_classes = np.unique(y_true)
        test_class_names = [class_names[i] for i in test_classes]

        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted', labels=test_classes)
        recall = recall_score(y_true, y_pred, average='weighted', labels=test_classes)
        f1 = f1_score(y_true, y_pred, average='weighted', labels=test_classes)

        print(f"\n{mode.capitalize()} Classification Report:")
        print(classification_report(y_true, y_pred, labels=test_classes, target_names=test_class_names))

        cm = confusion_matrix(y_true, y_pred, labels=test_classes)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=test_class_names, yticklabels=test_class_names)
        plt.title(f'{mode.capitalize()} Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.savefig(f'{mode}_confusion_matrix.png')
        plt.show()
        plt.close()

        if mode == 'binary':
            fpr, tpr, _ = roc_curve(y_true, y_score[:, 1])
            roc_auc = auc(fpr, tpr)
            plt.figure()
            plt.plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc:.2f})')
            plt.plot([0, 1], [0, 1], 'k--')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title(f'{mode.capitalize()} ROC Curve')
            plt.legend(loc='lower right')
            plt.savefig(f'{mode}_roc_curve.png')
            plt.show()
            plt.close()
        else:
            y_true_bin = label_binarize(y_true, classes=range(num_classes))
            plt.figure()
            for i in test_classes:
                fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_score[:, i])
                roc_auc = auc(fpr, tpr)
                plt.plot(fpr, tpr, label=f'Class {test_class_names[list(test_classes).index(i)]} (AUC = {roc_auc:.2f})')
            plt.plot([0, 1], [0, 1], 'k--')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title(f'{mode.capitalize()} ROC Curve (One-vs-Rest)')
            plt.legend(loc='lower right')
            plt.savefig(f'{mode}_roc_curve.png')
            plt.show()
            plt.close()

    print(f"\n{mode.capitalize()} Evaluation Metrics:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision (weighted): {precision:.4f}")
    print(f"Recall (weighted): {recall:.4f}")
    print(f"F1 (weighted): {f1:.4f}")

# Binary classification
print("\n=== Binary Classification ===")
df_processed_bin, y_bin, num_classes_bin, _, class_names_bin, edge_features_bin, edge_feature_cols = prepare_data(data, mode='binary')
data_graph_bin = build_graph(df_processed_bin, y_bin, edge_features_bin, 'binary')
model_bin = EGraphSAGEModel(data_graph_bin.num_features, len(edge_feature_cols), num_classes_bin)
train_evaluate(model_bin, data_graph_bin, 'binary', num_classes_bin, class_names_bin)

# Multiclass classification
print("\n=== Multiclass Classification ===")
df_processed_multi, y_multi, num_classes_multi, le_attack_multi, class_names_multi, edge_features_multi, edge_feature_cols = prepare_data(data, mode='multiclass')
data_graph_multi = build_graph(df_processed_multi, y_multi, edge_features_multi, 'multiclass')
model_multi = EGraphSAGEModel(data_graph_multi.num_features, len(edge_feature_cols), num_classes_multi)
train_evaluate(model_multi, data_graph_multi, 'multiclass', num_classes_multi, class_names_multi)


Loaded sample shape: (1000000, 45)

=== Binary Classification ===


  x = torch.tensor(x, dtype=torch.float)


Binary Graph constructed: 65273 nodes, 2000000 edges


TypeError: MessagePassing.__init__() got an unexpected keyword argument 'edge_dim'

In [14]:
pip install torch-geometric




In [None]:
pip install torch torchvision torchaudio

Note: you may need to restart the kernel to use updated packages.


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.loader import DataLoader
import networkx as nx
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


AttributeError: partially initialized module 'torch_geometric' has no attribute 'typing' (most likely due to a circular import)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.loader import DataLoader
import networkx as nx
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# 1. Data Loading and Preprocessing
def load_and_preprocess_data(file_path):
    # Load data
    data = pd.read_parquet(file_path)
    
    # Handle negative port values
    data['L4_SRC_PORT'] = data['L4_SRC_PORT'].abs()
    data['L4_DST_PORT'] = data['L4_DST_PORT'].abs()
    
    # Select numerical features
    numerical_features = [
        'L4_SRC_PORT', 'L4_DST_PORT', 'IN_BYTES', 'IN_PKTS', 
        'OUT_BYTES', 'OUT_PKTS', 'TCP_WIN_MAX_IN', 'TCP_WIN_MAX_OUT'
    ]
    
    # Handle missing values
    data[numerical_features] = data[numerical_features].fillna(0)
    
    # Encode categorical variables
    le_attack = LabelEncoder()
    data['Attack_encoded'] = le_attack.fit_transform(data['Attack'])
    
    # Standardize numerical features
    scaler = StandardScaler()
    data[numerical_features] = scaler.fit_transform(data[numerical_features])
    
    return data, numerical_features, le_attack

# 2. Exploratory Data Analysis
def perform_eda(data, numerical_features):
    print("\nEDA Summary:")
    print("\nDataset Shape:", data.shape)
    print("\nMissing Values:\n", data.isnull().sum().sum())
    print("\nLabel Distribution:\n", data['Label'].value_counts())
    print("\nAttack Type Distribution:\n", data['Attack'].value_counts())
    
    # Correlation heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(data[numerical_features].corr(), annot=True, cmap='coolwarm')
    plt.title('Feature Correlation Heatmap')
    plt.savefig('correlation_heatmap.png')
    plt.close()
    
    # Attack type distribution plot
    plt.figure(figsize=(10, 6))
    sns.countplot(x='Attack', data=data)
    plt.xticks(rotation=45)
    plt.title('Attack Type Distribution')
    plt.savefig('attack_distribution.png')
    plt.close()

# 3. Graph Construction
def create_graph_data(data, numerical_features, target_col='Label'):
    # Create edge index based on similar source/destination ports
    edge_index = []
    for i in range(len(data)):
        for j in range(i+1, len(data)):
            if (data.iloc[i]['L4_SRC_PORT'] == data.iloc[j]['L4_SRC_PORT'] or 
                data.iloc[i]['L4_DST_PORT'] == data.iloc[j]['L4_DST_PORT']):
                edge_index.append([i, j])
                edge_index.append([j, i])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    # Node features
    x = torch.tensor(data[numerical_features].values, dtype=torch.float)
    
    # Labels
    y = torch.tensor(data[target_col].values, dtype=torch.long)
    
    # Create PyG data object
    graph_data = Data(x=x, edge_index=edge_index, y=y)
    
    # Create train/test mask
    train_mask, test_mask = train_test_split(
        range(len(data)), test_size=0.2, random_state=42, stratify=data[target_col]
    )
    graph_data.train_mask = torch.tensor(train_mask, dtype=torch.long)
    graph_data.test_mask = torch.tensor(test_mask, dtype=torch.long)
    
    return graph_data

# 4. Attention-based GNN Model
class AttentionGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=4):
        super(AttentionGNN, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim, heads=heads)
        self.conv2 = GATConv(hidden_dim * heads, hidden_dim, heads=heads)
        self.fc1 = nn.Linear(hidden_dim * heads, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = self.dropout(x)
        
        x = self.conv2(x, edge_index)
        x = F.elu(x)
        x = self.dropout(x)
        
        x = self.fc1(x)
        x = F.elu(x)
        x = self.dropout(x)
        
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# 5. Training Function
def train_model(model, data, optimizer, criterion, epochs=100):
    model.train()
    train_losses = []
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
    
    return train_losses

# 6. Evaluation Function
def evaluate_model(model, data, le_attack=None):
    model.eval()
    with torch.no_grad():
        pred = model(data).argmax(dim=1)
        
        # Get true and predicted labels
        y_true = data.y[data.test_mask].numpy()
        y_pred = pred[data.test_mask].numpy()
        
        # Classification report
        print("\nClassification Report:")
        if le_attack is not None:
            target_names = le_attack.classes_
        else:
            target_names = ['Benign', 'Malicious']
        print(classification_report(y_true, y_pred, target_names=target_names))
        
        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.savefig('confusion_matrix.png')
        plt.close()
        
        # ROC AUC for binary classification
        if len(np.unique(data.y)) == 2:
            probs = torch.softmax(model(data), dim=1)[:, 1]
            roc_auc = roc_auc_score(y_true, probs[data.test_mask].numpy())
            print(f"\nROC AUC Score: {roc_auc:.4f}")

# Main Execution
if __name__ == "__main__":
    # Load and preprocess data
    file_path = "../dataset/NF-BoT-IoT-V2.parquet"
    data, numerical_features, le_attack = load_and_preprocess_data(file_path)
    
    # Perform EDA
    perform_eda(data, numerical_features)
    
    # Binary Classification
    print("\n=== Binary Classification ===")
    binary_graph = create_graph_data(data, numerical_features, 'Label')
    
    # Initialize model
    binary_model = AttentionGNN(
        input_dim=len(numerical_features),
        hidden_dim=64,
        output_dim=2
    )
    
    # Training setup
    optimizer = torch.optim.Adam(binary_model.parameters(), lr=0.01)
    criterion = nn.NLLLoss()
    
    # Train model
    binary_losses = train_model(binary_model, binary_graph, optimizer, criterion)
    
    # Evaluate model
    evaluate_model(binary_model, binary_graph)
    
    # Multiclass Classification
    print("\n=== Multiclass Classification ===")
    multiclass_graph = create_graph_data(data, numerical_features, 'Attack_encoded')
    
    # Initialize model
    multiclass_model = AttentionGNN(
        input_dim=len(numerical_features),
        hidden_dim=64,
        output_dim=len(le_attack.classes_)
    )
    
    # Training setup
    optimizer = torch.optim.Adam(multiclass_model.parameters(), lr=0.01)
    criterion = nn.NLLLoss()
    
    # Train model
    multiclass_losses = train_model(multiclass_model, multiclass_graph, optimizer, criterion)
    
    # Evaluate model
    evaluate_model(multiclass_model, multiclass_graph, le_attack)
    
    # Plot training losses
    plt.figure(figsize=(10, 6))
    plt.plot(binary_losses, label='Binary Classification')
    plt.plot(multiclass_losses, label='Multiclass Classification')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curves')
    plt.legend()
    plt.savefig('training_loss.png')
    plt.close()

ModuleNotFoundError: No module named 'torch_geometric'