In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset
import time
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, f1_score, recall_score
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from imblearn.under_sampling import RandomUnderSampler
import itertools
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
from dataloader_creator import CreatorDL
creator = CreatorDL(seed=42, bs=32)

In [None]:
df_UNSW = creator.reader("NF-UNSW-NB15-v3")

df_train_UNSW, df_test_UNSW, df_val_UNSW = creator.splitter(df_UNSW)

train_loader_UNSW, test_loader_UNSW, val_loader_UNSW = creator.balancer(df_train_UNSW, df_test_UNSW, df_val_UNSW)

In [None]:
df_BOT= creator.reader("NF-BoT-IoT-v3")

df_train_BOT, df_test_BOT, df_val_BOT = creator.splitter(df_BOT)

train_loader_BOT, test_loader_BOT, val_loader_BOT = creator.balancer(df_train_BOT, df_test_BOT, df_val_BOT)

In [None]:
df_CIC= creator.reader("NF-CICIDS2018-v3")

df_train_CIC, df_test_CIC, df_val_CIC = creator.splitter(df_CIC)

train_loader_CIC, test_loader_CIC, val_loader_CIC = creator.balancer(df_train_CIC, df_test_CIC, df_val_CIC)

In [None]:
train_loaders = [train_loader_UNSW, train_loader_BOT, train_loader_CIC]
test_loaders = [test_loader_UNSW, test_loader_BOT, test_loader_CIC]
val_loaders = [val_loader_UNSW, val_loader_BOT, val_loader_CIC]

In [None]:
INPUT_DIM = 32

class IDSBranchyNet(nn.Module):
    def __init__(self, input_dim=INPUT_DIM, num_classes=2):
        super(IDSBranchyNet, self).__init__()
        
        self.shared_layers = nn.Sequential(
            nn.Linear(input_dim, input_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(input_dim * 2, input_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        
        self.exit1_layers = nn.Sequential(
            nn.Linear(input_dim * 2, num_classes)
        )
        
        self.exit2_layers = nn.Sequential(
            nn.Linear(input_dim * 2, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, num_classes)
        )

    def forward_exit1(self, x):
        features = self.shared_layers(x)
        return self.exit1_layers(features)

    def forward_exit2(self, x):
        features = self.shared_layers(x)
        return self.exit2_layers(features)

model = IDSBranchyNet()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
def exit_soft_count(logits, thresholds, device, sharpness=50.0):
    probs = F.softmax(logits, dim=1)
    confidence, pred_class = torch.max(probs, dim=1)
    
    if not isinstance(thresholds, torch.Tensor):
        t_tensor = torch.tensor(thresholds, dtype=torch.float32).to(device)
    else:
        t_tensor = thresholds.to(device)
        
    selected_thresholds = t_tensor[pred_class]
    
    soft_decisions = torch.sigmoid(sharpness * (confidence - selected_thresholds))
    
    return soft_decisions.sum()

In [None]:
def train_model(model, train_loaders, val_loaders, epochs, lr, attack_threshold, normal_threshold, desired_exit_ex1, device, patience=15):

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.001, patience=7)
    
    model.to(device)
    
    criterion = nn.CrossEntropyLoss() 
    
    metrics = [
        'loss1_a', 'loss1_b', 'loss1_c', 'loss_ex1_avg',
        'loss2_a', 'loss2_b', 'loss2_c', 'loss_ex2_avg',
        'total_loss', 'joint_loss', 'ex1_pct', 'ex1_diff'
    ]

    history = {
        'train': {k: [] for k in metrics},
        'val': {k: [] for k in metrics}
    }
    
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    max_train_batches = max(len(l) for l in train_loaders) 
    train_iter_loaders = [itertools.cycle(l) if len(l) < max_train_batches else l for l in train_loaders]
    
    max_val_batches = max(len(l) for l in val_loaders)
    val_iter_loaders = [itertools.cycle(l) if len(l) < max_val_batches else l for l in val_loaders]

    class_thresholds = [normal_threshold, attack_threshold]

    for epoch in range(epochs):
        model.train()
        
        running_metrics = {k: 0.0 for k in metrics}
        total_steps = 0

        loader_iterators = [iter(l) for l in train_iter_loaders]
                
        for _ in range(max_train_batches):
            try:
                batches = [next(it) for it in loader_iterators]
            except StopIteration:
                break
            
            optimizer.zero_grad()
            
            (inputs_a, labels_a) = batches[0]
            (inputs_b, labels_b) = batches[1]
            (inputs_c, labels_c) = batches[2]
            
            inputs_a, labels_a = inputs_a.to(device), labels_a.to(device)
            inputs_b, labels_b = inputs_b.to(device), labels_b.to(device)
            inputs_c, labels_c = inputs_c.to(device), labels_c.to(device)
            total_samples = inputs_a.size(0) + inputs_b.size(0) + inputs_c.size(0)

            # ------------------------------------------------------------------------------------------
            
            out_a_1 = model.forward_exit1(inputs_a)
            loss_a_1 = criterion(out_a_1, labels_a)
            
            out_a_2 = model.forward_exit2(inputs_a)
            loss_a_2 = criterion(out_a_2, labels_a)

            # ------------------------------------------------------------------------------------------

            out_b_1 = model.forward_exit1(inputs_b)
            loss_b_1 = criterion(out_b_1, labels_b)
            
            out_b_2 = model.forward_exit2(inputs_b)
            loss_b_2 = criterion(out_b_2, labels_b)

            # ------------------------------------------------------------------------------------------

            out_c_1 = model.forward_exit1(inputs_c)
            loss_c_1 = criterion(out_c_1, labels_c)
            
            out_c_2 = model.forward_exit2(inputs_c)
            loss_c_2 = criterion(out_c_2, labels_c)

            # ------------------------------------------------------------------------------------------

            num_ex1_a = exit_soft_count(out_a_1, class_thresholds, device)
            num_ex1_b = exit_soft_count(out_b_1, class_thresholds, device)
            num_ex1_c = exit_soft_count(out_c_1, class_thresholds, device)

            count_ex1 = num_ex1_a + num_ex1_b + num_ex1_c

            # ------------------------------------------------------------------------------------------

            loss_ex1_avg = (loss_a_1 + loss_b_1 + loss_c_1) / 3
            loss_ex2_avg = (loss_a_2 + loss_b_2 + loss_c_2) / 3

            # ------------------------------------------------------------------------------------------
            
            joint_loss = loss_ex1_avg * 1 + loss_ex2_avg * 1
            
            ex1_percentage = count_ex1 / total_samples

            ex1_diff = abs(desired_exit_ex1 - ex1_percentage)
            
            total_loss = joint_loss + ex1_diff * 1

            # ------------------------------------------------------------------------------------------
                        
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_metrics['loss1_a'] += loss_a_1.item()
            running_metrics['loss1_b'] += loss_b_1.item()
            running_metrics['loss1_c'] += loss_c_1.item()
            running_metrics['loss_ex1_avg'] += loss_ex1_avg.item()
            
            running_metrics['loss2_a'] += loss_a_2.item()
            running_metrics['loss2_b'] += loss_b_2.item()
            running_metrics['loss2_c'] += loss_c_2.item()
            running_metrics['loss_ex2_avg'] += loss_ex2_avg.item()
            
            running_metrics['total_loss'] += total_loss.item()
            
            running_metrics['joint_loss'] += joint_loss.item()
            running_metrics['ex1_pct'] += ex1_percentage if isinstance(ex1_percentage, (float, int)) else ex1_percentage.item()
            running_metrics['ex1_diff'] += ex1_diff.item() if isinstance(ex1_diff, torch.Tensor) else ex1_diff
                        
            total_steps += 1

        for key in metrics:
            history['train'][key].append(running_metrics[key] / total_steps)
        
        epoch_train_loss = history['train']['total_loss'][-1]
        epoch_train_loss1 = history['train']['loss_ex1_avg'][-1]
        epoch_train_loss2 = history['train']['loss_ex2_avg'][-1]

        model.eval()
        running_metrics_val = {k: 0.0 for k in metrics}
        total_steps_val = 0
        
        val_loader_iterators = [iter(l) for l in val_iter_loaders]
        
        with torch.no_grad():
            for _ in range(max_val_batches):
                try:
                    batches = [next(it) for it in val_loader_iterators]
                except StopIteration:
                    break

                (inputs_a, labels_a) = batches[0]
                (inputs_b, labels_b) = batches[1]
                (inputs_c, labels_c) = batches[2]
                
                inputs_a, labels_a = inputs_a.to(device), labels_a.to(device)
                inputs_b, labels_b = inputs_b.to(device), labels_b.to(device)
                inputs_c, labels_c = inputs_c.to(device), labels_c.to(device)
                total_samples = inputs_a.size(0) + inputs_b.size(0) + inputs_c.size(0)
    
                # ------------------------------------------------------------------------------------------
                
                out_a_1 = model.forward_exit1(inputs_a)
                loss_a_1 = criterion(out_a_1, labels_a)
                
                out_a_2 = model.forward_exit2(inputs_a)
                loss_a_2 = criterion(out_a_2, labels_a)
    
                # ------------------------------------------------------------------------------------------
    
                out_b_1 = model.forward_exit1(inputs_b)
                loss_b_1 = criterion(out_b_1, labels_b)
                
                out_b_2 = model.forward_exit2(inputs_b)
                loss_b_2 = criterion(out_b_2, labels_b)
    
                # ------------------------------------------------------------------------------------------
    
                out_c_1 = model.forward_exit1(inputs_c)
                loss_c_1 = criterion(out_c_1, labels_c)
                
                out_c_2 = model.forward_exit2(inputs_c)
                loss_c_2 = criterion(out_c_2, labels_c)
    
                # ------------------------------------------------------------------------------------------
    
                num_ex1_a = exit_soft_count(out_a_1, class_thresholds, device)
                num_ex1_b = exit_soft_count(out_b_1, class_thresholds, device)
                num_ex1_c = exit_soft_count(out_c_1, class_thresholds, device)
    
                count_ex1 = num_ex1_a + num_ex1_b + num_ex1_c
    
                # ------------------------------------------------------------------------------------------
    
                loss_ex1_avg = (loss_a_1 + loss_b_1 + loss_c_1) / 3
                loss_ex2_avg = (loss_a_2 + loss_b_2 + loss_c_2) / 3
    
                # ------------------------------------------------------------------------------------------
                
                joint_loss = loss_ex1_avg * 1 + loss_ex2_avg * 1
                
                ex1_percentage = count_ex1 / total_samples
    
                ex1_diff = abs(desired_exit_ex1 - ex1_percentage)
                
                total_loss = joint_loss + ex1_diff * 1
    
                # ------------------------------------------------------------------------------------------
                                                                                                            
                running_metrics_val['loss1_a'] += loss_a_1.item()
                running_metrics_val['loss1_b'] += loss_b_1.item()
                running_metrics_val['loss1_c'] += loss_c_1.item()
                running_metrics_val['loss_ex1_avg'] += loss_ex1_avg.item()
                
                running_metrics_val['loss2_a'] += loss_a_2.item()
                running_metrics_val['loss2_b'] += loss_b_2.item()
                running_metrics_val['loss2_c'] += loss_c_2.item()
                running_metrics_val['loss_ex2_avg'] += loss_ex2_avg.item()
                
                running_metrics_val['total_loss'] += total_loss.item()

                running_metrics_val['joint_loss'] += joint_loss.item()
                running_metrics_val['ex1_pct'] += ex1_percentage if isinstance(ex1_percentage, (float, int)) else ex1_percentage.item()
                running_metrics_val['ex1_diff'] += ex1_diff.item() if isinstance(ex1_diff, torch.Tensor) else ex1_diff
                
                total_steps_val += 1

        for key in metrics:
            history['val'][key].append(running_metrics_val[key] / total_steps_val)

        epoch_val_loss = history['val']['total_loss'][-1]
        epoch_val_loss1 = history['val']['loss_ex1_avg'][-1]
        epoch_val_loss2 = history['val']['loss_ex2_avg'][-1]
        
        print(f'Epoch [{epoch+1}/{epochs}] | Train Total Loss: {epoch_train_loss:.4f} | Val Total Loss: {epoch_val_loss:.4f}\nTrain Loss Exit 1: {epoch_train_loss1:.4f} | Train Loss Exit 2: {epoch_train_loss2:.4f}\nVal Loss Exit 1: {epoch_val_loss1:.4f} | Val Loss Exit 2: {epoch_val_loss2:.4f}\n\n')        
        
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("Early stopping...")
                if best_model_state: model.load_state_dict(best_model_state)
                break
                
        scheduler.step(epoch_val_loss)
                
    epochs_range = range(1, len(history['train']['total_loss']) + 1)
    
    fig, axs = plt.subplots(1, 5, figsize=(45, 7))
    
    ax = axs[0]
    ax.set_title("Exit 1")
    ax.plot(epochs_range, history['train']['loss1_a'], label='Tr A', alpha=0.6)
    ax.plot(epochs_range, history['train']['loss1_b'], label='Tr B', alpha=0.6)
    ax.plot(epochs_range, history['train']['loss1_c'], label='Tr C', alpha=0.6)
    ax.plot(epochs_range, history['train']['loss_ex1_avg'], label='Tr Avg', linewidth=2)
    ax.plot(epochs_range, history['val']['loss1_a'], label='Val A', linestyle='--')
    ax.plot(epochs_range, history['val']['loss1_b'], label='Val B', linestyle='--')
    ax.plot(epochs_range, history['val']['loss1_c'], label='Val C', linestyle='--')
    ax.plot(epochs_range, history['val']['loss_ex1_avg'], label='Val Avg', linestyle='--', linewidth=2)
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Loss')
    ax.legend()
    ax.grid(True)

    ax = axs[1]
    ax.set_title("Exit 2")
    ax.plot(epochs_range, history['train']['loss2_a'], label='Tr A', alpha=0.6)
    ax.plot(epochs_range, history['train']['loss2_b'], label='Tr B', alpha=0.6)
    ax.plot(epochs_range, history['train']['loss2_c'], label='Tr C', alpha=0.6)
    ax.plot(epochs_range, history['train']['loss_ex2_avg'], label='Tr Avg', linewidth=2)
    ax.plot(epochs_range, history['val']['loss2_a'], label='Val A', linestyle='--')
    ax.plot(epochs_range, history['val']['loss2_b'], label='Val B', linestyle='--')
    ax.plot(epochs_range, history['val']['loss2_c'], label='Val C', linestyle='--')
    ax.plot(epochs_range, history['val']['loss_ex2_avg'], label='Val Avg', linestyle='--', linewidth=2)
    ax.set_xlabel('Epochs')
    ax.legend()
    ax.grid(True)

    ax = axs[2]
    ax.set_title("Global Optimization")
    ax.plot(epochs_range, history['train']['total_loss'], label='Tr Total', color='orange')
    ax.plot(epochs_range, history['val']['total_loss'], label='Val Total', color='orange', linestyle='--')
    ax.plot(epochs_range, history['train']['joint_loss'], label='Tr Joint', color='blue', alpha=0.7)
    ax.plot(epochs_range, history['val']['joint_loss'], label='Val Joint', color='blue', linestyle='--', alpha=0.7)
    ax.set_xlabel('Epochs')
    ax.legend()
    ax.grid(True)

    ax = axs[3]
    ax.set_title("Exit 1 Rate (Percentage)")
    ax.plot(epochs_range, history['train']['ex1_pct'], label='Tr Rate', color='green')
    ax.plot(epochs_range, history['val']['ex1_pct'], label='Val Rate', color='green', linestyle='--')
    ax.axhline(y=desired_exit_ex1, color='red', linestyle=':', label='Desired')
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Ratio (0-1)')
    ax.legend()
    ax.grid(True)

    ax = axs[4]
    ax.set_title("Exit 1 Rate Error (|Desired - Actual|)")
    ax.plot(epochs_range, history['train']['ex1_diff'], label='Tr Diff', color='purple')
    ax.plot(epochs_range, history['val']['ex1_diff'], label='Val Diff', color='purple', linestyle='--')
    ax.set_xlabel('Epochs')
    ax.legend()
    ax.grid(True)

    plt.tight_layout()
    plt.show()

    return history

In [None]:
def precompute_outputs(model, loader, device):
    model.to(device)
    model.eval()

    all_probs_exit1 = []
    all_probs_exit2 = []
    all_labels = []
    
    with torch.no_grad():
        for samples, labels in loader:
            samples = samples.to(device)
            
            output1 = model.forward_exit1(samples)
            probs1 = F.softmax(output1, dim=1)
            
            output2 = model.forward_exit2(samples)
            probs2 = F.softmax(output2, dim=1)

            all_probs_exit1.append(probs1.cpu())
            all_probs_exit2.append(probs2.cpu())
            all_labels.append(labels.cpu())

    probs_exit1 = torch.cat(all_probs_exit1)
    probs_exit2 = torch.cat(all_probs_exit2)
    targets = torch.cat(all_labels)
    
    return probs_exit1, probs_exit2, targets

In [None]:
def grid_search_rejection(probs_exit1, probs_exit2, targets):
    threshold_values = [round(x * 0.05, 2) for x in range(10, 21)] 
    
    results_rejection = []
    results_error = []
    results_exit1_rate = []
    results_exit2_rate = []
    
    best_error = float('inf')
    best_cm = None
    best_params = None
    best_exit1 = 0.0
    best_exit2 = 0.0
    best_rej = 0.0
    
    confs1, preds1 = torch.max(probs_exit1, dim=1)
    confs2, preds2 = torch.max(probs_exit2, dim=1)
    total_samples = len(targets)
    
    print(f"Starting Grid Search with {len(threshold_values)**4} combinations...")
    
    count = 0
    for t_atk1, t_norm1, t_atk2, t_norm2 in itertools.product(threshold_values, repeat=4):
        
        thresh_tensor1 = torch.where(preds1 == 1, t_atk1, t_norm1)
        mask_exit1 = confs1 > thresh_tensor1
        
        thresh_tensor2 = torch.where(preds2 == 1, t_atk2, t_norm2)
        mask_exit2 = (~mask_exit1) & (confs2 > thresh_tensor2)
        
        mask_rejected = (~mask_exit1) & (~mask_exit2)
        
        cnt_exit1 = mask_exit1.sum().item()
        cnt_exit2 = mask_exit2.sum().item()
        cnt_rejected = mask_rejected.sum().item()
        
        rate_exit1 = (cnt_exit1 / total_samples)
        rate_exit2 = (cnt_exit2 / total_samples)
        rate_rejection = (cnt_rejected / total_samples)
        
        if cnt_rejected == total_samples:
            continue

        final_preds = preds2.clone()
        final_preds[mask_exit1] = preds1[mask_exit1]
        
        mask_accepted = ~mask_rejected
        y_true_accepted = targets[mask_accepted]
        y_pred_accepted = final_preds[mask_accepted]
        
        current_f1 = f1_score(y_true_accepted.numpy(), y_pred_accepted.numpy(), zero_division=0)
        error_rate_f1 = (1 - current_f1)
        
        results_rejection.append(rate_rejection)
        results_error.append(error_rate_f1)
        results_exit1_rate.append(rate_exit1)
        results_exit2_rate.append(rate_exit2)
        
        if rate_rejection < 11.0:
            if error_rate_f1 < best_error:
                best_error = error_rate_f1
                best_cm = confusion_matrix(y_true_accepted.numpy(), y_pred_accepted.numpy())
                best_params = (t_atk1, t_norm1, t_atk2, t_norm2)
                best_exit1 = rate_exit1
                best_exit2 = rate_exit2
                best_rej = rate_rejection
        
        count += 1
        if count % 5000 == 0:
            print(f"{count} combinations processed...")

    return (results_rejection, results_error, results_exit1_rate, results_exit2_rate, 
            best_cm, best_params, best_exit1, best_exit2, best_rej, best_error)

In [None]:
def evaluate_model_with_grid_search(model, loader, device):
    probs1, probs2, targets = precompute_outputs(model, loader, device)
    
    start_time = time.time()
    (rejection, error, exit1_rates, exit2_rates, 
     best_cm, best_params, best_exit1, best_exit2, best_rej, best_error) = grid_search_rejection(probs1, probs2, targets)
    
    print(f"Grid Search concluído em {time.time() - start_time:.2f} segundos.")
    
    plt.figure(figsize=(10, 8))

    sc = plt.scatter(rejection, error, 
                     c=exit1_rates,
                     cmap='viridis',
                     s=10, alpha=0.7)
    
    cbar = plt.colorbar(sc)
    cbar.set_label('Taxa de Uso da Saída 1 (%)')
    
    plt.xlabel('Taxa de Rejeição (%)')
    plt.ylabel('Erro (1 - F1 Score) (%)')
    plt.ylim(bottom=0)
    plt.xlim(0, 20)
    plt.ylim(0, 10)
    
    plt.show()
    
    return {
        'rejection': rejection,
        'error': error,
        'exit1_rates': exit1_rates,
        'exit2_rates': exit2_rates,
        'best_cm': best_cm,
        'best_params': best_params,
        'best_exit1': best_exit1,
        'best_exit2': best_exit2,
        'best_rej': best_rej,
        'best_error': best_error
    }

In [None]:
modelname = 'arq_v2_abor_v3'
modelname

In [None]:
epochs = 500

history = train_model(
    model.to(device), 
    train_loaders, 
    val_loaders, 
    epochs,
    lr=0.0001,
    attack_threshold=0.85,
    normal_threshold=0.85,
    desired_exit_ex1=0.4,
    device=device,
)

torch.save(model.state_dict(), f'models/{modelname}.pth')
print(f"\nModelo treinado e salvo em 'models/{modelname}.pth'")

df_train = pd.DataFrame(history['train'])
df_val = pd.DataFrame(history['val'])

df_train = df_train.add_prefix('train_')
df_val = df_val.add_prefix('val_')

df_history = pd.concat([df_train, df_val], axis=1)

df_history.insert(0, 'epoch', range(1, len(df_history) + 1))

csv_filename = f'logs/{modelname}_history.csv'
df_history.to_csv(csv_filename, index=False)
print(f"Curvas de aprendizado salvas em '{csv_filename}'")

In [None]:
loaders_para_concatenar = [test_loader_UNSW, test_loader_BOT, test_loader_CIC]

datasets = [loader.dataset for loader in loaders_para_concatenar]

combined_dataset = ConcatDataset(datasets)

bs = loaders_para_concatenar[0].batch_size
combined_loader = DataLoader(combined_dataset, batch_size=bs, shuffle=False, num_workers=4)

print(f"Total de amostras combinadas: {len(combined_dataset)}")
print(f"Total de batches: {len(combined_loader)}")