<h1> Birth Weight Prediction </h1>

This notebook was used to obtain the coefficients for the OLS regression of gestational age on birth weight.

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
#seaborn settings
sns.set(style="whitegrid",palette="colorblind",font_scale=1.5)
sns.set_context("paper")
sns.set_context("paper")
sns.despine(left=True,right=True)

In [None]:
df = pd.read_csv('example_dataset.csv')

In [None]:
#show distribution of geboortegew
plt.figure(figsize=(10,6),dpi=400)
sns.histplot(df['birth_weight'],bins=100,stat='density')
plt.xlabel('birth weight in grams')
plt.ylabel('density')
plt.title('Distribution of birth weights')
plt.show()

In [None]:
#show stats of geboortegew
print(df['birth_weight'].describe())


In [None]:
df.info()

In [None]:
#filter out where amddd is nan
df = df[df['gestational_age'].notna()]
#filter out where weight is nan
df = df[df['birth_weight'].notna()]

In [None]:
#ln of geboortegew
df['ln_birth_weight'] = np.log(df['birth_weight'])


In [None]:
#filter out where amddd > 260 and ln_geboortegew < 7.2, filter out both
df_filtered = df[~((df['gestational_age'] > 250) & (df['ln_birth_weight'] < 7.2))]

In [None]:
#plot weight against amddd with regression line
#sns.regplot(data=df,x='amddd',y='ln_geboortegew')
#dpi 400
plt.figure(figsize=(10,6),dpi=400)
sns.scatterplot(data=df_filtered,x='gestational_age',y='birth_weight')
#ylabel
plt.ylabel('birth weight')
#xlabel
plt.xlabel('gestational age in weeks')


In [None]:
plt.figure(figsize=(10,6),dpi=400)
sns.scatterplot(data=df_filtered,x='gestational_age',y='ln_birth_weight')
#ylabel
plt.ylabel('ln birth weight')
#xlabel
plt.xlabel('gestational age in weeks')



In [None]:
#average weight by gesl
print(df_filtered.groupby('c_section')['birth_weight'].mean())

In [None]:
#set is_male
df['is_male'] = df['sex'].apply(lambda x: 1 if x == 1 else 0)

In [None]:
#OLS regression
import statsmodels.api as sm
X = df[['gestational_age']]
X = sm.add_constant(X)
y = df['ln_birth_weight']
model = sm.OLS(y,X)
results = model.fit()
print(results.summary())




In [None]:
residuals = results.resid
residual_std = residuals.std()
mse = residuals.var()
r2 = results.rsquared
print(f"Residual std: {residual_std}")
print(f"Residual mse: {mse}")
print(f"R2: {r2}")




In [None]:
#plot residuals
sns.histplot(residuals,bins=20,stat='density')