In [None]:
from dataclasses import dataclass
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import col, explode, lit
from pyspark.sql.types import *


@dataclass
class ServerInfo:
    name: str
    base_url: str


@dataclass
class JsonDataLoader:
    spark: SparkSession
    dbutils: object
    read_dirs: list[str]
    target: str

    def __init__(self, spark, dbutils, read_dirs, target):
        self.spark = spark
        self.dbutils = dbutils
        self.read_dirs = read_dirs
        self.target = target

    def read_json(self, file_path: str) -> DataFrame:
        return self.spark.read.json(file_path, multiLine=True)

    def normalize_schema(self, df: DataFrame) -> DataFrame:
        # Ensure the DataFrame has all the required columns with the same schema
        columns = [
            "address",
            "dn",
            "fabricSt",
            "id",
            "lastStateModTs",
            "modTs",
            "model",
            "name",
            "role",
            "serial",
            "vendor",
            "version",
            "device_type",
            "descr",
            "hwVer",
            "macB",
            "rdSt",
        ]
        for column in columns:
            if column not in df.columns:
                df = df.withColumn(column, lit(None).cast(StringType()))
        return df.select(columns)

    def map_columns(self, df: DataFrame) -> DataFrame:
        # Map columns to ensure consistency
        column_mapping = {
            "ser": "serial",
            "address": "address",
            "dn": "dn",
            "fabricSt": "fabricSt",
            "id": "id",
            "lastStateModTs": "lastStateModTs",
            "modTs": "modTs",
            "model": "model",
            "name": "name",
            "role": "role",
            "serial": "serial",
            "vendor": "vendor",
            "version": "version",
            "device_type": "device_type",
            "descr": "descr",
            "hwVer": "hwVer",
            "macB": "macB",
            "rdSt": "rdSt",
        }
        for old_col, new_col in column_mapping.items():
            if old_col in df.columns:
                df = df.withColumnRenamed(old_col, new_col)
        return df

    def extract_chassis_attributes(self, df: DataFrame, path: str) -> DataFrame:
        exploded_df = df.withColumn("element", explode(col(path)))
        attributes_df = exploded_df.select(col("element.fabricNode.attributes.*"))
        attributes_df = attributes_df.withColumn("device_type", lit("chassis"))
        return self.map_columns(attributes_df)

    def extract_linecard_attributes(self, df: DataFrame, path: str) -> DataFrame:
        exploded_df = df.withColumn("leaf_element", explode(col(path)))
        exploded_df = exploded_df.withColumn(
            "lc_element", explode(col("leaf_element.linecard.imdata"))
        )
        attributes_df = exploded_df.select(col("lc_element.eqptLC.attributes.*"))
        attributes_df = attributes_df.withColumn("device_type", lit("linecard"))
        return self.map_columns(attributes_df)

    def process_json_file(self, file_path: str):
        df = self.read_json(file_path)
        environments = ["lab", "prod"]
        final_df = None

        for env in environments:
            for pod in ["1", "2"]:
                pod_path = f"environments.{env}.pods.{pod}.nodes.imdata"
                try:
                    exploded_df = self.extract_chassis_attributes(df, pod_path)
                    exploded_df = self.normalize_schema(exploded_df)
                    final_df = (
                        exploded_df
                        if final_df is None
                        else final_df.unionByName(exploded_df)
                    )
                except Exception as e:
                    print(f"Error processing path {pod_path}: {e}")

        for env in environments:
            for pod in ["1", "2"]:
                leaf_nodes_path = f"environments.{env}.pods.{pod}.leaf_nodes_data"
                try:
                    leaf_exploded_df = self.extract_linecard_attributes(
                        df, leaf_nodes_path
                    )
                    leaf_exploded_df = self.normalize_schema(leaf_exploded_df)
                    final_df = (
                        leaf_exploded_df
                        if final_df is None
                        else final_df.unionByName(leaf_exploded_df)
                    )
                except Exception as e:
                    print(f"Error processing path {leaf_nodes_path}: {e}")

        if final_df:
            final_df.select(
                col("address").alias("Address"),
                col("dn").alias("DN"),
                col("fabricSt").alias("FabricSt"),
                col("id").alias("ID"),
                col("lastStateModTs").alias("LastStateModTs"),
                col("modTs").alias("ModTs"),
                col("model").alias("Model"),
                col("name").alias("Name"),
                col("role").alias("Role"),
                col("serial").alias("Serial"),
                col("vendor").alias("Vendor"),
                col("version").alias("Version"),
                col("device_type"),
                col("descr").alias("Descr"),
                col("hwVer").alias("HwVer"),
                col("macB").alias("MacB"),
                col("rdSt").alias("RdSt"),
            ).write.mode("overwrite").parquet(self.target)

        self.dbutils.fs.mv(file_path, file_path.replace("/unprocessed", "/processed"))

    def load_files(self) -> None:
        for read_dir in self.read_dirs:
            files_raw = self.dbutils.fs.ls(f"{read_dir}/unprocessed")
            files = [file.path for file in files_raw if file.name.endswith(".json")]
            for file_path in files:
                self.process_json_file(file_path)


# Define common attribute structure
attributes_schema = StructType(
    [
        StructField("address", StringType(), True),
        StructField("dn", StringType(), True),
        StructField("fabricSt", StringType(), True),
        StructField("id", StringType(), True),
        StructField("lastStateModTs", StringType(), True),
        StructField("modTs", StringType(), True),
        StructField("model", StringType(), True),
        StructField("name", StringType(), True),
        StructField("role", StringType(), True),
        StructField("serial", StringType(), True),
        StructField("vendor", StringType(), True),
        StructField("version", StringType(), True),
        StructField("descr", StringType(), True),
        StructField("hwVer", StringType(), True),
        StructField("macB", StringType(), True),
        StructField("rdSt", StringType(), True),
    ]
)

# Define the fabricNode schema
fabric_node_schema = StructType(
    [
        StructField(
            "fabricNode", StructType([StructField("attributes", attributes_schema)])
        )
    ]
)

# Define the nodes schema
nodes_schema = StructType(
    [
        StructField("totalCount", StringType(), True),
        StructField("imdata", ArrayType(fabric_node_schema)),
    ]
)

# Define the leaf_nodes_data schema
leaf_nodes_data_schema = StructType(
    [
        StructField("leaf_dn", StringType(), True),
        StructField("role", StringType(), True),
        StructField(
            "linecard",
            StructType(
                [
                    StructField("totalCount", StringType(), True),
                    StructField(
                        "imdata",
                        ArrayType(
                            StructType(
                                [
                                    StructField(
                                        "eqptLC",
                                        StructType(
                                            [
                                                StructField(
                                                    "attributes", attributes_schema
                                                )
                                            ]
                                        ),
                                    )
                                ]
                            )
                        ),
                    ),
                ]
            ),
        ),
    ]
)


# Define a function to create the pods schema
def create_pods_schema() -> StructType:
    return StructType(
        [
            StructField("1", nodes_schema),
            StructField("2", nodes_schema),
            StructField("leaf_nodes_data", ArrayType(leaf_nodes_data_schema)),
        ]
    )


# Define the main schema for environments
apic_schemas = StructType(
    [
        StructField(
            "environments",
            StructType(
                [
                    StructField(
                        "lab", StructType([StructField("pods", create_pods_schema())])
                    ),
                    StructField(
                        "prod", StructType([StructField("pods", create_pods_schema())])
                    ),
                ]
            ),
        )
    ]
)

In [None]:
# Instantiate and use the JsonDataLoader
json_data_loader = JsonDataLoader(
    spark,
    dbutils,
    ["/mnt/edl/raw/iit_apic_raw"],
    "/mnt/edl/raw/iit_apic_raw/apic_data.parquet",
)
json_data_loader.load_files()