In [1]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join('../', 'people')))

In [2]:
from utilities import DataGenerator, AgeGroup
from family import Family
from person import Person
import random

class FamilyGraph():
    def __init__(self, generator: DataGenerator, number_of_progenitors: int, oldest_group: AgeGroup):
        self.generator = generator
        self.levels = []
        self.current_level = -1
        self.generate_progenitors(number_of_progenitors, oldest_group)
    
    def add_level(self):
        self.levels.append([])
        self.current_level += 1

    def generate_progenitors(self, number_of_progenitors: int, oldest_group: AgeGroup):
        self.add_level()
        for i in range(number_of_progenitors):
            p = Person(
                self.generator,
                age=self.generator.get_age(group=oldest_group, n=1)
            )
            self.levels[self.current_level].append(p)

    def generate_initial_families(self, max_children: int = 10):
        self.add_level()
        progenitors = [p for p in self.levels[0]]
        
        for progenitor in progenitors:
            family = Family(self.generator, progenitor)
            progenitor.set_original_family(family)
            progenitor.set_new_family(family)
            family.create_family(n_children_max=max_children)
            self.levels[0].append(family.get_partner())
            family.get_partner().set_original_family(family)
            family.get_partner().set_new_family(family)
            self.levels[1].extend(family.get_children())
    
    def generate_next_level(self, max_children: int = 4):        
        current_level_individuals = self.levels[self.current_level]
        if not current_level_individuals:
            return False
        
        min_age_group = AgeGroup.LATE_YOUTH

        males = [p for p in current_level_individuals if p.gender == 'M' and p.age.group >= min_age_group]
        females = [p for p in current_level_individuals if p.gender == 'F' and p.age.group >= min_age_group]
        
        if len(males) > len(females):
            family_roots = random.sample(males, len(females))
            potential_partners = females
        else:
            family_roots = random.sample(females, len(males))
            potential_partners = males
        
        if not family_roots:
            return False
        
        used_partners = set()
        
        self.add_level()
        added_any = False
        
        for root in family_roots:
            family = Family(self.generator, root)
            
            partner = next(
                (p for p in potential_partners 
                if p.original_family != root.original_family and p not in used_partners), 
                None
            )
            
            if partner:
                used_partners.add(partner)
                family.create_family(partner=partner, n_children_max=max_children)
            else:
                family.create_family(n_children_max=max_children, no_partner=True)
            
            root.set_new_family(family)
            
            children = family.get_children()
            self.levels[self.current_level].extend(children)
            
            if children:
                added_any = True
        return added_any

    
    def generate_full_family_tree(self, limit_group: AgeGroup = AgeGroup.INFANT, start_max_children: int = 8):
        self.generate_initial_families()
        
        while True:
            should_continue = self.generate_next_level(max_children=start_max_children)
            if not should_continue:
                break
            
            start_max_children -= 3
            next_level_individuals = self.levels[self.current_level]
            avg_age = sum([individual.age.age_value for individual in next_level_individuals]) / len(next_level_individuals)
            
            if self.generator._get_age_group(avg_age) <= limit_group:
                break
        
    def get_csv(self, filename: str):
        with open(filename, 'w', encoding='utf-8') as f:
            f.write("cf,name,last_name,birthdate,gender,city,gen1,gen2,partner_of\n")
            for level, individuals in enumerate(self.levels):
                for individual in individuals:
                    f.write(f"{individual.to_csv()}\n")

In [3]:
graph = FamilyGraph(DataGenerator(), 50, AgeGroup.EARLY_LATE_ELDERLY)
graph.generate_full_family_tree()

[1;33m[DataGenerator][0m: Initializing data generator
[1;32m[DataGenerator][0m: Data are ready


In [4]:
c = 0

for level in graph.levels:
    print(f"Level {graph.levels.index(level)}:")
    print(f"Number of individuals: {len(level)}")
    c += len(level)

print(f"Total number of individuals: {c}")

Level 0:
Number of individuals: 100
Level 1:
Number of individuals: 248
Level 2:
Number of individuals: 445
Level 3:
Number of individuals: 524
Level 4:
Number of individuals: 64
Total number of individuals: 1381


In [5]:
graph.get_csv("family_tree.csv")