#### 1. Imports, settings & load the synthetic dataset

In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

import joblib
import os

# Better visuals
%matplotlib inline
sns.set(style="whitegrid", palette="muted", font_scale=1.1)
plt.rcParams['figure.figsize'] = (10, 6)

# Reproducibility
np.random.seed(42)

# Load the dataset
data_path = "../data/processed/crop_risk_insurance_synthetic_v1.csv"
df = pd.read_csv(data_path)

print("Dataset loaded successfully!")
print("Shape:", df.shape)

df.head(3)

Dataset loaded successfully!
Shape: (12000, 18)


Unnamed: 0,country,crop,season_year,rainfall_mm,avg_temp_c,heat_stress_days,ndvi_peak,soil_ph,soc_percent,fertilizer_n_kg_ha,pest_disease_level,irrigated,yield_potential_t_ha,actual_yield_t_ha,expected_yield_t_ha,yield_loss_pct,risk_class,payout_usd_per_ha
0,Ethiopia,Beans,2015,515.6,22.9,0,0.693,5.82,0.69,39.0,1,0,1.425,1.29,0.78,0.0,Low,0.0
1,Ethiopia,Beans,2015,508.6,24.8,0,0.439,5.52,0.9,9.0,2,0,1.425,0.47,1.29,0.636,High,381.6
2,Ethiopia,Beans,2015,891.2,20.0,6,0.614,6.64,0.45,18.0,0,0,1.425,1.3,0.88,0.0,Low,0.0


#### 2. Basic data quality, descriptive stats & missing values check

- Basic info

In [6]:
print("DataFrame Info:")
df.info()

DataFrame Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 12000 entries, 0 to 11999
Data columns (total 18 columns):
 #   Column                Non-Null Count  Dtype  
---  ------                --------------  -----  
 0   country               12000 non-null  object 
 1   crop                  12000 non-null  object 
 2   season_year           12000 non-null  int64  
 3   rainfall_mm           12000 non-null  float64
 4   avg_temp_c            12000 non-null  float64
 5   heat_stress_days      12000 non-null  int64  
 6   ndvi_peak             12000 non-null  float64
 7   soil_ph               12000 non-null  float64
 8   soc_percent           12000 non-null  float64
 9   fertilizer_n_kg_ha    12000 non-null  float64
 10  pest_disease_level    12000 non-null  int64  
 11  irrigated             12000 non-null  int64  
 12  yield_potential_t_ha  12000 non-null  float64
 13  actual_yield_t_ha     12000 non-null  float64
 14  expected_yield_t_ha   12000 non-null  float64
 15  yie

- Missing values (should be zero or very few)

In [7]:
print("\nMissing Values:")
missing = df.isnull().sum()
print(missing[missing > 0] if missing.sum() > 0 else "No missing values — perfect!")


Missing Values:
No missing values — perfect!


- Descriptive statistics for numerical columns

In [8]:
print("\nNumerical Features Summary:")
num_cols = df.select_dtypes(include=['float64', 'int64']).columns
print(df[num_cols].describe().round(2))


Numerical Features Summary:
       season_year  rainfall_mm  avg_temp_c  heat_stress_days  ndvi_peak  \
count     12000.00     12000.00    12000.00          12000.00   12000.00   
mean       2020.85       673.52       24.69              8.27       0.54   
std           3.06       466.72        3.16              9.73       0.19   
min        2015.00        50.00       14.80              0.00       0.05   
25%        2018.00       348.25       22.50              3.00       0.39   
50%        2021.00       545.00       24.70              6.00       0.54   
75%        2024.00       848.88       26.80              9.00       0.68   
max        2025.00      2500.00       34.80             54.00       0.95   

        soil_ph  soc_percent  fertilizer_n_kg_ha  pest_disease_level  \
count  12000.00     12000.00            12000.00            12000.00   
mean       5.94         0.75               37.65                0.85   
std        0.74         0.43               39.18                1.01  

 - Descriptive statistics for categorical columns

In [9]:
print("\nCategorical Features Summary:")
cat_cols = ['country', 'crop', 'risk_class']
for col in cat_cols:
    print(f"\n{col} value counts:")
    print(df[col].value_counts(normalize=True).round(3) * 100, "%")


Categorical Features Summary:

country value counts:
country
Kenya       25.3
Ethiopia    22.3
Uganda      18.1
Zambia      13.5
Tanzania    12.9
Malawi       7.9
Name: proportion, dtype: float64 %

crop value counts:
crop
Sorghum      17.0
Groundnut    16.8
Maize        16.7
Beans        16.6
Millet       16.5
Cassava      16.4
Name: proportion, dtype: float64 %

risk_class value counts:
risk_class
Low       53.0
High      34.1
Medium    12.9
Name: proportion, dtype: float64 %


- Quick check for duplicates

In [10]:
duplicates = df.duplicated().sum()
print(f"\nNumber of duplicate rows: {duplicates}")


Number of duplicate rows: 0
