In [4]:
import pandas as pd
import requests
from time import sleep

# ---------- 1. Load FAOSTAT data ----------

faostat_path = "data/faostat_yield.csv"
df = pd.read_csv(faostat_path)

# Keep only needed columns
df = df[['Area', 'Item', 'Year', 'Unit', 'Value']]

# Your crops (make sure names match exactly what FAOSTAT uses)
target_crops = ["Wheat", "Maize (corn)", "Rice", "Sugar cane"]
df = df[df['Item'].isin(target_crops)]

# Rename columns for convenience
df = df.rename(columns={
    'Area': 'country',
    'Item': 'crop',
    'Year': 'year',
    'Value': 'yield_kg_ha'
})

# Filter years 2001–2023 just in case more exist
df = df[(df['year'] >= 2001) & (df['year'] <= 2023)]

# Keep only non-null yields
df = df.dropna(subset=['yield_kg_ha'])

print("FAOSTAT rows after filtering:", len(df))
df



FAOSTAT rows after filtering: 11534


Unnamed: 0,country,crop,year,Unit,yield_kg_ha
0,Afghanistan,Maize (corn),2001,kg/ha,2000.0
1,Afghanistan,Maize (corn),2002,kg/ha,2980.0
2,Afghanistan,Maize (corn),2003,kg/ha,840.0
3,Afghanistan,Maize (corn),2004,kg/ha,1600.0
4,Afghanistan,Maize (corn),2005,kg/ha,1206.9
...,...,...,...,...,...
11529,Zimbabwe,Wheat,2019,kg/ha,3914.9
11530,Zimbabwe,Wheat,2020,kg/ha,4779.6
11531,Zimbabwe,Wheat,2021,kg/ha,5075.9
11532,Zimbabwe,Wheat,2022,kg/ha,5154.2


In [5]:

# ---------- 2. Your selected countries (from your attached list) ----------

selected_countries = [
    "Morocco",
    "Uruguay",
    "Peru",
    "Angola",
    "Madagascar",
    "Argentina",
    "Niger",
    "Australia",
    "Thailand",
    "Bangladesh",
    "Kenya",
    "Bhutan",
    "Mali",
    "Bolivia (Plurinational State of)",
    "Myanmar",
    "Brazil",
    "Pakistan",
    "Burundi",
    "Somalia",
    "Cameroon",
    "United Republic of Tanzania",
    "China",
    "Japan",
    "China, mainland",
    "Afghanistan",
    "China, Taiwan Province of",
    "Malawi",
    "Colombia",
    "Mexico",
    "Democratic Republic of the Congo",
    "Mozambique",
    "Ecuador",
    "Nepal",
    "Egypt",
    "Nigeria",
    "Eswatini",
    "Paraguay",
    "Ethiopia",
    "Rwanda",
    "Honduras",
    "South Africa",
    "India",
    "Uganda",
    "Iran (Islamic Republic of)",
    "United States of America",
    "Zambia",
    "Venezuela (Bolivarian Republic of)",
    "Zimbabwe",
    "Chad",
    "Guatemala",
    "Portugal",
    "Spain",
    "Iraq",
    "Senegal",
    "Uzbekistan",
    "Hungary",
    "Côte d'Ivoire",
    "Republic of Korea",
    "Kazakhstan",
    "Bulgaria",
    "Cuba",
    "Ukraine",
    "Kyrgyzstan",
    "Philippines",
    "Lao People's Democratic Republic",
    "Russian Federation",
    "Democratic People's Republic of Korea",
    "Guyana",
    "Cambodia",
    "Suriname",
    "Dominican Republic",
    "Turkmenistan",
    "Azerbaijan",
    "Benin",
    "Mauritania",
    "Viet Nam",
    "Central African Republic",
    "Guinea",
    "El Salvador",
    "Romania",
    "Algeria",
    "Guinea-Bissau",
    "Chile",
    "Sierra Leone",
    "Fiji",
    "Haiti",
    "Nicaragua",
    "Sri Lanka",
    "France",
    "Tajikistan",
    "Gabon",
    "Türkiye",
    "North Macedonia",
    "Burkina Faso",
    "Ghana",
    "Indonesia",
    "Panama",
    "Congo",
    "Papua New Guinea",
    "Italy",
    "Greece",
    "Costa Rica",
    "Belize"
]

df = df[df['country'].isin(selected_countries)]
print("Rows after keeping only selected countries:", len(df))



Rows after keeping only selected countries: 8299


In [8]:

# ---------- 3. Country → (lat, lon) dictionary ----------

country_coords = {
    "Morocco": (31.8, -7.1),
    "Uruguay": (-32.5, -56.0),
    "Peru": (-9.2, -75.0),
    "Angola": (-12.5, 18.5),
    "Madagascar": (-19.0, 47.0),
    "Argentina": (-34.0, -64.0),
    "Niger": (17.6, 8.1),
    "Australia": (-25.0, 133.0),
    "Thailand": (15.0, 101.0),
    "Bangladesh": (23.7, 90.4),
    "Kenya": (0.0, 37.9),
    "Bhutan": (27.4, 90.4),
    "Mali": (17.3, -3.4),
    "Bolivia (Plurinational State of)": (-16.7, -64.7),
    "Myanmar": (21.9, 95.9),
    "Brazil": (-10.0, -55.0),
    "Pakistan": (30.4, 69.4),
    "Burundi": (-3.4, 29.9),
    "Somalia": (5.2, 45.5),
    "Cameroon": (5.7, 12.7),
    "United Republic of Tanzania": (-6.3, 35.0),
    "China": (35.9, 104.2),
    "Japan": (36.2, 138.3),
    "China, mainland": (35.9, 104.2),
    "Afghanistan": (33.9, 67.7),
    "China, Taiwan Province of": (23.7, 121.0),
    "Malawi": (-13.3, 34.3),
    "Colombia": (4.6, -74.1),
    "Mexico": (23.6, -102.5),
    "Democratic Republic of the Congo": (-2.9, 23.7),
    "Mozambique": (-18.7, 35.5),
    "Ecuador": (-1.8, -78.2),
    "Nepal": (28.4, 84.1),
    "Egypt": (26.8, 30.8),
    "Nigeria": (9.1, 8.7),
    "Eswatini": (-26.5, 31.5),
    "Paraguay": (-23.4, -58.4),
    "Ethiopia": (9.1, 40.5),
    "Rwanda": (-1.9, 29.9),
    "Honduras": (14.8, -86.2),
    "South Africa": (-30.6, 22.9),
    "India": (22.0, 79.0),
    "Uganda": (1.4, 32.3),
    "Iran (Islamic Republic of)": (32.4, 53.7),
    "United States of America": (39.8, -98.6),
    "Zambia": (-13.1, 27.8),
    "Venezuela (Bolivarian Republic of)": (7.0, -66.0),
    "Zimbabwe": (-19.0, 29.2),
    "Chad": (15.4, 18.7),
    "Guatemala": (15.6, -90.3),
    "Portugal": (39.6, -8.0),
    "Spain": (40.3, -3.7),
    "Iraq": (33.0, 44.3),
    "Senegal": (14.4, -14.5),
    "Uzbekistan": (41.4, 64.6),
    "Hungary": (47.2, 19.5),
    "Côte d'Ivoire": (7.5, -5.5),
    "Republic of Korea": (36.5, 128.0),
    "Kazakhstan": (48.0, 67.0),
    "Bulgaria": (42.8, 25.5),
    "Cuba": (21.5, -79.4),
    "Ukraine": (49.0, 31.4),
    "Kyrgyzstan": (41.2, 74.8),
    "Philippines": (12.9, 122.9),
    "Lao People's Democratic Republic": (19.9, 102.6),
    "Russian Federation": (61.5, 105.3),
    "Democratic People's Republic of Korea": (40.0, 127.0),
    "Guyana": (5.0, -58.9),
    "Cambodia": (12.6, 104.9),
    "Suriname": (4.1, -55.9),
    "Dominican Republic": (19.0, -70.0),
    "Turkmenistan": (39.1, 59.4),
    "Azerbaijan": (40.1, 47.5),
    "Benin": (9.3, 2.3),
    "Mauritania": (20.3, -10.9),
    "Viet Nam": (14.1, 108.3),
    "Central African Republic": (6.6, 20.9),
    "Guinea": (10.4, -10.9),
    "El Salvador": (13.8, -88.9),
    "Romania": (45.9, 24.9),
    "Algeria": (28.0, 1.7),
    "Guinea-Bissau": (12.0, -14.9),
    "Chile": (-35.7, -71.5),
    "Sierra Leone": (8.5, -11.8),
    "Fiji": (-17.8, 178.0),
    "Haiti": (18.9, -72.3),
    "Nicaragua": (13.0, -85.0),
    "Sri Lanka": (7.9, 80.7),
    "France": (46.2, 2.2),
    "Tajikistan": (38.9, 71.3),
    "Gabon": (-0.8, 11.6),
    "Türkiye": (39.0, 35.2),
    "North Macedonia": (41.6, 21.7),
    "Burkina Faso": (12.2, -1.6),
    "Ghana": (7.9, -1.2),
    "Indonesia": (-0.8, 113.9),
    "Panama": (8.5, -80.8),
    "Congo": (-0.2, 15.8),
    "Papua New Guinea": (-6.3, 143.9),
    "Italy": (41.9, 12.7),
    "Greece": (39.1, 22.9),
    "Costa Rica": (9.7, -84.2),
    "Belize": (17.2, -88.5)
}

missing_coords = sorted(set(df['country']) - set(country_coords.keys()))
if missing_coords:
    print("WARNING: No coordinates for:", missing_coords)
else:
    print('good')

good


In [14]:
import requests
import pandas as pd

def fetch_nasa_power(lat, lon, start=2001, end=2023):
    """
    Fetch NASA POWER monthly climate data in JSON format.
    Aggregate to annual means/sums.
    Returns: DataFrame(year, t2m, precip, rad, rh2m)
    """

    url = (
        "https://power.larc.nasa.gov/api/temporal/monthly/point"
        f"?start={start}&end={end}"
        f"&latitude={lat}&longitude={lon}"
        "&parameters=T2M,PRECTOTCORR,ALLSKY_SFC_SW_DWN,RH2M"
        "&community=AG"
        "&format=JSON"
    )

    # Fetch JSON
    r = requests.get(url)
    r.raise_for_status()
    data = r.json()

    # Monthly values are under properties → parameter → variable → {YYYYMM: value}
    params = data["properties"]["parameter"]

    # Build a monthly table
    records = []
    for var_name, monthly_dict in params.items():
        for ym_str, value in monthly_dict.items():
            year = int(ym_str[:4])
            month = int(ym_str[4:6])
            records.append({
                "year": year,
                "month": month,
                "variable": var_name,
                "value": value
            })

    df = pd.DataFrame(records)

    # Pivot into columns: T2M, PRECTOTCORR, ...
    df = df.pivot_table(index=["year", "month"], columns="variable", values="value").reset_index()

    # Aggregate monthly → yearly
    df_yearly = df.groupby("year").agg({
        "T2M": "mean",
        "ALLSKY_SFC_SW_DWN": "mean",
        "RH2M": "mean",
        "PRECTOTCORR": "sum"  # precip should be summed across months
    }).reset_index()

    # Rename columns
    df_yearly = df_yearly.rename(columns={
        "T2M": "t2m",
        "PRECTOTCORR": "precip",
        "ALLSKY_SFC_SW_DWN": "rad",
        "RH2M": "rh2m"
    })

    return df_yearly


In [24]:
# lat, lon = 31.8, -7.1  # Morocco from your coordinates
# test_df = fetch_nasa_power(lat, lon, start=2001, end=2023)
# print(test_df.head())
# print(test_df.tail())
# print(test_df.columns)
#UNIT TESTING THI, later commented

In [17]:

# ---------- 5. Build climate table for all selected countries ----------

climate_rows = []
for country in selected_countries:
    if country not in country_coords:
        print("Skipping (no coords):", country)
        continue
    lat, lon = country_coords[country]
    print("Fetching NASA POWER for:", country, lat, lon)
    try:
        cdf = fetch_nasa_power(lat, lon, start=2001, end=2023)
        cdf['country'] = country
        climate_rows.append(cdf)
        sleep(0.3)  # be nice to API
    except Exception as e:
        print("Error fetching", country, ":", e)

climate_df = pd.concat(climate_rows, ignore_index=True)
print("Climate rows:", len(climate_df))
print(climate_df.head())



Fetching NASA POWER for: Morocco 31.8 -7.1
Fetching NASA POWER for: Uruguay -32.5 -56.0
Fetching NASA POWER for: Peru -9.2 -75.0
Fetching NASA POWER for: Angola -12.5 18.5
Fetching NASA POWER for: Madagascar -19.0 47.0
Fetching NASA POWER for: Argentina -34.0 -64.0
Fetching NASA POWER for: Niger 17.6 8.1
Fetching NASA POWER for: Australia -25.0 133.0
Fetching NASA POWER for: Thailand 15.0 101.0
Fetching NASA POWER for: Bangladesh 23.7 90.4
Fetching NASA POWER for: Kenya 0.0 37.9
Fetching NASA POWER for: Bhutan 27.4 90.4
Fetching NASA POWER for: Mali 17.3 -3.4
Fetching NASA POWER for: Bolivia (Plurinational State of) -16.7 -64.7
Fetching NASA POWER for: Myanmar 21.9 95.9
Fetching NASA POWER for: Brazil -10.0 -55.0
Fetching NASA POWER for: Pakistan 30.4 69.4
Fetching NASA POWER for: Burundi -3.4 29.9
Fetching NASA POWER for: Somalia 5.2 45.5
Fetching NASA POWER for: Cameroon 5.7 12.7
Fetching NASA POWER for: United Republic of Tanzania -6.3 35.0
Fetching NASA POWER for: China 35.9 104.2


In [19]:

# ---------- 6. Merge FAOSTAT yields with climate ----------

merged = pd.merge(
    df,
    climate_df,
    on=['country', 'year'],
    how='inner'
)

print("Merged rows:", len(merged))
print(merged.head())


Merged rows: 8299
       country          crop  year   Unit  yield_kg_ha       t2m        rad  \
0  Afghanistan  Maize (corn)  2001  kg/ha       2000.0  5.570769  21.180000   
1  Afghanistan  Maize (corn)  2002  kg/ha       2980.0  5.046154  20.841538   
2  Afghanistan  Maize (corn)  2003  kg/ha        840.0  4.636923  20.180769   
3  Afghanistan  Maize (corn)  2004  kg/ha       1600.0  5.232308  21.035385   
4  Afghanistan  Maize (corn)  2005  kg/ha       1206.9  3.538462  20.149231   

        rh2m  precip  
0  36.056154    3.26  
1  40.205385    5.49  
2  44.924615    6.76  
3  42.566154    6.56  
4  46.926154    8.85  


In [21]:
print(len(df), len(climate_df), len(merged))
merged = merged.drop(columns=['Unit'])


8299 2369 8299


In [22]:

# Optional: create a lagged yield feature (previous year's yield per country+crop)
merged = merged.sort_values(['country', 'crop', 'year'])
merged['lag_yield_kg_ha'] = merged.groupby(['country', 'crop'])['yield_kg_ha'].shift(1)

# Drop first-year rows where lag is NaN if you want clean models
merged = merged.dropna(subset=['lag_yield_kg_ha'])

# Save final dataset
merged.to_csv("yield_climate_merged.csv", index=False)
print("Saved merged data to yield_climate_merged.csv")

Saved merged data to yield_climate_merged.csv


In [23]:
print(len(merged))
print(merged.head())


7937
       country          crop  year  yield_kg_ha       t2m        rad  \
1  Afghanistan  Maize (corn)  2002       2980.0  5.046154  20.841538   
2  Afghanistan  Maize (corn)  2003        840.0  4.636923  20.180769   
3  Afghanistan  Maize (corn)  2004       1600.0  5.232308  21.035385   
4  Afghanistan  Maize (corn)  2005       1206.9  3.538462  20.149231   
5  Afghanistan  Maize (corn)  2006       2620.4  4.574615  20.161538   

        rh2m  precip  lag_yield_kg_ha  
1  40.205385    5.49           2000.0  
2  44.924615    6.76           2980.0  
3  42.566154    6.56            840.0  
4  46.926154    8.85           1600.0  
5  48.684615   12.50           1206.9  
