# Final Notebook

From the modeling notebook:
- BERT model has the highest F1-score of 0.9977 at threshold 0.3
- We have stored the model as 'model.pkl'

In this notebook, we are going to create functions to predict on user input queries and to calculate f1-score if a list of queries and labels are provided.

## Importing libraries

In [None]:
!pip install tensorflow-text

Collecting tensorflow-text
  Downloading tensorflow_text-2.8.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.9 MB)
[K     |████████████████████████████████| 4.9 MB 8.3 MB/s 
Collecting tf-estimator-nightly==2.8.0.dev2021122109
  Downloading tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)
[K     |████████████████████████████████| 462 kB 48.2 MB/s 
Installing collected packages: tf-estimator-nightly, tensorflow-text
Successfully installed tensorflow-text-2.8.1 tf-estimator-nightly-2.8.0.dev2021122109


In [None]:
from sklearn.metrics import f1_score
import numpy as np
import pandas as pd
import re
import tensorflow_text as text
import joblib

In [None]:
text.__version__

'2.8.1'

In [None]:
joblib.__version__

'1.1.0'

## Preprocessing Function

From the EDA notebook, we have seen that the following preprocessing is required for our case.

In [None]:
num_reg = re.compile(r'[0-9]+\b')
hex_reg = re.compile(r'0x[A-Fa-f0-9]+\b') # regex to match hexadecimal numbers
def replace_num_hex(query):
    '''
    This function replaces numbers with
    <num> and hexadecimals wtih <hex>
    tokens respectively
    :param query: (str) SQL query
    :returns: (str) processed query
    '''
    q = num_reg.sub('<num>', query)
    q = hex_reg.sub('<hex>', q)
    return q

# Loading Model

Loading model from pickle file

In [None]:
model = joblib.load('model.pkl')

# Setting optimum threshold

We get the optimum threshold from modeling.

In [None]:
OPTIMUM_THRESHOLD = 0.3

# Function to predict if query is SQLIA or not

In [None]:
def final_fun1(query):
  """
  Predicts if the query is an SQLIA or not.
  Inputs:
  query(str): SQL query
  Returns:
  prediction(bool): True if query is SQLIA
  """
  try:
    assert type(query) == str
    preprocessed_input = replace_num_hex(query)
    model_input = np.ravel(preprocessed_input)
    model_output = model.predict(model_input)
    prediction = model_output[0][0] > OPTIMUM_THRESHOLD
    return prediction
  except AssertionError:
    print("Enter str input. Current input is of type {}.".format(type(query)))

In [None]:
final_fun1("WHERE password='abc123' OR '1'='1'")

True

In [None]:
final_fun1(123)

Enter str input. Current input is of type <class 'int'>.


# Function to predict and calculate F1-score for given queries and labels

In [None]:
def final_fun2(queries, labels):
  """
  Returns the f1-score of predictions for given queries.
  Inputs:
  queries(list): list of SQL queries of type str
  labels(list): list of integer labels for the queries. 
  Label should be either 0 or 1.
  Returns:
  f_score(float): f1-score of the predictions
  """
  try:
    assert type(queries) == list
    try:
      assert type(labels) == list
      preprocessed_queries = list(map(replace_num_hex, queries))
      model_outputs = model.predict(preprocessed_queries)
      predictions = model_outputs > OPTIMUM_THRESHOLD
      f_score = f1_score(labels, predictions)
      return f_score
    except AssertionError:
      print("Wrong input for labels: Enter list of integer inputs. Current input is of type {}.".format(type(labels)))
  except AssertionError:
    print("Wrong input for queries: Enter list of string inputs. Current input is of type {}.".format(type(queries)))

In [None]:
final_fun2(["SELECT employee_name FROM employees", "or pg_sleep ( __TIME__ )"],
           [0, 1])

1.0

In [None]:
final_fun2("WHERE password='abc123' OR '1'='1'", 1)

Wrong input for queries: Enter list of string inputs. Current input is of type <class 'str'>.


In [None]:
final_fun2(["SELECT employee_name FROM employees", "or pg_sleep ( __TIME__ )"],
           1)

Wrong input for labels: Enter list of integer inputs. Current input is of type <class 'int'>.
