In [1]:
import polars as pl
import numpy as np
from datetime import datetime
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, balanced_accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [2]:
df_polars = pl.read_parquet('dataset.parquet')

In [3]:
df_polars = df_polars.sample(fraction=0.01, seed=42)

In [4]:
df_polars = df_polars.with_columns([
    pl.col('duration').fill_null(0),
    pl.col('orig_bytes').fill_null(0),
    pl.col('resp_bytes').fill_null(0)
])

In [5]:
df_polars = df_polars.drop(["ts", "uid", "id.orig_h", "id.resp_h", "local_orig", "local_resp", "missed_bytes" , "tunnel_parents", "detailed-label", "__index_level_0__"])

In [6]:
X = df_polars.drop('label')
y = df_polars['label']       

In [7]:
scaler = MinMaxScaler()
X = scaler.fit_transform(X)

In [8]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

In [9]:
y_train_np = y_train.to_numpy()
y_test_np = y_test.to_numpy()

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train_np, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test_np, dtype=torch.float32)

# Treinamento

In [10]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, dropout=dropout_rate, bidirectional=False)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        lstm_out, (hidden, _) = self.lstm(x)
        hidden = self.relu(hidden[-1])
        hidden = self.dropout(hidden)
        output = self.fc(hidden)
        return output

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_dim = X_train.shape[1]  
hidden_dim = 100  
dropout_rate = 0.2  
output_dim = 1  

model = LSTMClassifier(input_dim, hidden_dim, output_dim, dropout_rate).to(device)

batch_size = 5000
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

criterion = nn.BCEWithLogitsLoss() 
optimizer = optim.Adam(model.parameters(), lr=0.001)

results = []
epochs = 10
print(datetime.now)
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for data in train_loader:
        inputs, targets = data
        inputs, targets = inputs.float().to(device), targets.float().to(device)
        inputs = inputs.unsqueeze(1) 
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * inputs.size(0)
    print(f'Epoch {epoch+1}/{epochs}, fim em: {datetime.now()}')
    
    model.eval()
    with torch.no_grad():
        all_outputs = []
        all_targets = []
        for data in test_loader:
            inputs, targets = data
            inputs, targets = inputs.float().to(device), targets.float().to(device)
            inputs = inputs.unsqueeze(1)  
            outputs = model(inputs)
            all_outputs.append(outputs.cpu())
            all_targets.append(targets.cpu())
        
        all_outputs = torch.cat(all_outputs)
        all_targets = torch.cat(all_targets)

        y_pred = (all_outputs > 0.5).float().numpy()
        y_true = all_targets.numpy()
        print(f'Epoch {epoch+1}/{epochs}, avaliada em: {datetime.now()}')

        confusion = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = confusion.ravel()
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)
        recall = recall_score(y_true, y_pred)
        specificity = tn / (tn + fp)
        f1 = f1_score(y_true, y_pred)
        balanced_accuracy = balanced_accuracy_score(y_true, y_pred)
        false_alarm_rate = fp / (fp + tn) if (fp + tn) > 0 else 0
        

    results.append([epoch+1, accuracy, balanced_accuracy, precision, recall, specificity, f1, false_alarm_rate, tn, fp, fn, tp])



<built-in method now of type object at 0x00007FFA9A814FB0>
Epoch 1/10, fim em: 2024-09-16 09:42:47.411553
Epoch 1/10, avaliada em: 2024-09-16 09:42:50.061152
Epoch 2/10, fim em: 2024-09-16 09:42:57.719669
Epoch 2/10, avaliada em: 2024-09-16 09:43:00.256513
Epoch 3/10, fim em: 2024-09-16 09:43:08.088709
Epoch 3/10, avaliada em: 2024-09-16 09:43:10.783507
Epoch 4/10, fim em: 2024-09-16 09:43:19.027756
Epoch 4/10, avaliada em: 2024-09-16 09:43:21.749582
Epoch 5/10, fim em: 2024-09-16 09:43:29.144207
Epoch 5/10, avaliada em: 2024-09-16 09:43:31.543609
Epoch 6/10, fim em: 2024-09-16 09:43:38.554801
Epoch 6/10, avaliada em: 2024-09-16 09:43:40.949802
Epoch 7/10, fim em: 2024-09-16 09:43:47.869846
Epoch 7/10, avaliada em: 2024-09-16 09:43:50.137648
Epoch 8/10, fim em: 2024-09-16 09:43:57.139823
Epoch 8/10, avaliada em: 2024-09-16 09:43:59.461845
Epoch 9/10, fim em: 2024-09-16 09:44:06.465869
Epoch 9/10, avaliada em: 2024-09-16 09:44:08.876059
Epoch 10/10, fim em: 2024-09-16 09:44:15.958672
Ep

In [11]:
metrics_df = pl.DataFrame(
    results,
    schema=['Epoch', 'Accuracy', 'Balanced Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-score', 'False Alarm Rate', 'tn', 'fp', 'fn', 'tp']
)
metrics_df

  return dispatch(args[0].__class__)(*args, **kw)


Epoch,Accuracy,Balanced Accuracy,Precision,Recall,Specificity,F1-score,False Alarm Rate,tn,fp,fn,tp
i64,f64,f64,f64,f64,f64,f64,f64,i64,i64,i64,i64
1,0.839677,0.5,0.839677,1.0,0.0,0.912853,1.0,0,26705,0,139865
2,0.839677,0.5,0.839677,1.0,0.0,0.912853,1.0,0,26705,0,139865
3,0.883821,0.643548,0.880309,0.997226,0.289871,0.935127,0.710129,7741,18964,388,139477
4,0.977955,0.960106,0.987354,0.98638,0.933833,0.986867,0.066167,24938,1767,1905,137960
5,0.977241,0.959953,0.987469,0.9854,0.934507,0.986433,0.065493,24956,1749,2042,137823
6,0.977223,0.960018,0.987503,0.985343,0.934694,0.986422,0.065306,24961,1744,2050,137815
7,0.977331,0.960234,0.987575,0.9854,0.935068,0.986486,0.064932,24971,1734,2042,137823
8,0.977709,0.961338,0.987986,0.985436,0.93724,0.986709,0.06276,25029,1676,2037,137828
9,0.978273,0.962962,0.988589,0.9855,0.940423,0.987042,0.059577,25114,1591,2028,137837
10,0.978832,0.964612,0.989207,0.985543,0.943681,0.987372,0.056319,25201,1504,2022,137843
