In [3]:
import thinkbayes2

pmf = thinkbayes2.Pmf()
cdf = thinkbayes2.Cdf(.5, .2)

func = thinkbayes2.Pmf.Mean

In [13]:
"""This file contains code for use with "Think Stats",
by Allen B. Downey, available from greenteapress.com

Copyright 2014 Allen B. Downey
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
"""

from __future__ import print_function, division

import numpy as np
import pandas as pd

import thinkbayes2
import thinkplot

from collections import Counter

FORMATS = ['pdf', 'eps', 'png']


class SurvivalFunction(object):
    """Represents a survival function."""

    def __init__(self, ts, ss, label=''):
        self.ts = ts
        self.ss = ss
        self.label = label

    def __len__(self):
        return len(self.ts)

    def __getitem__(self, t):
        return self.Prob(t)

    def Prob(self, t):
        """Returns S(t), the probability that corresponds to value t.
        t: time
        returns: float probability
        """
        return np.interp(t, self.ts, self.ss, left=1.0)

    def Probs(self, ts):
        """Gets probabilities for a sequence of values."""
        return np.interp(ts, self.ts, self.ss, left=1.0)

    def Items(self):
        """Sorted list of (t, s) pairs."""
        return zip(self.ts, self.ss)

    def Render(self):
        """Generates a sequence of points suitable for plotting.
        returns: tuple of (sorted times, survival function)
        """
        return self.ts, self.ss

    def MakeHazardFunction(self, label=''):
        """Computes the hazard function.

        This simple version does not take into account the
        spacing between the ts.  If the ts are not equally
        spaced, it is not valid to compare the magnitude of
        the hazard function across different time steps.

        label: string

        returns: HazardFunction object
        """
        lams = pd.Series(index=self.ts)

        prev = 1.0
        for t, s in zip(self.ts, self.ss):
            lams[t] = (prev - s) / prev
            prev = s

        return HazardFunction(lams, label=label)

    def MakePmf(self, filler=None):
        """Makes a PMF of lifetimes.

        filler: value to replace missing values

        returns: Pmf
        """
        cdf = thinkbayes2.Cdf(self.ts, 1-self.ss)
        pmf = thinkbayes2.Pmf()
        for val, prob in cdf.Items():
            pmf.Set(val, prob)

        cutoff = cdf.ps[-1]
        if filler is not None:
            pmf[filler] = 1-cutoff

        return pmf

    def RemainingLifetime(self, filler=None, func=thinkbayes2.Pmf.Mean):
        """Computes remaining lifetime as a function of age.
        func: function from conditional Pmf to expected liftime
        returns: Series that maps from age to remaining lifetime
        """
        pmf = self.MakePmf(filler=filler)
        d = {}
        for t in sorted(pmf.Values())[:-1]:
            pmf[t] = 0
            pmf.Normalize()
            d[t] = func(pmf) - t

        return pd.Series(d)


def MakeSurvivalFromSeq(values, label=''):
    """Makes a survival function based on a complete dataset.

    values: sequence of observed lifespans
    
    returns: SurvivalFunction
    """
    counter = Counter(values)
    ts, freqs = zip(*sorted(counter.items()))
    ts = np.asarray(ts)
    ps = np.cumsum(freqs, dtype=np.float)
    ps /= ps[-1]
    ss = 1 - ps
    return SurvivalFunction(ts, ss, label)


def MakeSurvivalFromCdf(cdf, label=''):
    """Makes a survival function based on a CDF.

    cdf: Cdf
    
    returns: SurvivalFunction
    """
    ts = cdf.xs
    ss = 1 - cdf.ps
    return SurvivalFunction(ts, ss, label)


class HazardFunction(object):
    """Represents a hazard function."""

    def __init__(self, d, label=''):
        """Initialize the hazard function.

        d: dictionary (or anything that can initialize a series)
        label: string
        """
        self.series = pd.Series(d)
        self.label = label

    def __getitem__(self, t):
        return self.series[t]

    def Render(self):
        """Generates a sequence of points suitable for plotting.

        returns: tuple of (sorted times, hazard function)
        """
        return self.series.index, self.series.values

    def MakeSurvival(self, label=''):
        """Makes the survival function.

        returns: SurvivalFunction
        """
        ts = self.series.index
        ss = (1 - self.series).cumprod()
        sf = SurvivalFunction(ts, ss, label=label)
        return sf

    def Extend(self, other):
        """Extends this hazard function by copying the tail from another.
        other: HazardFunction
        """
        last_index = self.series.index[-1] if len(self) else 0
        more = other.series[other.series.index > last_index]
        self.series = pd.concat([self.series, more])

    def Truncate(self, t):
        """Truncates this hazard function at the given value of t.
        t: number
        """
        self.series = self.series[self.series.index < t]


def ConditionalSurvival(pmf, t0):
    """Computes conditional survival function.

    Probability that duration exceeds t0+t, given that
    duration >= t0.

    pmf: Pmf of durations
    t0: minimum time

    returns: tuple of (ts, conditional survivals)
    """
    cond = thinkbayes2.Pmf()
    for t, p in pmf.Items():
        if t >= t0:
            cond.Set(t-t0, p)
    cond.Normalize()
    return MakeSurvivalFromCdf(cond.MakeCdf())


def PlotConditionalSurvival(durations):
    """Plots conditional survival curves for a range of t0.

    durations: list of durations
    """
    pmf = thinkbayes2.Pmf(durations)
    
    times = [8, 16, 24, 32]
    thinkplot.PrePlot(len(times))

    for t0 in times:
        sf = ConditionalSurvival(pmf, t0)
        label = 't0=%d' % t0
        thinkplot.Plot(sf, label=label)

    thinkplot.Show()


def PlotSurvival(complete):
    """Plots survival and hazard curves.

    complete: list of complete lifetimes
    """
    thinkplot.PrePlot(3, rows=2)

    cdf = thinkbayes2.Cdf(complete, label='cdf')
    sf = MakeSurvivalFromCdf(cdf, label='survival')
    print(cdf[13])
    print(sf[13])

    thinkplot.Plot(sf)
    thinkplot.Cdf(cdf, alpha=0.2)
    thinkplot.Config()

    thinkplot.SubPlot(2)
    hf = sf.MakeHazardFunction(label='hazard')
    print(hf[39])
    thinkplot.Plot(hf)
    thinkplot.Config(ylim=[0, 0.75])


def PlotHazard(complete, ongoing):
    """Plots the hazard function and survival function.

    complete: list of complete lifetimes
    ongoing: list of ongoing lifetimes
    """
    # plot S(t) based on only complete pregnancies
    sf = MakeSurvivalFromSeq(complete)
    thinkplot.Plot(sf, label='old S(t)', alpha=0.1)

    thinkplot.PrePlot(2)

    # plot the hazard function
    hf = EstimateHazardFunction(complete, ongoing)
    thinkplot.Plot(hf, label='lams(t)', alpha=0.5)

    # plot the survival function
    sf = hf.MakeSurvival()

    thinkplot.Plot(sf, label='S(t)')
    thinkplot.Show(xlabel='t (weeks)')


def EstimateHazardFunction(complete, ongoing, label='', verbose=False):
    """Estimates the hazard function by Kaplan-Meier.

    http://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator

    complete: list of complete lifetimes
    ongoing: list of ongoing lifetimes
    label: string
    verbose: whether to display intermediate results
    """
    if np.sum(np.isnan(complete)):
        raise ValueError("complete contains NaNs")
    if np.sum(np.isnan(ongoing)):
        raise ValueError("ongoing contains NaNs")

    hist_complete = Counter(complete)
    hist_ongoing = Counter(ongoing)

    ts = list(hist_complete | hist_ongoing)
    ts.sort()

    at_risk = len(complete) + len(ongoing)

    lams = pd.Series(index=ts)
    for t in ts:
        ended = hist_complete[t]
        censored = hist_ongoing[t]

        lams[t] = ended / at_risk
        if verbose:
            print(t, at_risk, ended, censored, lams[t])
        at_risk -= ended + censored

    return HazardFunction(lams, label=label)


def EstimateHazardNumpy(complete, ongoing, label=''):
    """Estimates the hazard function by Kaplan-Meier.

    Just for fun, this is a version that uses NumPy to
    eliminate loops.

    complete: list of complete lifetimes
    ongoing: list of ongoing lifetimes
    label: string
    """
    hist_complete = Counter(complete)
    hist_ongoing = Counter(ongoing)

    ts = set(hist_complete) | set(hist_ongoing)
    at_risk = len(complete) + len(ongoing)

    ended = [hist_complete[t] for t in ts]
    ended_c = np.cumsum(ended)
    censored_c = np.cumsum([hist_ongoing[t] for t in ts])

    not_at_risk = np.roll(ended_c, 1) + np.roll(censored_c, 1)
    not_at_risk[0] = 0

    at_risk_array = at_risk - not_at_risk
    hs = ended / at_risk_array

    lams = dict(zip(ts, hs))

    return HazardFunction(lams, label=label)


def AddLabelsByDecade(groups, **options):
    """Draws fake points in order to add labels to the legend.

    groups: GroupBy object
    """
    thinkplot.PrePlot(len(groups))
    for name, _ in groups:
        label = '%d0s' % name
        thinkplot.Plot([15], [1], label=label, **options)


def EstimateMarriageSurvivalByDecade(groups, **options):
    """Groups respondents by decade and plots survival curves.

    groups: GroupBy object
    """
    thinkplot.PrePlot(len(groups))
    for _, group in groups:
        _, sf = EstimateMarriageSurvival(group)
        thinkplot.Plot(sf, **options)


def PlotPredictionsByDecade(groups, **options):
    """Groups respondents by decade and plots survival curves.

    groups: GroupBy object
    """
    hfs = []
    for _, group in groups:
        hf, sf = EstimateMarriageSurvival(group)
        hfs.append(hf)

    thinkplot.PrePlot(len(hfs))
    for i, hf in enumerate(hfs):
        if i > 0:
            hf.Extend(hfs[i-1])
        sf = hf.MakeSurvival()
        thinkplot.Plot(sf, **options)


def ResampleSurvival(resp, iters=101):
    """Resamples respondents and estimates the survival function.

    resp: DataFrame of respondents
    iters: number of resamples
    """ 
    _, sf = EstimateMarriageSurvival(resp)
    thinkplot.Plot(sf)

    low, high = resp.agemarry.min(), resp.agemarry.max()
    ts = np.arange(low, high, 1/12.0)

    ss_seq = []
    for _ in range(iters):
        sample = thinkstats2.ResampleRowsWeighted(resp)
        _, sf = EstimateMarriageSurvival(sample)
        ss_seq.append(sf.Probs(ts))

    low, high = thinkbayes2.PercentileRows(ss_seq, [5, 95])
    thinkplot.FillBetween(ts, low, high, color='gray', label='90% CI')
    thinkplot.Save(root='survival3',
                   xlabel='age (years)',
                   ylabel='prob unmarried',
                   xlim=[12, 46],
                   ylim=[0, 1],
                   formats=FORMATS)


def EstimateMarriageSurvival(resp):
    """Estimates the survival curve.

    resp: DataFrame of respondents

    returns: pair of HazardFunction, SurvivalFunction
    """
    # NOTE: Filling missing values would be better than dropping them.
    complete = resp[resp.evrmarry == 1].agemarry.dropna()
    ongoing = resp[resp.evrmarry == 0].age

    hf = EstimateHazardFunction(complete, ongoing)
    sf = hf.MakeSurvival()

    return hf, sf


def PlotMarriageData(resp):
    """Plots hazard and survival functions.

    resp: DataFrame of respondents
    """
    hf, sf = EstimateMarriageSurvival(resp)

    thinkplot.PrePlot(rows=2)
    thinkplot.Plot(hf)
    thinkplot.Config(ylabel='hazard', legend=False)

    thinkplot.SubPlot(2)
    thinkplot.Plot(sf)
    thinkplot.Save(root='survival2',
                   xlabel='age (years)',
                   ylabel='prob unmarried',
                   ylim=[0, 1],
                   legend=False,
                   formats=FORMATS)
    return sf


def PlotPregnancyData(preg):
    """Plots survival and hazard curves based on pregnancy lengths.
    
    preg:


    Outcome codes from http://www.icpsr.umich.edu/nsfg6/Controller?
    displayPage=labelDetails&fileCode=PREG&section=&subSec=8016&srtLabel=611932

    1	LIVE BIRTH	 	9148
    2	INDUCED ABORTION	1862
    3	STILLBIRTH	 	120
    4	MISCARRIAGE	 	1921
    5	ECTOPIC PREGNANCY	190
    6	CURRENT PREGNANCY	352

    """
    complete = preg.query('outcome in [1, 3, 4]').prglngth
    print('Number of complete pregnancies', len(complete))
    ongoing = preg[preg.outcome == 6].prglngth
    print('Number of ongoing pregnancies', len(ongoing))

    PlotSurvival(complete)
    thinkplot.Save(root='survival1',
                   xlabel='t (weeks)',
                   formats=FORMATS)

    hf = EstimateHazardFunction(complete, ongoing)
    sf = hf.MakeSurvival()
    return sf


def PlotRemainingLifetime(sf1, sf2):
    """Plots remaining lifetimes for pregnancy and age at first marriage.

    sf1: SurvivalFunction for pregnancy length
    sf2: SurvivalFunction for age at first marriage
    """
    thinkplot.PrePlot(cols=2)
    rem_life1 = sf1.RemainingLifetime()
    thinkplot.Plot(rem_life1)
    thinkplot.Config(title='remaining pregnancy length',
                     xlabel='weeks',
                     ylabel='mean remaining weeks')

    thinkplot.SubPlot(2)
    func = lambda pmf: pmf.Percentile(50)
    rem_life2 = sf2.RemainingLifetime(filler=np.inf, func=func)
    thinkplot.Plot(rem_life2)
    thinkplot.Config(title='years until first marriage',
                     ylim=[0, 15],
                     xlim=[11, 31],
                     xlabel='age (years)',
                     ylabel='median remaining years')

    thinkplot.Save(root='survival6',
                   formats=FORMATS)



def PlotResampledByDecade(resps, iters=11, predict_flag=False, omit=None):
    """Plots survival curves for resampled data.

    resps: list of DataFrames
    iters: number of resamples to plot
    predict_flag: whether to also plot predictions
    """
    for i in range(iters):
        samples = [thinkbayes2.ResampleRowsWeighted(resp) 
                   for resp in resps]
        sample = pd.concat(samples, ignore_index=True)
        groups = sample.groupby('decade')

        if omit:
            groups = [(name, group) for name, group in groups 
                      if name not in omit]

        # TODO: refactor this to collect resampled estimates and
        # plot shaded areas
        if i == 0:
            AddLabelsByDecade(groups, alpha=0.7)

        if predict_flag:
            PlotPredictionsByDecade(groups, alpha=0.1)
            EstimateMarriageSurvivalByDecade(groups, alpha=0.1)
        else:
            EstimateMarriageSurvivalByDecade(groups, alpha=0.2)



# NOTE: The functions below are copied from marriage.py in
# the MarriageNSFG repo.

# def ReadFemResp1995():
#     """Reads respondent data from NSFG Cycle 5.

#     returns: DataFrame
#     """
#     dat_file = '1995FemRespData.dat.gz'
#     names = ['cmintvw', 'timesmar', 'cmmarrhx', 'cmbirth', 'finalwgt']
#     colspecs = [(12360-1, 12363),
#                 (4637-1, 4638),
#                 (11759-1, 11762),
#                 (14-1, 16),
#                 (12350-1, 12359)]
#     df = pd.read_fwf(dat_file, 
#                          compression='gzip', 
#                          colspecs=colspecs, 
#                          names=names)

#     df.timesmar.replace([98, 99], np.nan, inplace=True)
#     df['evrmarry'] = (df.timesmar > 0)

#     CleanFemResp(df)
#     return df


# def ReadFemResp2002():
#     """Reads respondent data from NSFG Cycle 6.

#     returns: DataFrame
#     """
#     usecols = ['caseid', 'cmmarrhx', 'cmdivorcx', 'cmbirth', 'cmintvw', 
#                'evrmarry', 'parity', 'finalwgt']
#     df = ReadFemResp(usecols=usecols)
#     df['evrmarry'] = (df.evrmarry == 1)
#     CleanFemResp(df)
#     return df


# def ReadFemResp2010():
#     """Reads respondent data from NSFG Cycle 7.

#     returns: DataFrame
#     """
#     usecols = ['caseid', 'cmmarrhx', 'cmdivorcx', 'cmbirth', 'cmintvw',
#                'evrmarry', 'parity', 'wgtq1q16']
#     df = ReadFemResp('2006_2010_FemRespSetup.dct',
#                        '2006_2010_FemResp.dat.gz',
#                         usecols=usecols)
#     df['evrmarry'] = (df.evrmarry == 1)
#     df['finalwgt'] = df.wgtq1q16
#     CleanFemResp(df)
#     return df


# def ReadFemResp2013():
#     """Reads respondent data from NSFG Cycle 8.

#     returns: DataFrame
#     """
#     usecols = ['caseid', 'cmmarrhx', 'cmdivorcx', 'cmbirth', 'cmintvw',
#                'evrmarry', 'parity', 'wgt2011_2013']
#     df = ReadFemResp('2011_2013_FemRespSetup.dct',
#                         '2011_2013_FemRespData.dat.gz',
#                         usecols=usecols)
#     df['evrmarry'] = (df.evrmarry == 1)
#     df['finalwgt'] = df.wgt2011_2013
#     CleanFemResp(df)
#     return df


# def ReadFemResp(dct_file='2002FemResp.dct',
#                 dat_file='2002FemResp.dat.gz',
#                 **options):
#     """Reads the NSFG respondent data.

#     dct_file: string file name
#     dat_file: string file name

#     returns: DataFrame
#     """
#     dct = thinkbayes2.ReadStataDct(dct_file, encoding='iso-8859-1')
#     df = dct.ReadFixedWidth(dat_file, compression='gzip', **options)
#     return df


# def CleanFemResp(resp):
#     """Cleans a respondent DataFrame.

#     resp: DataFrame of respondents

#     Adds columns: agemarry, age, decade, fives
#     """
#     resp.cmmarrhx.replace([9997, 9998, 9999], np.nan, inplace=True)

#     resp['agemarry'] = (resp.cmmarrhx - resp.cmbirth) / 12.0
#     resp['age'] = (resp.cmintvw - resp.cmbirth) / 12.0

#     month0 = pd.to_datetime('1899-12-15')
#     dates = [month0 + pd.DateOffset(months=cm) 
#              for cm in resp.cmbirth]
#     resp['year'] = (pd.DatetimeIndex(dates).year - 1900)
#     resp['decade'] = resp.year // 10
#     resp['fives'] = resp.year // 5


# def main():
#     thinkbayes2.RandomSeed(17)
    
#     preg = nsfg.ReadFemPreg()
#     sf1 = PlotPregnancyData(preg)

#     # make the plots based on Cycle 6
#     resp6 = ReadFemResp2002()

#     sf2 = PlotMarriageData(resp6)

#     ResampleSurvival(resp6)

#     PlotRemainingLifetime(sf1, sf2)

#     # read Cycles 5 and 7
#     resp5 = ReadFemResp1995()
#     resp7 = ReadFemResp2010()

#     # plot resampled survival functions by decade
#     resps = [resp5, resp6, resp7]
#     PlotResampledByDecade(resps)
#     thinkplot.Save(root='survival4',
#                    xlabel='age (years)',
#                    ylabel='prob unmarried',
#                    xlim=[13, 45],
#                    ylim=[0, 1],
#                    formats=FORMATS)

#     # plot resampled survival functions by decade, with predictions
#     PlotResampledByDecade(resps, predict_flag=True, omit=[5])
#     thinkplot.Save(root='survival5',
#                    xlabel='age (years)',
#                    ylabel='prob unmarried',
#                    xlim=[13, 45],
#                    ylim=[0, 1],
#                    formats=FORMATS)


# if __name__ == '__main__':
#     main()


In [14]:
complete = [1, 2, 3, 1, 5, 4, 6, 8, 2, 4, 6, 2, 7, 4, 9, 3, 3]
# cdf = thinkbayes2.Cdf(complete, label='cdf')
sf = MakeSurvivalFromSeq(complete, label='ppls getting jobs')
# sf = SurvivalFunction(cdf, ss, label='survival')
thinkplot.Plot(sf)
thinkplot.Config(xlabel='months', ylabel='survival function')