In [None]:
from pyspark import SparkContext
from pyspark.sql import SparkSession
import pydeequ
from pydeequ import Check, CheckLevel, AnalysisRunner
from pydeequ.analyzers import *
from pydeequ.suggestions import *
from pydeequ.repository import FileSystemMetricsRepository, ResultKey
from pydeequ.verification import VerificationSuite, VerificationResult
import time
import pytest
import pytest_check as pc
from pytest_check import check_func

# Global variables
tpass_count = 0
tfail_count = 0

#. Methods
def create_session():
    print("Creating Spark Session")
    return (SparkSession
        .builder
        .config("spark.jars.packages", pydeequ.deequ_maven_coord)
        .config("spark.jars.excludes", pydeequ.f2j_maven_coord)
        .getOrCreate())

def stop_spark_session():
    spark.sparkContext.stop()
    spark.stop()

def stop_spark_session():
    spark.sparkContext.stop()
    spark.stop()
    
@check_func
def check_results(checkResult):
    for check_json in checkResult:
        #assert check_json['constraint_status'] == "Success"
        if check_json['constraint_status'] == "Success": 
            print(f"Passed - \t{check_json['constraint']} passed")
            global tpass_count
            tpass_count = tpass_count + 1

        else:
            # print('We found errors in the data, the following constraints were not satisfied:')
            print(f"Failed - \t{check_json['constraint']} failed")
            global tfail_count
            tfail_count = tfail_count + 1

            
def analyze_file(files):
    for idx, file_data in enumerate(files):
        print("test set " + str(idx+1))
        print("Reading " + file_data[0] + " Data")
        df = spark.read.parquet(file_data[0])
        df.printSchema()
        #. Analysis
        print("NVD Data Analysis")
        analysisResult = AnalysisRunner(spark)\
                            .onData(df)\
                            .addAnalyzer(Size())\
                            .addAnalyzer(Completeness(file_data[1]))\
                            .addAnalyzer(Completeness(file_data[2]))\
                            .run()                 
        analysisResult_df = AnalyzerContext.successMetricsAsDataFrame(spark, analysisResult)
        analysisResult_df.show()
        print("********************************************************")

        
def generate_suggestions(files):
    for idx, file_data in enumerate(files):
        print("test suggestion set " + str(idx+1))
        print("Reading " + file_data[0] + " Data")
        df = spark.read.parquet(file_data[0])
        # Suggestions
        suggestionResult = ConstraintSuggestionRunner(spark) \
                     .onData(df) \
                     .addConstraintRule(DEFAULT()) \
                     .run()

        # Constraint Suggestions in JSON format
        print(json.dumps(suggestionResult, indent=2))
        print("********************************************************")
    
def check_constraints(files):
    check = Check(spark, CheckLevel.Warning, "NVD Intrim Parquet Data Check")
    for idx, file_data in enumerate(files):
        print("test set " + str(idx+1))
        print("Reading " + file_data[0] + " Data")
        df = spark.read.parquet(file_data[0])
        checkResult = VerificationSuite(spark) \
            .onData(df) \
            .addCheck(
                check.isComplete(file_data[1])  \
                .isUnique(file_data[1])  \
                .isNonNegative(file_data[1])) \
            .run()

        checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
        checkResult_df.show()
