In [0]:
import pandas as pd
import numpy as np

In [0]:
df = spark.sql(
    """
    select
        rt.flightkey,
        rt.route,
        rt.flight_dt,
        rt.chargeproduct,
        flt.capacity,

        -- route characteristics
        rtmap.type, 
        rtmap.region,

        -- statistics & metrics
        sum(rt.unt_net) as total_pax,
        sum(rt.rev_net) as total_rev
        
    from 
        data_experience_commercial.cbt_1423_rtsuite.master rt
    join 
        data_prod.silver_sanezdb.routemap rtmap on rt.route = rtmap.route
    join
        data_prod.silver_curated_eres.flight flt on rt.flightkey = flt.flightkey

    where 1=1 
        and rt.chargeproduct = 'Ticket'
        and rt.flight_dt between current_date() - 365 and current_date()

    group by
        rt.flightkey,
        rt.route,
        rt.flight_dt,
        rt.chargeproduct,
        flt.capacity, 
        rtmap.type, 
        rtmap.region
    """
).toPandas()

In [0]:
df['fligth_dt'] = pd.to_datetime(df['flight_dt']) # conversion datetime 

In [0]:
# pseudo_test sample
df_test = df.loc[
    df['flight_dt'] > pd.to_datetime('2024-11-01')
]

In [0]:
df_stats = df_test.groupby(
    ['route', 'region', 'type']
).agg(
    {
        'capacity': 'sum',
        'total_pax': 'mean',
        'total_rev': 'mean'
    }
).reset_index()

In [0]:
# Compute quantiles for each (region, type) group
df_capacity_quantiles = df_stats.groupby(['region', 'type']).agg(
    capacity_25th=('capacity', lambda x: x.quantile(0.25)),
    capacity_50th=('capacity', lambda x: x.quantile(0.50)),
    capacity_75th=('capacity', lambda x: x.quantile(0.75))
).reset_index()

# Merge back to the original df_stats
df_stats = df_stats.merge(df_capacity_quantiles, on=['region', 'type'], how='left')

def cap_group(row):
    if row['capacity'] <= row['capacity_50th']:
        return 'Low'
    return 'High'
df_stats['cap_group'] = df_stats.apply(cap_group, axis=1)

df_stats = df_stats.loc[
    ~df_stats.route.isin(('EDITFS', 'LISNCE', 'NAPAMS', 'MANBCN', 'MANAGP', 'MADEDI', 'MXPSUF', 'MXPBRI'))
]

In [0]:
class StratifiedSampler:
    def __init__(self, df: pd.DataFrame, stratify_cols: list, seed: int = None):
        """
        Initialize the stratified sampler.

        :param df: Pandas DataFrame containing the data.
        :param stratify_cols: List of columns to stratify by.
        :param seed: Random seed for reproducibility.
        """
        self.df = df.copy()
        self.stratify_cols = stratify_cols
        self.seed = seed
        
        if self.seed:
            np.random.seed(self.seed)
        
        # Compute group sizes
        self.strata_counts = self.df.groupby(self.stratify_cols).size().reset_index(name="count")
    
    def sample(self, n: int = None, fixed_allocation: dict = None):
        """
        Perform stratified sampling.

        :param n: Total number of samples (proportional to group sizes).
        :param fixed_allocation: Dictionary specifying exact samples per stratum.
        :return: Sampled DataFrame.
        """
        if n and fixed_allocation:
            raise ValueError("Provide either 'n' for proportional sampling or 'fixed_allocation', not both.")

        df_sampled = pd.DataFrame()
        
        for _, row in self.strata_counts.iterrows():
            stratum_filter = (self.df[self.stratify_cols] == row[self.stratify_cols].values).all(axis=1)
            stratum_df = self.df[stratum_filter]

            if fixed_allocation:
                sample_size = fixed_allocation.get(tuple(row[self.stratify_cols]), 0)
            else:
                sample_size = int(row["count"] * n / self.strata_counts["count"].sum())

            sampled_stratum = stratum_df.sample(n=sample_size, random_state=self.seed) if sample_size > 0 else pd.DataFrame()
            df_sampled = pd.concat([df_sampled, sampled_stratum])

        return df_sampled.reset_index(drop=True)
    
    def validate(self):
        """
        Validate the stratification process.
        """
        sampled_counts = self.sampled_df.groupby(self.stratify_cols).size().reset_index(name="sampled_count")
        merged = self.strata_counts.merge(sampled_counts, on=self.stratify_cols, how="left").fillna(0)
        return merged



In [0]:
fixed_allocation = {
    ("Beach", "UK-Regions", 'High'): 4,
    ("City", "Portugal", "Low"): 2,
    ("City", "Italy", "High"): 2,
    ("City", "UK-Regions", "High"): 4,
    ("Domestic", "Italy", "High"): 4
}

samples = []
for i in range(10):
    seed = np.random.randint(1000)
    sampler = StratifiedSampler(df_stats, stratify_cols=["type", "region", "cap_group"], seed=seed)
    sampled_df = sampler.sample(fixed_allocation=fixed_allocation)
    samples.append(sampled_df.route.to_list())

In [0]:
samples