In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl

In [2]:
from sampling_main_secondary import sample_df_via_main_and_secondary_flags, make_sampling_summary_df, decode_flags

### spec

In [3]:
df = pd.read_csv('titanic.csv')
print(f"{df.shape = }")
df['who'].value_counts().div(df.shape[0])

df.shape = (891, 15)


who
man      0.602694
woman    0.304153
child    0.093154
Name: count, dtype: float64

In [4]:
print(df.head().to_string())

   survived  pclass     sex   age  sibsp  parch     fare embarked  class    who  adult_male deck  embark_town alive  alone
0         0       3    male  22.0      1      0   7.2500        S  Third    man        True  NaN  Southampton    no  False
1         1       1  female  38.0      1      0  71.2833        C  First  woman       False    C    Cherbourg   yes  False
2         1       3  female  26.0      0      0   7.9250        S  Third  woman       False  NaN  Southampton   yes   True
3         1       1  female  35.0      1      0  53.1000        S  First  woman       False    C  Southampton   yes  False
4         0       3    male  35.0      0      0   8.0500        S  Third    man        True  NaN  Southampton    no   True


In [5]:
from enum import IntFlag


class Flags(IntFlag):
    IS_MAN = 1 << 0
    IS_WOMAN = 1 << 1
    IS_CHILD = 1 << 2

    AGE_BELOW_18 = 1 << 3
    AGE_ABOVE_18 = 1 << 4

    CLASS_1 = 1 << 5
    CLASS_2 = 1 << 6
    CLASS_3 = 1 << 7

    SURVIVED = 1 << 8
    NOT_SURVIVED = 1 << 9

df['flags'] = 0
df.loc[df['who'] == 'man', 'flags'] |= int(Flags.IS_MAN)
df.loc[df['who'] == 'woman', 'flags'] |= int(Flags.IS_WOMAN)
df.loc[df['who'] == 'child', 'flags'] |= int(Flags.IS_CHILD)
df.loc[df['age'] < 18, 'flags'] |= int(Flags.AGE_BELOW_18)
df.loc[df['age'] >= 18, 'flags'] |= int(Flags.AGE_ABOVE_18)
df.loc[df['pclass'] == 1, 'flags'] |= int(Flags.CLASS_1)
df.loc[df['pclass'] == 2, 'flags'] |= int(Flags.CLASS_2)
df.loc[df['pclass'] == 3, 'flags'] |= int(Flags.CLASS_3)
df.loc[df['survived'] == 1, 'flags'] |= int(Flags.SURVIVED)
df.loc[df['survived'] == 0, 'flags'] |= int(Flags.NOT_SURVIVED)

df['flags_decoded'] = df['flags'].apply(lambda x: decode_flags(x, Flags))
print(df[['who', 'age', 'pclass', 'survived', 'flags', 'flags_decoded']].head().to_string())

     who   age  pclass  survived  flags                                  flags_decoded
0    man  22.0       3         0    657  [IS_MAN, AGE_ABOVE_18, CLASS_3, NOT_SURVIVED]
1  woman  38.0       1         1    306    [IS_WOMAN, AGE_ABOVE_18, CLASS_1, SURVIVED]
2  woman  26.0       3         1    402    [IS_WOMAN, AGE_ABOVE_18, CLASS_3, SURVIVED]
3  woman  35.0       1         1    306    [IS_WOMAN, AGE_ABOVE_18, CLASS_1, SURVIVED]
4    man  35.0       3         0    657  [IS_MAN, AGE_ABOVE_18, CLASS_3, NOT_SURVIVED]


### process

In [6]:
flags_main = [Flags.IS_MAN, Flags.IS_WOMAN, Flags.IS_CHILD]
flags_secondary = [Flags.AGE_BELOW_18, Flags.AGE_ABOVE_18, Flags.SURVIVED, Flags.NOT_SURVIVED]
req_n_main = [10,10,10]
req_n_secondary = [15,15, 10, 10]

In [7]:
np.random.seed(42)

df['chosen'] = sample_df_via_main_and_secondary_flags(
    df,
    flags_col='flags',
    flags_main = flags_main,
    flags_secondary = flags_secondary,
    req_n_main = req_n_main,
    req_n_secondary = req_n_secondary,
)


[32m2025-10-13 20:44:41.627[0m | [34m[1mDEBUG   [0m | [36msampling_main_secondary[0m:[36msample_df_via_main_and_secondary_flags[0m:[36m27[0m - [34m[1mChosen 10 samples for flag IS_MAN[0m
[32m2025-10-13 20:44:41.630[0m | [34m[1mDEBUG   [0m | [36msampling_main_secondary[0m:[36msample_df_via_main_and_secondary_flags[0m:[36m27[0m - [34m[1mChosen 10 samples for flag IS_WOMAN[0m
[32m2025-10-13 20:44:41.633[0m | [34m[1mDEBUG   [0m | [36msampling_main_secondary[0m:[36msample_df_via_main_and_secondary_flags[0m:[36m27[0m - [34m[1mChosen 10 samples for flag IS_CHILD[0m
[32m2025-10-13 20:44:41.636[0m | [34m[1mDEBUG   [0m | [36msampling_main_secondary[0m:[36msample_df_via_main_and_secondary_flags[0m:[36m42[0m - [34m[1mChosen 4 samples for flag AGE_BELOW_18[0m
[32m2025-10-13 20:44:41.639[0m | [34m[1mDEBUG   [0m | [36msampling_main_secondary[0m:[36msample_df_via_main_and_secondary_flags[0m:[36m34[0m - [34m[1mAlready have 17 samples 

In [8]:
sdf = make_sampling_summary_df(
    df,
    sample_col='chosen',
    flags_col='flags',
    flags_main = flags_main,
    flags_secondary = flags_secondary
)
print(sdf.to_string())

              Chosen  Total Percentage       Type   Ratio
Flag                                                     
IS_MAN            10    537       1.9%       Main  10/537
IS_WOMAN          10    271       3.7%       Main  10/271
IS_CHILD          14     83      16.9%       Main   14/83
AGE_BELOW_18      15    113      13.3%  Secondary  15/113
AGE_ABOVE_18      17    601       2.8%  Secondary  17/601
SURVIVED          18    342       5.3%  Secondary  18/342
NOT_SURVIVED      16    549       2.9%  Secondary  16/549


In [9]:
print(f"total chosen: {df['chosen'].sum()}")

total chosen: 34


In [10]:
print(df.loc[lambda x:x['chosen'], :].to_string())


     survived  pclass     sex    age  sibsp  parch      fare embarked   class    who  adult_male deck  embark_town alive  alone  flags                                    flags_decoded  chosen
9           1       2  female  14.00      1      0   30.0708        C  Second  child       False  NaN    Cherbourg   yes  False    332      [IS_CHILD, AGE_BELOW_18, CLASS_2, SURVIVED]    True
63          0       3    male   4.00      3      2   27.9000        S   Third  child       False  NaN  Southampton    no  False    652  [IS_CHILD, AGE_BELOW_18, CLASS_3, NOT_SURVIVED]    True
127         1       3    male  24.00      0      0    7.1417        S   Third    man        True  NaN  Southampton   yes   True    401        [IS_MAN, AGE_ABOVE_18, CLASS_3, SURVIVED]    True
134         0       2    male  25.00      0      0   13.0000        S  Second    man        True  NaN  Southampton    no   True    593    [IS_MAN, AGE_ABOVE_18, CLASS_2, NOT_SURVIVED]    True
141         1       3  female  22.00    