In [None]:
import os
import toml
import requests
from datetime import datetime
import pandas as pd
import warnings
from requests.exceptions import SSLError
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from dotenv import load_dotenv
from pyspark.sql import SparkSession

# Suppress unverified HTTPS warnings, but keep others
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
warnings.filterwarnings("ignore", ".*Unverified HTTPS.*")

# Load environment variables from .env file
load_dotenv()


class APICClient:
    def __init__(self, config_path):
        self.config = self.load_config(config_path)
        self.username, self.password = self.load_api_credentials()

    @staticmethod
    def load_config(config_path):
        return toml.load(config_path)

    @staticmethod
    def load_api_credentials():
        username = os.getenv("API_USERNAME")
        password = os.getenv("API_PASSWORD")
        return username, password

    def login(self, environment):
        api_base_url = self.config["api"][environment]
        login_endpoint = self.config["endpoint"]["login"]
        login_url = f"{api_base_url}{login_endpoint}"

        session = requests.Session()

        try:
            response = session.post(
                login_url,
                json={
                    "aaaUser": {
                        "attributes": {"name": self.username, "pwd": self.password}
                    }
                },
                verify=False,
            )
        except (Exception, SSLError) as e:
            return None, f"Login error: {e}"

        if response.status_code != 200:
            return (
                None,
                f"Failed to authenticate with the APIC. Status code: {response.status_code}",
            )

        login_data = response.json()
        token = login_data["imdata"][0]["aaaLogin"]["attributes"]["token"]
        session.cookies.set("APIC-cookie", token)

        return session, {"APIC-cookie": token}

    def fetch_data(self, session, cookie, environment, endpoint_key, **url_params):
        if session is None:
            return []

        api_base_url = self.config["api"][environment]
        api_endpoint = self.config["endpoint"][endpoint_key]
        url = f"{api_base_url}{api_endpoint.format(**url_params)}"

        response = session.get(url, cookies=cookie, verify=False)
        response_json = response.json()

        return response_json.get("imdata", [])


class DataCollector:
    def __init__(self, client):
        self.client = client

    def get_active_nodes(self, session, cookie, environment, pod):
        nodes_json = self.client.fetch_data(
            session, cookie, environment, "node", pod=pod
        )

        node_list = []
        node_name_dict = {}

        for node_item in nodes_json:
            fabric_node = node_item.get("fabricNode")
            if fabric_node and fabric_node["attributes"]["adSt"] == "on":
                leaf_dn = fabric_node["attributes"]["id"]
                node_name = fabric_node["attributes"]["name"]
                node_name_dict[leaf_dn] = node_name
                node_list.append(leaf_dn)

        node_list.sort()
        return node_list, node_name_dict

    def get_memory(self, session, cookie, environment, pod, leaf_dn):
        memory_data = self.client.fetch_data(
            session, cookie, environment, "memory", pod=pod, leaf_dn=leaf_dn
        )

        try:
            attributes = (
                memory_data[0].get("procSysMemHist5min", {}).get("attributes", {})
            )
            used_avg = int(attributes.get("usedAvg", 0))
            total_avg = int(attributes.get("totalAvg", 0))
            if total_avg == 0:
                total_capacity_percentage = 0
            else:
                total_capacity_percentage = (used_avg / total_avg) * 100
        except KeyError:
            print(
                f"Key 'procSysMemHist5min' not found in memory data for pod {pod}, leaf_dn {leaf_dn}"
            )
            used_avg, total_avg, total_capacity_percentage = 0, 0, 0

        return {
            "used_avg": used_avg,
            "total_avg": total_avg,
            "total_capacity_percentage": total_capacity_percentage,
        }

    def get_cpu(self, session, cookie, environment, pod, leaf_dn):
        cpu_data = self.client.fetch_data(
            session, cookie, environment, "cpus", pod=pod, leaf_dn=leaf_dn
        )

        if not cpu_data:
            return {"cpu_kernel": "", "cpu_user": ""}

        try:
            attributes = cpu_data[0].get("procSysCPUHist5min", {}).get("attributes", {})
            if not attributes:
                print(f"No CPU data available for pod {pod}, leaf_dn {leaf_dn}")
                return {"cpu_kernel": "", "cpu_user": ""}

            return {
                "cpu_kernel": attributes.get("kernelAvg", ""),
                "cpu_user": attributes.get("userAvg", ""),
            }
        except (requests.exceptions.RequestException, KeyError) as e:
            print(f"Error retrieving CPU data for pod {pod}, leaf_dn {leaf_dn}: {e}")
            return {"cpu_kernel": "", "cpu_user": ""}

    def get_tcam(self, session, cookie, environment, pod, leaf_dn):
        tcam_data = self.client.fetch_data(
            session, cookie, environment, "tcam", pod=pod, leaf_dn=leaf_dn
        )

        if not tcam_data:
            return {
                "tcam_current_util": "",
                "tcam_max_capacity": "",
                "total_tcam_percentage": "",
            }

        try:
            attributes = (
                tcam_data[0].get("eqptcapacityPolUsage5min", {}).get("attributes", {})
            )
            if not attributes:
                print(f"No TCAM data available for pod {pod}, leaf_dn {leaf_dn}")
                return {
                    "tcam_current_util": "",
                    "tcam_max_capacity": "",
                    "total_tcam_percentage": "",
                }

            tcam_usage = int(attributes["polUsageCum"]) - int(
                attributes["polUsageBase"]
            )
            pol_usage_cap_cum = int(attributes["polUsageCapCum"])

            if pol_usage_cap_cum > 0:
                total_capacity_percentage = (tcam_usage / pol_usage_cap_cum) * 100
            else:
                total_capacity_percentage = 0

            return {
                "tcam_current_util": attributes["polUsageCum"],
                "tcam_max_capacity": pol_usage_cap_cum,
                "total_tcam_percentage": total_capacity_percentage,
            }
        except (requests.exceptions.RequestException, KeyError) as e:
            print(f"Error retrieving TCAM data for pod {pod}, leaf_dn {leaf_dn}: {e}")
            return {
                "tcam_current_util": "",
                "tcam_max_capacity": "",
                "total_tcam_percentage": "",
            }

    def get_lpm(self, session, cookie, environment, pod, leaf_dn):
        lpm_data = self.client.fetch_data(
            session,
            cookie,
            environment,
            "lpm",
            pod=pod,
            leaf_dn=leaf_dn,
        )

        # check if lpm_data is empty
        if not lpm_data:
            # set the values in data_dict to blank strings
            data_dict = {
                "lpm_current_util": "",
                "lpm_max_capacity": "",
            }
        else:
            try:
                # extract the attributes object
                attributes = lpm_data[0]["eqptcapacityPrefixEntries5min"]["attributes"]
                lpm_current_util = int(attributes["extNormalizedAvg"])
                lpm_max_capacity = int(attributes["extNormalizedMax"])

                # extract the fields of interest and store them in a dictionary
                data_dict = {
                    "lpm_current_util": lpm_current_util,
                    "lpm_max_capacity": lpm_max_capacity,
                }
            except requests.exceptions.RequestException as e:
                # set the values in data_dict to blank strings
                data_dict = {
                    "lpm_current_util": "",
                    "lpm_max_capacity": "",
                }

        return data_dict


def get_switch_data():
    client = APICClient("api.toml")
    collector = DataCollector(client)

    environments = ["lab", "prod"]
    pods = {1: "19thST", 2: "ESF"}

    login_errors = []
    current_datetime = datetime.now()
    apic_data = []

    for environment in environments:
        login_error = False
        error_msg = ""
        session = None

        try:
            session, cookie = client.login(environment)
        except Exception as e:
            print(f"Login error: {e}")
            login_error = True
            error_msg = str(e)

        if not session:
            login_error = True
            error_msg = cookie

        login_errors.append((environment, login_error, error_msg))

        if login_error:
            continue

        for pod, location in pods.items():
            active_nodes, node_name_dict = collector.get_active_nodes(
                session, cookie, environment, pod
            )

            for leaf_dn in active_nodes:
                node_name = node_name_dict[leaf_dn]

                tcam_dict = collector.get_tcam(
                    session, cookie, environment, pod, leaf_dn
                )
                cpu_dict = collector.get_cpu(session, cookie, environment, pod, leaf_dn)
                memory_dict = collector.get_memory(
                    session, cookie, environment, pod, leaf_dn
                )
                lpm_dict = collector.get_lpm(session, cookie, environment, pod, leaf_dn)

                row_dict = {
                    "date_time": current_datetime,
                    "environment": environment,
                    "location": location,
                    "node_name": node_name,
                    **tcam_dict,
                    **cpu_dict,
                    **memory_dict,
                    **lpm_dict,
                }
                apic_data.append(row_dict)

        if session:
            session.close()

    df = pd.DataFrame(apic_data)

    # Ensure consistent data types for all numeric columns
    numeric_columns = [
        "used_avg",
        "total_avg",
        "total_capacity_percentage",
        "tcam_current_util",
        "tcam_max_capacity",
        "total_tcam_percentage",
        "cpu_kernel",
        "cpu_user",
        "lpm_current_util",
        "lpm_max_capacity",
    ]
    for column in numeric_columns:
        df[column] = pd.to_numeric(df[column], errors="coerce").astype(float)

    print(df)

    # Convert Pandas DataFrame to Spark DataFrame without Arrow optimization
    spark = SparkSession.builder.getOrCreate()
    spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
    spark_df = spark.createDataFrame(df)
    spark.conf.set(
        "spark.sql.execution.arrow.pyspark.enabled", "true"
    )  # Revert to the default setting

    # Save the Spark DataFrame to a Parquet file with overwrite mode
    output_parquet_path = "/mnt/edl/raw/iit_apic_raw/TCAM.parquet"
    spark_df.write.mode("overwrite").parquet(output_parquet_path)

    if login_errors:
        return df, login_errors
    else:
        return df, None


if __name__ == "__main__":
    get_switch_data()