In [1]:
import pandas as pd
import numpy as np

In [6]:

# Example DataFrame
data = {
    "pid": [111, 111, 111, 111, 111, 111, 111, 111, 111, 111],
    "syear": [5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
    "health_state": [1, np.nan, np.nan, np.nan, 1, 0, np.nan, np.nan, 0, np.nan]
}
df = pd.DataFrame(data)

# Set the index
df = df.set_index(["pid", "syear"])

# Define a function to fill gaps with conditions
def fill_health_gaps(group):
    # Forward-fill and backward-fill
    ffilled = group['health_state'].ffill()
    bfilled = group['health_state'].bfill()
    # Create a mask where forward-fill and backward-fill agree
    agreeing_mask = ffilled == bfilled
    # Fill only where the mask is True
    group['health_state'] = group['health_state'].where(~group['health_state'].isna() | ~agreeing_mask, ffilled)
    return group

# Apply the function groupwise
df = df.groupby("pid").apply(fill_health_gaps)
df.index = df.index.droplevel(1) # remove extra pid index level -> index = (pid, syear) instead of (pid, pid, syear)
print("Obs. after filling health gaps:", len(df))

# find last know health state for each individual before the death year
df["last_known_health_state"] = df.groupby("pid")["health_state"].transform("last")

display(df)


Obs. after filling health gaps: 10


Unnamed: 0_level_0,Unnamed: 1_level_0,health_state,last_known_health_state
pid,syear,Unnamed: 2_level_1,Unnamed: 3_level_1
111,5,1.0,0.0
111,6,1.0,0.0
111,7,1.0,0.0
111,8,1.0,0.0
111,9,1.0,0.0
111,10,0.0,0.0
111,11,0.0,0.0
111,12,0.0,0.0
111,13,0.0,0.0
111,14,,0.0
