In [None]:
import pandas as pd # Importing pandas package

# Set the maximum number of columns to display to None
pd.set_option('display.max_columns', None)

import numpy as np # Importing numpy package

from typing import Dict, Tuple, List, Union # Importing specific types from typing module

import re # Importing regular expression package

from src.database_manager import DatabricksOdbcConnector # Importing DatabricksOdbcConnector class from database_manager module
from src.utils import reorder_columns # Importing reorder_columns function from utils module

from scipy.spatial.distance import cdist # Importing cdist function from scipy package

import time

import pyproj # Importing pyproj package

from custom_logger import CustomLogger # Importing CustomLogger class from custom

In [None]:
class SpacingIKPairs:
    """
    Class for identifying spacing IK pairs in a given dataset.
    """

    def __init__(self, db: DatabricksOdbcConnector, header_df: pd.DataFrame , log_dir: str = "./logs"):
        """
        Initializes the SpacingIKPairs class with a database connection and table name.

        Args:
            db (DatabricksOdbcConnector): Database connection object.
            table_name (str): Name of the table to be processed.
        """
        self.header_df = header_df # Header DataFrame
        self.logger = CustomLogger("spacing_ik_pairs", "SpacingIKLogger", log_dir).get_logger() # Custom logger
        self.db = db # Database connection

        self.logger.info(f"SpacingIKPairs instance initialized.")

    def check_required_columns(self) -> bool:
        """
        Checks if the required columns are present in the header DataFrame.

        Returns:
            bool: True if all required columns are present, False otherwise.
        """
        required_columns = ["chosen_id", "lease_name", "well_name", "rsv_cat", "bench", "first_prod_date", "hole_direction"]
        missing_columns = [col for col in required_columns if col not in self.header_df.columns]

        if missing_columns:
            self.logger.warning(f"Missing columns: {missing_columns}")
            return False

        return True

    def get_directional_survey_data(self) -> pd.DataFrame:
        """
        Retrieves directional data from the databricks.

        Returns:
            pd.DataFrame: DataFrame containing the directional data from databricks.
        """
        # Get the unique chosen_ids for horizontal wells only
        chosen_ids = ", ".join(f"'{id}'" for id in self.header_df[self.header_df['hole_direction']=='H']['chosen_id'].unique())

        try:
            self.db.connect()

            query = f"""
            SELECT
                LEFT(uwi, 10) AS chosen_id, 
                station_md_uscust AS md, 
                station_tvd_uscust AS tvd,
                inclination, 
                azimuth, 
                latitude, 
                longitude, 
                x_offset_uscust AS `deviation_E/W`,
                ew_direction,
                y_offset_uscust AS `deviation_N/S`,
                ns_direction,
                point_type
                
            FROM ihs_sp.well.well_directional_survey_station
            WHERE LEFT(uwi, 10) IN ({chosen_ids})
            order by uwi, md;
            """

            return self.db.execute_query(query)

        except Exception as e:
            self.logger.error(f"Error retrieving directional data from databricks: {e}")
        finally:
            self.db.close()

    def determine_utm_zone(self, longitude: float) -> int:
        """
        Determines the UTM zone based on a given longitude.
        """
        return int((longitude + 180) / 6) + 1
    
    def batch_latlon_to_utm(self, lat: np.ndarray, lon: np.ndarray, utm_zone: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Converts arrays of latitudes and longitudes to UTM coordinates in meters for a given UTM zone.
        """
        proj_utm = pyproj.Transformer.from_crs(
            "EPSG:4326", f"EPSG:326{utm_zone}", always_xy=True
        )
        
        return proj_utm.transform(lon, lat)
    
    def compute_mean_elevation(self,df: pd.DataFrame) -> pd.DataFrame:
        """
        Computes the mean elevation (mean z value) for each ChosenID.

        Parameters:
        - df (pd.DataFrame): DataFrame containing 'ChosenID' and 'z' columns.

        Returns:
        - pd.DataFrame: DataFrame with 'ChosenID' and corresponding mean 'z' values.
        """
        mean_z_df = df.groupby("ChosenID", as_index=False)["z"].mean()
        mean_z_df.rename(columns={"z": "elevation"}, inplace=True)
        return mean_z_df

    def compute_utm_coordinates(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Computes UTM (x, y, z) coordinates for multiple wells, using surface location to determine UTM zones.
        Converts UTM coordinates from meters to feet. Uses vectorized batch processing for performance.

        Parameters:
        - df (pd.DataFrame): Original directional survey DataFrame.

        Returns:
        - pd.DataFrame: DataFrame with all original columns + x, y, z (in feet), and utm_zone.
        """
        start_time = time.time()  # Start timing

        # Step 1: Sort dataframe by md to identify surface location
        df = df.sort_values(by=["chosen_id", "md"], ascending=[True, True])
        
        # Step 2: Determine UTM zones using the surface location (first row per well)
        surface_locs = df.groupby("chosen_id").first()[["latitude", "longitude"]]
        surface_locs["utm_zone"] = surface_locs["longitude"].apply(self.determine_utm_zone)

        # Merge UTM zones back into the original dataframe
        df = df.merge(surface_locs[["utm_zone"]], on="chosen_id", how="left")

        self.logger.info(f"✅ Determined UTM zones in {time.time() - start_time:.4f} seconds.")

        # Step 3: Batch transformation for each unique UTM zone
        start_transform_time = time.time()
        unique_zones = df["utm_zone"].unique()
        utm_converters: Dict[int, Tuple[np.ndarray, np.ndarray]] = {}

        for zone in unique_zones:
            subset = df[df["utm_zone"] == zone]
            easting, northing = self.batch_latlon_to_utm(subset["latitude"].values, subset["longitude"].values, zone)
            utm_converters[zone] = (easting, northing)

        self.logger.info(f"✅ Performed batch EPSG transformations in {time.time() - start_transform_time:.4f} seconds.")

        # Step 4: Assign the converted coordinates back to the DataFrame
        start_assign_time = time.time()
        df["x"], df["y"] = np.zeros(len(df)), np.zeros(len(df))

        for zone in unique_zones:
            mask = df["utm_zone"] == zone
            df.loc[mask, "x"], df.loc[mask, "y"] = utm_converters[zone]

        self.logger.info(f"✅ Assigned transformed coordinates in {time.time() - start_assign_time:.4f} seconds.")

        # Step 5: Convert UTM coordinates from meters to feet (Conversion factor: 1 meter = 3.28084 feet)
        df["x"] *= 3.28084
        df["y"] *= 3.28084
        
        df["z"] = -df["tvd"] # Elevation is negative TVD

        self.logger.info(f"✅ Total execution time: {time.time() - start_time:.4f} seconds.")

        return df
    
    def filter_after_heel_point(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Filters the dataframe to include all rows for each chosen_id where the first occurrence 
        of either '80' or 'heel' appears in the point_type column and all subsequent rows.

        Parameters:
        df (pd.DataFrame): A dataframe containing directional survey data with a 'chosen_id' column and 'point_type' column.

        Returns:
        pd.DataFrame: Filtered dataframe containing rows from the first occurrence of '80' or 'heel' onward.
        """

        # Convert 'point_type' to lowercase and check for '80' or 'heel'
        mask = df['point_type'].str.lower().str.contains(r'80|heel', regex=True, na=False)

        # Identify the first occurrence for each chosen_id
        idx_start = df[mask].groupby('chosen_id', sort=False).head(1).index

        # Create a mapping of chosen_id to the starting index
        start_idx_map = dict(zip(df.loc[idx_start, 'chosen_id'], idx_start))

        # Create a boolean mask using NumPy to filter rows
        chosen_ids = df['chosen_id'].values
        indices = np.arange(len(df))

        # Get the minimum start index for each row's chosen_id
        start_indices = np.vectorize(start_idx_map.get, otypes=[float])(chosen_ids)

        # Mask rows where index is greater than or equal to the start index
        valid_rows = indices >= start_indices

        return df[valid_rows].reset_index(drop=True)
    
    def extract_heel_toe_mid_lat_lon(self, well_trajectory: pd.DataFrame) -> pd.DataFrame:
        """
        Extract the heel, toe, and mid-point latitude/longitude for each chosen_id in the well trajectory DataFrame.

        Parameters:
        well_trajectory: pd.DataFrame
            DataFrame containing well trajectory data, including 'chosen_id', 'md', 'latitude', and 'longitude'.

        Returns:
        pd.DataFrame
            A DataFrame with 'chosen_id', 'Heel_Lat', 'Heel_Lon', 'Toe_Lat', 'Toe_Lon', 'Mid_Lat', 'Mid_Lon'.

        Example:
        >>> data = {
        ...     "chosen_id": [1001, 1001, 1001, 1002, 1002],
        ...     "md": [5000, 5100, 5200, 6000, 6100],
        ...     "latitude": [31.388, 31.389, 31.387, 31.400, 31.401],
        ...     "longitude": [-103.314, -103.315, -103.316, -103.318, -103.319]
        ... }
        >>> df = pd.DataFrame(data)
        >>> extract_heel_toe_mid_lat_lon(df)
        chosen_id  Heel_Lat  Heel_Lon  Toe_Lat  Toe_Lon  Mid_Lat  Mid_Lon
        0     1001    31.388  -103.314   31.387  -103.316  31.3875 -103.315
        1     1002    31.400  -103.318   31.401  -103.319  31.4005 -103.3185
        """
        # Ensure the data is sorted by MD in ascending order
        well_trajectory = well_trajectory.sort_values(by=["chosen_id", "md"], ascending=True)

        # Group by 'chosen_id' and extract heel/toe lat/lon
        heel_toe_df = (
            well_trajectory.groupby("chosen_id")
            .agg(
                heel_lat=("latitude", "first"),
                heel_lon=("longitude", "first"),
                toe_lat=("latitude", "last"),
                toe_lon=("longitude", "last"),
            )
            .reset_index()
        )

        # Calculate midpoints
        heel_toe_df["mid_Lat"] = (heel_toe_df["heel_lat"] + heel_toe_df["toe_lat"]) / 2
        heel_toe_df["mid_Lon"] = (heel_toe_df["heel_lon"] + heel_toe_df["toe_lon"]) / 2

        return heel_toe_df
    
    def get_direction(self, lat1: np.ndarray, lon1: np.ndarray, lat2: np.ndarray, lon2: np.ndarray) -> np.ndarray:
        """
        Determine the relative direction of (lat2, lon2) with respect to (lat1, lon1).
        
        Parameters:
        lat1, lon1: np.ndarray
            Latitude and longitude of the first well.
        lat2, lon2: np.ndarray
            Latitude and longitude of the second well.
        
        Returns:
        np.ndarray
            Array indicating the direction (e.g., North, South, East, West) of well B relative to well A.
        """
        lat_diff = lat2 - lat1
        lon_diff = lon2 - lon1

        conditions = [
            np.abs(lat_diff) > np.abs(lon_diff),
            lat_diff > 0,
            lon_diff > 0
        ]

        choices = ["N", "S", "E", "W"]
        
        return np.select(
            [conditions[0] & conditions[1], conditions[0] & ~conditions[1], ~conditions[0] & conditions[2], ~conditions[0] & ~conditions[2]],
            choices
        )
    
    def calculate_drill_direction_vectorized(self, well_trajectories: Dict[str, pd.DataFrame], i_indices: np.ndarray) -> np.ndarray:
        """
        Vectorized function to determine the drilling direction of multiple wells using NumPy operations.
        
        Parameters:
        well_trajectories: Dict[str, pd.DataFrame]
            Dictionary containing well trajectory data indexed by ChosenID.
        i_indices: np.ndarray
            Array of ChosenID whose drill directions need to be calculated.
        
        Returns:
        np.ndarray
            Array containing "EW" (East-West) or "NS" (North-South) for each well.
        """
        azimuth_values = np.array([well_trajectories[i]["azimuth"].median() if not well_trajectories[i].empty else np.nan for i in i_indices])
        
        conditions = (45 <= azimuth_values) & (azimuth_values < 135) | (225 <= azimuth_values) & (azimuth_values < 315)
        drill_directions = np.where(np.isnan(azimuth_values), "Unknown", np.where(conditions, "EW", "NS"))
        
        return drill_directions
    
    def calculate_3D_distance_matrix(self,
        trajectories: Dict[str, pd.DataFrame], i_indices: np.ndarray, k_indices: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Fully vectorized 3D distance calculations for well pairs using NumPy and Pandas.
        
        Parameters:
        trajectories: Dict[str, pd.DataFrame]
            Dictionary containing well trajectory data indexed by well ID.
        i_indices: np.ndarray
            Array of well IDs representing the first well in each pair.
        k_indices: np.ndarray
            Array of well IDs representing the second well in each pair.
        
        Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]
            - Horizontal distances between the well pairs.
            - Vertical distances between the well pairs.
            - 3D distances between the well pairs.
        """
        # 🚀 Precompute mean (midpoint) for each well ID across all wells at once
        all_trajectories_df = pd.concat(trajectories.values(), keys=trajectories.keys()).reset_index(drop=True)

        midpoints_df = all_trajectories_df.groupby("chosen_id")[["x", "y", "tvd"]].mean()

        # Convert to NumPy arrays for fast lookup
        well_ids = midpoints_df.index.to_numpy()
        midpoints = midpoints_df.to_numpy()

        # Create a mapping from well ID to its index
        well_id_to_idx = {well_id: idx for idx, well_id in enumerate(well_ids)}

        # Efficiently extract midpoints using NumPy indexing
        mid_A = midpoints[np.array([well_id_to_idx[i] for i in i_indices])]
        mid_B = midpoints[np.array([well_id_to_idx[k] for k in k_indices])]

        # Compute distances
        vertical_distances = np.abs(mid_A[:, 2] - mid_B[:, 2])
        mid_B[:, 2] = mid_A[:, 2]  # Align Well B to Well A’s TVD

        horizontal_distances = np.linalg.norm(mid_A[:, :2] - mid_B[:, :2], axis=1)
        total_3D_distances = np.sqrt(horizontal_distances**2 + vertical_distances**2)

        return horizontal_distances, vertical_distances, total_3D_distances
    
    def create_i_k_pairs(self, df: pd.DataFrame, trajectories: Union[Dict[str, pd.DataFrame], pd.DataFrame]) -> pd.DataFrame:
        """
        Generate the i_k_pairs DataFrame, computing horizontal and vertical distances, 
        3D distances, drilling directions, and relative directions between well pairs.
        
        Parameters:
        df: pd.DataFrame
            DataFrame containing well metadata with:
            - "chosen_id" (str): Unique well identifier.

        trajectories: Union[Dict[str, pd.DataFrame], pd.DataFrame]
            Either:
            - A dictionary mapping well IDs ("chosen_id") to trajectory DataFrames.
            - A single DataFrame containing all trajectory data (must have "chosen_id" column).
            
        Each trajectory DataFrame should include:
        - "md" (float): Measured depth.
        - "tvd" (float): True vertical depth.
        - "inclination" (float): Inclination angle in degrees.
        - "azimuth" (float): represents the drilling direction.
        - "latitude" (float): Latitude values, define the geographical position.
        - "longitude" (float): Longitude values, define the geographical position.
        - "x" (float): X-coordinate in a Cartesian coordinate system.
        - "y" (float): Y-coordinate in a Cartesian coordinate system.
        - "z" (float): Z-coordinate in a Cartesian coordinate system (elevation).
        
        Returns:
        pd.DataFrame
            DataFrame containing pairs of wells (`i_uwi`, `k_uwi`) with their computed distances 
            and directional relationships.
        """
        start_time = time.time()
        
        # Convert to dictionary if input is a DataFrame
        step1_start = time.time()
        if isinstance(trajectories, pd.DataFrame):
            if "chosen_id" not in trajectories.columns:
                raise ValueError("🚨 Error: Trajectory DataFrame must contain a 'chosen_id' column.")
            trajectories = {cid: group for cid, group in trajectories.groupby("chosen_id")}
        step1_end = time.time()
        self.logger.info(f"✅ Step 1: Converted trajectory DataFrame to dictionary in {step1_end - step1_start:.4f} seconds.")

        # Get unique chosen_id from df
        step2_start = time.time()
        chosen_ids = df["chosen_id"].unique()
        missing_ids = [cid for cid in chosen_ids if cid not in trajectories]

        if missing_ids:
            self.logger.info(f"⚠️ The following chosen_id do not exist in the trajectory data and will be excluded: {missing_ids}")

        df = df[df["chosen_id"].isin(trajectories)] # Filter out missing IDs in the DataFrame
        chosen_ids = df["chosen_id"].unique() # Update chosen_ids without missing IDs
        step2_end = time.time()
        self.logger.info(f"✅ Step 2: Extracted unique chosen_id in {step2_end - step2_start:.4f} seconds.")

        # Generate all possible pairs (excluding self-comparison)
        step3_start = time.time()
        i_uwi, k_uwi = np.meshgrid(chosen_ids, chosen_ids, indexing='ij')
        i_uwi, k_uwi = i_uwi.ravel(), k_uwi.ravel()

        # Remove self-comparisons
        valid_mask = i_uwi != k_uwi
        i_uwi, k_uwi = i_uwi[valid_mask], k_uwi[valid_mask]
        step3_end = time.time()
        self.logger.info(f"✅ Step 3: Generated well pairs in {step3_end - step3_start:.4f} seconds.")

        # 🚀 Optimized Heel/Toe Extraction (Vectorized)
        step4_start = time.time()
        heel_toe_df = pd.concat(
            [self.extract_heel_toe_mid_lat_lon(trajectories[cid]) for cid in chosen_ids], ignore_index=True
        )
        heel_toe_dict = heel_toe_df.set_index("chosen_id").to_dict(orient="index")
        step4_end = time.time()
        self.logger.info(f"✅ Step 4: Heel/Toe extraction took {step4_end - step4_start:.4f} seconds.")

        # Efficiently extract values using vectorized lookups
        step5_start = time.time()
        mid_lat_i = np.array([heel_toe_dict[i]["mid_Lat"] for i in i_uwi])
        mid_lon_i = np.array([heel_toe_dict[i]["mid_Lon"] for i in i_uwi])
        mid_lat_k = np.array([heel_toe_dict[k]["mid_Lat"] for k in k_uwi])
        mid_lon_k = np.array([heel_toe_dict[k]["mid_Lon"] for k in k_uwi])
        step5_end = time.time()
        self.logger.info(f"✅ Step 5: Heel/Toe dictionary lookup took {step5_end - step5_start:.4f} seconds.")

        # 🚀 Optimized Distance Calculation (Fully Vectorized)
        step6_start = time.time()
        horizontal_dist, vertical_dist, total_3D_dist = self.calculate_3D_distance_matrix(trajectories, i_uwi, k_uwi)
        step6_end = time.time()
        self.logger.info(f"✅ Step 6: Distance calculations took {step6_end - step6_start:.4f} seconds.")

        # Compute drill directions
        step7_start = time.time()
        drill_directions = self.calculate_drill_direction_vectorized(trajectories, i_uwi)
        step7_end = time.time()
        self.logger.info(f"✅ Step 7: Drill direction calculation took {step7_end - step7_start:.4f} seconds.")

        # Determine directional relationship
        step8_start = time.time()
        ward_of_i = self.get_direction(mid_lat_i, mid_lon_i, mid_lat_k, mid_lon_k)
        step8_end = time.time()
        self.logger.info(f"✅ Step 8: Directional relationship calculation took {step8_end - step8_start:.4f} seconds.")

        # Compute mean elevation
        step9_start = time.time()
        elevation_df = self.compute_mean_elevation(df)
        elevation_dict = elevation_df.set_index("ChosenID")["elevation"].to_dict()

        # Add elevation values to pairs
        elevation_i = np.array([elevation_dict.get(i, np.nan) for i in i_uwi])
        elevation_k = np.array([elevation_dict.get(k, np.nan) for k in k_uwi])
        step9_end = time.time()
        self.logger.info(f"✅ Step 9: Mean elevation calculation took {step9_end - step9_start:.4f} seconds.")

        # Create DataFrame
        step10_start = time.time()
        result_df = pd.DataFrame({
            "i_uwi": i_uwi,
            "k_uwi": k_uwi,
            "horizontal_dist": horizontal_dist,
            "vertical_dist": vertical_dist,
            "3D_ft_dist": total_3D_dist,
            "drill_direction": drill_directions,
            "ward_of_i": ward_of_i,
            "elevation_i": elevation_i,
            "elevation_k": elevation_k
        })
        step10_end = time.time()
        self.logger.info(f"✅ Step 9: Created result DataFrame in {step10_end - step10_start:.4f} seconds.")

        total_time = time.time() - start_time
        self.logger.info(f"🚀 Total Execution Time: {total_time:.4f} seconds.")

        return result_df