In [None]:
import os
import torch
from torch.utils.data import Dataset
import xarray as xr
import numpy as np
import rasterio
import torchvision.transforms as transforms
from torchvision.transforms.functional import resize
from datetime import datetime

class MultiModalDataset(Dataset):
    def __init__(self, sentinel1_dir, sentinel2_dir, modis_dir, crop_dir, soil_file, weather_dir, transform=None):
        """
        Initialize the dataset using Sentinel-1 folder dates (YYYY-MM-DD) and random 1 mile x 1 mile patches,
        including all bands and variables.

        Args:
            sentinel1_dir (str): Directory for Sentinel-1 data (YYYY-MM-DD/vv.tif, vh.tif).
            sentinel2_dir (str): Directory for Sentinel-2 data (YYYY-MM-DD/B01.tif, etc.).
            modis_dir (str): Directory for MODIS data (YYYY-MM-DD/Band1.tif, etc.).
            crop_dir (str): Directory containing crop data NetCDF files (IA_year.nc).
            soil_file (str): Path to the soil NetCDF file (e.g., IA.nc) with variables 'nccpi3all', etc.
            weather_dir (str): Directory containing weather data NetCDF files (IA_year.nc).
            transform (callable, optional): Optional transform to apply to the data.
        """
        self.sentinel1_dir = sentinel1_dir
        self.sentinel2_dir = sentinel2_dir
        self.modis_dir = modis_dir
        self.crop_dir = crop_dir
        self.soil_file = soil_file if soil_file.endswith('.nc') else os.path.join(soil_file, 'IA.nc')
        self.weather_dir = weather_dir
        self.transform = transform
        self.patch_size = 224  # ViT input size
        self.mile_in_meters = 1609.34  # 1 mile in meters

        # Get week_start_dates from Sentinel-1 folder, filter for April to September
        self.week_start_dates = [
            d for d in os.listdir(sentinel1_dir)
            if os.path.isdir(os.path.join(sentinel1_dir, d)) and self._is_in_april_to_september(d)
        ]
        if not self.week_start_dates:
            raise ValueError("No Sentinel-1 data found for April to September.")

        # Load static soil data and define all variables
        self.soil_ds = xr.open_dataset(self.soil_file)
        self.soil_vars = ['nccpi3all', 'nccpi3corn', 'rootznaws', 'soc150', 'soc999', 'pctearthmc']  # 6 variables
        self.soil_height, self.soil_width = self.soil_ds['nccpi3all'].shape  # Using 'nccpi3all' as reference

        # Load crop and weather datasets per year
        self.crop_ds = {}
        self.weather_ds = {}
        unique_years = set(date.split('-')[0] for date in self.week_start_dates)
        for year in unique_years:
            crop_file = os.path.join(crop_dir, f'IA_{year}.nc')
            weather_file = os.path.join(weather_dir, f'IA_{year}.nc')
            self.crop_ds[year] = xr.open_dataset(crop_file)
            self.weather_ds[year] = xr.open_dataset(weather_file)

        # Define all bands/variables
        self.s1_bands = ['vv', 'vh']  # 2 bands
        self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']  # 13 bands
        self.modis_bands = ['Band1', 'Band2', 'Band3', 'Band4', 'Band5', 'Band6', 'Band7']  # 7 bands
        self.weather_vars = ['dayl', 'prcp', 'srad', 'swe', 'tmax', 'tmin', 'vp']  # 7 variables
        self.crop_vars = ['__xarray_dataarray_variable__']  # Placeholder; replace with actual crop variables

        # Get spatial bounds from a sample Sentinel-1 GeoTIFF
        sample_s1_path = os.path.join(sentinel1_dir, self.week_start_dates[0], 'vv.tif')
        with rasterio.open(sample_s1_path) as src:
            self.resolution = src.res[0]  # Meters per pixel (assumes square pixels)
            self.width, self.height = src.width, src.height

        # Calculate patch size in pixels (1 mile x 1 mile)
        self.patch_size_meters = int(self.mile_in_meters / self.resolution)

    def _is_in_april_to_september(self, date_str):
        """Check if date_str (YYYY-MM-DD) is between April 1st and September 30th."""
        try:
            date = datetime.strptime(date_str, '%Y-%m-%d')
            month = date.month
            day = date.day
            return (month == 4 and day >= 1) or (4 < month < 9) or (month == 9 and day <= 30)
        except ValueError:
            return False

    def _random_patch_coords(self):
        """Generate random top-left coordinates for a 1 mile x 1 mile patch."""
        x_start = np.random.randint(0, self.width - self.patch_size_meters)
        y_start = np.random.randint(0, self.height - self.patch_size_meters)
        return x_start, y_start

    def __len__(self):
        return len(self.week_start_dates)

    def __getitem__(self, idx):
        week_start_date = self.week_start_dates[idx]
        year = week_start_date.split('-')[0]
        x_start, y_start = self._random_patch_coords()
        x_end, y_end = x_start + self.patch_size_meters, y_start + self.patch_size_meters

        # 1. Sentinel-1 (all bands: vv, vh)
        s1_folder = os.path.join(self.sentinel1_dir, week_start_date)
        s1_patches = []
        for band in self.s1_bands:
            band_path = os.path.join(s1_folder, f'{band}.tif')
            with rasterio.open(band_path) as src:
                band_patch = src.read(1)[y_start:y_end, x_start:x_end]
            s1_patches.append(torch.from_numpy(band_patch).float())
        s1_tensor = torch.stack(s1_patches, dim=0)  # (2, N, N)
        s1_tensor = resize(s1_tensor, [self.patch_size, self.patch_size])  # (2, 224, 224)

        # 2. Sentinel-2 (all bands)
        s2_folder = os.path.join(self.sentinel2_dir, week_start_date)
        s2_patches = []
        for band in self.s2_bands:
            band_path = os.path.join(s2_folder, f'{band}.tif')
            with rasterio.open(band_path) as src:
                band_patch = src.read(1)[y_start:y_end, x_start:x_end]
            s2_patches.append(torch.from_numpy(band_patch).float())
        s2_tensor = torch.stack(s2_patches, dim=0)  # (13, N, N)
        s2_tensor = resize(s2_tensor, [self.patch_size, self.patch_size])  # (13, 224, 224)

        # 3. MODIS (all bands)
        modis_folder = os.path.join(self.modis_dir, week_start_date)
        modis_patches = []
        for band in self.modis_bands:
            band_path = os.path.join(modis_folder, f'{band}.tif')
            with rasterio.open(band_path) as src:
                band_patch = src.read(1)[y_start:y_end, x_start:x_end]
            modis_patches.append(torch.from_numpy(band_patch).float())
        modis_tensor = torch.stack(modis_patches, dim=0)  # (7, N, N)
        modis_tensor = resize(modis_tensor, [self.patch_size, self.patch_size])  # (7, 224, 224)

        # Convert pixel coords to NetCDF indices (assuming same resolution and alignment)
        nc_x_start = x_start
        nc_y_start = y_start
        nc_x_end = x_end
        nc_y_end = y_end

        # 4. Crop Data (all variables)
        crop_ds = self.crop_ds[year]
        time_idx = np.where(crop_ds['time'].values == np.datetime64(week_start_date))[0][0]
        crop_patches = []
        for var in self.crop_vars:  # Replace with actual variable names
            patch = crop_ds[var].isel(time=time_idx).values[nc_y_start:nc_y_end, nc_x_start:nc_x_end]
            crop_patches.append(torch.from_numpy(patch).float())
        crop_tensor = torch.stack(crop_patches, dim=0)  # (N_vars, N, N)
        crop_tensor = resize(crop_tensor, [self.patch_size, self.patch_size])  # (N_vars, 224, 224)

        # 5. Soil Data (all variables: nccpi3all, nccpi3corn, rootznaws, soc150, soc999, pctearthmc)
        soil_patches = []
        for var in self.soil_vars:
            patch = self.soil_ds[var].values[nc_y_start:nc_y_end, nc_x_start:nc_x_end]
            soil_patches.append(torch.from_numpy(patch).float())
        soil_tensor = torch.stack(soil_patches, dim=0)  # (6, N, N)
        soil_tensor = resize(soil_tensor, [self.patch_size, self.patch_size])  # (6, 224, 224)

        # 6. Weather Data (all variables: dayl, prcp, srad, swe, tmax, tmin, vp)
        weather_ds = self.weather_ds[year]
        time_idx = np.where(weather_ds['time'].values == np.datetime64(week_start_date))[0][0]
        weather_patches = []
        for var in self.weather_vars:
            patch = weather_ds[var].isel(time=time_idx).values[nc_y_start:nc_y_end, nc_x_start:nc_x_end]
            weather_patches.append(torch.from_numpy(patch).float())
        weather_tensor = torch.stack(weather_patches, dim=0)  # (7, N, N)
        weather_tensor = resize(weather_tensor, [self.patch_size, self.patch_size])  # (7, 224, 224)

        # Stack modalities into a list (variable channels per modality)
        modalities = [s1_tensor, s2_tensor, modis_tensor, crop_tensor, soil_tensor, weather_tensor]
        # Shapes: [(2, 224, 224), (13, 224, 224), (7, 224, 224), (N_crop, 224, 224), (6, 224, 224), (7, 224, 224)]

        if self.transform:
            modalities = [self.transform(m) for m in modalities]

        # Dummy label (replace with actual labels if available)
        label = 0

        return modalities, label


In [None]:
# Example usage
dataset = MultiModalDataset(
    sentinel1_dir='/work/mech-ai-scratch/rtali/gis-sentinel1/final_s1',
    sentinel2_dir='/work/mech-ai-scratch/rtali/gis-sentinel2/final_s2_v3',
    modis_dir='/work/mech-ai-scratch/rtali/gis-modis/final_modis_data',
    crop_dir='/work/mech-ai-scratch/rtali/AI_READY_IOWA/CDL/IN4326',
    soil_file='/work/mech-ai-scratch/rtali/AI_READY_IOWA/SOIL/IA.nc',
    weather_dir='/work/mech-ai-scratch/rtali/AI_READY_IOWA/WEATHER',
    transform=transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean/std per channel if needed
)

# Get channel counts
num_channels_list = [
    len(dataset.s1_bands),      # 2 (Sentinel-1)
    len(dataset.s2_bands),      # 13 (Sentinel-2)
    len(dataset.modis_bands),   # 7 (MODIS)
    len(dataset.crop_vars),     # Update with actual count
    len(dataset.soil_vars),     # 6 (Soil)
    len(dataset.weather_vars)   # 7 (Weather)
]

KeyError: "No variable named 'soil_variable'. Variables on the dataset include ['band', 'x', 'y', 'spatial_ref', 'nccpi3all', ..., 'nccpi3soy', 'rootznaws', 'soc150', 'soc999', 'pctearthmc']"