## Training a Naive Bayes classifier for detection of an attack

In [67]:
import pandas as pd
df = pd.read_csv("./data/sqli.csv", encoding='utf-16')

In [68]:
df['Sentence'] = df['Sentence'].apply(lambda x: str(x).lower())

In [69]:
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer(analyzer='word', ngram_range=(1, 2))

In [70]:
X = vectorizer.fit_transform(df['Sentence'].values.astype('U')).toarray()
y = df['Label']

In [71]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [72]:
from sklearn.naive_bayes import GaussianNB
nb_clf = GaussianNB()
nb_clf.fit(X_train, y_train)

In [73]:
from sklearn.metrics import accuracy_score
y_pred = nb_clf.predict(X_test)
accuracy_score(y_test, y_pred)

0.9773809523809524

In [None]:
def predict(query, variable, ip):
    injected_query = query.replace(variable, ip)
    features = vectorizer.transform([injected_query]).toarray()
    return nb_clf.predict(features)[0]

In [74]:
import nltk
from nltk.tokenize import RegexpTokenizer

def RETokenizer(query):
    pattern = r'\$?\w+'
    tokenizer = RegexpTokenizer(pattern)
    tokens = tokenizer.tokenize(query)
    return tokens

In [76]:
def sanitize(token):
   return token.replace("'", "''")

## Correcting the SQL query using pattern matching and tokenization

In [77]:
def find_vars(tokens):
    variables = []
    for token in tokens:
        if token.startswith("$") and query.index(token) > query.index("WHERE"):
            datatype = "int"
            if "date" in token:
                datatype = "date"
            if token.startswith("$is"):
                datatype = "bool"
            if query[query.index(token) - 1] in ["'", "\""]:
                datatype = "str"
            variables.append((token, datatype))
    return variables


In [78]:
def create_pstmt(query, variables):
    pstmt = query
    for name, datatype in variables:
        placeholder = "?"
        if datatype == "date":
            placeholder = "TO_DATE(?, 'YYYY-MM-DD')"
        pstmt = pstmt.replace(name, placeholder, 1)
    return pstmt

In [79]:
def correction(query):
    tokens = RETokenizer(query)
    print("Statement type", tokens[0])

    variables = find_vars(tokens)
    pstmt = create_pstmt(query, variables)
    print(pstmt)
    
    for i, var in enumerate(variables):
        print("bind(" + var[0] + ", " + str(i) + ")")

# File system parsing

In [121]:
import re

pattern = r"\"(SELECT|INSERT|UPDATE|DELETE)(.*?)\""

with open('sample.php', 'r') as f:
    php_code = f.read()

queries = [i + j for i, j in re.findall(pattern, php_code, re.DOTALL)]

In [80]:
print("Performing basic check for variables:")
vulnerable = []
for q in queries:
    tokens = RETokenizer(q)
    variables = [token for token in tokens if token.startswith("$") and q.index("WHERE") and q.index(token) > q.index("WHERE")]
    
    print(query)
    if len(variables) > 0:
        vulnerable.append((q, variables))
        print("Maybe vulnerable\n")
    else:
        print("Not vulnerable directly\n")

Performing basic check for variables:
SELECT * FROM orders WHERE order_date >= '$start_date' AND order_total < $max_total AND order_status = $status AND order_id IN ($order_ids) AND order_amount BETWEEN :min_amount AND $max_amount AND customer_id = $customer_id AND order_is_active = $is_active
Maybe vulnerable



In [81]:
for i in range(0, len(vulnerable)):
    print(str(i + 1) + ".", vulnerable[i][0])

1. SELECT * FROM orders WHERE order_date >= '$start_date' AND order_total < $max_total AND order_status = $status AND order_id IN ($order_ids) AND order_amount BETWEEN :min_amount AND $max_amount AND customer_id = $customer_id AND order_is_active = $is_active


In [82]:
q = int(input("Choose a query:")) - 1

In [83]:
for i in range(0, len(vulnerable[q][1])):
    print(str(i + 1) + ".", vulnerable[q][1][i][1:])

1. start_date
2. max_total
3. status
4. order_ids
5. max_amount
6. customer_id
7. is_active


In [84]:
v = int(input("Choose a variable to check:")) - 1

In [85]:
ip = input("Make an attack: ")

In [86]:
if (predict(vulnerable[q][0], vulnerable[q][1][v], ip) == 1):
    print("Possible attempt of SQL injection")
else:
    print("Not an SQL injection attack")

Not an SQL injection attack


In [87]:
c = int(input("Would you wish to correct this query? "))

if c == 1:
    correction(vulnerable[q][0])

Statement type SELECT
SELECT * FROM orders WHERE order_date >= '?' AND order_total < ? AND order_status = ? AND order_id IN (?) AND order_amount BETWEEN :min_amount AND ? AND customer_id = ? AND order_is_active = ?
bind($start_date, 0)
bind($max_total, 1)
bind($status, 2)
bind($order_ids, 3)
bind($max_amount, 4)
bind($customer_id, 5)
bind($is_active, 6)
