In [None]:
import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
import re


In [None]:

class DefineLimits(BaseEstimator, TransformerMixin):
    """Change columns that contain punctual, minimum and maximum values
        into two separate columns that indicate the limits of an observation

        target - dataframes with varying ranges of values within observations
        expected change - dataframe that, per each of its original conflicting
                         columns, it will have two new columns with the maximum
                         and minimum per column per observation.

        when fit: it gathers minimum and maximum registered values of reference
        when transform: adjust values to fit between memorized global minimums and  
                        maximums of each column.
        
     """
    def __init__(self, columns=None, margin=0.0):
        self.columns = columns
        self.margin = margin
        
        # learned limits
        self.col_global_min_ = {}
        self.col_global_max_ = {}

        #booleans
        self.col_has_expanded_min_ = {}
        self.col_has_expanded_max_ = {}

    # ---------------------------------
    #              HELPERS
    # ---------------------------------

    @staticmethod
    def _extract_numeric(x):
        """Extract the numeric portion of strings
        input - single element
        output - numeric part of string as float, or input. OR nan if it empty
        
        """
        if pd.isna(x):
            return np.nan
        return float(re.sub(r"[<>]", "", str(x)))

    # ---------------------------------
    #              FIT
    # ---------------------------------

    def fit(self, X, y=None):
        X = X.copy()
        for col in self.columns:
            raw = X[col].astype(str)
            
            #check if limits need to be expanded
            self.col_has_expanded_min_[col] = raw.str.startswith("<").any() 
            self.col_has_expanded_max_[col] = raw.str.startswith(">").any()

            # save limits
            numeric = raw.apply(self._extract_numeric)
            self.col_global_min_[col] = numeric.min()
            self.col_global_max_[col] = numeric.max()

            #compute expanded limits if min or max contains <,>
            if self.col_has_expanded_min_[col]:
                self.col_global_min_[col] = self.col_global_min_[col] - self.margin * self.col_real_min_[col]

            if self.col_has_expanded_max_[col]:
                self.col_global_max_[col] = self.col_global_max_[col] + self.margin * self.col_real_max_[col]

        return self

    # ---------------------------------
    #              TRANSFORM
    # ---------------------------------
    def transform(self, X, y=None):
        X = X.copy()
        result = pd.DataFrame(index=X.index)

        for col in self.columns:

            global_min = self.col_global_min_[col]
            global_max = self.col_global_max_[col]

            # check booleans
            allow_lower_expansion = self.col_has_expanded_min_[col]
            allow_upper_expansion = self.col_has_expanded_max_[col]

            col_min = []
            col_max = []

            for raw in X[col]:

                if pd.isna(raw):
                    col_min.append(np.nan)
                    col_max.append(np.nan)
                    continue

                value = self._extract_numeric(raw)
                string_val = str(raw)

                
                # if new value is <value
                if string_val.startswith("<"):

                    col_max.append(value)

                    # ---------------------------------
                    #          HANDLE MINIMUMS
                    # ---------------------------------

                    # expanded global_limit & (value > global_limit)
                    if allow_lower_expansion & (value>global_min):
                        col_min.append(value)
                    
                    # expanded global_limit & (value < global_limit)
                    elif allow_lower_expansion & (value<global_min):
                        col_min.append(value - self.margin * value)
                        self.col_global_min_[col] = value - self.margin * value
                    
                    # expanded global_limit & (value == global_limit)
                    elif allow_lower_expansion & (value == global_min):
                        col_min.append(value - self.margin * value)
                        self.col_global_min_[col] = value - self.margin * value
                    
                # if new value is >value
                elif string_val.startswith(">"):
                    col_min.append(value)

                    # ---------------------------------
                    #         HANDLE MAXIMUMS
                    # ---------------------------------
                    
                    # expanded global_limit & (value < global_limit)
                    if allow_upper_expansion & (value<global_max):
                        col_max.append(value)
                    
                    # expanded global_limit & (value > global_limit)
                    elif allow_lower_expansion & (value>global_max):
                        col_max.append(value + self.margin * value)
                        self.col_global_max_[col] = value + self.margin * value
                    
                    # expanded global_limit & (value == global_limit)
                    elif allow_lower_expansion & (value == global_max):
                        col_max.append(value + self.margin * value)
                        self.col_global_max_[col] = value + self.margin * value

                # punctual value
                else:
                    col_min.append(value)
                    col_max.append(value)

            # save new columns
            result[f"{col}_min"] = col_min
            result[f"{col}_max"] = col_max

        return result