In [None]:
%run oeai_py

In [None]:
# Create an instance of OEAI class and set the platform ("Synapse" or "Fabric")
oeai = OEAI()

In [None]:
from pyspark.sql.functions import udf
import itertools
replace_counter = itertools.count(start=1)

In [None]:
# CHANGE VALUES FOR YOUR KEY VAULT
keyvault = "INSERT_KEYVAULT_NAME"  # Fabric requires full URL eg "https://key_vault_name.vault.azure.net/"
keyvault_linked_service = "INSERT_LINKED_SERVICE_NAME"  # Not required for Fabric.

# Synapse OEA environment paths
silver_path = oeai.get_secret(spark, "wonde-silver", keyvault_linked_service, keyvault)
pseudo_path = oeai.get_secret(spark, "pseudo-path", keyvault_linked_service, keyvault)
storage_account_name = oeai.get_secret(spark, "storage-account", keyvault_linked_service, keyvault)
storage_account_access_key = oeai.get_secret(spark, "storage-accesskey", keyvault_linked_service, keyvault)

In [None]:
pseudo_mapping = {
    "dim_Organisation": {
        # pseudoID
        "URN": "pseudoID",
        # mask
        "Establishment_Number": "mask",
        "LA_Code": "mask",
        # replace
        "Organisation_Name": {"string": "Academy "}
        # hash
    },
    "dim_Student": {
        # pseudoID
        "UPN": "pseudoID",
        # mask string
        
        "Forename": "mask",
        "Legal_Forename": "mask",
        "Legal_Surname": "mask",
        "Middle_Names": "mask",
        # replace
        "Surname": {"string": "Student "},
        # hash
        # mask date
        "Date_Of_Birth": "mask_date",
    },
    "dim_StudentExtended": {
        # pseudoID
        # mask
        "First_Language": "mask",
        # replace
        # hash
    },
    "fact_Achievement": {
        # pseudoID
        # mask
        "Comment": "mask",
        # replace
        # hash
    },
    "fact_Behaviour": {
        # pseudoID
        # mask
        "Comment": "mask",
        # replace
        # hash
    },
    "fact_Exclusion": {
        # pseudoID
        # mask
        "Comments": "mask",
        # replace
        # hash
    },
}

In [None]:
# Pseudonymization functions
def mask_function(col_value):
    return "XXXXXX"  # Simple mask

def get_replace_function(prefix):
    def replace_function(col_value):
        
        return f"{prefix}{next(replace_counter)}"
    return replace_function

mask_udf = udf(mask_function, StringType())

fixed_date_str = "2000-01-01"


In [None]:
def pseudonymize_df(df, table_name, config, pseudo_path, spark, replace_counter):
    for col_name, pseudo_type in config.items():
        if pseudo_type == "pseudoID":
            # call save_mapping_for_pseudoID
            df = save_mapping_for_pseudoID(df, table_name, col_name, pseudo_path, spark)
        elif pseudo_type == "mask":
            # Apply fixed mask
            df = df.withColumn(col_name, mask_udf(col_name))
        elif isinstance(pseudo_type, dict) and "string" in pseudo_type:
            # Dynamically create and apply a UDF for replacement with the specific prefix
            prefix = pseudo_type["string"]
            replace_udf = udf(get_replace_function(prefix), StringType())
            df = df.withColumn(col_name, replace_udf(col(col_name)))
        elif pseudo_type == "mask_date":
            # Apply fixed date masking
            df = df.withColumn(col_name, lit(fixed_date_str))
    return df

In [None]:
def get_table_name_from_path(table_path):
    # Extracts the table name from the full path
    return table_path.split('/')[-1]

In [None]:
def save_mapping_for_pseudoID(df, table_name, col_name, gold_path, spark):
    # Generate a new ID column
    df_with_new_id = df.withColumn('new_id', monotonically_increasing_id())
    
    # Create a mapping DataFrame with original values and new pseudo IDs
    mapping_df = df_with_new_id.select(col(col_name).alias('original_value'), 'new_id')
    
    # Define the path for saving the mapping file
    mapping_file_path = f"{pseudo_path}/pseudo_map_{table_name}_{col_name}.parquet"
    
    # Save the mapping DataFrame to a Parquet file
    mapping_df.write.mode('overwrite').parquet(mapping_file_path)
    
    # Replace the original column with the 'new_id' column in the main DataFrame
    df_final = df_with_new_id.drop(col_name).withColumnRenamed('new_id', col_name)
    
    return df_final

In [None]:
def process_delta_tables_to_parquet(spark, storage_account_name, storage_account_access_key, silver_path, pseudo_path):
    """
    Sets up configuration for Azure storage access, lists subdirectories in the silver path, and processes
    each Delta Lake table by converting and saving it in Parquet format in the gold path.

    Args:
        spark (SparkSession): Active Spark session.
        storage_account_name (str): Azure storage account name.
        storage_account_access_key (str): Access key for the Azure storage account.
        silver_path (str): Path to the silver layer directory (source Delta tables).
        gold_path (str): Path to the gold layer directory (destination for Parquet files).

    This function will process each Delta Lake table found in the silver layer, partition the data by 
    'organisationkey', and write it to the gold layer as Parquet files.
    """
    # Set up the configuration for accessing the storage account
    spark.conf.set(f"fs.azure.account.key.{storage_account_name}.dfs.core.windows.net", storage_account_access_key)

    sc = spark.sparkContext
    hadoop_conf = sc._jsc.hadoopConfiguration()
    hadoop_conf.set("fs.azure", "org.apache.hadoop.fs.azure.NativeAzureFileSystem")
    hadoop_conf.set("fs.azure.account.key." + storage_account_name + ".blob.core.windows.net", storage_account_access_key)

    # URI for the parent directory
    parent_dir_uri = sc._gateway.jvm.java.net.URI(silver_path)

    # Hadoop Path of the parent directory
    Path = sc._gateway.jvm.org.apache.hadoop.fs.Path
    FileSystem = sc._gateway.jvm.org.apache.hadoop.fs.FileSystem

    # Get the FileSystem for the given URI and configuration
    fs = FileSystem.get(parent_dir_uri, hadoop_conf)

    # List the subdirectories at the given URI
    status = fs.listStatus(Path(silver_path))
    delta_table_paths = [file.getPath().toString() for file in status if file.isDirectory()]

    for table_path in delta_table_paths:
        try:
            df = spark.read.format("delta").load(table_path)
            table_name = os.path.basename(urlparse(table_path).path)

            # Retrieve pseudonymization config for the table
            config = pseudo_mapping.get(table_name, {})
                            
            # Apply pseudonymization
            if config:  
                df = pseudonymize_df(df, table_name, pseudo_mapping[table_name], pseudo_path, spark, replace_counter)
            else:
                print(f"No pseudonymization configuration found for {table_name}")

            parquet_output_folder_path = os.path.join(pseudo_path, table_name)
            df = df.withColumn("partitionkey", col("organisationkey"))
            df.write.partitionBy("partitionkey").mode("overwrite").format("parquet").save(parquet_output_folder_path)

        except AnalysisException as e:
            print(f"Error reading Delta table at {table_path}: ", e)

In [None]:
# Process Delta tables to Parquet
process_delta_tables_to_parquet(spark, storage_account_name, storage_account_access_key, silver_path, pseudo_path)