In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType
import hashlib, base64


class PIIEncryptor:
    def __init__(self, key=""):
        self.key = key

    def _encrypt_value(self, v):
        if v is None:
            return None

        v_str = str(v)  # convert ANY type (int, float, bool) into string

        hashed = hashlib.sha256(f"{v_str}:{self.key}".encode()).hexdigest()
        raw = f"{len(v_str)}:{hashed}:{v_str}"
        return base64.b64encode(raw.encode()).decode()

    def _decrypt_value(self, v):
        if v is None:
            return None
        try:
            v_str = str(v)
            decoded = base64.b64decode(v_str.encode()).decode()
            parts = decoded.split(":", 2)
            return parts[2] if len(parts) == 3 else None
        except:
            return None

    def encrypt_dataframe(self, df: DataFrame, columns: list[str]) -> DataFrame:
        encrypt = udf(self._encrypt_value, StringType())
        out = df
        for c in columns:
            if c in df.columns:
                out = out.withColumn(c, encrypt(col(c)))
        return out

    def decrypt_dataframe(self, df: DataFrame, columns: list[str]) -> DataFrame:
        decrypt = udf(self._decrypt_value, StringType())
        out = df
        for c in columns:
            if c in df.columns:
                out = out.withColumn(c, decrypt(col(c)))
        return out