# To Bleach or Not to Bleach? 

## Quickstart Notebook

**Goal:** identify drivers of coral bleaching and predict the percentage of coral that is bleached (`Percent_Bleached`).  
**Dataset:** csv file with site-level (aka "location"/coral reef) grouping by `Site_ID`.

### Cleaning

*Oh no, the data seems to be corrupted!* It looks like the people from the fossil fuel industry have been tampering with the data. There are some missing values, some outliers, and invalid/duplicate rows. **Start by cleaning it up!**


### Imports

In [1]:
import pandas as pd
from typing import Iterable, Dict, Tuple, Literal
from sklearn.impute import KNNImputer
from sklearn.preprocessing import StandardScaler

### Reusable functions

In [2]:
Strategy = Literal["impute", "drop"]

def iqr_mask(df: pd.DataFrame, columns: Iterable[str], k: float = 1.5) -> Dict[str, pd.Series]:
    """
    Return a boolean mask per selected column: True where value is an IQR outlier.
    k=1.5 -> Tukey 'mild' outliers; use 3.0 for 'extreme'.
    """
    masks: Dict[str, pd.Series] = {}
    for col in columns:
        s = pd.to_numeric(df[col], errors="coerce")
        q1 = s.quantile(0.25)
        q3 = s.quantile(0.75)
        iqr = q3 - q1
        if pd.isna(iqr) or iqr == 0:
            masks[col] = pd.Series(False, index=df.index)
        else:
            lower = q1 - k * iqr
            upper = q3 + k * iqr
            masks[col] = (s < lower) | (s > upper)
    return masks

def clean_outliers(
    df: pd.DataFrame,
    outlier_columns: Iterable[str],
    *,
    k_iqr: float = 1.5,
    strategy: Strategy = "impute",
    n_neighbors: int = 5,
    weights: str = "distance",
    output_csv: str | None = None,
) -> Tuple[pd.DataFrame, pd.Series]:
    """
    Pipeline:
      1) Mark IQR outliers in `outlier_columns` -> NaN
      2) strategy == 'impute': KNN-impute numeric columns (scaled)
         strategy == 'drop'  : drop rows having NaN in any of `outlier_columns`
      3) Drop any remaining NaNs (whole-row drop)
    Returns: (clean_df, outlier_row_mask_before_cleaning)
    """
    df_work = df.copy()

    # 1) Detect outliers & set to NaN (only in requested columns)
    masks = iqr_mask(df_work, outlier_columns, k=k_iqr)
    mask_df = pd.DataFrame(masks, index=df_work.index) if masks else pd.DataFrame(index=df_work.index)
    any_outlier = mask_df.any(axis=1) if not mask_df.empty else pd.Series(False, index=df_work.index)

    for col, m in masks.items():
        df_work.loc[m, col] = pd.NA

    # 2) Clean by strategy
    if strategy == "impute":
        num_cols = df_work.select_dtypes(include="number").columns.tolist()
        if num_cols:
            scaler = StandardScaler()
            imputer = KNNImputer(n_neighbors=n_neighbors, weights=weights)

            X_num = df_work[num_cols].astype(float).values
            X_scaled = scaler.fit_transform(X_num)
            X_imputed_scaled = imputer.fit_transform(X_scaled)
            X_imputed = scaler.inverse_transform(X_imputed_scaled)

            df_work[num_cols] = X_imputed
    elif strategy == "drop":
        df_work = df_work.dropna(subset=list(outlier_columns), how="any")
    else:
        raise ValueError("strategy must be 'impute' or 'drop'")

    # 3) Drop any remaining NaNs anywhere (whole row)
    df_clean = df_work.dropna(axis=0, how="any").reset_index(drop=True)

    if output_csv:
        df_clean.to_csv(output_csv, index=False)

    return df_clean, any_outlier


### Data cleaning

Reading the csv

In [3]:
filepath_train = r"../data/coral_students.csv"
df = pd.read_csv(filepath_train)

df.head()

Unnamed: 0,Sample_ID,Site_ID,ClimSST,Temperature_Kelvin,Temperature_Mean,Temperature_Minimum,Temperature_Maximum,Temperature_Kelvin_Standard_Deviation,Windspeed,SSTA,...,TSA_DHW,TSA_DHW_Standard_Deviation,TSA_DHWMax,TSA_DHWMean,Depth_m,Distance_to_Shore,Exposure,Turbidity,Cyclone_Frequency,Percent_Bleached
0,10274495.0,12082.0,301.65,303.5,299.79,293.35,305.54,2.52,2.0,0.49,...,1.27,0.74,6.05,0.22,8.3,8311.0,exposed,0.0586,56.583448,4.76
1,10274496.0,12083.0,299.31,300.84,299.75,293.68,305.44,2.54,6.0,-0.42,...,1.2,0.93,10.39,0.27,14.9,10747.0,exposed,0.0543,52.842523,21.88
2,10274497.0,12084.0,300.56,302.65,299.81,293.35,305.47,2.5,5.0,0.36,...,2.71,0.83,7.18,0.23,10.7,9396.0,exposed,0.0571,56.583448,19.66
3,10274498.0,12085.0,299.75,302.43,299.81,293.35,305.47,2.5,7.0,0.54,...,3.6,0.83,7.18,0.23,7.6,9408.0,exposed,0.0571,56.583448,28.03
4,10274499.0,12086.0,297.65,295.69,299.81,293.35,305.47,2.5,7.0,-0.91,...,0.0,0.83,7.18,0.23,10.0,9362.0,exposed,0.0571,56.583448,2.75


Checking the structure of data

In [4]:
print(df.shape)
df.isna().sum()

(4446, 41)


Sample_ID                                  1
Site_ID                                    1
ClimSST                                    1
Temperature_Kelvin                         1
Temperature_Mean                           1
Temperature_Minimum                        1
Temperature_Maximum                        1
Temperature_Kelvin_Standard_Deviation      1
Windspeed                                  1
SSTA                                       1
SSTA_Standard_Deviation                    1
SSTA_Mean                                  1
SSTA_Minimum                               1
SSTA_Maximum                               1
SSTA_Frequency                             1
SSTA_Frequency_Standard_Deviation          1
SSTA_FrequencyMax                          1
SSTA_FrequencyMean                         1
SSTA_DHW                                   1
SSTA_DHW_Standard_Deviation                1
SSTA_DHWMax                                1
SSTA_DHWMean                               1
TSA       

Searching for duplicate rows

In [5]:
uniqueId = ["Sample_ID"]

# 1) Exact row duplicates
dups = df[df.duplicated(subset=uniqueId, keep=False)]          # all duplicate rows
df_no_dups = df.drop_duplicates(subset=uniqueId, keep="first") # keep first occurrence

# 2) Duplicates by subset of columns (business key)
dups_by_key = df[df.duplicated(subset=uniqueId, keep=False)]
firsts = df.drop_duplicates(subset=uniqueId, keep="first")

# 3) Mark duplicates instead of dropping
df["is_dup"] = df.duplicated(subset=uniqueId, keep="first")

# 4) Count duplicate groups
dup_counts = (df
  .groupby(uniqueId, dropna=False)
  .size()
  .reset_index(name="count")
  .query("count > 1"))
print(dup_counts)

       Sample_ID  count
125   10274739.0      2
311   10275110.0      2
376   10275267.0      2
918   10276398.0      2
1236  10290633.0      2
1267  10322434.0      2
1424  10323868.0      2
2231  10326707.0      2
2246  10326723.0      2
2263  10326760.0      2
2267  10326765.0      2
2310  10326841.0      2
2318  10326849.0      2
2570  10327590.0      2
2672  10327888.0      2
2749  10328138.0      2
2830  10328411.0      2
2838  10328421.0      2
2878  10328513.0      2
3206  10329203.0      2
3365  10329491.0      2
3552  10329892.0      2
3629  10329992.0      2
3634  10330006.0      2
3856  10330530.0      2
3866  10330547.0      4
3878  10330569.0      2
3953  10330776.0      4
3994  10330878.0      4
4227  10331344.0      2
4368  10331601.0      2


Removing duplicates

In [6]:
df = df.drop_duplicates(subset=["Sample_ID"], keep="first").reset_index(drop=True)

Searching for outliers

In [7]:
cols = [
    "Temperature_Kelvin", "Temperature_Mean", "Temperature_Minimum",
    "Temperature_Maximum", "Temperature_Kelvin_Standard_Deviation",
    "Windspeed", "Depth_m", "Distance_to_Shore"
]

masks = iqr_mask(df, cols, k=1.5)

df_flagged = df.copy()
for col, m in masks.items():
    df_flagged[f"{col}_is_outlier"] = m

if masks:
    df_flagged["any_outlier"] = pd.concat(masks.values(), axis=1).any(axis=1)
else:
    df_flagged["any_outlier"] = False

outlier_rows = df_flagged[df_flagged["any_outlier"]]
print(outlier_rows.count())


Sample_ID                                           590
Site_ID                                             590
ClimSST                                             590
Temperature_Kelvin                                  590
Temperature_Mean                                    590
Temperature_Minimum                                 590
Temperature_Maximum                                 590
Temperature_Kelvin_Standard_Deviation               590
Windspeed                                           590
SSTA                                                590
SSTA_Standard_Deviation                             590
SSTA_Mean                                           590
SSTA_Minimum                                        590
SSTA_Maximum                                        590
SSTA_Frequency                                      590
SSTA_Frequency_Standard_Deviation                   590
SSTA_FrequencyMax                                   590
SSTA_FrequencyMean                              

Removing the outliers and 'na' values. Saving to the csv file

In [8]:
clean_df, outlier_rows = clean_outliers(
    df,
    cols,
    k_iqr=1.5,
    strategy="drop",
    output_csv="../data/cleaned_data_dropped.csv",
)
clean_df.isna().sum()

Sample_ID                                0
Site_ID                                  0
ClimSST                                  0
Temperature_Kelvin                       0
Temperature_Mean                         0
Temperature_Minimum                      0
Temperature_Maximum                      0
Temperature_Kelvin_Standard_Deviation    0
Windspeed                                0
SSTA                                     0
SSTA_Standard_Deviation                  0
SSTA_Mean                                0
SSTA_Minimum                             0
SSTA_Maximum                             0
SSTA_Frequency                           0
SSTA_Frequency_Standard_Deviation        0
SSTA_FrequencyMax                        0
SSTA_FrequencyMean                       0
SSTA_DHW                                 0
SSTA_DHW_Standard_Deviation              0
SSTA_DHWMax                              0
SSTA_DHWMean                             0
TSA                                      0
TSA_Standar