In [None]:
# load and process the census data function

def load_census(date):
    file_name = f'data/hh_rw_' + str(date) + '_ms.csv'

    temp = pd.read_csv(file_name, delimiter=',')
    temp.drop(columns=['hhweight', 'age_discrete', 'nind', 'total_num_licences'], inplace=True)

    temp.reset_index(inplace=True)
    temp.rename(columns={'hid': 'household_id', 'index': 'person_id', 'role':'ind_rank'}, inplace=True)

    temp['start_date'] = date - 1
    temp['end_date'] = 9999
    temp['leave_cause'] = None

    temp['gender'] = temp['gender'].replace({1: 'M', 2: 'F'})

    # fix values in htype --> age of consent = 16 impossible to be a child if age gap with parent is < 16
    temp.sort_values(by=['household_id', 'age'], ascending=[True, False])

    # Calculate the age difference from the oldest member in each household
    temp['age_diff'] = temp.groupby('household_id')['age'].transform(lambda x: x.max() - x)

    # Calculate age difference from the oldest woman in the household
    temp['age_diff_woman'] = temp.groupby('household_id').apply(
        lambda x: x.loc[x['gender'] == 'F', 'age'].max() - x['age']).reset_index(level=0, drop=True)
    
    # single parent with children with size = 2 where the child has age diff < 16 --> set as couple
    temp.loc[temp['household_id'].isin(temp.loc[(temp['age_diff'] < 16) &
                                                (temp['ind_rank'] == 2) &
                                                (temp['hsize'] == 2) & 
                                                (temp['htype'] == 230), 'household_id'].unique()), 'htype'] = 210

    # single parent with children with size >2 where 1st child has age diff < 16 --> set as couple with children
    temp.loc[temp['household_id'].isin(temp.loc[(temp['age_diff'] < 16) &
                                                    (temp['ind_rank'] > 1) & 
                                                    (temp['hsize'] > 2) &
                                                    (temp['htype'] == 230), 'household_id'].unique()), 'htype'] = 220
     
    if to_print:
        print(f'\nAFTER FIXING SINGLE PARENTS:',temp['htype'].value_counts())

    # multi-gen if age_gap_woman > 46 or > 56 yo from oldest man --> new bracket "40" = 3-gen hh / complex situations
    temp['multi_gen'] = 0
    for htype_sel in [30, 220, 230]:
        temp.loc[temp.household_id.isin(temp.loc[(temp['age_diff_woman'] > 46) & 
                                            (temp['htype'] == htype_sel),'household_id'].unique()), ('htype','multi_gen')] = (40,1)
        temp.loc[temp.household_id.isin(temp.loc[(temp['age_diff'] > 56) & 
                                            (temp['htype'] == htype_sel),'household_id'].unique()), ('htype','multi_gen')] = (40,1)

    # filter htype 30 with only one generation (NO MEMBER with age gap >= 16) 
    # --> new bracket "31" = 2-gen hh (FALSE flatmates: flatmates with children?)
    temp['single_gen'] = 0
    not_one_gen = temp.loc[(temp['age_diff'] >= 16), 'household_id'].unique()
    temp.loc[(temp['htype'] == 30) & (temp['household_id'].isin(not_one_gen)), ('htype', 'single_gen')] = (31, 1)
    

    # drop complex situations (htype = 40 & 31) !!!!
    temp = temp.loc[~temp['htype'].isin([31, 40])]
    # - for others (htype = 31, 40) --> primary_hh = household_id + max(household_id) + 1
    #TO BE DONE --> EXPERIMENT: FIND BIGGEST GAPS BETWEEN AGE SORTED --> assuming htype 30 is 2-gen, htype 40 is 3-gen


    # fix primary_hh
    # - for single and couple (htype = 10, 210, 30) --> primary_hh = household_id + max(household_id) + 1
    max_hh_id = temp['household_id'].max()

    temp.loc[(temp['htype'].isin([10,210,30])), 'primary_hh'] = temp['person_id'] + max_hh_id + 1
    temp.loc[(temp['htype'].isin([10,30])), 'role'] = 'single'
    temp.loc[(temp['htype'] == 210), 'role'] = 'partner'

    # - for couple & single with children (htype = 220, 230) -->  for parents: primary_hh = household_id + max(household_id) + 1
    #                                                             for children: primary_hh = current household_id
    temp.loc[(temp['htype'] == 220) & (temp['ind_rank'] > 2), 'primary_hh'] = temp['household_id']
    temp.loc[(temp['htype'] == 220) & (temp['ind_rank'] > 2), 'role'] = 'child'

    temp.loc[(temp['htype'] == 220) & (temp['ind_rank'] <= 2), 'primary_hh'] = temp['person_id'] + max_hh_id + 1
    temp.loc[(temp['htype'] == 220) & (temp['ind_rank'] <= 2), 'role'] = 'partner'

    temp.loc[(temp['htype'] == 230) & (temp['ind_rank'] > 1), 'primary_hh'] = temp['household_id']
    temp.loc[(temp['htype'] == 230) & (temp['ind_rank'] > 1), 'role'] = 'child'

    temp.loc[(temp['htype'] == 230) & (temp['ind_rank'] == 1), 'primary_hh'] = temp['person_id'] + max_hh_id + 1
    temp.loc[(temp['htype'] == 230) & (temp['ind_rank'] == 1), 'role'] = 'single'

    # temp['age_diff_int'] = temp['age_diff'].diff().fillna(0).astype(int)
    # temp['age_diff_int'] = temp['age_diff_int'].clip(lower=0)

    temp.loc[temp['age'] < 18, 'role'] = 'child' 
    temp.loc[temp['age'] < 18, 'primary_hh'] = temp['household_id']
    
    temp.loc[temp.role == 'child', 'marital_status'] = 1
    # For partners in hh with htype [210,220], and different marital_status between partners --> marital_status = 1 (not married)
    married_in_hh = temp.groupby('household_id')['marital_status'].apply(lambda x: (x == 2).sum())
    temp = temp.merge(married_in_hh.rename('married_in_hh'), on='household_id', how='left')
    temp.loc[(temp.htype.isin([210,220])) & (temp.married_in_hh == 1) & (temp.role == 'partner'), 'marital_status'] = 1

    temp['birth'] = date - temp['age']

    temp['membership_id'] = temp['person_id'].astype(str) + '_' + temp['household_id'].astype(str)
    temp.set_index('membership_id', inplace=True)

    # generate persons_df 
    temp_pers = temp[['person_id', 'birth', 'gender', 'primary_hh']]
    temp_pers['death'] = 9999
    temp_pers.set_index('person_id', inplace=True)

    # generate households_df
    temp_household = temp[['household_id', 'htype']].drop_duplicates().reset_index(drop=True)
    temp_household.set_index('household_id', inplace=True)
    temp_household['hh_start'] = 2009
    temp_household['hh_end'] = 9999
    temp_household['end_cause'] = None
        
    return temp_pers, temp_household, temp[['person_id', 'household_id', 'start_date', 'end_date',
                                             'role', 'leave_cause', 'marital_status']]

In [None]:
# GENERATE households, persons and memberships dataframes 

if data_switch == 0:
    # Import existing data
    persons_df , households_df, memberships_df = read_data()

elif data_switch == 1:
    # Create fake data
    persons_df, households_df, memberships_df = create_fake_data()

else:
    persons_df, households_df, memberships_df = load_census(initial_date)
    temp = memberships_df.copy()



AFTER FIXING SINGLE PARENTS: htype
220    68600
210    38968
10     17397
230     9190
30      4752
Name: count, dtype: int64

WITH MULTI-GEN HH: htype
220    67727
210    38968
10     17397
230     9076
30      2043
31      1850
40      1846
Name: count, dtype: int64

AFTER FIXING MULTI-GEN HH: htype
220    67727
210    38968
10     17397
230     9076
30      2043
Name: count, dtype: int64


In [None]:
########### PAVEL - INITIAL DATASET MARRIAGE YEAR ###########

# BAYESIAN FORMULA: 
#                                                        (1)                             (2)
#                                          P(Active(t0)|Start(t),yob_M,yob_F) * P(Start(t)|yob_M,yob_F)  
# P(Start(t)|Active(t0),yob_M,yob_F) = ---------------------------------------------------------------------
#                                      SUM(t')[P(Active(t0)|Start(t),yob_M,yob_F) * P(Start(t)|yob_M,yob_F)] 
#                                                        (3)                             (4)

### COMPUTATIONS ###
# (1) and (3): Prob of marriage still active at t0 given the start of marriage at time t and the YOBs of husband and wife


# Use Survival Model: 
# Assumption --> simplify conditional: only take duration into account (assume independence from YOBs of partners)
# P(Active(t0)|Start(t),yob_M,yob_F) --> P(Active(t0)|Start(t)) = S(t0 - t) = survival function at time t0 - t
# P(Active(t0)|Start(t)) = exp(-lambda * (t0 - t)) 

# get lambda from data --> mean of exponential distribution = 1/lambda --> lambda = 1/(mean(marriage_duration))


#                                       (A)                                     (B)   
#                             marriages ending in divorce                                                   
# mean(marriage_duration) = -------------------------------- * (mean(duration(marriages_ending_divorce)))  + 
#                                   total marriages 
#                     
#                                       (C)                                     (D)
#                               marriages ending in death
#                           + ----------------------------- * (mean(duration(marriages_ending_death)))
#                                   total marriages


# A) available
# B) available
# C) derive from: 1-A 
# D) derive from: life expectancy - mean(age @ marriage)


# (2) and (4): probability of marriage starting at time t given the YOBs of husband and wife

def mean_truncated_geometric(tm, tw, p, t0):
    """
    Compute E[T_u] for a geometric(pi=p) truncated to t ∈ [tmin, t0], where
      tmin = max(tm + 18, tw + 18)
      tmax = t0

    The (unconditional) Geometric(X; p) is defined on X = number of failures 
    before first success, taking values X=0,1,2,... with P(X=k) = q^k p, q=1-p.
    Then T = tmin + X, truncated so that X <= N where N = tmax - tmin.

    This function returns that E[T].
    """
    tmin = max(tm + 18, tw + 18)
    tmax = t0
    if tmin > tmax:
        raise ValueError(f"Invalid: tmin={tmin} > t0={t0}.")

    q = 1.0 - p
    N = tmax - tmin
    qN1 = q ** (N + 1)
    denom = 1.0 - qN1
    # numerator = 1 − (N+1)q^N + N q^(N+1)
    numerator = 1.0 - (N + 1) * (q ** N) + N * (q ** (N + 1))
    # E[X | X ≤ N] = numerator / [p * (1 − q^(N+1))]
    EX_trunc = numerator / (p * denom)
    return tmin + EX_trunc



####### Calibrate p so that E[T_u] = mean_obs ##########

def calibrate_p(tm, tw, t0, mean_obs, tol=1e-8, max_iter=100):
    """
    Find p ∈ (0,1) such that mean_truncated_geometric(tm, tw, p, t0) == mean_obs,
    via bisection.

    Raises if mean_obs is not in [tmin, (tmin+t0)/2] approximately.
    """
    tmin = max(tm + 18, tw + 18)
    if not (tmin <= mean_obs):
        raise ValueError(
            f"mean_obs={mean_obs} is outside feasible range [{tmin}, {t0}]."
        )

    # p_low → 1 − p very close to 1: The truncated‐geo mean → (tmin+t0)/2 
    # as p→0. So at p_low, mean > mean_obs if mean_obs < (tmin+t0)/2.
    # at p_high → 1, mean → tmin.
    p_low, p_high = 1e-12, 1.0 - 1e-12

    f_low  = mean_truncated_geometric(tm, tw, p_low,  t0) - mean_obs
    f_high = mean_truncated_geometric(tm, tw, p_high, t0) - mean_obs

    # Instead of raising, clamp mean_obs to feasible range if needed
    tmin = max(tm + 18, tw + 18)

    ######################################################################
    #########################  CORRECTED HERE  ###########################
    ######################################################################
    
    if f_low < 0:
        # mean_obs is too high, set p to minimum possible (mean ≈ (tmin + t0)/2)
        return p_low
    if f_high > 0:
        # mean_obs is too low, set p to maximum possible (mean ≈ tmin)
        return p_high

    #######################################################################


    for _ in range(max_iter):
        p_mid = 0.5 * (p_low + p_high)
        f_mid = mean_truncated_geometric(tm, tw, p_mid, t0) - mean_obs

        if abs(f_mid) < tol:
            return p_mid
        if f_mid > 0:
            # mean_trunc(p_mid) > mean_obs → need larger p to reduce mean
            p_low = p_mid
        else:
            # mean_trunc(p_mid) < mean_obs → need smaller p to increase mean
            p_high = p_mid

    return 0.5 * (p_low + p_high)


# ──────────────────────────────────────────────────────────────────────────────
# 3. Build the prior PMF P(tu = t | tm, tw) on t = tmin..t0
# ──────────────────────────────────────────────────────────────────────────────

def pmf_prior_tuple(tm, tw, p, t0):
    """
    Returns two arrays:
      t_values = [tmin, tmin+1, ..., t0]
      prior_pmf = truncated geometric PMF at those t's with parameter p.
    """
    tmin = max(tm + 18, tw + 18)
    tmax = t0
    if tmin > tmax:
        return np.array([]), np.array([])

    q = 1.0 - p
    t_values = np.arange(tmin, tmax + 1)
    exponents = t_values - tmin  # these go 0,1,2,...,N
    unnorm = (q ** exponents) * p
    denom = 1.0 - q ** (tmax - tmin + 1)
    pmf = unnorm / denom
    return t_values, pmf


# posterior: P(tu | still‐married at t0, tm, tw) 
# we simplify the conditional by only taking into account the lenght of the marriage to optain: 
# posterior ∝ exp(−λ⋅(t0−t)) ⋅ prior(t)

def posterior_tu(tm, tw, p, t0, lambda_haz, random_state=None):
    """
    Draw one sample of tu from the posterior discrete distribution:
        P(tu = t | survived to t0, tm, tw)
      ∝ exp(−lambda_haz ⋅ (t0 − t)) ⋅ P_prior(t).
    """
    t_values, prior_pmf = pmf_prior_tuple(tm, tw, p, t0)
    if t_values.size == 0:
        raise ValueError("No valid tmin ≤ t0 for given birth years.")

    survival_weights = np.exp(-lambda_haz * (t0 - t_values))
    unnorm = prior_pmf * survival_weights
    total = unnorm.sum()
    if total <= 0:
        raise RuntimeError("Posterior weights all zero; check parameters.")
    post_pmf = unnorm / total

    rng = np.random.default_rng(random_state)
    return int(rng.choice(t_values, p=post_pmf))


######## compue marriage years ###############

def assign_marriage_dates(df_couples, t0, avg_duration):
    """
    Given df_couples with columns ['household_id','tm','tw','mean_obs'], attaches new columns:
      - p      = calibrated geometric parameter for that couple
      - tu     = sampled marriage year (drawn from the posterior)
      - duration = t0 - tu

    Keeps the original index and columns, so household_id and other info are preserved.
    Any couple for which calibration fails will receive NaN for (p,tu,duration).
    """
    p_list = []
    tu_list = []
    duration_list = []
    lambda_haz = 1.0 / avg_duration

    for _, row in df_couples.iterrows():
        tm = int(row["tm"])
        tw = int(row["tw"])
        mean_obs = float(row["mean_obs"])

        try:
            # 1) calibrate p so that E[T_u] matches mean_obs
            p_hat = calibrate_p(tm, tw, t0, mean_obs)
            # 2) sample a marriage year from the posterior
            tu_sample = posterior_tu(tm, tw, p_hat, t0, lambda_haz)
            # 3) compute resulting duration
            duration = t0 - tu_sample
        except Exception:
            p_hat = np.nan
            tu_sample = np.nan
            duration = np.nan

        p_list.append(p_hat)
        tu_list.append(tu_sample)
        duration_list.append(duration)

    df_couples = df_couples.copy()
    df_couples["p"] = p_list
    df_couples["marriage_date"] = tu_list
    df_couples["duration"] = duration_list
    return df_couples


In [None]:
######## APPLY INITIAL DATASET MARRIAGE YEAR ALGORITHM ##########

#### OBSERBVED DATA #### 

# get lambda from data --> mean of exponential distribution = 1/lambda --> lambda = 1/(mean(marriage_duration))

#### alternative to lambda = 1/(mean(marriage_duration)): 
# lambda = -ln(1-D/M)
# D = number of divorces at t
# M = total number of marriages active at t

# weighted mean: 
divorces_tot = pd.read_csv('data/bfs_divorces_tot.csv', delimiter=';').replace(',', '.', regex=True)
divorces_tot = divorces_tot.astype(int)

active_marriages = pd.read_csv('data/bfs_active_marriages.csv', delimiter=';').astype(int)

sum_D_t = divorces_tot['count'].sum()
sum_M_t = active_marriages['count'].sum()

avg_duration = sum_M_t/sum_D_t
lambda_alt = -np.log(1-(sum_D_t/ sum_M_t))

##########################################

av = get_available_actions(memberships_df, persons_df, initial_date)

av_married = av.loc[(av.role == 'partner') &
         (av.marital_status == 2) &
         (av.married_in_hh == 2)]

# Group ages of partners in households where both are marked with marital_status == 2
# Extract households with exactly 2 partners and assign birth years by gender

# filter out households with partners of same sex 
# Only keep households where there is one male and one female partner
married_hetero_couples = av_married[['household_id', 'birth', 'gender']].sort_values(['household_id', 'gender']).groupby(
    'household_id').filter(lambda x: set(x['gender']) == {'F', 'M'})

married_gay_couples = av_married[['household_id', 'birth', 'gender']].sort_values(['household_id', 'gender']).groupby(
    'household_id').filter(lambda x: set(x['gender']) != {'F', 'M'})
married_gay_couples.iloc[::2,-1] = 'F'
married_gay_couples.iloc[1::2,-1] = 'M'

# append married_couples and married_gay_couples
married_couples = pd.concat([married_hetero_couples, married_gay_couples], ignore_index=True)

# Pivot to get wife and husband birth years in separate columns, then drop the 'gender' column if present
pivot_mc = married_couples.reset_index(drop=True).pivot(index='household_id', columns='gender', values='birth')

#Only keep households where both wife and husband are present
pivot_mc = pivot_mc.rename(columns={'F': 'tw', 'M': 'tm'}
                           ).reset_index().dropna(subset=['tw','tm'])

#################

mean_obs = pd.read_csv('data/oced_avg_age_marriage.csv', delimiter=';')
mean_obs['tw'] = round(mean_obs['year'] - mean_obs['avg_age_w'])

# Fit a linear regression to extend the avg_marriage_age in both directions

X = mean_obs['tw'].values.reshape(-1, 1)
y = mean_obs['avg_age_w'].values

# Fit the model
reg = LinearRegression().fit(X, y)

# Predict for a wider range (extend before and after) --> linear extrapolation
tw_min = int(mean_obs['tw'].min()) - 60
tw_max = initial_date-18
tw_extended = np.arange(tw_min, tw_max + 1).reshape(-1, 1)
avg_age_pred = reg.predict(tw_extended).round(1)
# join avg_age_pred with tw_extended to create a DataFrame  
avg_age_pred_df = pd.DataFrame({
    'tw': tw_extended.flatten(),
    'mean_obs': avg_age_pred
})
avg_age_pred_df['mean_obs'] += avg_age_pred_df['tw']

# Merge the predicted average ages with the pivoted DataFrame
to_pred_marriage = pivot_mc.merge(avg_age_pred_df, on='tw', how='left')

###################


# Assign marriage years only to valid rows
pred_marriage_results = assign_marriage_dates(to_pred_marriage, initial_date, avg_duration)
pred_marriage_results.loc[pred_marriage_results.marriage_date.isna(), 'marriage_date'] = pred_marriage_results['mean_obs']

# Merge results back to the original DataFrame (if you want to keep all rows)
pred_marriage_df = to_pred_marriage.merge(
    pred_marriage_results[['household_id', 'marriage_date']],
    on=['household_id'],
    how='left'
)

av_married['membership_id'] = av_married.apply(
    lambda row: membership_id_str(row['person_id'], row['household_id']), axis=1)
to_add = av_married[['household_id', 'membership_id']].merge(
    pred_marriage_df[['household_id', 'marriage_date']], on=['household_id'], how='left').set_index('membership_id')

# Assign marriage_date to memberships_df, filling NaN with initial_date and converting to int
memberships_df.loc[to_add.index, 'marriage_date'] = to_add['marriage_date'].fillna(initial_date).astype(int)
memberships_df['marriage_date'] = memberships_df['marriage_date'].fillna(9999).astype(int)
