In [1]:
import polars as pl
import numpy as np
from datetime import datetime
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, balanced_accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [2]:
df_polars = pl.read_parquet('dataset.parquet')

In [3]:
df_polars = df_polars.sample(fraction=0.01, seed=42)

In [4]:
df_polars = df_polars.with_columns([
    pl.col('duration').fill_null(0),
    pl.col('orig_bytes').fill_null(0),
    pl.col('resp_bytes').fill_null(0)
])

In [5]:
df_polars = df_polars.drop(["ts", "uid", "id.orig_h", "id.resp_h", "local_orig", "local_resp", "missed_bytes" , "tunnel_parents", "detailed-label", "__index_level_0__"])

In [6]:
X = df_polars.drop('label')
y = df_polars['label']       

In [7]:
scaler = MinMaxScaler()
X = scaler.fit_transform(X)

In [8]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

In [9]:
y_train_np = y_train.to_numpy()
y_test_np = y_test.to_numpy()

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train_np, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test_np, dtype=torch.float32)

# Treinamento

In [10]:
class GRUClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate):
        super(GRUClassifier, self).__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True, dropout=dropout_rate, bidirectional=False)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        _, hidden = self.gru(x)
        hidden = self.relu(hidden[-1])
        hidden = self.dropout(hidden)
        output = self.fc(hidden)
        return output


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_dim = X_train.shape[1]  
hidden_dim = 100  
dropout_rate = 0.2  
output_dim = 1  

model = GRUClassifier(input_dim, hidden_dim, output_dim, dropout_rate).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch_size = 5000
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

criterion = nn.BCEWithLogitsLoss() 
optimizer = optim.Adam(model.parameters(), lr=0.001)

results = []
epochs = 10
print(datetime.now)
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for data in train_loader:
        inputs, targets = data
        inputs, targets = inputs.float().to(device), targets.float().to(device)
        inputs = inputs.unsqueeze(1) 
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * inputs.size(0)
    print(f'Epoch {epoch+1}/{epochs}, fim em: {datetime.now()}')
    
    model.eval()
    with torch.no_grad():
        all_outputs = []
        all_targets = []
        for data in test_loader:
            inputs, targets = data
            inputs, targets = inputs.float().to(device), targets.float().to(device)
            inputs = inputs.unsqueeze(1)  
            outputs = model(inputs)
            all_outputs.append(outputs.cpu())
            all_targets.append(targets.cpu())
        
        all_outputs = torch.cat(all_outputs)
        all_targets = torch.cat(all_targets)

        y_pred = (all_outputs > 0.5).float().numpy()
        y_true = all_targets.numpy()
        print(f'Epoch {epoch+1}/{epochs}, avaliada em: {datetime.now()}')

        confusion = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = confusion.ravel()
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)
        recall = recall_score(y_true, y_pred)
        specificity = tn / (tn + fp)
        f1 = f1_score(y_true, y_pred)
        balanced_accuracy = balanced_accuracy_score(y_true, y_pred)
        false_alarm_rate = fp / (fp + tn) if (fp + tn) > 0 else 0
        

    results.append([epoch+1, accuracy, balanced_accuracy, precision, recall, specificity, f1, false_alarm_rate, tn, fp, fn, tp])




<built-in method now of type object at 0x00007FFA9A814FB0>
Epoch 1/10, fim em: 2024-09-16 10:43:14.452698
Epoch 1/10, avaliada em: 2024-09-16 10:43:17.096477
Epoch 2/10, fim em: 2024-09-16 10:43:24.300931
Epoch 2/10, avaliada em: 2024-09-16 10:43:26.748514
Epoch 3/10, fim em: 2024-09-16 10:43:33.541749
Epoch 3/10, avaliada em: 2024-09-16 10:43:35.947888
Epoch 4/10, fim em: 2024-09-16 10:43:42.778098
Epoch 4/10, avaliada em: 2024-09-16 10:43:45.103757
Epoch 5/10, fim em: 2024-09-16 10:43:52.122991
Epoch 5/10, avaliada em: 2024-09-16 10:43:54.466289
Epoch 6/10, fim em: 2024-09-16 10:44:01.297637
Epoch 6/10, avaliada em: 2024-09-16 10:44:03.805894
Epoch 7/10, fim em: 2024-09-16 10:44:10.850208
Epoch 7/10, avaliada em: 2024-09-16 10:44:13.410384
Epoch 8/10, fim em: 2024-09-16 10:44:21.694678
Epoch 8/10, avaliada em: 2024-09-16 10:44:24.134181
Epoch 9/10, fim em: 2024-09-16 10:44:31.879470
Epoch 9/10, avaliada em: 2024-09-16 10:44:34.330458
Epoch 10/10, fim em: 2024-09-16 10:44:41.378076
Ep

In [11]:
metrics_df = pl.DataFrame(
    results,
    schema=['Epoch', 'Accuracy', 'Balanced Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-score', 'False Alarm Rate', 'tn', 'fp', 'fn', 'tp']
)
metrics_df

  return dispatch(args[0].__class__)(*args, **kw)


Epoch,Accuracy,Balanced Accuracy,Precision,Recall,Specificity,F1-score,False Alarm Rate,tn,fp,fn,tp
i64,f64,f64,f64,f64,f64,f64,f64,i64,i64,i64,i64
1,0.839677,0.5,0.839677,1.0,0.0,0.912853,1.0,0,26705,0,139865
2,0.839677,0.5,0.839677,1.0,0.0,0.912853,1.0,0,26705,0,139865
3,0.934388,0.809266,0.932787,0.993444,0.625089,0.96216,0.374911,16693,10012,917,138948
4,0.977877,0.960196,0.987416,0.986222,0.93417,0.986819,0.06583,24947,1758,1927,137938
5,0.977247,0.959972,0.987476,0.9854,0.934544,0.986437,0.065456,24957,1748,2042,137823
6,0.977235,0.960026,0.987504,0.985357,0.934694,0.986429,0.065306,24961,1744,2048,137817
7,0.977775,0.961589,0.988085,0.985415,0.937764,0.986748,0.062236,25043,1662,2040,137825
8,0.97846,0.963678,0.988872,0.985436,0.941921,0.987151,0.058079,25154,1551,2037,137828
9,0.979156,0.965744,0.989647,0.985486,0.946003,0.987562,0.053997,25263,1442,2030,137835
10,0.979846,0.967776,0.990408,0.985543,0.950009,0.98797,0.049991,25370,1335,2022,137843
