# SQL Injection: Detection and Correction using Machine Learning and Natural Language Processing

## Training a Naive Bayes classifier for detection of an attack

In [1]:
import pandas as pd
df = pd.read_csv("../data/sqli/sqli.csv", encoding='utf-16')

In [2]:
df['Sentence'] = df['Sentence'].apply(lambda x: str(x).lower())

In [3]:
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer(analyzer='word', ngram_range=(1, 2))

In [4]:
X = vectorizer.fit_transform(df['Sentence'].values.astype('U')).toarray()
y = df['Label']

In [5]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [6]:
from sklearn.naive_bayes import GaussianNB
nb_clf = GaussianNB()
nb_clf.fit(X_train, y_train)

In [7]:
y_pred = nb_clf.predict(X_test)

In [8]:
from sklearn.metrics import accuracy_score, precision_score, recall_score

print("Accuracy:", accuracy_score(y_test, y_pred))
print("Precision:", precision_score(y_test, y_pred))
print("Recall:", recall_score(y_test, y_pred))

Accuracy: 0.9773809523809524
Precision: 0.9330855018587361
Recall: 0.996031746031746


In [8]:
def predict(ip):
    #injected_query = query.replace(variable, ip)
    features = vectorizer.transform([ip]).toarray()
    return nb_clf.predict(features)[0]

In [13]:
predict("hellowoeld")

1

In [10]:
import nltk
from nltk.tokenize import RegexpTokenizer

def RETokenizer(query):
    pattern = r'\$?\w+'
    tokenizer = RegexpTokenizer(pattern)
    tokens = tokenizer.tokenize(query)
    return tokens

In [11]:
def sanitize(token):
   return token.replace("'", "''")

## Correcting the SQL query using pattern matching and tokenization

In [26]:
def find_vars(tokens, query):
    variables = []
    for token in tokens:
        if token.startswith("$"): #and query.index(token) > query.index("WHERE"):
            datatype = "int"
            if "date" in token:
                datatype = "date"
            if token.startswith("$is"):
                datatype = "bool"
            if query[query.index(token) - 1] in ["'", "\""]:
                datatype = "str"
            variables.append((token, datatype))
    return variables


In [13]:
def create_pstmt(query, variables):
    pstmt = query
    for name, datatype in variables:
        placeholder = "?"
        if datatype == "date":
            placeholder = "TO_DATE(?, 'YYYY-MM-DD')"
        pstmt = pstmt.replace(name, placeholder, 1)
    return pstmt

In [25]:
def correction(query):
    tokens = RETokenizer(query)
    print("Statement type", tokens[0])

    variables = find_vars(tokens, query)
    pstmt = create_pstmt(query, variables)
    print(pstmt)
    
    for i, var in enumerate(variables):
        print("bind(" + var[0] + ", " + str(i) + ")")

## File system parsing

In [15]:
import re

pattern = r"\"(SELECT|INSERT|UPDATE|DELETE)(.*?)\""

with open('samples/sample.php', 'r') as f:
    php = f.read()

queries = [i + j for i, j in re.findall(pattern, php, re.DOTALL)]

In [16]:
print("Performing basic check for variables:")
vulnerable = []
for q in queries:
    tokens = RETokenizer(q)
    variables = [token for token in tokens if token.startswith("$")]
    
    print(q)
    if len(variables) > 0:
        vulnerable.append((q, variables))
        print("Maybe vulnerable\n")
    else:
        print("Not vulnerable directly\n")

Performing basic check for variables:
SELECT * FROM orders WHERE order_date >= '$start_date' AND order_total < $max_total AND order_status = $status AND order_id IN ($order_ids) AND order_amount BETWEEN :min_amount AND $max_amount AND customer_id = $customer_id AND order_is_active = $is_active
Maybe vulnerable

INSERT INTO MyGuests (firstname, lastname, email) VALUES ('John', 'Doe', 'john@example.com')
Not vulnerable directly

INSERT INTO tablename (firstname, lastname) VALUES ('John', 'Doe')
Not vulnerable directly

SELECT * FROM mytable WHERE name like '$name'
Maybe vulnerable

UPDATE mytable SET age = $age, isLogin = TRUE WHERE id = '$sessionId'
Maybe vulnerable



In [17]:
for i in range(0, len(vulnerable)):
    print(str(i + 1) + ".", vulnerable[i][0])

1. SELECT * FROM orders WHERE order_date >= '$start_date' AND order_total < $max_total AND order_status = $status AND order_id IN ($order_ids) AND order_amount BETWEEN :min_amount AND $max_amount AND customer_id = $customer_id AND order_is_active = $is_active
2. SELECT * FROM mytable WHERE name like '$name'
3. UPDATE mytable SET age = $age, isLogin = TRUE WHERE id = '$sessionId'


## Driver code

In [18]:
q = int(input("Choose a query:")) - 1

In [19]:
for i in range(0, len(vulnerable[q][1])):
    print(str(i + 1) + ".", vulnerable[q][1][i][1:])

1. age
2. sessionId


In [20]:
v = int(input("Choose a variable to check:")) - 1

In [21]:
ip = input("Make an attack: ")

In [22]:
if (predict(vulnerable[q][0], vulnerable[q][1][v], ip) == 1):
    print("Possible attempt of SQL injection")
else:
    print("Not an SQL injection attack")

Possible attempt of SQL injection


In [27]:
c = int(input("Would you wish to correct this query? "))

if c == 1:
    correction(vulnerable[q][0])

Statement type UPDATE
UPDATE mytable SET age = ?, isLogin = TRUE WHERE id = '?'
bind($age, 0)
bind($sessionId, 1)
