# ðŸ¦· Dental Implant 10-Year Survival Prediction

## Notebook 01: Exploratory Data Analysis (EDA)

**Objective:** Explore and understand the dataset before building any models. This includes understanding feature distributions, identifying patterns, and detecting potential issues like missing values or class imbalance.

---


### ðŸŽ¨ Setup: Import Libraries & Configure Plotting Style

We'll use the Periospot brand colors for all visualizations.


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Periospot Brand Colors
COLORS = {
    'periospot_blue': '#15365a',
    'mystic_blue': '#003049',
    'periospot_red': '#6c1410',
    'crimson_blaze': '#a92a2a',
    'vanilla_cream': '#f7f0da',
    'black': '#000000',
    'white': '#ffffff',
    'classic_periospot_blue': '#0031af',
    'periospot_light_blue': '#0297ed',
    'periospot_dark_blue': '#02011e',
    'periospot_yellow': '#ffc430',
    'periospot_bright_blue': '#1040dd'
}

# Create a custom color palette for plots
periospot_palette = [
    COLORS['periospot_blue'], 
    COLORS['crimson_blaze'], 
    COLORS['periospot_light_blue'],
    COLORS['periospot_yellow'],
    COLORS['mystic_blue'], 
    COLORS['periospot_red']
]

# Configure matplotlib
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['figure.facecolor'] = COLORS['white']
plt.rcParams['axes.facecolor'] = COLORS['vanilla_cream']
plt.rcParams['axes.edgecolor'] = COLORS['periospot_blue']

# Set seaborn style
sns.set_palette(periospot_palette)

print("âœ… Libraries imported and plotting style configured!")


---

### 1. Load Data & Initial Inspection

First, let's load the training data and get a feel for what we're working with.


In [None]:
# TODO: Load the train.csv dataset from the /data/raw/ folder.
# Hint: Use pd.read_csv() with the correct relative path.
df = ...

# TODO: Display the first 5 rows of the dataframe.
# Hint: Use the .head() method.
...


In [None]:
# TODO: Get the shape of the dataframe (rows, columns).
# Hint: Use the .shape attribute.
print(f"Dataset shape: ...")


In [None]:
# TODO: Get a concise summary of the dataframe, including data types and non-null values.
# Hint: Use the .info() method.
...


In [None]:
# TODO: Check for missing values in each column.
# Hint: Use df.isnull().sum() to count missing values per column.
...


---

### 2. Summary Statistics & Target Distribution

Understanding the target variable distribution is crucial - it tells us if we're dealing with a balanced or imbalanced classification problem.


In [None]:
# TODO: Generate descriptive statistics for the numerical columns.
# Hint: Use the .describe() method.
...


In [None]:
# TODO: Create a count plot to visualize the distribution of the target variable 'implant_survival_10y'.
# Hint: Use sns.countplot(). Remember to save the figure to the /figures/ folder using plt.savefig().

fig, ax = plt.subplots(figsize=(8, 6))

# TODO: Create the count plot
# sns.countplot(data=df, x='implant_survival_10y', palette=periospot_palette, ax=ax)
...

ax.set_title('Distribution of Implant Survival (10-Year)', fontweight='bold')
ax.set_xlabel('Implant Survival (10-Year)')
ax.set_ylabel('Count')

# TODO: Add value labels on bars
# Hint: Use ax.bar_label(ax.containers[0])

plt.tight_layout()
# TODO: Save the figure
# plt.savefig('../figures/target_distribution.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# TODO: Calculate the percentage of each class in the target variable.
# Hint: Use df['implant_survival_10y'].value_counts(normalize=True) * 100
# This will help you understand if the dataset is imbalanced.
...


---

### 3. Univariate Analysis (Visualizing Features)

Let's examine each feature individually to understand their distributions.


In [None]:
# TODO: Identify numerical and categorical columns.
# Hint: Use df.select_dtypes(include=['int64', 'float64']).columns for numerical
# and df.select_dtypes(include=['object']).columns for categorical.

numerical_cols = ...
categorical_cols = ...

print(f"Numerical columns: {list(numerical_cols)}")
print(f"Categorical columns: {list(categorical_cols)}")


In [None]:
# TODO: Create histograms for all numerical features to understand their distributions.
# Hint: Use df[numerical_cols].hist() and adjust the figure size.

# TODO: Create histograms
# df[numerical_cols].hist(bins=20, figsize=(15, 10), color=COLORS['periospot_blue'], 
#                         edgecolor=COLORS['mystic_blue'])
...

plt.suptitle('Distribution of Numerical Features', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()

# TODO: Save the figure
# plt.savefig('../figures/numerical_distributions.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# TODO: Create count plots for all categorical features.
# Hint: Loop through the categorical columns and use sns.countplot().

# First, determine how many categorical columns we have to set up the subplot grid
n_cats = len(categorical_cols) if 'categorical_cols' in dir() else 0

if n_cats > 0:
    n_cols = 2
    n_rows = (n_cats + 1) // 2
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 4 * n_rows))
    axes = axes.flatten() if n_cats > 1 else [axes]
    
    # TODO: Loop through categorical columns and create count plots
    # for i, col in enumerate(categorical_cols):
    #     sns.countplot(data=df, x=col, palette=periospot_palette, ax=axes[i])
    #     axes[i].set_title(f'Distribution of {col}')
    #     axes[i].tick_params(axis='x', rotation=45)
    ...
    
    # Hide any unused subplots
    for j in range(n_cats, len(axes)):
        axes[j].set_visible(False)
    
    plt.tight_layout()
    # TODO: Save the figure
    # plt.savefig('../figures/categorical_distributions.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("No categorical columns found or not yet defined.")


---

### 4. Bivariate Analysis (Features vs. Target)

Now let's see how each feature relates to the target variable. This can give us hints about which features might be predictive.


In [None]:
# TODO: For each categorical feature, create a count plot showing the distribution of the target variable.
# Hint: Use sns.countplot() with the 'hue' parameter set to the target variable.

n_cats = len(categorical_cols) if 'categorical_cols' in dir() else 0

if n_cats > 0:
    n_cols = 2
    n_rows = (n_cats + 1) // 2
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 4 * n_rows))
    axes = axes.flatten() if n_cats > 1 else [axes]
    
    # TODO: Loop through categorical columns and create count plots with hue
    # for i, col in enumerate(categorical_cols):
    #     sns.countplot(data=df, x=col, hue='implant_survival_10y', 
    #                   palette=[COLORS['periospot_blue'], COLORS['crimson_blaze']], ax=axes[i])
    #     axes[i].set_title(f'{col} vs. Implant Survival')
    #     axes[i].tick_params(axis='x', rotation=45)
    #     axes[i].legend(title='Survival')
    ...
    
    for j in range(n_cats, len(axes)):
        axes[j].set_visible(False)
    
    plt.tight_layout()
    # TODO: Save the figure
    # plt.savefig('../figures/categorical_vs_target.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("No categorical columns found or not yet defined.")


In [None]:
# TODO: For each numerical feature, create a box plot to see how its distribution varies with the target.
# Hint: Use sns.boxplot() with x='implant_survival_10y'.

# Exclude the target from numerical columns if it's there
num_features = [col for col in numerical_cols if col != 'implant_survival_10y'] if 'numerical_cols' in dir() else []
n_nums = len(num_features)

if n_nums > 0:
    n_cols_plot = 2
    n_rows_plot = (n_nums + 1) // 2
    
    fig, axes = plt.subplots(n_rows_plot, n_cols_plot, figsize=(14, 4 * n_rows_plot))
    axes = axes.flatten() if n_nums > 1 else [axes]
    
    # TODO: Loop through numerical features and create box plots
    # for i, col in enumerate(num_features):
    #     sns.boxplot(data=df, x='implant_survival_10y', y=col,
    #                 palette=[COLORS['periospot_blue'], COLORS['crimson_blaze']], ax=axes[i])
    #     axes[i].set_title(f'{col} by Implant Survival')
    ...
    
    for j in range(n_nums, len(axes)):
        axes[j].set_visible(False)
    
    plt.tight_layout()
    # TODO: Save the figure
    # plt.savefig('../figures/numerical_vs_target.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("No numerical columns found or not yet defined.")


---

### 5. Correlation Analysis

Correlation analysis helps us understand relationships between numerical features and can reveal multicollinearity issues.


In [None]:
# TODO: Calculate the correlation matrix for the numerical features.
# Hint: Use the .corr() method.

correlation_matrix = ...

# Display the correlation matrix
correlation_matrix


In [None]:
# TODO: Create a heatmap to visualize the correlation matrix.
# Hint: Use sns.heatmap() with the correlation matrix.

fig, ax = plt.subplots(figsize=(12, 10))

# Create a custom colormap using Periospot colors
# TODO: Create the heatmap
# sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
#             square=True, linewidths=0.5, fmt='.2f', ax=ax,
#             cbar_kws={'shrink': 0.8})
...

ax.set_title('Correlation Matrix of Numerical Features', fontsize=16, fontweight='bold')

plt.tight_layout()
# TODO: Save the figure
# plt.savefig('../figures/correlation_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# TODO: Identify features most correlated with the target variable.
# Hint: Use correlation_matrix['implant_survival_10y'].sort_values(ascending=False)

# This will help you understand which features have the strongest relationship
# with the outcome we're trying to predict.
...


---

### 6. Key Findings & Insights

Summarize your findings from the EDA before moving to preprocessing.


#### TODO: Write your EDA summary here

**Dataset Overview:**
- Number of samples: _____
- Number of features: _____
- Target variable: `implant_survival_10y`

**Class Balance:**
- Is the dataset balanced or imbalanced? _____
- What percentage is each class? _____

**Missing Values:**
- Are there any missing values? _____
- Which columns have missing data? _____

**Key Observations:**
1. _____
2. _____
3. _____

**Features Most Correlated with Target:**
1. _____
2. _____
3. _____

**Recommendations for Preprocessing:**
- _____
- _____


---

### âœ… EDA Complete!

**Next Step:** Proceed to `02_Data_Preprocessing.ipynb` to clean and prepare the data for modeling.


fir