In [1]:
# !pip3 install cx_Oracle --user

In [2]:
# OracleDB CX Oracle Python library 
# https://developer.oracle.com/dsl/prez-python-queries.html

In [3]:
import pandas as pd
import numpy as np
import cx_Oracle
from sqlalchemy import types, create_engine
from types import *
from pprint import pprint
import cx_Oracle
import sqlalchemy as sa
from datetime import datetime
import base64
import hashlib
from Crypto import Random
from Crypto.Cipher import AES
import json
# encoding=utf8
import os
os.environ["PYTHONIOENCODING"] = "utf-8"

In [4]:
# Declare location of config file, encrypted connection strings here
# Assumes that these directories will be accessible to the code

keyLocation = "workspace/de/drm-de/connectionStrings/key.txt"
connectionStringLocation = "workspace/de/drm-de/connectionStrings/connectionString.txt"

configFile = "workspace/de/drm-de/dataValidation/config_validation.xlsx"
healthCheckSheet = "Health_Check"
parametersSheet = "Parameters"
dataQualityCheckList = ["Data_Ingestion", "Datetime_Check", "Condition_Check"]

In [5]:
class Encoder(object):

    def __init__(self, keyLocation, stringLocation): 
        self.bs = AES.block_size
        f = open(keyLocation,"r")
        keyValue = f.readline()
        f.close()
        self.key = hashlib.sha256(keyValue.encode()).digest()
        self.stringLocation = stringLocation

    def encrypt(self, raw):
        raw = self._pad(raw)
        iv = Random.new().read(AES.block_size)
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        return base64.b64encode(iv + cipher.encrypt(raw.encode()))

    def decrypt(self, enc):
        enc = base64.b64decode(enc)
        iv = enc[:AES.block_size]
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        return self._unpad(cipher.decrypt(enc[AES.block_size:])).decode('utf-8')

    def getConnectionString(self):
        f = open(self.stringLocation,"r")
        connectionString = f.readline()
        f.close()
        connectionString = self.decrypt(connectionString)
        return connectionString
        
    def _pad(self, s):
        return s + (self.bs - len(s) % self.bs) * chr(self.bs - len(s) % self.bs)

    def _unpad(self, s):
        return s[:-ord(s[len(s)-1:])]

In [6]:
encoder = Encoder(keyLocation, connectionStringLocation)
encodedConnStr = encoder.getConnectionString()
dictConnStr = json.loads(encodedConnStr)

dbType = dictConnStr['dbType']
dbUser = dictConnStr['dbUser']
dbPassword = dictConnStr['dbPassword']
dbHost = dictConnStr['dbHost']
dbName = dictConnStr['dbName']
dbPort = dictConnStr['dbPort']

oracleConnStr = 'oracle+cx_oracle://'+ dbUser + ':' + dbPassword + "@" + dbHost +':'+ dbPort +'/' + dbName
oracleDatabaseEngine = sa.create_engine(oracleConnStr)
oracleDbConnection = oracleDatabaseEngine.connect()

In [7]:
def runOracleQuery(sqlQuery):
    ##################################################################################
    # Function to run Oracle Query and print output
    # Takes input SQL Query and database congigurations
    # DB configuration is expected to be pre-declared in scope to the function call
    ##################################################################################
    try:
        oracleTns = cx_Oracle.makedsn(dbHost, dbPort, dbName)
        dbConnection = cx_Oracle.connect(dbUser, dbPassword, oracleTns)
        
        cursor = dbConnection.cursor()
        cursor.execute(sqlQuery)
        dbConnection.commit()

        try:
            data = cursor.fetchall()
            pprint(cursor.description)
            pprint(data)
        except:
            print("Not a data query, no data to print")

        cursor.close()
        dbConnection.close()
        print("Query executed")
        
    except Exception as e:
        print("ERROR: Something went wrong executing the query")
        print(type(e).__name__)   
        print(e)     

In [8]:
def checkIfTableExists(tableName):
    ##################################################################################
    # Function to check if table exists in Oracle DB
    # Takes tablename as the input
    # return True if table has data, False otherwise
    ##################################################################################
    try:
        oracleTns = cx_Oracle.makedsn(dbHost, dbPort, dbName)
        dbConnection = cx_Oracle.connect(dbUser, dbPassword, oracleTns)
        
        cursor = dbConnection.cursor()
        cursor.execute("SELECT * FROM "+tableName.strip() + " FETCH FIRST 10 ROWS ONLY")
        dbConnection.commit()
        
        return True
    
        cursor.close()
        dbConnection.close()
        
    except Exception as e:
        print("ERROR: Table '" + tableName.strip() + "' does not exist!")
        return False

In [9]:
def checkIfTableHasData(tableName):
    ##################################################################################
    # Function to check if table has data
    # Takes tablename as the input
    # return True if table has data, False otherwise
    ##################################################################################
    oracleDbConnection = oracleDatabaseEngine.connect()
    
    query = "SELECT COUNT(*) FROM " + tableName.strip()

    try:
        tableSize = pd.read_sql(query, con=oracleDbConnection)
        
        tableSize = int(tableSize['COUNT(*)'][0])
                
        if tableSize > 0:
            return True
        else:
            return False

    except Exception as e:
        print("ERROR: Table '" + tableName.strip() + "' does not have data!")
        oracleDbConnection.close()
        return False
        
    oracleDbConnection.close()    

In [10]:
def checkIfColumnsExist(tableName, columnNames):
    ##################################################################################
    # Function to check if columns exist in the table
    # Takes tablename and a list of comma-separated column names as the input
    # return the list of columns that do not exist in the table
    ##################################################################################
    try:
        oracleTns = cx_Oracle.makedsn(dbHost, dbPort, dbName)
        dbConnection = cx_Oracle.connect(dbUser, dbPassword, oracleTns)
        
        cursor = dbConnection.cursor()
        cursor.execute("SELECT * FROM "+tableName.strip()+ " FETCH FIRST 10 ROWS ONLY")
        dbConnection.commit()

        columnDescription = cursor.description
        
        columnList = []
        for columnString in columnDescription:
            columnList.append(columnString[0])

        cursor.close()
        dbConnection.close()
        
        flag = 1
        missingColumnList = []
        for columnName in columnNames.split(','):
            columnName = columnName.strip()
            if (columnName!="") and (columnName!="nan"):
                if columnName not in columnList:
                    print("ERROR: Column '" + columnName + "' does not exist!")
                    missingColumnList.append(columnName)
                    flag = 0

        return missingColumnList
        
    except Exception as e:
        print(type(e).__name__)   
        print(e)     

In [11]:
def checkColumnTypes(tableName, columnTypePairs):
    ##################################################################################
    # Function to check if columns has the expected data type in the table
    # Takes tablename and a list of columnName:columnType pairs as the input
    # return the list of columns that have incorrect data type
    ##################################################################################
    try:
        oracleTns = cx_Oracle.makedsn(dbHost, dbPort, dbName)
        dbConnection = cx_Oracle.connect(dbUser, dbPassword, oracleTns)
        
        cursor = dbConnection.cursor()
        cursor.execute("SELECT * FROM "+tableName.strip()+ " FETCH FIRST 10 ROWS ONLY")
        dbConnection.commit()

        columnDescription = cursor.description
                
        columnTypeDict = dict()
        for columnString in columnDescription:
            columnName = columnString[0]
            columnType = str(columnString[1]).split("'")[1].split('.')[1]
            columnTypeDict[columnName] = columnType

        cursor.close()
        dbConnection.close()
        
        flag = 1
        incorrectColumnTypeList = []
        
        for columnPair in columnTypePairs.split(','):
            if (columnPair!="") and (columnPair!="nan"):
                columnName = columnPair.split(':')[0].strip()
                columnType = columnPair.split(':')[1].strip()
                if columnName not in columnTypeDict:            
                    print("ERROR: Column '" + columnName + "' does not exist!")
                    incorrectColumnTypeList.append(columnName)
                    flag = 0
                else:
                    if columnType.lower() not in columnTypeDict[columnName].lower():
                        print("ERROR: Column '" + columnName + "' does not have the correct data type!")
                        incorrectColumnTypeList.append(columnName)
                        flag = 0

        return incorrectColumnTypeList
        
    except Exception as e:
        print(type(e).__name__)   
        print(e)

In [12]:
def checkCondition(tableName, condition, filterCondition, debugMode = True):
    ##################################################################################
    # Function to perform data quality check
    # Takes tablename, data quality check condition and filter condition
    # return percentage of rows in the table that pass data quality check condition
    ##################################################################################
    oracleDbConnection = oracleDatabaseEngine.connect()
    
    allRows_query = "SELECT COUNT(*) FROM " + tableName
    passRows_query = "SELECT COUNT(*) FROM " + tableName + " WHERE " + condition
    
    if filterCondition != "" and filterCondition != "nan":
        allRows_query = allRows_query + " WHERE " + filterCondition
        passRows_query = passRows_query + " AND " + filterCondition
    try:
        tableSize = pd.read_sql(allRows_query, con=oracleDbConnection)
        filteredTableSize = pd.read_sql(passRows_query, con=oracleDbConnection)
        
        tableSize = int(tableSize['COUNT(*)'][0])
        filteredTableSize = int(filteredTableSize['COUNT(*)'][0])
        
    except Exception as e:
        print("ERROR: Something went wrong while fetching the table")
        print("ERROR: " + str(type(e).__name__))
        print("ERROR: " + str(e))
        oracleDbConnection.close()
        
    oracleDbConnection.close()

    if filterCondition != "" and filterCondition != "nan":
        print("Checking if " + condition + " for " + filterCondition + " on " + tableName)
    else:       
        print("Checking if " + condition + " on " + tableName)

    successPercentage = (filteredTableSize/tableSize) * 100
    
    return successPercentage

In [13]:
def readParameters(Df):
    ##################################################################################
    # Function to read parameters in the configuration file and stored in a dictionary
    ##################################################################################
    dictParameters = dict()
    for index, row in Df.iterrows():
        key = str(row['key']).strip()
        value = str(row['value']).strip()
        dictParameters[key] = value
        
    return dictParameters
    

In [14]:
def healthCheck(configDf):
    ##################################################################################
    # Function to perform health check on the tables/columns listed in the configuration
    # Health check include:
    # 1. If table exists
    # 2. If table has data
    # 3. If column exists in the table
    # 4. If column has the correct data type
    ##################################################################################    
    dict_df = dict()
    dict_tableExistence = dict()
    dict_tableHasData = dict()
    dict_columnExistence = dict()
    dict_columnDatatype = dict()

    print("##########################################################")
    print("#                                          ")
    print("# Performing health check on tables     ")
    print("#                                          ")
    print("##########################################################\n")

    configDf['tableExistence'] = 'PASS'
    configDf['tableHasData'] = 'PASS'
    configDf['columnExistence'] = 'PASS'
    configDf['columnDatatype'] = 'PASS'
    
    for index, row in configDf.iterrows():

        tableName = str(row['tableName']).strip()
        columnName = str(row['columnName']).strip()
        columeType = str(row['columnType']).strip()

        if tableName == '':
            print("ERROR: 'tableName' should not be empty string. Incorrect format for index " + index)
            continue
        else:
            if tableName not in dict_df:
                dict_df[tableName] = dict()
                dict_tableExistence[tableName] = "PASS"
                dict_tableHasData[tableName] = "PASS"
                dict_columnExistence[tableName] = dict()
                dict_columnDatatype[tableName] = dict()

            dict_df[tableName][columnName] = columeType
            dict_columnExistence[tableName][columnName] = "PASS"
            dict_columnDatatype[tableName][columnName] = "PASS"
            

    
    for tableName in dict_df:
        print("-----------------------------------------------------------")
        print("Validating Table "+ tableName+"......\n") 
        bool_tableExists = checkIfTableExists(tableName)
        
        if bool_tableExists:            
            bool_tableHasData = checkIfTableHasData(tableName)
            
            if bool_tableHasData:
                columnNames = ""
                columeTypes = ""

                for columnName in dict_df[tableName]:
                    columnType = dict_df[tableName][columnName]
                    columnNames = columnNames + columnName + ","
                    columeTypes = columeTypes + columnName + ":" + columnType + ","

                columnNames = columnNames.rstrip(',')
                columeTypes = columeTypes.rstrip(',')

                missingColumns = checkIfColumnsExist(tableName, columnNames)

                incorrectColumnTypes = checkColumnTypes(tableName, columeTypes)
            else:
                dict_tableExistence[tableName] = "PASS"
                dict_tableHasData[tableName] = "FAIL"
                missingColumns = []
                incorrectColumnTypes = []                
        else:
            bool_tableHasData = False
            dict_tableExistence[tableName] = "FAIL"
            dict_tableHasData[tableName] = "FAIL"
            missingColumns = []
            incorrectColumnTypes = []
        
        if bool_tableExists and bool_tableHasData and len(missingColumns)==0 and len(incorrectColumnTypes)==0:
            print("\nPASS!\n")
        else:
            print("\nFAIL!\n")
            
        for columnName in dict_columnExistence[tableName]:
            if columnName in missingColumns:
                dict_columnExistence[tableName][columnName] = "FAIL"
                
        for columnName in dict_columnDatatype[tableName]:
            if columnName in incorrectColumnTypes:
                dict_columnDatatype[tableName][columnName] = "FAIL"                
            

    for index, row in configDf.iterrows():

        tableName = str(row['tableName']).strip()
        columnName = str(row['columnName']).strip()
        columeType = str(row['columnType']).strip()
        
        configDf.iloc[index]['tableExistence'] = dict_tableExistence[tableName]
        configDf.iloc[index]['tableHasData'] = dict_tableHasData[tableName]

        configDf.iloc[index]['columnExistence'] = dict_columnExistence[tableName][columnName]
        configDf.iloc[index]['columnDatatype'] = dict_columnDatatype[tableName][columnName]
    
            
    return configDf

In [None]:
def dataQualityCheck(sheet, configDf, dictParameters):
    ##################################################################################
    # Function to data quality check based on rules defined in the configuration
    ##################################################################################    
    print("##########################################################")
    print("#                                          ")
    print("# Performing data quality check on "+sheet)
    print("#                                          ")
    print("##########################################################\n")
        

    finalFlag = True
    
    configDf['Status'] = 'PASS'
    
    for index, row in configDf.iterrows():

        taskId = str(row['id'])
        if taskId == "":
            continue

        tableName = str(row['tableName']).strip()
        condition = str(row['condition']).strip()
        filterCondition = str(row['filterCondition']).strip()
        minPCT = str(row['minPCT']).strip()
        maxPCT = str(row['maxPCT']).strip()
        includeBoundary = str(row['includeBoundary']).strip()
        
        for key in dictParameters:
            if key in condition:
                value = dictParameters[key]
                condition = condition.replace(key, value)

        print()
        
        if includeBoundary.lower() == "false":
            includeBoundary = False
        else:
            includeBoundary = True
            
        success = checkCondition(tableName, condition, filterCondition)
        print(str(success) + "%")
        
        # Compare the success rate with the passing range defined in the configuration file
        flag = True
        if includeBoundary:
            if minPCT!="" and minPCT!="nan":
                if success<float(minPCT):
                    flag = False
                    finalFlag = False
                    print("Test Failed for minPCT " + minPCT + "%")    
                else:
                    print("Test Passed for minPCT " + minPCT + "%")   

            if maxPCT!="" and maxPCT!="nan":
                if success>float(maxPCT):
                    flag = False
                    finalFlag = False
                    print("Test Failed for maxPCT " + maxPCT + "%")    
                else:
                    print("Test Passed for maxPCT " + maxPCT + "%")  
        else:
            if minPCT!="" and minPCT!="nan":
                if success<=float(minPCT):
                    flag = False
                    finalFlag = False
                    print("Test Failed for minPCT " + minPCT + "%")    
                else:
                    print("Test Passed for minPCT " + minPCT + "%")   

            if maxPCT!="" and maxPCT!="nan":
                if success>=float(maxPCT):
                    flag = False
                    finalFlag = False
                    print("Test Failed for maxPCT " + maxPCT + "%")    
                else:
                    print("Test Passed for maxPCT " + maxPCT + "%")  
         
        if flag:
            configDf.loc[index,'Status'] = 'PASS'
        else:
            configDf.loc[index,'Status'] = 'FAIL'
            
    if finalFlag:
        print("\nPASS!\n")
    else:        
        print("\nFAIL!\n")
        
    return configDf 
            

In [None]:
# Main Function

writer = pd.ExcelWriter('workspace/de/drm-de/dataValidation/checkResult.xlsx', engine='xlsxwriter')

headCheckDf = pd.read_excel(open(configFile, 'rb'), sheet_name=healthCheckSheet)
headCheckDf = headCheckDf.dropna(how='all')
headCheckDf = healthCheck(headCheckDf)
headCheckDf.to_excel(writer, sheet_name=healthCheckSheet, index=False)
print()

parameterDf = pd.read_excel(open(configFile, 'rb'), sheet_name=parametersSheet)
dictParameters = readParameters(parameterDf)


for sheet in dataQualityCheckList:
    dataQualityCheckDf = pd.read_excel(open(configFile, 'rb'), sheet_name=sheet)
    dataQualityCheckDf = dataQualityCheckDf.dropna(how='all')
    dataQualityCheckDf = dataQualityCheck(sheet, dataQualityCheckDf, dictParameters)
    print()
    
    dataQualityCheckDf.to_excel(writer, sheet_name=sheet, index=False)

    
writer.save()

##########################################################
#                                          
# Performing health check on tables     
#                                          
##########################################################

-----------------------------------------------------------
Validating Table DP_INVENTORY......


PASS!

-----------------------------------------------------------
Validating Table DP_PROD_ACCURACY......


PASS!

-----------------------------------------------------------
Validating Table DP_MACH1_MASTER_NEW......


PASS!

-----------------------------------------------------------
Validating Table DP_MACH1_CUR_MONTH_YMSC......


PASS!

-----------------------------------------------------------
Validating Table DP_MACH1_PE_SYN_FACT......


PASS!

-----------------------------------------------------------
Validating Table DP_MACH1_SHIPTO......


PASS!

-----------------------------------------------------------
Validating Table DP_MACH1_PRO