# Imports

In [None]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
from matplotlib.ticker import FormatStrFormatter
from scipy import stats
from statannotations.Annotator import Annotator
from statsmodels.stats.outliers_influence import variance_inflation_factor

import common_functions
import secret
import utils

# Aim of this notebook  


1. Construct GLM models to assess the effects of potential exposure sources and effectiveness of mitigation 



In [None]:
DATA_PATH = utils.Configuration.INTERIM_DATA_PATH.joinpath(
    "HBM4EU_E-waste_template_V3_all_data_INTERIM.parquet.gzip"
)

df = pd.read_parquet(DATA_PATH).rename(columns=lambda x: x.replace("PCB", "CB"))

In [None]:
columns_to_ignore = [
    "CB 28",
    "CB 52",
    "CB 101",
    "BDE 28",
    "BDE 47",
    "BDE 99",
    "BDE 100",
    "BDE 153",
    "BDE 154",
    "BDE 183",
    "BDE 209",
    "Dechlorane",
]

df_GLM = (
    df.loc[:, lambda df: ~df.columns.isin(columns_to_ignore)]
    .rename(columns=lambda df: df.replace(" ", "_"))
    .assign(
        how_many_km=lambda df: df.how_many_km.fillna(20),
        cigarette_smoking=lambda df: df.cigarette_smoking.fillna(
            df.cigarette_smoking.value_counts().index[0]
        ),
        years_smoked=lambda df: df.years_smoked.fillna(0),
        cigarettes_per_day=lambda df: df.cigarettes_per_day.mask(
            df.cigarette_smoking == "Former smoker", 0
        ).mask(df.cigarette_smoking == "No", 0),
        former_smoker_years_ago_stopped=lambda df: df.former_smoker_years_ago_stopped.fillna(
            0
        ),
        former_smoker_cigatette_a_day=lambda df: df.former_smoker_cigatette_a_day.fillna(
            0
        ),
        former_smoker_for_how_many_years=lambda df: df.former_smoker_for_how_many_years.fillna(
            0
        ),
        Tonnes=lambda df: df.Tonnes.replace({"No": 0}).astype(float)
        # .fillna(0)
    )
    .assign(
        CB_118=lambda df: df.CB_118.fillna(df.CB_118.min()),
        CB_138=lambda df: df.CB_138.fillna(df.CB_138.min()),
        CB_153=lambda df: df.CB_153.fillna(df.CB_153.min()),
        CB_180=lambda df: df.CB_180.fillna(
            df.CB_180.min(),
        ),
        Tonnes=lambda df: df["Tonnes"].fillna(
            df.groupby("sub_category")["Tonnes"].transform("median")
        ),
        Age_binned=lambda df: pd.cut(
            df.Age,
            bins=[17, 30, 40, 50, 65],
            labels=[
                "Between 18 and 30",
                "Between 30 and 40",
                "Between 40 and 50",
                "Over 50",
            ],
            right=True,
        ).astype(object),
        BMI_binned=lambda df: pd.cut(
            df.BMI,
            bins=[17, 18.5, 25, 30, 45],
            labels=["underweight", "healthy", "overweight", "obese"],
            right=True,
        ).astype(object),
        years_worked_binned=lambda df: pd.cut(
            df.years_worked,
            bins=[-1, 5, 15, 25],
            labels=[
                "less than 5",
                "between 5 and 15",
                "over 15",
            ],
            right=True,
        ).astype(object),
    )
    .fillna(0)
)

In [None]:
df_GLM[["sum_RPE_use", "bool_RPE_use"]]

In [None]:
variables_to_test_1 = df_GLM.drop(
    columns=["CB_118", "CB_138", "CB_153", "CB_180"]
).select_dtypes("number")

vif_1 = pd.DataFrame()
vif_1["VIF"] = [
    variance_inflation_factor(variables_to_test_1.values, i)
    for i in range(variables_to_test_1.shape[1])
]
vif_1["variable"] = variables_to_test_1.columns
vif_1.set_index("variable").plot.bar(legend=False, figsize=(8, 4))

In [None]:
variables_to_test_2 = df_GLM.loc[
    :,
    lambda df: df.columns.isin(
        [
            "Tonnes",
            "Age",
            "how_many_km",
            "years_smoked",
            "years_worked",
            "bool_RPE_use",
            "bool_enclosed_process",
            "bool_local_exhaust",
            # "bool_other_PPE_use"
            # "sum_other_PPE_use"
            # "industrial_plants_in_surroundings",
            # "sub_category",
            # "Sex",
            # "home_location",
        ]
    ),
]
vif_2 = pd.DataFrame()
vif_2["VIF"] = [
    variance_inflation_factor(variables_to_test_2.values, i)
    for i in range(variables_to_test_2.shape[1])
]
vif_2["variable"] = variables_to_test_2.columns
vif_2.set_index("variable").plot.barh(legend=False, figsize=(8, 4))

### PCB-138

#### Identify outliers

This was used before: 
formula = "np.sqrt(CB_138) ~ Tonnes  + Age + how_many_km + years_smoked + years_worked +\
sum_RPE_use + C(bool_enclosed_process) + sum_local_exhaust +  C(industrial_plants_in_surroundings) + C(sub_category) + C(Sex) + C(home_location)"

In [None]:
for compound in ["CB_138", "CB_153", "CB_180"]:
    formula = f"np.sqrt({compound}) ~ Tonnes  + Age + how_many_km + years_smoked + years_worked +\
    C(bool_local_exhaust) +\
    C(bool_RPE_use) +\
    (sum_enclosed_process) +\
    C(industrial_plants_in_surroundings) + C(sub_category) + C(Sex) + C(home_location)"

    model = smf.glm(
        formula=formula,
        data=df_GLM,
        family=sm.families.Gamma(link=sm.families.links.log()),
    ).fit()

    """
    The Cook's distance is considered high if it is greater than 0.5 and extreme if it is greater than 1
    """
    COOK_THRESHOLD = 0.5

    plt.scatter(range(df_GLM.shape[0]), model.get_influence().cooks_distance[0])
    plt.hlines(
        y=COOK_THRESHOLD,
        xmin=0,
        xmax=df_GLM.shape[0],
        ls="dashed",
        colors="r",
    )
df_GLM_outlier_removed = df_GLM.drop(index=secret.SECRET.exclude_from_model_1)

#### Build GLM model

In [None]:
GLM_result = []
models = []

for compound in ["CB_138", "CB_153", "CB_180"]:
    formula = f"np.sqrt({compound}) ~ Tonnes  + Age +  how_many_km + years_smoked + years_worked +\
    C(bool_RPE_use, Treatment(reference=False)) +\
    C(bool_enclosed_process, Treatment(reference=False)) +\
    sum_local_exhaust +\
    C(industrial_plants_in_surroundings, Treatment(reference=False)) +\
    C(industrial_plants_in_surroundings, Treatment(reference=False)) +\
    C(sub_category, Treatment(reference='outwith_CTR'))\
    + C(Sex, Treatment(reference='Female')) + C(home_location, Treatment(reference='Rural')) "

    model = smf.glm(
        formula=formula,
        data=df_GLM,
        family=sm.families.Gamma(link=sm.families.links.log()),
    ).fit()

    df = model_PCB138.summary2().tables[1]
    GLM_result.append(df)
    models.append(model)

pd.concat(GLM_result, axis=1).to_clipboard()

#### Inspect models

In [None]:
for idx, compound in enumerate(["CB_138", "CB_153", "CB_180"]):
    print(idx, compound)

In [None]:
for idx, compound in enumerate(["CB_138", "CB_153", "CB_180"]):
    common_functions.plot_diagnostics(
        df=df_GLM,
        model=models[idx],
        dependent_variable=f"{compound}",
        save_to_disk=True,
    )

In [None]:
with sns.plotting_context("paper", font_scale=1):
    fig, axs = plt.subplots(1, 3, figsize=(8, 4), sharey=True, sharex=True)
    models = {"model1": models[0], "model2": models[1], "model3": models[2]}
    for t, (k, u) in zip(range(0, 3), models.items()):
        temp = (
            u.summary2()
            .tables[1]
            .iloc[1:, :]
            .rename(
                index={
                    "C(bool_local_exhaust, Treatment(reference=0))[T.1.0]": "Exhaust filter",
                    "C(bool_RPE_use, Treatment(reference=0))[T.1.0]": "RPE use",
                    "C(bool_enclosed_process, Treatment(reference=0))[T.1.0]": "E-waste processing involves work in enclosed environments?",
                    "C(industrial_plants_in_surroundings, Treatment(reference=False))[T.True]": "Industrial plants nearby home?",
                    "C(sub_category, Treatment(reference='outwith_CTR'))[T.Batteries]": "Subcategory : Batteries",
                    "C(sub_category, Treatment(reference='outwith_CTR'))[T.Brown goods]": "Subcategory : Brown goods",
                    "C(sub_category, Treatment(reference='outwith_CTR'))[T.Metals and plastics]": "Subcategory : Metals and plastics",
                    "C(sub_category, Treatment(reference='outwith_CTR'))[T.Miscellaneous]": "Subcategory : Miscellaneous",
                    "C(sub_category, Treatment(reference='outwith_CTR'))[T.White goods]": "Subcategory : White goods",
                    "C(Sex, Treatment(reference='Female'))[T.Male]": "Sex : Male",
                    "C(sub_category, Treatment(reference='outwith_CTR'))[T.within_CTR]": "Subcategory : Within CTR",
                    "C(home_location, Treatment(reference='Rural'))[T.Urban]": "Home dwelling : Urban",
                    "Tonnes": "Amounts of e-waste processed",
                    "how_many_km": "Distance of nearest industrial plant from home",
                }
            )
            .rename(index=lambda x: x.replace("_", " ").title())
            .rename(
                index={
                    "Rpe Use": "RPE Use",
                    "Subcategory : Within Ctr": "Subcategory : Within CTR",
                }
            )
            .sort_index()
        )
        (
            temp.pipe(
                lambda df: axs[t].errorbar(
                    y=df.index,
                    x=df["Coef."],
                    xerr=df["Std.Err."]
                    * 1.959964,  # std error is converted to CI https://stats.stackexchange.com/questions/512789/converting-between-confidence-interval-and-standard-error
                    ls="None",
                    capsize=3,
                    markersize=2,
                    ecolor="black",
                    fmt="o",
                    c="black",
                )
            )
        )

    for ax in axs:
        ax.axvline(x=0, linestyle="dashed", color="red")
        # ax.set_xscale("log")

    axs[0].set_title("CB 138", fontweight="bold", fontsize=12)
    axs[1].set_title("CB 153", fontweight="bold", fontsize=12)
    axs[2].set_title("CB 180", fontweight="bold", fontsize=12)

    plt.tight_layout()

    plt.savefig(utils.Configuration.PLOTS.joinpath("img3.png"), dpi=600)