# Generate and Save All Crop Maps

**Goal:** This notebook automates the process of creating and saving high-quality maps for every available crop dataset. It will generate:
1.  A single map of the average yield (1981-2016) for each crop.
2.  A full series of yearly maps (1981-2016) for each crop.

All files will be saved in the `reports/figures/` directory, organized by crop name.

In [1]:
# Cell 1: The Complete Map Generation Script
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import re
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import os

def save_all_maps_for_crop(crop_name: str, data_dir: str = '../data/', output_dir: str = '../reports/figures/'):
    """
    Loads a crop dataset, then generates and saves its average map and all yearly maps.

    Args:
        crop_name (str): The name of the folder in the data directory (e.g., 'wheat_winter').
        data_dir (str): Path to the main data folder.
        output_dir (str): Path to the main output folder for figures.
    """
    print(f"--- Processing Dataset: {crop_name.upper()} ---")
    
    # --- 1. Load Data ---
    data_path = os.path.join(data_dir, crop_name, 'yield_*.nc4')
    try:
        ds = xr.open_mfdataset(data_path, combine='nested', concat_dim='time')
        filepaths = sorted(glob.glob(data_path))
        years = [int(re.search(r'(\d{4})\.nc4$', f).group(1)) for f in filepaths]
        ds = ds.assign_coords(time=years)
        yield_data = ds['var']
        print(f"Loaded {len(years)} years of data.")
    except Exception as e:
        print(f"ERROR: Could not load data for '{crop_name}'. Skipping. Reason: {e}")
        return

    # --- 2. Create Output Directories ---
    crop_output_dir = os.path.join(output_dir, crop_name)
    yearly_output_dir = os.path.join(crop_output_dir, 'yearly')
    os.makedirs(yearly_output_dir, exist_ok=True) # Creates both parent and child directories

    # --- 3. Generate and Save Average Map ---
    print("Generating average map...")
    mean_yield_map = yield_data.mean(dim='time').where(yield_data.mean(dim='time') > 0)
    
    fig, ax = plt.subplots(figsize=(15, 8), subplot_kw={'projection': ccrs.PlateCarree()})
    ax.add_feature(cfeature.LAND, edgecolor='gray', facecolor='#f0f0f0', linewidth=0.5)
    ax.add_feature(cfeature.BORDERS, linestyle=':', edgecolor='gray', linewidth=0.5)
    ax.coastlines(linewidth=0.5)
    
    mean_yield_map.plot(ax=ax, cmap='viridis', robust=True, transform=ccrs.PlateCarree(), cbar_kwargs={'shrink': 0.7})
    ax.set_title(f'Average Yield Map (1981-2016): {crop_name}')
    
    avg_file_path = os.path.join(crop_output_dir, f'average_yield_{crop_name}.png')
    plt.savefig(avg_file_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved average map to {avg_file_path}")

    # --- 4. Generate and Save Yearly Maps ---
    print("Generating yearly maps...")
    vmin = yield_data.quantile(0.02).compute().item()
    vmax = yield_data.quantile(0.98).compute().item()

    for year in ds['time'].values:
        yield_for_one_year = yield_data.sel(time=year).where(yield_data.sel(time=year) > 0)
        
        fig, ax = plt.subplots(figsize=(15, 8), subplot_kw={'projection': ccrs.PlateCarree()})
        ax.add_feature(cfeature.LAND, edgecolor='gray', facecolor='#f0f0f0', linewidth=0.5)
        ax.add_feature(cfeature.BORDERS, linestyle=':', edgecolor='gray', linewidth=0.5)
        ax.coastlines(linewidth=0.5)
        
        yield_for_one_year.plot(ax=ax, cmap='viridis', transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax, cbar_kwargs={'shrink': 0.7})
        ax.set_title(f'Yield in {year}: {crop_name}')
        
        yearly_file_path = os.path.join(yearly_output_dir, f'yield_{crop_name}_{year}.png')
        plt.savefig(yearly_file_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
        
    print(f"Finished saving yearly maps for {crop_name}.")
    print("-" * 50 + "\n")

# --- Main Execution ---
# Automatically find all the crop folders in the data directory
base_data_dir = '../data/'
crop_folders = [d for d in os.listdir(base_data_dir) if os.path.isdir(os.path.join(base_data_dir, d))]

print(f"Found {len(crop_folders)} crop datasets to process: {crop_folders}\n")

# Loop through each folder and generate the maps
for crop in crop_folders:
    save_all_maps_for_crop(crop)
    
print("--- All Done! ---")

Found 10 crop datasets to process: ['maize', 'maize_major', 'maize_second', 'rice', 'rice_major', 'rice_second', 'soybean', 'wheat', 'wheat_spring', 'wheat_winter']

--- Processing Dataset: MAIZE ---
Loaded 36 years of data.
Generating average map...
Saved average map to ../reports/figures/maize\average_yield_maize.png
Generating yearly maps...
Finished saving yearly maps for maize.
--------------------------------------------------

--- Processing Dataset: MAIZE_MAJOR ---
Loaded 36 years of data.
Generating average map...
Saved average map to ../reports/figures/maize_major\average_yield_maize_major.png
Generating yearly maps...
Finished saving yearly maps for maize_major.
--------------------------------------------------

--- Processing Dataset: MAIZE_SECOND ---
Loaded 36 years of data.
Generating average map...
Saved average map to ../reports/figures/maize_second\average_yield_maize_second.png
Generating yearly maps...
Finished saving yearly maps for maize_second.
------------------

# Analysis of Combined Crop Yields

**Goal:** To create a single dataset representing the total yield of the four major crops (maize, rice, wheat, and soybean) for each grid cell and year. This will give us a macro view of overall agricultural productivity.

**Methodology:**
We will combine the top-level datasets for each crop (`maize`, `rice`, `wheat`, `soybean`). We are intentionally *not* using the seasonal datasets (e.g., `maize_major`, `maize_second`) in this sum to avoid double-counting yields in regions with multiple harvests. We will sum the yield values at each point in space and time.

In [2]:
# Cell 2: Combine, Analyze, and Save Total Yield Maps

# --- Configuration ---
# Define which crop folders to combine. We use the main ones to avoid double-counting.
CROPS_TO_COMBINE = ['maize', 'rice', 'wheat', 'soybean']
DATA_DIR = '../data/'
OUTPUT_DIR = '../reports/figures/combined_total_yield/'
os.makedirs(os.path.join(OUTPUT_DIR, 'yearly'), exist_ok=True)

print("--- Starting Combined Yield Analysis ---")

# --- 1. Load and Combine Datasets ---
all_crop_dataarrays = []
for crop in CROPS_TO_COMBINE:
    print(f"Loading {crop}...")
    data_path = os.path.join(DATA_DIR, crop, 'yield_*.nc4')
    try:
        ds = xr.open_mfdataset(data_path, combine='nested', concat_dim='time')
        # We only need the data variable, not the full dataset
        all_crop_dataarrays.append(ds['var'])
    except Exception as e:
        print(f"Could not load {crop}, skipping. Reason: {e}")

# Use xarray's concat and sum methods to combine them.
# We fill any missing values (NaNs) with 0 so the sum works correctly.
# For example, a grid cell with wheat but no maize should be treated as wheat_yield + 0.
if all_crop_dataarrays:
    combined_yield = xr.concat(all_crop_dataarrays, dim='crop').fillna(0).sum(dim='crop')
    
    # Re-assign the time coordinate, which can get lost in the process
    years = range(1981, 2017)
    combined_yield = combined_yield.assign_coords(time=years)
    
    print("\nAll datasets combined successfully.")
else:
    print("No datasets were loaded. Stopping.")
    # This stops the script if no data was found, preventing errors.
    # You might need to add a 'pass' or 'raise' here depending on the desired behavior.


# --- 2. Generate and Save the Average Combined Map ---
print("Generating average combined map...")
mean_combined_map = combined_yield.mean(dim='time').where(combined_yield.mean(dim='time') > 0)

fig, ax = plt.subplots(figsize=(15, 8), subplot_kw={'projection': ccrs.PlateCarree()})
ax.add_feature(cfeature.LAND, edgecolor='gray', facecolor='#f0f0f0', linewidth=0.5)
ax.add_feature(cfeature.BORDERS, linestyle=':', edgecolor='gray', linewidth=0.5)
ax.coastlines(linewidth=0.5)

mean_combined_map.plot(ax=ax, cmap='viridis', robust=True, transform=ccrs.PlateCarree(), cbar_kwargs={'shrink': 0.7})
ax.set_title('Average Combined Yield (Maize, Rice, Wheat, Soybean) 1981-2016')

avg_file_path = os.path.join(OUTPUT_DIR, 'average_combined_yield.png')
plt.savefig(avg_file_path, dpi=300, bbox_inches='tight')
plt.close(fig)
print(f"Saved average combined map to {avg_file_path}")

# --- 3. Generate and Save Yearly Combined Maps ---
print("Generating yearly combined maps...")
vmin = combined_yield.quantile(0.02).compute().item()
vmax = combined_yield.quantile(0.98).compute().item()
yearly_output_dir = os.path.join(OUTPUT_DIR, 'yearly')

for year in combined_yield['time'].values:
    yield_for_one_year = combined_yield.sel(time=year).where(combined_yield.sel(time=year) > 0)
    
    fig, ax = plt.subplots(figsize=(15, 8), subplot_kw={'projection': ccrs.PlateCarree()})
    ax.add_feature(cfeature.LAND, edgecolor='gray', facecolor='#f0f0f0', linewidth=0.5)
    ax.add_feature(cfeature.BORDERS, linestyle=':', edgecolor='gray', linewidth=0.5)
    ax.coastlines(linewidth=0.5)
    
    yield_for_one_year.plot(ax=ax, cmap='viridis', transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax, cbar_kwargs={'shrink': 0.7})
    ax.set_title(f'Combined Yield in {year}')
    
    yearly_file_path = os.path.join(yearly_output_dir, f'combined_yield_{year}.png')
    plt.savefig(yearly_file_path, dpi=300, bbox_inches='tight')
    plt.close(fig)

print(f"Finished saving all yearly combined maps to {yearly_output_dir}")
print("\n--- All Done! ---")

--- Starting Combined Yield Analysis ---
Loading maize...
Loading rice...
Loading wheat...
Loading soybean...

All datasets combined successfully.
Generating average combined map...
Saved average combined map to ../reports/figures/combined_total_yield/average_combined_yield.png
Generating yearly combined maps...
Finished saving all yearly combined maps to ../reports/figures/combined_total_yield/yearly

--- All Done! ---
