In [1]:
pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
import pandas as pd
import numpy as np
import re
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List

# ----------------------------
# 1. INITIALIZE APP & LOAD MODELS
# ----------------------------
# This runs once when the API starts
app = FastAPI(title="Network Traffic Classifier API", version="1.0")

# Load the trained model and the one-hot encoder
model = joblib.load("Models/lightgbm_packet_model.pkl")
ohe = joblib.load("Models/ohe_protocol.pkl")
print("Model and encoder loaded successfully.")


# ----------------------------
# 2. DEFINE DATA INPUT MODEL
# ----------------------------
# This defines what a single packet's data should look like for an API request.
class Packet(BaseModel):
    Time: float
    Source: str
    Destination: str
    Protocol: str
    Length: int
    Info: str

class PacketList(BaseModel):
    packets: List[Packet]


# ----------------------------
# 3. FEATURE ENGINEERING FUNCTION
# ----------------------------
# This function replicates the feature engineering from your notebook.
def create_features(df: pd.DataFrame) -> pd.DataFrame:
    """Takes a raw packet dataframe and returns a dataframe with engineered features."""
    
    # Ensure Info is a string
    df['Info'] = df['Info'].astype(str)

    # Patterns and keywords from the notebook
    port_pattern = re.compile(r'(\d{1,5})\s*>\s*(\d{1,5})')
    flags = ['SYN', 'ACK', 'PSH', 'RST', 'FIN', 'URG', 'ECE', 'CWR']
    keywords = ['http', 'tls', 'quic', 'dns', 'get', 'post', 'rtp', 'rtcp', 'rtsp', 'ssl', 'video', 'audio', 'application data']

    # Basic text-derived features
    df['info_len'] = df['Info'].apply(len)
    df['digits_in_info'] = df['Info'].apply(lambda s: sum(ch.isdigit() for ch in s))

    # Extract ports
    ports = df['Info'].apply(lambda s: port_pattern.search(s))
    df['src_port'] = ports.apply(lambda m: int(m.group(1)) if m else np.nan)
    df['dst_port'] = ports.apply(lambda m: int(m.group(2)) if m else np.nan)

    # Flags presence
    for f in flags:
        df['flag_' + f] = df['Info'].apply(lambda s: 1 if f in s else 0)

    # Keywords indicators
    for k in keywords:
        col = 'has_' + k.replace(' ', '_')
        df[col] = df['Info'].str.lower().str.contains(k).fillna(False).astype(int)

    # Time delta (crucial for time-series features)
    df = df.sort_values('Time').reset_index(drop=True)
    df['delta_time'] = df['Time'].diff().fillna(0.0)
    
    return df

# ----------------------------
# 4. PREDICTION ENDPOINT
# ----------------------------
# This is the main endpoint that will receive data and return predictions.
@app.post("/predict")
def predict_traffic(data: PacketList):
    """
    Predicts the traffic type for a list of packets.
    
    - **packets**: A list of packet data objects.
    """
    # 1. Convert incoming JSON data to a Pandas DataFrame
    input_df = pd.DataFrame([p.dict() for p in data.packets])

    # 2. Create features using the same logic as in the notebook
    features_df = create_features(input_df)

    # 3. Prepare the feature matrix for the model (must match training)
    candidate_features = [
        'Length', 'info_len', 'digits_in_info', 'delta_time', 'src_port', 'dst_port'
    ] + [f'flag_{f}' for f in flags] + ['has_' + k.replace(' ', '_') for k in keywords] + ['Protocol']
    
    X_df = features_df[[c for c in candidate_features if c in features_df.columns]].copy()

    X_df['src_port'] = X_df['src_port'].fillna(-T1).astype(int)
    X_df['dst_port'] = X_df['dst_port'].fillna(-1).astype(int)
    if 'Length' in X_df:
        X_df['Length'] = X_df['Length'].fillna(X_df['Length'].median())

    # 4. Apply the OneHotEncoder and combine features robustly
    cat_cols = [c for c in ['Protocol'] if c in X_df.columns]
    num_cols = [c for c in X_df.columns if c not in cat_cols]

    # Get the numeric data and reset its index
    X_num = X_df[num_cols].astype(float).fillna(0.0).reset_index(drop=True)
    
    # Handle the categorical data
    if cat_cols:
        X_cat = X_df[cat_cols].fillna('missing').astype(str)
        P_sparse = ohe.transform(X_cat)
        
        # Robustly get feature names for different scikit-learn versions
        try:
            P_cols = list(ohe.get_feature_names_out(cat_cols))
        except AttributeError:
            P_cols = list(ohe.get_feature_names(cat_cols))

        # Create DataFrame from encoded data and reset its index
        P_df = pd.DataFrame(P_sparse.toarray(), columns=P_cols).reset_index(drop=True)
        
        # Concatenate numeric and encoded dataframes
        X_pre = pd.concat([X_num, P_df], axis=1)
    else:
        X_pre = X_num.copy()

    # Sanitize column names for LightGBM
    X_pre.columns = [re.sub(r'[^0-9a-zA-Z_]', '_', str(c)) for c in X_pre.columns]

    # 5. Make predictions
    predictions = model.predict(X_pre)

    # 6. Return the results
    input_df['predicted_label'] = predictions
    
    return {"predictions": input_df.to_dict(orient="records")}

# A simple root endpoint to check if the API is running
@app.get("/")
def read_root():
    return {"message": "Welcome to the Network Traffic Classifier API. Go to /docs for details."}

Model and encoder loaded successfully.
