In [3]:
# merge_and_correlate.py

import numpy as np
from shapely.geometry import box
from rasterio.transform import from_bounds
from scipy.stats import pearsonr
from tqdm import tqdm
from data.dataset_v3 import ImageMaskDataset
from sentence_transformers import SentenceTransformer
from collections import defaultdict
import rasterio
from rasterio.transform import Affine
import os
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import rasterio
from rasterio.crs import CRS

import os
import random
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import rasterio
from rasterio.transform import from_bounds
from rasterio.crs import CRS

def merge_masks(
    dataset,
    resolution=100,
    output_dir="/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results",
    crs="EPSG:5070",
    num_irr_classes=4,
    state='Arizona',
    max_patches=20000,
    seed=42
):
    # Step 1: Randomly select up to 20K valid patches
    all_bounds = []
    valid_samples = []

    indices = list(range(len(dataset)))
    random.seed(seed)
    random.shuffle(indices)

    selected = 0
    for i in tqdm(indices, desc="Scanning and selecting samples"):
        if selected >= max_patches:
            break
        sample = dataset[i]
        if sample['irr_mask'] is None:
            continue
        bounds = sample['polygon'].bounds
        all_bounds.append(bounds)
        valid_samples.append((sample, bounds))
        selected += 1

    print(f"\n✅ Selected {len(valid_samples)} valid samples (out of {len(dataset)})")

    # Step 2: Compute global raster bounds and shape
    bounds_array = np.array(all_bounds)
    xmin, ymin = bounds_array[:, 0].min(), bounds_array[:, 1].min()
    xmax, ymax = bounds_array[:, 2].max(), bounds_array[:, 3].max()

    width = int(np.ceil((xmax - xmin) / resolution))
    height = int(np.ceil((ymax - ymin) / resolution))
    transform = from_bounds(xmin, ymin, xmax, ymax, width, height)
    os.makedirs(output_dir, exist_ok=True)

    # Step 3: Initialize rasters
    crop_classes = np.zeros((255, height, width), dtype=np.float32)  # fixed 255 channels
    irr_stack = np.zeros((num_irr_classes, height, width), dtype=np.uint32)
    count_raster = np.zeros((height, width), dtype=np.uint32)

    for sample, bounds in tqdm(valid_samples, desc="Merging patches"):
        crop = sample['crop_mask'].squeeze().numpy()
        irr = sample['irr_mask'].squeeze().numpy().astype(np.uint8)

        h, w = crop.shape
        x0 = int((bounds[0] - xmin) / resolution)
        y0 = int((ymax - bounds[3]) / resolution)
        y1 = min(y0 + h, height)
        x1 = min(x0 + w, width)

        h_clip = y1 - y0
        w_clip = x1 - x0
        if h_clip <= 0 or w_clip <= 0:
            continue

        crop = crop[:h_clip, :w_clip]
        irr = irr[:h_clip, :w_clip]

        unique_vals = np.unique(crop)
        for val in unique_vals:
            if val <= 0.0 or val > 1.0:
                continue
            crop_type = int(round(val * 255))
            if crop_type == 0 or crop_type > 255:
                continue
            mask = (crop == val).astype(np.float32)
            crop_classes[crop_type - 1, y0:y1, x0:x1] += mask  # index from 0

        for c in range(num_irr_classes):
            mask = (irr == c).astype(np.uint8)
            irr_stack[c, y0:y1, x0:x1] += mask

        count_raster[y0:y1, x0:x1] += 1

    # Step 4: Post-process outputs
    irr_raster = np.argmax(irr_stack, axis=0).astype(np.uint8)
    valid_mask = (count_raster > 0).astype(np.uint8)
    count_raster[count_raster == 0] = 1  # avoid division by zero

    crop_stack = (crop_classes / count_raster).astype(np.float32)

    def save_raster(path, data, count=1, dtype="uint8"):
        with rasterio.open(
            path, "w", driver="GTiff",
            height=height, width=width, count=count,
            dtype=dtype, transform=transform, crs=CRS.from_string(crs)
        ) as dst:
            if count == 1:
                dst.write(data, 1)
            else:
                dst.write(data)

    save_raster(os.path.join(output_dir, f"irrigation_majority_{state}.tif"), irr_raster, count=1, dtype="uint8")
    save_raster(os.path.join(output_dir, f"valid_mask_{state}.tif"), valid_mask, count=1, dtype="uint8")
    save_raster(os.path.join(output_dir, f"crop_distribution_{state}.tif"), crop_stack, count=255, dtype="float32")

    print(f"\n✅ Merged and saved {selected} patches for {state}")
    print(f"✅ Files saved to: {output_dir}")



import rasterio
import numpy as np
from scipy.stats import pearsonr

def load_and_compute_crop_irr_correlation(
    crop_path: str,
    irr_path: str,
    valid_path: str,
    num_irr_classes: int = 4
):
    """
    Loads raster files and computes correlation between crop fractions and irrigation classes.
    Returns a nested dict[crop_id][irr_class] = pearson_corr
    """
    # --- Load crop multi-band raster ---
    with rasterio.open(crop_path) as src:
        crop_stack = src.read()  # shape: (num_crops, H, W)
        crop_ids = list(range(1, crop_stack.shape[0] + 1))  # assume band i = crop i

    crop_rasters = {crop_id: crop_stack[i] for i, crop_id in enumerate(crop_ids)}

    # --- Load categorical irrigation raster ---
    with rasterio.open(irr_path) as src:
        irr_raster = src.read(1)  # shape: (H, W)

    # --- Load valid pixel mask ---
    with rasterio.open(valid_path) as src:
        valid_mask = src.read(1).astype(bool)

    irr_flat = irr_raster[valid_mask].flatten()
    correlations = {}

    for crop_id, crop_band in crop_rasters.items():
        crop_flat = crop_band[valid_mask].flatten()
        crop_corrs = {}

        for irr_class in range(num_irr_classes):
            irr_binary = (irr_flat == irr_class).astype(np.float32)
            if np.std(irr_binary) == 0 or np.std(crop_flat) == 0:
                corr = 0.0  # fallback for degenerate case
            else:
                corr = pearsonr(crop_flat, irr_binary)[0]
            crop_corrs[irr_class] = corr

        correlations[int(crop_id)] = crop_corrs

    return correlations



def compute_text_irrigation_correlation(dataset):
    model = SentenceTransformer("all-MiniLM-L6-v2")
    embeddings = []
    labels = []

    for i in tqdm(range(len(dataset)), desc="Text-Irrigation correlation"):
        sample = dataset[i]
        if sample['irr_mask'] is None:
            continue
        text = sample['text_prompt']
        emb = model.encode(text)
        irr = sample['irr_mask'].squeeze().numpy()
        avg_irr = np.mean(irr > 0)
        embeddings.append(emb)
        labels.append(avg_irr)

    embeddings = np.array(embeddings)
    labels = np.array(labels)
    correlations = [pearsonr(embeddings[:, i], labels)[0] for i in range(embeddings.shape[1])]
    return correlations, np.mean(np.abs(correlations))

import pandas as pd

def save_correlation_to_csv(correlations: dict, output_path: str):
    """
    Saves nested dict[crop_id][irr_class] = correlation to CSV.
    Each row is a crop; columns are IrrClass_0, IrrClass_1, ...
    """
    df = pd.DataFrame.from_dict(correlations, orient='index')
    df.index.name = "CropID"
    df.columns = [f"IrrClass_{c}" for c in df.columns]
    df.to_csv(output_path)
    print(f"✅ Saved correlation CSV to: {output_path}")

# if __name__ == "__main__":
#     data_dir = "/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_irrigation_data/Train-Test-Split"
#     dataset = ImageMaskDataset(
#         data_dir=data_dir,
#         states=[('Arizona', 1.0)],
#         train_type='cross-state',
#         split='train',
#         transform=False,
#         vision_indices=['image']
#     )

#     width, height, transform, valid_samples, xmin, ymin, xmax, ymax, resolution = merge_masks(dataset)
#     crop_rasters, irr_raster, valid_mask = find_crop_irr_raster(width, height, transform, valid_samples, xmin, ymin, xmax, ymax, resolution)
#     crop_irr_corr = compute_crop_irr_correlation(crop_rasters, irr_raster, valid_mask)
#     print("\n✅ Crop-Irrigation Correlation per Crop Type:")
#     for crop_id, corr in crop_irr_corr.items():
#         print(f"  Crop {crop_id:3d} → Corr: {corr:.4f}")

#     correlations, mean_corr = compute_text_irrigation_correlation(dataset)
#     print(f"\n📝 Mean absolute correlation (text → irrigation): {mean_corr:.4f}")


In [2]:
data_dir = "/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_irrigation_data/Train-Test-Split"
state = 'Georgia'
dataset = ImageMaskDataset(
        data_dir=data_dir,
        states=[(state, 1.0)],
        train_type='cross-state',
        split='train',
        transform=False,
        vision_indices=['image']
    )

merge_masks(dataset, state=state)

Scanning and selecting samples: 100%|██████████| 7083/7083 [11:39<00:00, 10.12it/s]



✅ Selected 7083 valid samples (out of 7083)


Merging patches: 100%|██████████| 7083/7083 [00:24<00:00, 291.67it/s]



✅ Merged and saved 7083 patches for Georgia
✅ Files saved to: /project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results


In [5]:
state = 'Georgia'
crop_path = f'/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/crop_distribution_{state}.tif'
irr_path = f'/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/irrigation_majority_{state}.tif'
valid_path = f'/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/valid_mask_{state}.tif'
output_path = f'/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/crop_irr_corr_{state}.csv'
corr = load_and_compute_crop_irr_correlation(crop_path,
    irr_path,
    valid_path)
save_correlation_to_csv(corr,output_path=output_path)

✅ Saved correlation CSV to: /project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/crop_irr_corr_Georgia.csv


In [4]:
###### crop-irr correlation

from collections import defaultdict

from collections import defaultdict

CROP_GROUPS = {
    "Grains": [1, 3, 4, 21, 22, 23, 24, 27, 28, 29, 30],
    "Oilseeds": [5, 6, 31, 32, 33, 34, 35],
    "Legumes": [10, 37, 38, 39],
    "Vegetables": [12, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57],
    "Orchards": [68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80],
    "Forage": [36, 58, 59, 60],
    "Specialty": [2, 11, 13, 14, 41, 42, 43, 44],
    "DoubleCrop/Other": list(range(26, 67)) + list(range(81, 95)),  # includes double crop combinations and other mixed classes
}

def load_and_compute_group_crop_irr_correlation(
    crop_path: str,
    irr_path: str,
    valid_path: str,
    crop_groups: dict = CROP_GROUPS,
    num_irr_classes: int = 4
):
    """
    Computes correlation between grouped crop categories and irrigation class maps.
    Returns: dict[group_name][irr_class] = pearson_corr
    """
    import rasterio
    from scipy.stats import pearsonr
    import numpy as np

    with rasterio.open(crop_path) as src:
        crop_stack = src.read()  # shape: (num_crops, H, W)

    crop_ids = list(range(1, crop_stack.shape[0] + 1))
    crop_rasters = {crop_id: crop_stack[i] for i, crop_id in enumerate(crop_ids)}

    with rasterio.open(irr_path) as src:
        irr_raster = src.read(1)

    with rasterio.open(valid_path) as src:
        valid_mask = src.read(1).astype(bool)

    irr_flat = irr_raster[valid_mask].flatten()
    group_corrs = {}

    for group_name, crop_ids in crop_groups.items():
        if not crop_ids:
            continue
        combined_crop = sum(crop_rasters[cid] for cid in crop_ids if cid in crop_rasters) / len(crop_ids)
        crop_flat = combined_crop[valid_mask].flatten()

        class_corrs = {}
        for irr_class in range(num_irr_classes):
            irr_binary = (irr_flat == irr_class).astype(np.float32)
            if np.std(irr_binary) == 0 or np.std(crop_flat) == 0:
                corr = 0.0
            else:
                corr = pearsonr(crop_flat, irr_binary)[0]
            class_corrs[irr_class] = corr

        group_corrs[group_name] = class_corrs

    return group_corrs


In [7]:
state = 'Utah'
crop_path = f'/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/crop_distribution_{state}.tif'
irr_path = f'/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/irrigation_majority_{state}.tif'
valid_path = f'/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/valid_mask_{state}.tif'
output_path = f'/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/crop_irr_corr_{state}_group.csv'
corr = load_and_compute_group_crop_irr_correlation(crop_path,
    irr_path,
    valid_path)
save_correlation_to_csv(corr,output_path=output_path)

✅ Saved correlation CSV to: /project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Results/crop_irr_corr_Utah_group.csv


In [22]:
crop_stack.shape

(69, 4986, 6358)

In [1]:
from model_v3.TeacherModel import TeacherModel
from data.data_module_v2 import IrrigationDataModule
import yaml

In [21]:
cfg_path = '/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_Irrigation_Mapping_Model/Output/cross-state/vision/unet/result_stats/configs/hydra_config.yaml'

from omegaconf import OmegaConf
state = 'Florida'
cfg = OmegaConf.load(cfg_path)
cfg.dataset.train_type = 'unsupervised'
cfg.dataset.states = [[state,1]]



In [None]:
data_module = IrrigationDataModule(cfg)
data_module.setup('fit')
data_module.setup('test')
ckpt_path = '/sfs/gpfs/tardis/project/bii_nssac/people/wyr6fx/NeurIPS_Irrigation_Mapping_Model/Output/cross-state/vision/unet/result_stats/checkpoints/epoch=19-val_iou_macro_irr=0.819.ckpt'
model = TeacherModel.load_from_checkpoint(ckpt_path, **cfg)

In [None]:
import torch
import numpy as np
import geopandas as gpd
from shapely import wkt
from rasterio.features import rasterize
from rasterio.transform import from_bounds
from tqdm import tqdm

# --- Load polygons ---

gdf = gpd.read_file(f'/project/biocomplexity/wyr6fx(Nibir)/NeurIPS_irrigation_data/Agcensus/{state}_Irrigation.geojson')
gdf = gdf.to_crs("EPSG:5070")

# --- Define global raster extent from GDF ---
total_bounds = gdf.total_bounds  # [xmin, ymin, xmax, ymax]
resolution = 30  # in meters
xmin, ymin, xmax, ymax = total_bounds
width = int(np.ceil((xmax - xmin) / resolution))
height = int(np.ceil((ymax - ymin) / resolution))
transform = from_bounds(xmin, ymin, xmax, ymax, width, height)

# --- Init global rasters ---
irrigation_raster = np.zeros((height, width), dtype=np.uint8)
coverage_mask = np.zeros((height, width), dtype=np.uint8)

# --- Model inference + patch merging ---
model.eval()
for batch in tqdm(data_module.train_dataloader(), desc="Merging patches"):
    with torch.no_grad():
        polygons = batch['polygon']
        batch = {k: v.to('cuda') for k, v in batch.items() if isinstance(v, torch.Tensor)}
        preds = model(batch)['predictions'].argmax(dim=1).cpu().numpy()

        for i in range(preds.shape[0]):
            patch_mask = (preds[i] > 0).astype(np.uint8)
            poly = wkt.loads(polygons[i])
            bounds = poly.bounds

            h, w = patch_mask.shape
            patch_transform = from_bounds(*bounds, w, h)

            # Rasterize patch shape
            patch_raster = rasterize(
                [(poly, 1)],
                out_shape=(h, w),
                transform=patch_transform,
                fill=0,
                dtype=np.uint8
            )

            # Mask prediction to inside polygon
            patch_mask = patch_mask * patch_raster

            # Compute global position
            x0 = int((bounds[0] - xmin) / resolution)
            y0 = int((ymax - bounds[3]) / resolution)
            x1 = min(x0 + w, width)
            y1 = min(y0 + h, height)

            h_clip = y1 - y0
            w_clip = x1 - x0
            if h_clip <= 0 or w_clip <= 0:
                continue

            # Clip patch and update global raster
            patch_mask = patch_mask[:h_clip, :w_clip]
            irrigation_raster[y0:y1, x0:x1] = np.maximum(irrigation_raster[y0:y1, x0:x1], patch_mask)
            coverage_mask[y0:y1, x0:x1] = np.maximum(coverage_mask[y0:y1, x0:x1], patch_raster[:h_clip, :w_clip])


In [None]:
# --- Evaluate each GDF polygon if covered ---
valid_rows = []
irrigated_counts = []
gdf_group = gdf.groupby(['geometry','County'])['Irrigated Acres'].sum().reset_index()
for idx, geom in enumerate(gdf_group.geometry):
    poly_mask = rasterize(
        [(geom, 1)],
        out_shape=(height, width),
        transform=transform,
        fill=0,
        dtype=np.uint8
    )

    # Skip if polygon not covered by any patch
    if not np.any(poly_mask * coverage_mask):
        continue

    irrigated_pixels = (irrigation_raster * poly_mask).sum()
    valid_rows.append(idx)
    irrigated_counts.append(irrigated_pixels)

# --- Save filtered GeoDataFrame with counts ---
gdf_group = gdf_group.loc[valid_rows].copy()
gdf_group['irrigated_pixels'] = irrigated_counts

In [15]:
irrigation_raster.shape

(31806, 113450)

Unnamed: 0,geometry,County,Irrigated Acres
0,"POLYGON ((-1745434.141 1227920.179, -1745419.4...",YUMA,368786
1,"POLYGON ((-1413199.63 1073336.653, -1413347.47...",PIMA,42284
2,"POLYGON ((-1381055.931 1023969.623, -1381518.2...",SANTA CRUZ,680
3,"POLYGON ((-1357944.934 1068525.433, -1357908.5...",COCHISE,165368
4,"POLYGON ((-1245816.532 1215358.67, -1245798.73...",GREENLEE,9059
5,"POLYGON ((-1336936.328 1203808.587, -1336903.0...",GRAHAM,56944
6,"POLYGON ((-1420622.549 1162051.968, -1420939.1...",PINAL,472567
7,"POLYGON ((-1431952.572 1351760.221, -1431942.3...",GILA,1633
8,"POLYGON ((-1157883.284 1547341.315, -1157900.1...",APACHE,12428
9,"POLYGON ((-1248850.321 1513458.937, -1248913.9...",NAVAJO,5346
