#### 1. Imports

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid")
np.random.seed(42)

N = 18000

#### 2. Core parameters

In [2]:
crops = ['Maize', 'Sorghum', 'Millet', 'Beans', 'Cassava', 'Groundnut']
countries = ['Kenya', 'Ethiopia', 'Uganda', 'Zambia', 'Tanzania', 'Malawi']

base_yield = {
    'Maize': 5.2, 'Sorghum': 3.6, 'Millet': 2.3,
    'Beans': 2.0, 'Cassava': 15.0, 'Groundnut': 2.7
}

country_factor = {
    'Kenya': 0.92, 'Ethiopia': 0.72, 'Uganda': 0.95,
    'Zambia': 0.84, 'Tanzania': 0.80, 'Malawi': 0.76
}

#### 3. Create base dataframe

In [3]:
df = pd.DataFrame(index=range(N))

df['country'] = np.random.choice(countries, N, p=[0.24, 0.25, 0.16, 0.14, 0.12, 0.09])
df['crop'] = np.random.choice(crops, N)
df['season_year'] = np.random.randint(2015, 2025, N)

#### 4. Add weather variables (rainfall, temperature, heat stress)

In [4]:
df['rainfall_mm'] = np.random.lognormal(6.3, 0.72, N).clip(40, 2800)

# Country adjustment (Ethiopia drier, Uganda wetter)
rain_adj = {'Ethiopia':0.78, 'Kenya':0.94, 'Uganda':1.18, 'Zambia':0.97, 'Tanzania':0.96, 'Malawi':1.06}
df['rainfall_mm'] *= df['country'].map(rain_adj)

df['avg_temp_c'] = np.random.normal(25.2, 3.4, N).clip(15, 36)
df['heat_stress_days'] = np.where(df['avg_temp_c'] > 29, np.random.randint(15, 60, N), np.random.randint(0, 18, N))

#### 5. Vegetation & management

In [5]:
df['ndvi_peak'] = np.random.beta(4.2, 3.8, N).clip(0.04, 0.96)

# Strong NDVI drop in dry/hot conditions
stress = (df['rainfall_mm'] < 400) | (df['avg_temp_c'] > 30)
df.loc[stress, 'ndvi_peak'] *= np.random.uniform(0.35, 0.75, stress.sum())
df['ndvi_peak'] = df['ndvi_peak'].clip(0.04, 0.96)

df['soil_ph'] = np.random.normal(5.9, 0.8, N).clip(4.3, 8.4)
df['soc_percent'] = np.random.beta(1.8, 6.2, N) * 3.5
df['fertilizer_n_kg_ha'] = np.random.exponential(40, N).clip(0, 250) * df['country'].map({'Kenya':1.4, 'others':1.0}).fillna(1.0)
df['pest_disease_level'] = np.random.choice([0,1,2,3], N, p=[0.48, 0.26, 0.17, 0.09])
df['irrigated'] = np.random.choice([0,1], N, p=[0.92, 0.08])

#### 6. Generate actual yield with stronger signal

In [10]:
potential = df['crop'].map(base_yield) * df['country'].map(country_factor)

# Rainfall: sharp drought penalty + cap on excess
water_f = np.where(df['rainfall_mm'] < 250, 0.20,
          np.where(df['rainfall_mm'] < 400, 0.45,
          np.where(df['rainfall_mm'] < 600, 0.80,
                   1.0 + 0.0005 * (df['rainfall_mm'] - 600).clip(0, 1000))))

# NDVI: strong non-linear effect (quadratic penalty below 0.6)
ndvi_f = np.where(df['ndvi_peak'] < 0.6,
                  (df['ndvi_peak'] / 0.6) ** 3.0,   # very steep drop
                  1.0 + 0.5 * (df['ndvi_peak'] - 0.6))

# Heat: exponential penalty
temp_f = np.exp(-0.16 * (df['avg_temp_c'] - 24).clip(0, None))

# Pest: higher impact
pest_f = 1 - 0.26 * df['pest_disease_level']

fert_f = 1 + 0.007 * df['fertilizer_n_kg_ha'] ** 0.75
irr_f = np.where(df['irrigated'], 1.50, 1.0)

df['actual_yield_t_ha'] = (
    potential * water_f * ndvi_f * temp_f * pest_f * fert_f * irr_f *
    np.random.lognormal(0, 0.14, N)   # reduced noise for clearer patterns
).clip(0.03, potential * 1.9).round(2)

#### 7. Create insurance targets: expected yield, yield loss %, risk class, payout

In [11]:
df['expected_yield_t_ha'] = df.groupby(['country','crop'])['actual_yield_t_ha'].transform(
    lambda x: x.expanding().mean().shift(1).fillna(x.mean())
)

df['yield_loss_pct'] = (1 - df['actual_yield_t_ha'] / df['expected_yield_t_ha']).clip(0,1)
df['risk_class'] = pd.cut(df['yield_loss_pct'], [0,0.15,0.40,1], labels=['Low','Medium','High'], include_lowest=True)
df['payout_usd_per_ha'] = 600 * df['yield_loss_pct']

print("Risk distribution:\n", df['risk_class'].value_counts(normalize=True).round(3))
print("Average payout:", df['payout_usd_per_ha'].mean().round(1))

Risk distribution:
 risk_class
High      0.553
Low       0.383
Medium    0.064
Name: proportion, dtype: float64
Average payout: 290.0


#### 8. Save the complete dataset + final overall check

In [12]:
df.to_csv("../data/processed/crop_risk_insurance_v2.csv", index=False)
print("Saved v2 dataset with stronger signal")

Saved v2 dataset with stronger signal


In [9]:
df.head()

Unnamed: 0,country,crop,season_year,rainfall_mm,avg_temp_c,heat_stress_days,ndvi_peak,soil_ph,soc_percent,fertilizer_n_kg_ha,pest_disease_level,irrigated,actual_yield_t_ha,expected_yield_t_ha,yield_loss_pct,risk_class,payout_usd_per_ha
0,Ethiopia,Cassava,2020,213.948385,33.824047,18,0.146145,5.922404,0.98253,11.844722,2,1,0.07,2.622789,0.973311,High,583.986514
1,Malawi,Cassava,2016,284.730408,23.443005,10,0.421844,5.835399,0.77321,5.295611,0,0,1.74,3.12976,0.444047,High,266.428126
2,Zambia,Millet,2021,223.631276,24.732808,9,0.414059,5.870016,0.354913,58.889154,3,0,0.06,0.540215,0.888933,High,533.359904
3,Uganda,Maize,2024,261.619853,31.786345,18,0.322295,4.683592,0.874254,24.767924,0,0,0.16,1.414876,0.886916,High,532.149533
4,Kenya,Beans,2024,1500.581018,28.159429,12,0.621145,7.435826,0.7955,37.806496,2,0,0.54,0.502674,0.0,Low,0.0
