In [None]:
import synthpops as sp
import sciris as sc
import numpy as np
import random
import pandas as pd
import copy
import os
import time
from collections import Counter
import warnings
warnings.filterwarnings('ignore')


## Functions 

In [2]:
def con_net_init(n, rand_seed, location, state_location, country_location, sheet_name, use_default):
    """
    Generate synthetic population data and its corresponding contact network
    based on the given parameters.

    Parameters:
      n                : Total population
      rand_seed        : Random seed
      location         : Location parameter
      state_location   : State parameter
      country_location : Country parameter
      sheet_name       : Spreadsheet name
      use_default      : Whether to use default configuration
      
    Returns:
      PopAttr         : DataFrame with population attributes
      contact_network1: DataFrame of the merged contact network
      pop             : Synthpops population object
    """
    # Build the parameter dictionary
    pars = sc.objdict(
        n=n,
        rand_seed=rand_seed,
        household_method='fixed_ages',
        smooth_ages=1,
        location=location,
        state_location=state_location,
        country_location=country_location,
        sheet_name=sheet_name,
        with_school_types=1,
        with_non_teaching_staff=True,
        use_default=use_default,
        do_make=True,
    )
    pars.school_mixing_type = {
        'pk': 'age_and_class_clustered',
        'es': 'age_and_class_clustered',
        'ms': 'age_and_class_clustered',
        'hs': 'random',
        'uv': 'random'
    }
    
    # Generate synthetic population data
    pop = sp.Pop(**pars)
    popdict = pop.to_dict()
    popdf = pd.DataFrame.from_dict(popdict).transpose()
    cols = [col for col in popdf.columns if col != 'contacts']
    PopAttr = popdf[cols].copy()
    PopAttr.insert(0, 'uid', PopAttr.index)
    
    # Generate contact layers
    contacts = pop.to_people().contacts
    h = pd.DataFrame(contacts[0])
    s = pd.DataFrame(contacts[1])
    w = pd.DataFrame(contacts[2])
    c = pd.DataFrame(contacts[3])
    
    # Assign contact types
    h['ctype'], s['ctype'], w['ctype'], c['ctype'] = 'h', 's', 'w', 'c'
    
    # Merge different contact types and remove the beta column if present
    contact_network1 = pd.concat([h, s, w, c], axis=0)
    if 'beta' in contact_network1.columns:
        del contact_network1['beta']
    
    print(Counter(contact_network1['ctype']))
    return PopAttr, contact_network1, pop


def con_net_gen(PopAttr, distribution, br, dr, n, rand_seed, country_location):
    """
    Update the population attributes based on a given distribution as well as birth and death rates,
    and adjust household gender and age in order to later position new mothers.

    Parameters:
      PopAttr      : DataFrame of population attributes
      distribution : Dictionary of death rate distribution across age groups, where the key 
                     is an age interval string (e.g., "[00-01]")
      br, dr       : Birth and death rates
      n            : Total population
      rand_seed    : Random seed
      country_location : Country information (for subsequent processing)
      
    Returns:
      Updated PopAttr DataFrame
    """
    # Initialize status, set age 0 to 'newborn'
    PopAttr['status'] = 'alive'
    age_groups = list(distribution.keys())
    death_prob = list(distribution.values())
    people_count = [
        len(PopAttr[PopAttr['age'].between(int(age_group[1:3]), int(age_group[4:6]))])
        for age_group in age_groups
    ]
    death_count = np.floor(np.array(people_count) * np.array(death_prob)).astype(int)
    
    PopAttr.loc[PopAttr['age'] == 0, 'status'] = 'newborn'
    nb = len(PopAttr[PopAttr['age'] == 0])
    
    # Randomly select deaths in each age group (ensuring the selection does not exceed available samples)
    for age_group, count in zip(age_groups, death_count):
        indices = PopAttr[PopAttr['age'].between(int(age_group[1:3]), int(age_group[4:6]))].index.tolist()
        if len(indices) >= count and count > 0:
            selected_indices = random.sample(indices, count)
            PopAttr.loc[selected_indices, 'status'] = 'dead'
    
    # Adjust household gender distribution among new mothers
    hhid_list = list(set(PopAttr[PopAttr['age'] == 0]['hhid']))
    change_0to1 = 0
    change_1to0 = 0
    newborn_list = []
    for i in hhid_list:
        hhset = PopAttr[PopAttr['hhid'] == i].copy()
        adult = hhset[(hhset['age'] >= 18) & (hhset['age'] <= 65)]
        if len(adult) == 1:
            if adult['sex'].iloc[0] == 0:
                idx = random.choice(list(adult.index))
                hhset.loc[idx, 'sex'] = 1
                change_0to1 += 1
            elif adult['sex'].iloc[0] == 1:
                idx = random.choice(list(adult.index))
                hhset.loc[idx, 'sex'] = 0
                change_1to0 += 1
        newborn_list.append(hhset)
    newborn = pd.concat(newborn_list)
    newborn_len = len(newborn[newborn['age'] == 0])
    
    # Adjust newborn status: randomly mark a number of newborns as 'newborn'
    random_row = np.random.choice(
        newborn[newborn['age'] == 0].index,
        size=int(nb) if int(nb) < newborn_len else newborn_len,
        replace=False
    )
    for i in random_row:
        if newborn.loc[i, 'status'] == 'alive':
            newborn.loc[i, 'status'] = 'newborn'
    
    # Overwrite original PopAttr records with new newborn data
    PopAttr.loc[newborn['uid']] = newborn
    PopAttr['age'] = PopAttr['age'].astype(float)
    PopAttr.loc[PopAttr['age'] < 1, 'age'] = [float(random.randint(0, 9) / 10)
                                             for _ in range(len(PopAttr[PopAttr['age'] == 0]))]
    
    # Process households with only one member, adjust gender accordingly
    groups = PopAttr.groupby('hhid')
    halone = groups.filter(lambda x: len(x) == 1)
    d0 = halone[halone['sex'] == 0]
    d1 = halone[halone['sex'] == 1]
    list0to1 = np.unique(np.random.choice(list(d0.index), change_0to1, replace=False)) if len(d0) >= change_0to1 else []
    list1to0 = np.unique(np.random.choice(list(d1.index), change_1to0, replace=False)) if len(d1) >= change_1to0 else []
    PopAttr.loc[list0to1, 'sex'] = 1
    PopAttr.loc[list1to0, 'sex'] = 0
    PopAttr['sex'] = PopAttr['sex'].astype(int)
    
    # Identify households with newborns to find new mothers
    hh_with_newborn = PopAttr[PopAttr['hhid'].isin(list(set(PopAttr[PopAttr['age'] == 0]['hhid'])))]
    hh_without_newborn = PopAttr[~PopAttr.index.isin(hh_with_newborn.index)]
    
    # For families with newborns, select appropriate subjects for replacement
    sub_family_group = []
    for _, group in hh_with_newborn.groupby('hhid'):
        adults = group[group['age'] > 0]
        if len(adults) == 1:
            # Single parent household
            if int(adults['sex'].iloc[0]) == 0:
                if int(adults['age'].iloc[0]) < 18 or int(adults['age'].iloc[0]) > 49:
                    sub_family_group.append(int(adults['uid'].iloc[0]))
            else:
                sub_family_group.append(int(adults['uid'].iloc[0]))
        elif len(adults) == 2:
            # Two-parent household: if the female's age is not within the range, replace her
            if sum(adults['sex'].tolist()) == 1:
                female = adults[adults['sex'] == 0]
                if int(female['age'].iloc[0]) < 18 or int(female['age'].iloc[0]) > 49:
                    sub_family_group.append(int(female['uid'].iloc[0]))
            elif sum(adults['sex'].tolist()) == 2:
                sub_family_group.append(adults['uid'].tolist()[0])
            else:
                if int(adults.iloc[0]['age']) < 18 or int(adults.iloc[0]['age']) > 49:
                    sub_family_group.append(adults.iloc[0]['uid'])
                elif int(adults.iloc[1]['age']) < 18 or int(adults.iloc[1]['age']) > 49:
                    sub_family_group.append(adults.iloc[1]['uid'])
                else:
                    sub_family_group.append(adults.iloc[0]['uid'])
        else:
            # Households with more than two members
            if len(adults['sex'].unique()) == 1:
                sub_family_group.append(adults.iloc[0]['uid'])
            else:
                hh_female = adults[adults['sex'] == 0]
                for j in range(len(hh_female)):
                    if int(hh_female.iloc[j]['age']) < 18 or int(hh_female.iloc[j]['age']) > 49:
                        sub_family_group.append(hh_female.iloc[j]['uid'])
                        break
                        
    hh_Tobereplaced__id = pd.concat(
        [hh_with_newborn[hh_with_newborn['uid'] == i] for i in sub_family_group]
    )

    # Query eligible single parent households (females meeting the condition)
    condition = "49>= age >= 18 & sex == 0"
    hh_Tobechanged__id = hh_without_newborn.query(condition).sample(n=len(hh_Tobereplaced__id))
    
    hh_Tobereplaced__id_copy = copy.deepcopy(hh_Tobereplaced__id)
    hh_with_newborn.loc[hh_with_newborn['uid'].isin(hh_Tobereplaced__id['uid']),
                          ['age', 'sex']] = hh_Tobechanged__id[['age', 'sex']].values
    hh_without_newborn.loc[hh_without_newborn['uid'].isin(hh_Tobechanged__id['uid']),
                             ['age', 'sex']] = hh_Tobereplaced__id_copy[['age', 'sex']].values
    
    # For households with newborns, select the new mother by choosing a suitable candidate
    grouped_data = hh_with_newborn[hh_with_newborn['hhid'].isin(
        hh_with_newborn[hh_with_newborn['age'] == 0]['hhid'].unique()
    )].groupby('hhid')
    newmom_index = []
    for _, df in grouped_data:
        if len(df) > 1:
            condition = (df['age'].between(18, 49)) & (df['sex'] == 0)
            eligible = df[condition]
            if not eligible.empty:
                selected_row = eligible.sample(n=1, random_state=1)
                newmom_index.append(int(selected_row['uid']))
    hh_with_newborn.loc[hh_with_newborn['uid'].isin(newmom_index), 'status'] = 'new mother'
    
    PopAttr = pd.concat([hh_without_newborn, hh_with_newborn]).sort_index()
    PopAttr = PopAttr.drop(columns=['loc', 'ltcf_res', 'ltcf_staff', 'wpindcode', 'ltcfid'])
    print(Counter(PopAttr.status))
    return PopAttr


## Australia

In [None]:
# # Location Parameters
location = 'aus'
state_location = 'aus'
country_location = 'australia'
sheet_name       = 'Australia'

# # Birth Rate Distribution
distribution = {
    "[00-04]": 0.006844,
    "[05-09]": 0.000533,
    "[10-14]": 0.000862,
    "[15-19]": 0.002951,
    "[20-24]": 0.004526,
    "[25-29]": 0.0056978,
    "[30-34]": 0.00695,
    "[35-39]": 0.009051,
    "[40-44]": 0.011234,
    "[45-49]": 0.017954,
    "[50-54]": 0.024644,
    "[55-59]": 0.036541,
    "[60-64]": 0.050775,
    "[65-69]": 0.065728,
    "[70-74]": 0.096237,
    "[75-79]": 0.114619,
    "[80-84]": 0.145152,
    "[85-89]": 0.166398,
    "[90-99]": 0.233261,
}

# # Rates of Birth and Death
br = 0.0115
dr = 0.0063

## Canada

In [4]:
# location = 'can'
# state_location = 'can'
# country_location = 'canada'
# sheet_name       = 'Canada'
# distribution = {
#     "[00-01]": 0.0044,
#     "[02-04]": 0.0002,
#     "[05-09]": 0.0001,
#     "[10-14]": 0.0001,
#     "[15-19]": 0.0003,
#     "[20-24]": 0.0006,
#     "[25-29]": 0.0008,
#     "[30-34]": 0.001,
#     "[35-39]": 0.0011,
#     "[40-44]": 0.0015,
#     "[45-49]": 0.002,
#     "[50-54]": 0.0031,
#     "[55-59]": 0.0048,
#     "[60-64]": 0.0074,
#     "[65-69]": 0.0112,
#     "[70-74]": 0.0174,
#     "[75-79]": 0.029,
#     "[80-84]": 0.0509,
#     "[85-89]": 0.0928,
#     "[90-99]": 0.1972,
# }
# # define rate of died and birth
# br = 0.0094
# dr = 0.0081

## France

In [5]:
# location = 'fra'
# state_location = 'fra'
# country_location = 'france'
# sheet_name       = 'France'
# distribution = {
#     "[00-01]": 0.003,
#     "[02-04]": 0.0002,
#     "[05-09]": 0.0001,
#     "[10-14]": 0.0001,
#     "[15-19]": 0.0002,
#     "[20-24]": 0.0004,
#     "[25-29]": 0.0004,
#     "[30-34]": 0.0006,
#     "[35-39]": 0.0008,
#     "[40-44]": 0.0012,
#     "[45-49]": 0.002,
#     "[50-54]": 0.0032,
#     "[55-59]": 0.005,
#     "[60-64]": 0.0077,
#     "[65-69]": 0.0112,
#     "[70-74]": 0.0196,
#     "[75-79]": 0.0627,
#     "[80-89]": 0.091211622,
#     "[90-99]": 0.1195,
# }
# # define rate of died and birth
# br = 0.0112
# dr = 0.0091

## Germany

In [6]:
# location = 'deu'
# state_location = 'deu'
# country_location = 'germany'
# sheet_name       = 'Germany'
# distribution = {
#     "[00-01]": 0.003030424,
#     "[02-04]": 0.000140538,
#     "[05-09]": 7.71E-05,
#     "[10-14]": 8.26E-05,
#     "[15-19]": 0.000226943,
#     "[20-24]": 0.000301145,
#     "[25-29]": 0.00033992,
#     "[30-34]": 0.000504843,
#     "[35-39]": 0.000742879,
#     "[40-44]": 0.001137294,
#     "[45-49]": 0.001892382,
#     "[50-54]": 0.003216337,
#     "[55-59]": 0.005485791,
#     "[60-64]": 0.009185772,
#     "[65-69]": 0.013686804,
#     "[70-74]": 0.021051378,
#     "[75-79]": 0.03164996,
#     "[80-84]": 0.057917924,
#     "[85-99]": 0.100721335,
# }
# # define rate of died and birth
# br = 0.0094
# dr = 0.0113

## India

In [7]:
# location = 'inr'
# state_location = 'inr'
# country_location = 'india'
# sheet_name       = 'India'
# distribution = {
#     "[00-01]": 0.0283,
#     "[02-04]": 0.001559223,
#     "[05-09]": 0.000558316,
#     "[10-14]": 0.000542428,
#     "[15-19]": 0.000838727,
#     "[20-24]": 0.001301223,
#     "[25-29]": 0.001498421,
#     "[30-34]": 0.00190733,
#     "[35-39]": 0.002547802,
#     "[40-44]": 0.003358957,
#     "[45-49]": 0.004790502,
#     "[50-54]": 0.007797099,
#     "[55-59]": 0.011786385,
#     "[60-64]": 0.017562732,
#     "[65-69]": 0.026607315,
#     "[70-74]": 0.041230302,
#     "[75-79]": 0.062868256,
#     "[80-84]": 0.10828625,
#     "[85-99]": 0.178643404,
# }
# # define rate of died and birth
# br = 0.016949
# dr = 0.007416

## Ireland

In [8]:
# location = 'irl'
# state_location = 'irl'
# country_location = 'ireland'
# sheet_name       = 'Ireland'
# distribution = {
#     "[00-01]": 0.003,
#     "[02-04]": 0.0001,
#     "[05-09]": 0.0001,
#     "[10-14]": 0.0001,
#     "[15-19]": 0.0001,
#     "[20-24]": 0.0004,
#     "[25-29]": 0.0005,
#     "[30-34]": 0.0005,
#     "[35-39]": 0.0006,
#     "[40-44]": 0.0009,
#     "[45-49]": 0.0016,
#     "[50-54]": 0.0027,
#     "[55-59]": 0.0042,
#     "[60-64]": 0.0061,
#     "[65-69]": 0.0103,
#     "[70-74]": 0.0177,
#     "[75-79]": 0.0301,
#     "[80-84]": 0.0579,
#     "[85-99]": 0.1623,
# }
# # define rate of died and birth
# br = 0.012
# dr = 0.0063

## Italy

In [9]:
# location         = 'ita'
# state_location = 'ita'
# country_location = 'italy'
# sheet_name       = 'Italy'
# distribution = {
#     "[00-01]": 0.0019,
#     "[02-04]": 0.0002,
#     "[05-09]": 0.0001,
#     "[10-14]": 0.0001,
#     "[15-19]": 0.0002,
#     "[20-24]": 0.0003,
#     "[25-29]": 0.0003,
#     "[30-34]": 0.0004,
#     "[35-39]": 0.0006,
#     "[40-44]": 0.0008,
#     "[45-49]": 0.0014,
#     "[50-54]": 0.0023,
#     "[55-59]": 0.0034,
#     "[60-64]": 0.0054,
#     "[65-69]": 0.0095,
#     "[70-74]": 0.0128,
#     "[75-79]": 0.0230,
#     "[80-84]": 0.0435,
#     "[85-99]": 0.1113,
# }
# br = 0.007
# dr = 0.011

## Japan

In [10]:
# # Location Parameters
# location = 'jpn'
# state_location = 'jpn'
# country_location = 'japan'
# sheet_name = 'Japan'

# # Birth Rate Distribution
# distribution = {
#     "[00-01]": 0.0283,
#     "[02-04]": 0.001559223,
#     "[05-09]": 0.000558316,
#     "[10-14]": 0.000542428,
#     "[15-19]": 0.000838727,
#     "[20-24]": 0.001301223,
#     "[25-29]": 0.001498421,
#     "[30-34]": 0.00190733,
#     "[35-39]": 0.002547802,
#     "[40-44]": 0.003358957,
#     "[45-49]": 0.004790502,
#     "[50-54]": 0.007797099,
#     "[55-59]": 0.011786385,
#     "[60-64]": 0.017562732,
#     "[65-69]": 0.026607315,
#     "[70-74]": 0.041230302,
#     "[75-79]": 0.062868256,
#     "[80-84]": 0.10828625,
#     "[85-99]": 0.178643404,
# }

# # Rates of Birth and Death
# br = 0.016949
# dr = 0.007416


## Spain

In [11]:
# location='esp'
# state_location='esp'
# country_location='spain'
# sheet_name='Spain'
# distribution = {
#     "[00-01]": 0.0026,
#     "[02-04]": 0.0001,
#     "[05-09]": 0.0001,
#     "[10-14]": 0.0001,
#     "[15-19]": 0.0001,
#     "[20-24]": 0.0002,
#     "[25-29]": 0.0003,
#     "[30-34]": 0.0003,
#     "[35-39]": 0.0005,
#     "[40-44]": 0.0008,
#     "[45-49]": 0.0015,
#     "[50-54]": 0.0027,
#     "[55-59]": 0.0044,
#     "[60-64]": 0.0067,
#     "[65-69]": 0.0100,
#     "[70-74]": 0.0146,
#     "[75-79]": 0.0256,
#     "[80-84]": 0.0479,
#     "[85-99]": 0.1280,
# }
# # define rate of died and birth
# br = 0.0076
# dr = 0.0088

## Sweden

In [12]:
# location='swe'
# state_location='swe'
# country_location='sweden'
# sheet_name='Sweden'
# distribution = {
#     "[00-01]": 0.002,
#     "[02-04]": 0.0001,
#     "[05-09]": 0.0001,
#     "[10-14]": 0.0001,
#     "[15-19]": 0.0002,
#     "[20-24]": 0.0004,
#     "[25-29]": 0.0005,
#     "[30-34]": 0.0006,
#     "[35-39]": 0.0006,
#     "[40-44]": 0.0008,
#     "[45-49]": 0.0013,
#     "[50-54]": 0.0021,
#     "[55-59]": 0.0035,
#     "[60-64]": 0.0062,
#     "[65-69]": 0.0101,
#     "[70-74]": 0.0166,
#     "[75-79]": 0.0288,
#     "[80-84]": 0.054,
#     "[85-99]": 0.1570,
# }
# # define rate of died and birth
# br = 0.0111
# dr = 0.0086

## The Netherlands

In [None]:
# location='nld'
# state_location='nld'
# country_location='netherlands'
# sheet_name='Netherlands'
# distribution = {
#     "[00-01]": 0.0035,
#     "[02-04]": 0.0001,
#     "[05-09]": 0.0001,
#     "[10-14]": 0.0001,
#     "[15-19]": 0.0002,
#     "[20-24]": 0.0003,
#     "[25-29]": 0.0003,
#     "[30-34]": 0.0004,
#     "[35-39]": 0.0006,
#     "[40-44]": 0.0009,
#     "[45-49]": 0.0015,
#     "[50-54]": 0.0026,
#     "[55-59]": 0.0044,
#     "[60-64]": 0.0073,
#     "[65-69]": 0.0115,
#     "[70-74]": 0.0184,
#     "[75-79]": 0.0306,
#     "[80-84]": 0.0574,
#     "[85-99]": 0.1504,
# }
# # define rate of died and birth
# br = 0.0098
# dr = 0.0088

## UK

In [14]:
# location = 'england'
# state_location = 'england'
# country_location = 'uk'
# sheet_name       = 'United Kingdom of Great Britain'
# distribution = {
#     "[00-01]": 0.00375,
#     "[02-04]": 0.0001,
#     "[05-09]": 0.00005,
#     "[10-14]": 0.0001,
#     "[15-19]": 0.0002,
#     "[20-24]": 0.0003,
#     "[25-29]": 0.00045,
#     "[30-34]": 0.00065,
#     "[35-39]": 0.001,
#     "[40-44]": 0.0015,
#     "[45-49]": 0.00235,
#     "[50-54]": 0.00355,
#     "[55-59]": 0.00525,
#     "[60-64]": 0.0083,
#     "[65-69]": 0.013,
#     "[70-74]": 0.0205,
#     "[75-79]": 0.03615,
#     "[80-84]": 0.0658,
#     "[85-89]": 0.1209,
#     "[90-99]": 0.2434,
# }
# # define rate of died and birth
# br = 0.0107
# dr = 0.009

## US

In [32]:
location = 'seattle_metro'
state_location = 'Washington'
country_location = 'usa'
sheet_name       = 'United States of America'
distribution = {
    "[00-01]": 0.005286,
    "[02-04]": 0.000247,
    "[05-14]": 0.000146,
    "[15-24]": 0.000902,
    "[25-34]": 0.001786,
    "[35-44]": 0.00292,
    "[45-54]": 0.005385,
    "[55-64]": 0.011318,
    "[65-74]": 0.021482,
    "[75-84]": 0.048826,
    "[85-99]": 0.138262,
}
# define rate of died and birth
br = 0.0114
dr = 0.0087

## Main function

In [None]:
def main(location, state_location, country_location, sheet_name, distribution, br, dr, rand_seed, experiment_count, n):
    """
    Generate synthetic populations and contact networks for a given location.
    
    Parameters:
        location (str): Location code
        state_location (str): State location code
        country_location (str): Country location name
        sheet_name (str): Sheet name in input data
        distribution (dict): Age-specific death rate distribution
        br (float): Birth rate
        dr (float): Death rate
        rand_seed (int): Random seed to start with
        experiment_count (int): Number of successful experiments to run
        n (int): Population size
    """
    successful_experiments = 0
    location_name = location

    # Synthetic Population and Contact Network Generation
    while successful_experiments < experiment_count:
        start_time = time.time()
        try:
            PopAttr, contact_network1, pop = con_net_init(
                n, rand_seed, location, state_location, country_location, sheet_name, use_default=True
            )
            PopAttr = con_net_gen(PopAttr, distribution, br, dr, n, rand_seed, country_location)
            print('rand_seed:', rand_seed)

            # rename the location name for the output file
            if country_location == 'usa':
                location_name = 'US'
            elif country_location == 'uk':
                location_name = 'UK'
            elif country_location == 'canada':
                location_name = 'Canada'
            elif country_location == 'france':
                location_name = 'France'
            elif country_location == 'germany':
                location_name = 'Germany'
            elif country_location == 'japan':
                location_name = 'Japan'
            elif country_location == 'spain':
                location_name = 'Spain'
            elif country_location == 'netherlands':
                location_name = 'The Netherlands'
            elif country_location == 'sweden':
                location_name = 'Sweden'
            elif country_location == 'ireland':
                location_name = 'Ireland'
            elif country_location == 'italy':
                location_name = 'Italy'
            elif country_location == 'australia':
                location_name = 'Australia'
            elif country_location == 'india':
                location_name = 'India'

            # Create output folders if they don't exist
            country_folder_net = os.path.join('./net', location_name)
            country_folder_pop = os.path.join('./pop', location_name)
            os.makedirs(country_folder_net, exist_ok=True)
            os.makedirs(country_folder_pop, exist_ok=True)

            # Save contact network and population attributes to CSV files
            print(f"Saving {location_name} contact network and population attributes")
            net_csv = os.path.join(country_folder_net, f"{location_name}_net_{successful_experiments}.csv")
            pop_csv = os.path.join(country_folder_pop, f"{location_name}_pop_{successful_experiments}.csv")
            contact_network1.to_csv(net_csv, index=False)
            columns_to_remove = ['loc', 'ltcf_res', 'ltcf_staff', 'ltcfid','wpindcode']
            PopAttr = PopAttr.drop(columns=[col for col in columns_to_remove if col in PopAttr.columns])
            PopAttr.to_csv(pop_csv, index=False)

            successful_experiments += 1
            rand_seed += 1

        except Exception as e:
            print(f"rand_seed: {rand_seed} encountered error: {e}, skipping")
            rand_seed += 1

        finally:
            elapsed_time = time.time() - start_time
            if elapsed_time > 30:
                print(f"rand_seed: {rand_seed} took {elapsed_time:.2f} seconds, exceeding 30 seconds, skipping")

    print(f"Completed {experiment_count} successful random experiments")

In [34]:
if __name__ == "__main__":
    # parameters for the main function
    pars = {
        'location': location,
        'state_location': state_location,
        'country_location': country_location,
        'sheet_name': sheet_name,
        'distribution': distribution,
        'br': br,
        'dr': dr,
        'rand_seed': 0,
        'experiment_count': 100, # # of contact network and population attributes to generate
        'n': 10000 # population size
    }

    main(**pars)

Counter({'w': 42775, 'c': 30104, 's': 26434, 'h': 11043})
Counter({'alive': 9775, 'newborn': 124, 'dead': 85, 'new mother': 16})
rand_seed: 0
Saving US contact network and population attributes
Counter({'w': 40908, 'c': 30031, 's': 25468, 'h': 11054})
Counter({'alive': 9783, 'newborn': 121, 'dead': 85, 'new mother': 11})
rand_seed: 1
Saving US contact network and population attributes
Counter({'w': 40029, 'c': 30077, 's': 26203, 'h': 11038})
Counter({'alive': 9766, 'newborn': 134, 'dead': 81, 'new mother': 19})
rand_seed: 2
Saving US contact network and population attributes
Counter({'w': 40713, 'c': 29973, 's': 26921, 'h': 11075})
Counter({'alive': 9781, 'newborn': 126, 'dead': 83, 'new mother': 10})
rand_seed: 3
Saving US contact network and population attributes
Counter({'w': 43264, 'c': 29776, 's': 25501, 'h': 11071})
Counter({'alive': 9761, 'newborn': 144, 'dead': 84, 'new mother': 11})
rand_seed: 4
Saving US contact network and population attributes
Counter({'w': 45581, 'c': 3027