In [1]:
import sys
import pandas as pd
import numpy as np
import numpy.random as rnd
import matplotlib.pyplot as plt
import altair as alt
import scipy.optimize as opt
from scipy.special import logit, expit, logsumexp

# Load raw data

In [2]:
temp_df = pd.read_csv('/home/alex/src/martin/data/temp_data.csv')
redd_df = pd.read_csv('/home/alex/src/martin/data/redd_data.csv')
surv_df = pd.read_csv('/home/alex/src/martin/data/survival_data.csv')

surv_df['survival'] = logit(surv_df['survival'])
sigma_surv = surv_df['survival'].std(ddof=1)

temp_df['day'] = temp_df['day'].astype('datetime64[D]')
redd_df['day'] = redd_df['day'].astype('datetime64[D]')

trdf = temp_df.merge(redd_df, on=['location', 'day'], how='left')
trdf['redds'] = trdf['redds'].fillna(0).astype(np.float64)
trdf = trdf.sort_values(by=['location', 'day'])
trdf = trdf[~trdf['temperature'].isna()]
trdf['year'] = trdf['day'].dt.year

def normalize_redds(df):
    df['redds'] = np.log(df['redds']) - np.log(df['redds'].sum())
    return df

trdf = trdf.groupby(['year'], as_index=False).apply(normalize_redds)
trdf = trdf[~trdf['redds'].isna()]

  result = getattr(ufunc, method)(*inputs, **kwargs)
  df['redds'] = np.log(df['redds']) - np.log(df['redds'].sum())


# Compute AUT and hatching day

In [3]:
alpha = 0.00056
beta = 0.001044

def compute_incubation(df):
    mat = df['mat'].values
    hatch = np.repeat(np.nan, len(mat))
    for i in range(len(mat) - 1):
        j = 1 + np.searchsorted(mat[i + 1:], 1 + mat[i])
        if i + j < len(mat):
            hatch[i] = j
    df['incubation'] = hatch.astype(np.int64)
    return df

trdf['mat'] = beta * trdf['temperature'] + alpha
trdf_grp = trdf.groupby('location', as_index=False)
trdf['mat'] = trdf_grp['mat'].transform(pd.Series.cumsum)
trdf_grp = trdf.groupby('location', as_index=False)
trdf = trdf_grp.apply(compute_incubation)


def compute_hazard_at_hatching(beta_t, t_crit):
    grp = trdf.groupby('location', as_index=False)
    def hazard_aux(df):
        df['hazard'] = beta_t * np.fmax(df['temperature'] - t_crit, 0).cumsum()
        hzd = df['hazard'].values
        inc = df['incubation']
        mask = (inc >= 0)
        idx = np.arange(len(inc))[mask]
        jdx = idx + inc[mask]
        hzd[mask] = hzd[jdx] - hzd[idx]
        df['hazard'] = hzd
        df = df[df['incubation'] >= 0]
        return df
    res = grp.apply(hazard_aux)
    res = res[['year', 'day', 'location', 'hazard', 'redds']]
    return res



def compute_annual_hazard(beta_t, t_crit, base_surv):
    hzd_df = compute_hazard_at_hatching(beta_t, t_crit)
    hzd_df['surv'] = hzd_df['redds'] - hzd_df['hazard']
    hzd_df = hzd_df[['year', 'surv']]
    hzd_df = hzd_df.groupby(['year'], as_index=False).agg(logsumexp)
    hzd_df['surv'] += base_surv
    hzd_df['surv'] = np.exp(hzd_df['surv'])
    return hzd_df


def objective(x, srv):
    beta_t, t_crit, base_surv = x
    hzd_df = compute_annual_hazard(beta_t, t_crit, base_surv)
    df = srv.merge(hzd_df, on=['year'], how='inner')
    err = logit(df['surv'].values) - df['survival'].values
    err = (err * err).sum()
    return err

def resample_and_fit(x0, num=100):
    res = []
    for i in range(num):
        srv = surv_df.copy()
        srv['survival'] = srv['survival'] + rnd.normal(0, sigma_surv, size=len(srv))
        r = opt.minimize(objective, x0, args=(srv,), method='Nelder-Mead')
        if r.success:
            res.append(r)
    res = [r.x for r in res]
    res = pd.DataFrame(columns=['beta_t', 't_crit', 'base_surv'], data=res)
    return res


In [11]:
original = resample_and_fit([0.024, 12.0, np.log(0.366)], num=1000)

In [12]:
bogus = resample_and_fit([1.0, 16.0, np.log(0.2)], num=1000)

# Plot results

In [13]:
def scatter_plot_with_marginals(df, xcol, ycol, nbins=25):
    base = alt.Chart(df)
    
    xdom = (df[xcol].min(), df[xcol].max())
    ydom = (df[ycol].min(), df[ycol].max())
    log_ydom = (np.log10(ydom[0]), np.log10(ydom[1]))

    xscale = alt.Scale(domain=xdom, nice=False)
    yscale = alt.Scale(domain=ydom, nice=False, type='log')
    log_yscale = alt.Scale(domain=log_ydom, nice=False)

    area_args = {'opacity': .5, 'interpolate': 'step'}


    points = base.mark_circle().encode(
        alt.X(f'{xcol}:Q', scale=xscale),
        alt.Y(f'{ycol}:Q', scale=yscale),
    )
    
    top_hist = base.mark_area(**area_args).encode(
        alt.X(f'{xcol}:Q',
            # when using bins, the axis scale is set through
            # the bin extent, so we do not specify the scale here
            # (which would be ignored anyway)
            bin=alt.Bin(maxbins=nbins, extent=xscale.domain),
            stack=None,
            scale=xscale,
            axis=None,
            title='',
         ),
        alt.Y('count()', stack=None, title='')
    ).properties(height=60)

    right_hist = base.transform_calculate(
        logy = f'log(datum.{ycol})/log(10)'
    ).mark_bar(**area_args).encode(
        alt.Y(f'logy:Q',
              bin=alt.Bin(maxbins=nbins, extent=log_yscale.domain),
              stack=None,
              scale=log_yscale,
              axis=None,
              title='',
         ),
        alt.X('count()', stack=None, title=''),
    ).properties(width=60)
    
    res = top_hist & (points | right_hist)
    return res

def test(n):
    pass

In [14]:
scatter_plot_with_marginals(original, 't_crit', 'beta_t', nbins=50)

In [15]:
scatter_plot_with_marginals(bogus, 't_crit', 'beta_t', nbins=50)

In [16]:
orig = original.copy()
orig['base_surv'] = np.exp(orig['base_surv'])
orig.describe(percentiles=[0.025, 0.25, 0.5, 0.75, 0.975])

Unnamed: 0,beta_t,t_crit,base_surv
count,984.0,984.0,984.0
mean,0.144812,11.954456,0.328521
std,0.996184,0.899302,0.184202
min,0.00451,9.393987,0.181022
2.5%,0.007635,9.903933,0.211992
25%,0.014311,11.492485,0.249135
50%,0.023829,12.161144,0.277306
75%,0.038996,12.410977,0.316157
97.5%,1.068628,13.5735,0.953274
max,25.580107,14.243848,1.631903


In [17]:
bog = bogus.copy()
bog['base_surv'] = np.exp(bog['base_surv'])
bog.describe(percentiles=[0.025, 0.25, 0.5, 0.75, 0.975])

Unnamed: 0,beta_t,t_crit,base_surv
count,991.0,991.0,991.0
mean,13.976089,13.813046,0.253501
std,170.019929,0.825059,0.060592
min,0.010891,9.394073,0.166118
2.5%,0.221472,13.139371,0.193564
25%,1.211327,13.5735,0.230684
50%,1.654055,13.688,0.249895
75%,3.086347,13.716,0.270585
97.5%,9.964467,17.151757,0.314324
max,3020.394792,18.512515,1.82093


In [18]:
original.to_csv('data/original.csv')
bogus.to_csv('data/bogus.csv')