# GAF + LMDB + PyTorch Model Training Notebook
This notebook:
- Converts tabular network flows to GAF images (with GPU acceleration)
- Stores GAF images and labels in LMDB
- Loads data into PyTorch DataLoader
- Trains a CNN on the data
- Evaluates model and prints analysis

In [None]:

import os
import pandas as pd
import numpy as np
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import torch
import matplotlib.pyplot as plt
import io
import lmdb
import pickle


In [None]:

# ==== Set these ====
input_folder = './data'
output_folder = './output/'
lmdb_path = os.path.join(output_folder, 'gaf_images')
pca_components = 32
os.makedirs(output_folder, exist_ok=True)


In [None]:

def gaf_transform_torch(x, device='cpu'):
    min_x = x.min(dim=1, keepdim=True).values
    max_x = x.max(dim=1, keepdim=True).values
    scaled_x = (2 * (x - min_x) / (max_x - min_x + 1e-8)) - 1
    scaled_x = torch.clamp(scaled_x, -1, 1)
    phi = torch.arccos(scaled_x)
    gaf = torch.cos(phi.unsqueeze(2) + phi.unsqueeze(1))
    return gaf


In [None]:

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA:", torch.cuda.get_device_name(0))
elif getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple MPS (Metal Performance Shaders)")
else:
    device = torch.device('cpu')
    print("Using CPU")


In [None]:

dos_attacks = ['DoS GoldenEye', 'DoS Hulk', 'DoS Slowhttptest', 'DoS slowloris']
brute_force_attacks = ['FTP-Patator', 'SSH-Patator']
web_attacks = ['Web Attack � Brute Force', 'Web Attack � Sql Injection', 'Web Attack � XSS']

def map_to_broader_category(label):
    if label == 'BENIGN':
        return 'BENIGN'
    elif label == 'DDoS':
        return 'DDoS'
    elif label == 'PortScan':
        return 'PortScan'
    elif label in dos_attacks:
        return 'DoS'
    elif label in brute_force_attacks:
        return 'BruteForce'
    elif label in web_attacks:
        return 'WebAttack'
    elif label == 'Bot':
        return 'Bot'
    elif label == 'Infiltration':
        return 'Infiltration'
    elif label == 'Heartbleed':
        return 'Heartbleed'
    else:
        return 'Other'


In [None]:

env = lmdb.open(lmdb_path, map_size=int(1e12))
csv_files = [f for f in os.listdir(input_folder) if f.endswith('.csv')]
scaler = StandardScaler()
pca = PCA(n_components=pca_components)

with env.begin(write=True) as txn:
    for csv_file in csv_files:
        print(f"\n📄 Processing {csv_file}...")
        df_path = os.path.join(input_folder, csv_file)
        df = pd.read_csv(df_path)
        df.columns = df.columns.str.strip()
        if 'Label' not in df.columns:
            print(f"⚠️ Skipping {csv_file} — 'Label' column not found.")
            continue

        df = df.replace([np.inf, -np.inf], np.nan).dropna()
        labels = df['Label'].values
        features = df.select_dtypes(include=[np.number])
        features = features.replace([np.inf, -np.inf], np.nan).dropna()
        X = features.values
        y = labels

        unique_classes = np.unique(y)
        if len(unique_classes) > 1:
            print(f"Before SMOTE: {dict(zip(*np.unique(y, return_counts=True)))}")
            smote = SMOTE(random_state=42)
            X_res, y_res = smote.fit_resample(X, y)
            print(f"After SMOTE: {dict(zip(*np.unique(y_res, return_counts=True)))}")
        else:
            print(f"Skipping SMOTE for {csv_file} — only found one class: {unique_classes[0]}")
            X_res, y_res = X, y

        rus = RandomUnderSampler(sampling_strategy='auto', random_state=42)
        X_res, y_res = rus.fit_resample(X_res, y_res)
        print(f"After undersampling: {dict(zip(*np.unique(y_res, return_counts=True)))}")

        y_categorized = np.array([map_to_broader_category(label) for label in y_res])
        print(f"After category merging: {dict(zip(*np.unique(y_categorized, return_counts=True)))}")
        y_res = y_categorized

        features_scaled = scaler.fit_transform(X_res)
        features_reduced = pca.fit_transform(features_scaled)
        labels_res = y_res

        print(f"Generating GAF images and storing (label, data) in LMDB...")

        flows_torch = torch.tensor(features_reduced, dtype=torch.float32, device=device)
        gaf_images = gaf_transform_torch(flows_torch, device=device).cpu().numpy()

        for idx, (gaf_image, label) in enumerate(zip(gaf_images, labels_res)):
            try:
                buf = io.BytesIO()
                plt.imsave(buf, gaf_image, cmap='gray', format='png')
                buf.seek(0)
                image_bytes = buf.read()
                buf.close()
                key = f"{os.path.splitext(csv_file)[0]}_{idx:07d}".encode('utf-8')
                record = {"label": str(label), "data": image_bytes}
                txn.put(key, pickle.dumps(record))
            except Exception as e:
                print(f"⚠️ Error on flow {idx}: {e}")
                continue

print("\n✅ GAF image generation and LMDB (label, data) storage complete.")


## Load LMDB as PyTorch Dataset

In [None]:

from torch.utils.data import Dataset
import torchvision.transforms as T
from PIL import Image

class LMDBGAFDataset(Dataset):
    def __init__(self, lmdb_path, transform=None):
        self.env = lmdb.open(lmdb_path, readonly=True, lock=False)
        self.txn = self.env.begin()
        self.keys = []
        self.labels = []
        self.transform = transform
        with self.env.begin() as txn:
            for key, value in txn.cursor():
                record = pickle.loads(value)
                label = record['label']
                self.keys.append(key)
                self.labels.append(label)
        self.label_to_idx = {lbl: idx for idx, lbl in enumerate(sorted(set(self.labels)))}
    def __len__(self):
        return len(self.keys)
    def __getitem__(self, idx):
        key = self.keys[idx]
        value = self.txn.get(key)
        record = pickle.loads(value)
        label = record['label']
        image = Image.open(io.BytesIO(record['data'])).convert('L')
        if self.transform:
            image = self.transform(image)
        label_idx = self.label_to_idx[label]
        return image, label_idx


In [None]:

from torch.utils.data import random_split, DataLoader

transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.5], [0.5])
])

lmdb_path = os.path.join(output_folder, 'gaf_images.lmdb')
dataset = LMDBGAFDataset(lmdb_path, transform=transform)

n = len(dataset)
n_train = int(0.7 * n)
n_val = int(0.15 * n)
n_test = n - n_train - n_val
train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test])

batch_size = 64
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

num_classes = len(dataset.label_to_idx)
print("Classes:", dataset.label_to_idx)


## Define and Train Model

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

# --- Option to select model type ---
model_type = "resnet18"   # options: 'simple', 'resnet18', 'mobilenetv2'
print("Model type:", model_type)

input_size = dataset[0][0].shape[-1]

if model_type == "simple":
    class SimpleGAFNet(nn.Module):
        def __init__(self, num_classes, input_size=32):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
            self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
            self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.flatten_size = (input_size // 8) * (input_size // 8) * 64
            self.fc1 = nn.Linear(self.flatten_size, 128)
            self.fc2 = nn.Linear(128, num_classes)
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.conv3(x)))
            x = x.view(x.size(0), -1)
            x = F.relu(self.fc1(x))
            return self.fc2(x)
    model = SimpleGAFNet(num_classes=num_classes, input_size=input_size)
elif model_type == "resnet18":
    model = models.resnet18(pretrained=False)
    # Modify first conv for 1-channel input (from 3 channels)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    # Change the output layer
    model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_type == "mobilenetv2":
    model = models.mobilenet_v2(pretrained=False)
    # Change first conv for 1-channel input (replace with 1-in channel)
    model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
    # Change output classifier
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
else:
    raise ValueError("Unknown model_type")

model = model.to(device)
print("Model ready:", model_type)


In [None]:

import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        _, preds = out.max(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    return total_loss / total, correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            out = model(imgs)
            loss = criterion(out, labels)
            total_loss += loss.item() * imgs.size(0)
            _, preds = out.max(1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return total_loss / total, correct / total

for epoch in range(1, 11):  # 10 epochs
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    print(f"Epoch {epoch}: Train loss {train_loss:.4f}, acc {train_acc:.3f} | Val loss {val_loss:.4f}, acc {val_acc:.3f}")

test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"\nTest loss: {test_loss:.4f}, accuracy: {test_acc:.3f}")


## Model Analysis: Confusion Matrix and Classification Report

In [None]:

from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

model.eval()
all_labels = []
all_preds = []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        out = model(imgs)
        _, preds = out.max(1)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

print("Classification report:")
print(classification_report(all_labels, all_preds, target_names=dataset.label_to_idx.keys()))
cm = confusion_matrix(all_labels, all_preds)
print("Confusion matrix:\n", cm)
