In [None]:
from pathlib import Path
import pandas as pd 
from typing import Set, Optional


class CompleteSuicideRates:
    """a data class for suicide rates across many countries"""
    def __init__(self, csv_file_path: str):
        assert self.is_valid_csv(csv_file_path)
        self.raw_data_frame = pd.read_csv(csv_file_path)
        self._setup()
        self._genders = {'Male', 'Both sexes', 'Female'}
        self._age_groups = {'15-24 years', '25-34 years of age', 
                            '35-44 years of age', '45-54 years', 
                            '55-64 years', '65-74 years', 
                            '75-84 years', '85+ years', 
                            'all ages'}
        del self.raw_data_frame
    
    def _setup(self) -> None:
        self._set_data_frame()
        self._set_year()
        self._set_all_countries()
    
    def is_valid_csv(self, file: str) -> bool:
        file_path = Path(file)
        return ((file_path.suffix == '.csv') 
                and (file_path.exists()))
    
    def _set_data_frame(self) -> None:
        self._data_frame = (self.raw_data_frame.filter(
            items=['Period', 'Location',  'Dim1', 'Dim2', 'Value']))

    def _set_year(self) -> None:
        years = set(self.raw_data_frame.Period.to_numpy())
        assert len(years) == 1
        self._year = str(years.pop()) 
    
    def year(self) -> str:
        return self._year
    
    def _set_all_countries(self) -> None:
        self._all_countries = set(self.raw_data_frame.Location.to_list())
    
    def all_countries(self) -> Set[str]:
        """names of all countries in the data"""
        return Set(self._all_countries)
    
    # def total_suicide_rates_in_countries(self) -> pd.DataFrame:
    #     df = self._data_frame[self._data_frame.Dim1 == 'Both sexes']
    #     return df.groupby(['Location', 'Period'], as_index=False).sum()
    
    
    # def suicide_rates_in_countries(self,*, 
    #                                gender: str = 'Both sexes', 
    #                                age_group: str = '15-24 years'
    #                                ) -> pd.DataFrame:
    #     """return suicide rates in countries by gender and age group"""
    #     assert gender in self._genders 
    #     assert age_group in self._age_groups 
    #     return self._data_frame[((self._data_frame.Dim1 == gender) 
    #                              & (self._data_frame.Dim2 == age_group))]
    
    # for testing
    def _data(self) -> pd.DataFrame:
        return self._data_frame.copy()

    def _sum_across_age_groups(self, df: pd.DataFrame) -> pd.DataFrame:
        return df.groupby(
                          ['Location', 'Period', 'Dim1'], as_index=False  
                          ).sum()

    def suicide_rates_in_countries(self,*, 
                                   gender: str = 'Both sexes', 
                                   age_group: Optional[str] = None
                                   ) -> pd.DataFrame:
        """
        return suicide rates in countries by gender and age group

        Note: 
        if age_group is None, return suicide rates for all age groups
        if age_group is 'all ages', return the summed rates across all age groups
        """ 
        assert gender in self._genders 
        df = self._data_frame[self._data_frame.Dim1 == gender]
        if age_group is None:
            return df 
        assert age_group in self._age_groups
        if age_group == 'all ages':
            return self._sum_across_age_groups(df)
        else:
            return df[self._data_frame.Dim2 == age_group]