In [None]:
import os
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from tqdm import tqdm
import numpy as np
from datetime import datetime
from dateutil.relativedelta import relativedelta

def get_file_paths(folder_path):
    file_paths = []
    for root, _, files in os.walk(folder_path):
        for file in files:
            full_path = os.path.abspath(os.path.join(root, file))
            file_paths.append(full_path)
    return file_paths

def resample_to_fixed_cell_size(src_path, dst_path, crs, width, height, cell_size=30):
    """
    Resample a raster to fixed dimensions and convert to float32 with -9999 nodata.
    
    Args:
        src_path: Path to the source raster
        dst_path: Path to save the resampled raster
        crs: Target coordinate reference system
        width: Fixed width to use for all rasters
        height: Fixed height to use for all rasters
        cell_size: Target cell size in CRS units
    """
    try:
        with rasterio.open(src_path) as src:
            # Store original nodata value
            original_nodata = src.nodata
            
            # Calculate bounds in the target CRS
            west, south, east, north = rasterio.warp.transform_bounds(
                src.crs, crs, *src.bounds
            )
            
            # Use the specified fixed dimensions
            # Note: width and height are already ensured to be divisible by 128 in the calling function
            
            # Calculate the transformation matrix
            dst_transform = rasterio.transform.from_bounds(
                west, south, east, north, width, height
            )
            
            # Update profile for the new raster
            dst_kwargs = src.profile.copy()
            dst_kwargs.update({
                'crs': crs,
                'transform': dst_transform,
                'width': width,
                'height': height,
                'dtype': 'int16',  # Set dtype to float32
                'nodata': 0    # Set nodata to -9999
            })
            
            # Create directory if it doesn't exist
            os.makedirs(os.path.dirname(dst_path), exist_ok=True)
            
            # Create the new raster
            with rasterio.open(dst_path, 'w', **dst_kwargs) as dst:
                # Reproject and resample
                for i in range(1, src.count + 1):
                    # Read the source band
                    source_data = src.read(i)
                    
                    # Create a destination array filled with the nodata value
                    dest_data = np.full(
                        (dst_kwargs['height'], dst_kwargs['width']), 
                        0, 
                        dtype='int16'
                    )
                    
                    # Reproject with specified parameters
                    reproject(
                        source=source_data,
                        destination=dest_data,
                        src_transform=src.transform,
                        src_crs=src.crs,
                        dst_transform=dst_transform,
                        dst_crs=crs,
                        src_nodata=original_nodata,
                        dst_nodata=-0,
                        resampling=Resampling.nearest,
                        num_threads=4
                    )
                    
                    # Write the result
                    dst.write(dest_data, i)
    except Exception as e:
        print(f"Error processing {src_path}: {e}")

def processScene(albedo_files):
    # Process each scene
    for albedo_path in tqdm(albedo_files, desc='Preprocessing images...'):
        scene_dir = os.path.dirname(albedo_path)
        scene_files = [f for f in os.listdir(scene_dir) if os.path.isfile(os.path.join(scene_dir, f))]
        
        # Get all raster paths for this scene
        raster_paths = []
        for raster_file in scene_files:
            src_path = os.path.join(scene_dir, raster_file)
            raster_paths.append(src_path)
        
        # Get reference CRS from the first raster
        with rasterio.open(raster_paths[0]) as src:
            reference_crs = src.crs
        
        # Find minimum dimensions across all rasters in the scene
        min_width = float('inf')
        min_height = float('inf')
        cell_size = 30
        for path in raster_paths:
            try:
                with rasterio.open(path) as scene_src:
                    # Calculate bounds in the target CRS
                    west, south, east, north = rasterio.warp.transform_bounds(
                        scene_src.crs, reference_crs, *scene_src.bounds
                    )
                    
                    # Calculate dimensions based on cell size
                    width = max(int(round((east - west) / cell_size)), 1)
                    height = max(int(round((north - south) / cell_size)), 1)
                    
                    # Update minimum dimensions
                    min_width = min(min_width, width)
                    min_height = min(min_height, height)
            except Exception as e:
                print(f"Error reading dimensions from {path}: {e}")
                continue
        
        # Ensure dimensions are divisible by 128
        min_width = ((min_width + 127) // 128) * 128
        min_height = ((min_height + 127) // 128) * 128
        
        # Process each raster with the standardized dimensions
        for k, src_path in enumerate(raster_paths):
            dst_path = src_path.replace('Cities/', 'Cities_Preprocessed/')        
            dst_path = dst_path.replace('DEM_2014/', 'DEM_2014_Preprocessed/') 
            dst_path = dst_path.replace('Dataset/', 'ML/')
            # Skip if already processed
            if os.path.exists(dst_path):
                continue
            
            resample_to_fixed_cell_size(src_path, dst_path, reference_crs, min_width, min_height, cell_size)

def preprocessImages(data_dir: str, debug: bool):
    albedo_files = []
    dem_files = []
    
    # Find all albedo files from 2014
    for file_path in tqdm(get_file_paths(data_dir), desc='Gathering scenes (Preprocessing)...'):
        date = file_path.split('/')[-2]
        if 'DEM.tif' in file_path:
            dem_files.append(file_path)
        if debug and date[:4] not in ["2013", "2014"]:#, "2015", "2016", "2017"]:
            # print(date[:4])
            continue
        if 'albedo' in file_path:
            albedo_files.append(file_path)
    processScene(albedo_files)
    processScene(dem_files)

In [None]:
import os
import glob
from pathlib import Path
import numpy as np
import rioxarray as rxr
import xarray as xr
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def calculate_nodata_percentage(tiff_path):
    """
    Calculate the percentage of NODATA values in a TIFF file using rioxarray.
    
    Args:
        tiff_path (str): Path to the TIFF file
        
    Returns:
        float: Percentage of NODATA values (0-100)
        None: If file cannot be processed
    """
    try:
        # Open the dataset with rioxarray
        dataset = rxr.open_rasterio(tiff_path, masked=True)
        
        if dataset is None:
            logger.warning(f"Could not open file: {tiff_path}")
            return None
        
        total_pixels = 0
        nodata_pixels = 0
        
        # Process each band (dimension in xarray)
        for band_idx in range(dataset.sizes['band']):
            band_data = dataset.isel(band=band_idx)
            
            # Count total pixels in this band
            band_total_pixels = band_data.size
            total_pixels += band_total_pixels
            
            # Count NODATA pixels
            # rioxarray automatically masks NODATA values when masked=True
            if hasattr(band_data, 'mask') and band_data.mask is not None:
                # Count masked (NODATA) pixels
                band_nodata_pixels = np.sum(band_data.mask)
            else:
                # Check for NaN values if no mask is present
                band_nodata_pixels = np.sum(np.isnan(band_data.values))
            
            nodata_pixels += band_nodata_pixels
        
        # Close the dataset
        dataset.close()
        
        # Calculate percentage
        if total_pixels > 0:
            percentage = (nodata_pixels / total_pixels) * 100
            return percentage
        else:
            return 0.0
            
    except Exception as e:
        logger.error(f"Error processing {tiff_path}: {str(e)}")
        return None

def find_all_tiff_files(root_directory):
    """
    Recursively find all TIFF files in the given directory and its subdirectories.
    
    Args:
        root_directory (str): Root directory to search
        
    Returns:
        list: List of paths to TIFF files
    """
    tiff_extensions = ['*.tif', '*.tiff', '*.TIF', '*.TIFF']
    tiff_files = []
    
    root_path = Path(root_directory)
    
    for extension in tiff_extensions:
        # Use recursive glob to find all files with this extension
        pattern = f"**/{extension}"
        files = list(root_path.glob(pattern))
        tiff_files.extend([str(f) for f in files])
    
    return tiff_files

def calculate_mean_nodata_percentage(dataset_path):
    """
    Calculate the mean percentage of NODATA values across all TIFF files in a dataset.
    
    Args:
        dataset_path (str): Path to the dataset directory
        
    Returns:
        dict: Dictionary containing statistics
    """
    logger.info(f"Searching for TIFF files in: {dataset_path}")
    
    # Find all TIFF files
    tiff_files = find_all_tiff_files(dataset_path)
    
    if not tiff_files:
        logger.warning("No TIFF files found in the specified directory.")
        return {
            'mean_nodata_percentage': 0.0,
            'total_files': 0,
            'processed_files': 0,
            'failed_files': 0,
            'individual_percentages': []
        }
    
    logger.info(f"Found {len(tiff_files)} TIFF files")
    
    # Calculate NODATA percentage for each file
    nodata_percentages = []
    failed_files = 0
    
    for i, tiff_path in enumerate(tiff_files, 1):
        logger.info(f"Processing file {i}/{len(tiff_files)}: {os.path.basename(tiff_path)}")
        
        percentage = calculate_nodata_percentage(tiff_path)
        
        if percentage is not None:
            nodata_percentages.append(percentage)
            logger.info(f"  NODATA percentage: {percentage:.2f}%")
        else:
            failed_files += 1
            logger.warning(f"  Failed to process file")
    
    # Calculate statistics
    if nodata_percentages:
        mean_percentage = np.mean(nodata_percentages)
        std_percentage = np.std(nodata_percentages)
        min_percentage = np.min(nodata_percentages)
        max_percentage = np.max(nodata_percentages)
        
        results = {
            'mean_nodata_percentage': mean_percentage,
            'std_nodata_percentage': std_percentage,
            'min_nodata_percentage': min_percentage,
            'max_nodata_percentage': max_percentage,
            'total_files': len(tiff_files),
            'processed_files': len(nodata_percentages),
            'failed_files': failed_files,
            'individual_percentages': nodata_percentages
        }
        
        logger.info(f"\n--- RESULTS ---")
        logger.info(f"Total files found: {results['total_files']}")
        logger.info(f"Successfully processed: {results['processed_files']}")
        logger.info(f"Failed to process: {results['failed_files']}")
        logger.info(f"Mean NODATA percentage: {results['mean_nodata_percentage']:.2f}%")
        logger.info(f"Standard deviation: {results['std_nodata_percentage']:.2f}%")
        logger.info(f"Minimum NODATA percentage: {results['min_nodata_percentage']:.2f}%")
        logger.info(f"Maximum NODATA percentage: {results['max_nodata_percentage']:.2f}%")
        
        return results
    else:
        logger.error("No files were successfully processed.")
        return {
            'mean_nodata_percentage': 0.0,
            'total_files': len(tiff_files),
            'processed_files': 0,
            'failed_files': failed_files,
            'individual_percentages': []
        }

def main():
    """
    Main function to run the NODATA percentage calculator.
    """
    # You can modify this path to point to your dataset directory
    dataset_path = "./Data/Dataset/Cities"
    
    if not os.path.exists(dataset_path):
        logger.error(f"Directory does not exist: {dataset_path}")
        return
    
    if not os.path.isdir(dataset_path):
        logger.error(f"Path is not a directory: {dataset_path}")
        return
    
    # Calculate mean NODATA percentage
    results = calculate_mean_nodata_percentage(dataset_path)
        
    with open("out.txt", 'a') as f:
        f.write("TIFF NODATA Percentage Analysis Results\n")
        f.write("=" * 40 + "\n\n")
        f.write(f"Dataset Path: {dataset_path}\n")
        f.write(f"Total files found: {results['total_files']}\n")
        f.write(f"Successfully processed: {results['processed_files']}\n")
        f.write(f"Failed to process: {results['failed_files']}\n")
        f.write(f"Mean NODATA percentage: {results['mean_nodata_percentage']:.2f}%\n")
        if 'std_nodata_percentage' in results:
            f.write(f"Standard deviation: {results['std_nodata_percentage']:.2f}%\n")
            f.write(f"Minimum NODATA percentage: {results['min_nodata_percentage']:.2f}%\n")
            f.write(f"Maximum NODATA percentage: {results['max_nodata_percentage']:.2f}%\n")
        f.write(f"\nIndividual file percentages:\n")
        for i, percentage in enumerate(results['individual_percentages'], 1):
            f.write(f"File {i}: {percentage:.2f}%\n")
        
        logger.info(f"Results saved to: out.txt")

if __name__ == "__main__":
    main()

In [None]:
preprocessImages("./Data/Dataset", False)

In [None]:
import os
import shutil
import xarray as xr
import rioxarray as rxr
import numpy as np
from datetime import datetime
from dateutil.relativedelta import relativedelta
from tqdm import tqdm
from collections import defaultdict
import pandas as pd

def interpolateLinearScenePixels(data_dir="./Data/ML/Cities_Preprocessed", monthSpan=6, tolerance_for_missing=0.4):
    '''
    Take scenes from Cities_Preprocessed and perform linear interpolation for missing scenes
    within month spans. For each missing month, interpolate between the nearest original 
    scenes before and after it, doing linear interpolation per pixel.
    '''
    output_dir = "./Data/ML/Cities_Processed"
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # File types to process
    file_types = ['albedo.tif', 'blue.tif', 'green.tif', 'LST.tif', 
                  'ndbi.tif', 'ndvi.tif', 'ndwi.tif', 'red.tif']
    
    # Get list of cities
    cities = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    # Iterate through each city with progress bar
    for city_name in tqdm(cities, desc='Processing cities'):
        city_path = os.path.join(data_dir, city_name)
        
        # First, copy all original scenes to output
        output_city_path = os.path.join(output_dir, city_name)
        if os.path.exists(output_city_path):
            shutil.rmtree(output_city_path)
        shutil.copytree(city_path, output_city_path)
        
        # Collect all scene dates and paths
        original_scenes = []
        for scene_folder in os.listdir(city_path):
            scene_path = os.path.join(city_path, scene_folder)
            if not os.path.isdir(scene_path):
                continue
                
            try:
                date_str = scene_folder
                date_obj = datetime.fromisoformat(date_str.replace('Z', '+00:00'))
                original_scenes.append((date_obj, scene_folder, scene_path))
            except ValueError:
                continue
        
        # Sort scenes by date
        original_scenes.sort(key=lambda x: x[0])
        
        if len(original_scenes) < 2:
            continue
        
        # Generate monthly grid starting from the first original scene's month (not January)
        start_date = original_scenes[0][0]
        end_date = original_scenes[-1][0]
        
        # Start from the month of the first original scene, not January
        start_year = start_date.year
        start_month = start_date.month
        end_year = end_date.year
        end_month = end_date.month
        
        # Create a monthly grid from the first scene's month to the last scene's month
        first_original_date = original_scenes[0][0]
        current_date = first_original_date.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
        monthly_grid = []
        
        # Only create grid from first original scene month to last original scene month
        while (current_date.year < end_year) or (current_date.year == end_year and current_date.month <= end_month):
            monthly_grid.append(current_date)
            current_date = current_date + relativedelta(months=1)
        
        # Track which months have original data (not interpolated)
        original_months = set()
        for date_obj, _, _ in original_scenes:
            month_key = date_obj.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
            original_months.add(month_key)
        
        # Track which months we've already interpolated to avoid duplicates
        interpolated_months = set()
        
        # Count total interpolation tasks for progress bar
        total_tasks = 0
        for start_idx in range(len(monthly_grid) - monthSpan + 1):
            span_months = monthly_grid[start_idx:start_idx + monthSpan]
            original_count = sum(1 for month in span_months if month in original_months)
            missing_count = monthSpan - original_count
            missing_ratio = missing_count / monthSpan
            missing_months_in_span = [month for month in span_months 
                                    if month not in original_months and month not in interpolated_months]
            
            if missing_ratio <= tolerance_for_missing and original_count > 0 and len(missing_months_in_span) > 0:
                total_tasks += len(missing_months_in_span) * len(file_types)
        
        # Create a mapping from month to original scene for quick lookup
        month_to_scene = {}
        for date_obj, scene_folder, original_scene_path in original_scenes:
            scene_month = date_obj.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
            # Use the copied scene path in Cities_Processed instead of the original path
            copied_scene_path = os.path.join(output_city_path, scene_folder)
            month_to_scene[scene_month] = (date_obj, scene_folder, copied_scene_path)
        
        # Iterate through all possible monthSpan windows
        with tqdm(total=total_tasks, desc=f'Interpolating {city_name}', leave=False) as pbar:
            for start_idx in range(len(monthly_grid) - monthSpan + 1):
                span_months = monthly_grid[start_idx:start_idx + monthSpan]
                
                # Count how many months in this span have original data
                original_count = sum(1 for month in span_months if month in original_months)
                missing_count = monthSpan - original_count
                missing_ratio = missing_count / monthSpan
                
                # Check if there are any missing months in this span that we haven't interpolated yet
                missing_months_in_span = [month for month in span_months 
                                        if month not in original_months and month not in interpolated_months]
                
                # If missing ratio is within tolerance and we have missing months to fill
                if missing_ratio <= tolerance_for_missing and original_count > 0 and len(missing_months_in_span) > 0:
                    
                    # For each missing month, find the nearest original scenes before and after
                    for missing_month in missing_months_in_span:
                        
                        # Find the nearest original scenes before and after this missing month
                        before_scene = None
                        after_scene = None
                        
                        # Look for the closest original scene before this missing month
                        for month in reversed(monthly_grid):
                            if month < missing_month and month in original_months:
                                before_scene = month_to_scene[month]
                                break
                        
                        # Look for the closest original scene after this missing month
                        for month in monthly_grid:
                            if month > missing_month and month in original_months:
                                after_scene = month_to_scene[month]
                                break
                        
                        # Skip if we don't have both before and after scenes
                        if before_scene is None or after_scene is None:
                            pbar.update(len(file_types))
                            continue
                        
                        # Calculate temporal weights for linear interpolation
                        before_date = before_scene[0]
                        after_date = after_scene[0]
                        missing_date = missing_month.replace(day=15, hour=12, minute=0, second=0)
                        
                        # Calculate interpolation weights based on temporal distance
                        total_duration = (after_date - before_date).total_seconds()
                        if total_duration == 0:
                            pbar.update(len(file_types))
                            continue
                        
                        missing_duration = (missing_date - before_date).total_seconds()
                        weight_after = missing_duration / total_duration
                        weight_before = 1.0 - weight_after
                        
                        # For each file type, perform linear interpolation
                        for file_type in file_types:
                            try:
                                # Load the before and after scenes
                                before_file = os.path.join(before_scene[2], file_type)
                                after_file = os.path.join(after_scene[2], file_type)
                                
                                if not (os.path.exists(before_file) and os.path.exists(after_file)):
                                    pbar.update(1)
                                    continue
                                
                                before_data = rxr.open_rasterio(before_file, chunks=True)
                                after_data = rxr.open_rasterio(after_file, chunks=True)
                                
                                # Mask nodata values (0 is nodata)
                                before_masked = before_data.where(before_data != 0)
                                after_masked = after_data.where(after_data != 0)
                                
                                # Perform linear interpolation per pixel
                                # Only interpolate where both before and after have valid data
                                valid_mask = (~before_masked.isnull()) & (~after_masked.isnull())
                                
                                # Linear interpolation: before * weight_before + after * weight_after
                                interpolated = (before_masked * weight_before + after_masked * weight_after)
                                
                                # Where either before or after is nodata, set result to nodata
                                interpolated = interpolated.where(valid_mask, 0)
                                
                                # Ensure data type consistency and set nodata
                                interpolated = interpolated.astype('int16')
                                interpolated.rio.write_nodata(0, inplace=True)
                                
                                # Create output scene directory
                                scene_folder_name = missing_date.strftime('%Y-%m-%dT%H:%M:%SZ')
                                output_scene_path = os.path.join(output_city_path, scene_folder_name)
                                os.makedirs(output_scene_path, exist_ok=True)
                                
                                # Save interpolated file
                                output_file_path = os.path.join(output_scene_path, file_type)
                                interpolated.rio.to_raster(output_file_path, dtype='int16', nodata=0)
                                
                                # Clean up memory
                                before_data.close()
                                after_data.close()
                                del before_data, after_data, before_masked, after_masked, interpolated
                                
                                pbar.update(1)
                                
                            except Exception as e:
                                pbar.update(1)
                                continue
                    
                    # Mark these months as interpolated
                    interpolated_months.update(missing_months_in_span)
        
        # Save list of interpolated months to a text file
        if interpolated_months:
            interpolated_file_path = os.path.join("./Data/ML/", "interpolated.txt")
            with open(interpolated_file_path, 'a') as f:
                # Sort interpolated months chronologically
                sorted_interpolated = sorted(interpolated_months)
                for month in sorted_interpolated:
                    # Format as timestamp similar to scene folder names
                    timestamp = month.replace(day=15, hour=12, minute=0, second=0).strftime('%Y-%m-%dT%H:%M:%SZ')
                    f.write(f"{city_name}/{timestamp}\n")

In [None]:
interpolateLinearScenePixels()

In [None]:
import os
import numpy as np
import xarray as xr
import rioxarray as rxr
from datetime import datetime
from tqdm import tqdm
import pandas as pd

def validate_linear_interpolation(data_dir="./Data/Dataset/Cities_Processed", 
                                interpolated_file_path="./Data/Dataset/interpolated.txt",
                                sample_pixels=100, 
                                tolerance=1e-6):
    """
    Validate that interpolated scenes are indeed linear interpolations between their temporal neighbors.
    
    Parameters:
    - data_dir: Directory containing processed cities with interpolated scenes
    - interpolated_file_path: Path to the global interpolated.txt file
    - sample_pixels: Number of random pixels to test per scene (to avoid memory issues)
    - tolerance: Numerical tolerance for floating point comparisons
    
    Returns:
    - Dictionary with validation results for each city
    """
    
    file_types = ['albedo.tif', 'blue.tif', 'green.tif', 'LST.tif', 
                  'ndbi.tif', 'ndvi.tif', 'ndwi.tif', 'red.tif']
    
    # Read the global interpolated.txt file
    if not os.path.exists(interpolated_file_path):
        print(f"Global interpolated.txt not found at {interpolated_file_path}")
        return {}
    
    with open(interpolated_file_path, 'r') as f:
        interpolated_entries = [line.strip() for line in f.readlines()]
    
    # Group interpolated entries by city
    city_interpolated = {}
    for entry in interpolated_entries:
        if '/' in entry:
            city_name, timestamp = entry.split('/', 1)
            if city_name not in city_interpolated:
                city_interpolated[city_name] = []
            city_interpolated[city_name].append(timestamp)
    
    validation_results = {}
    
    # Process each city that has interpolated scenes
    for city_name, interpolated_timestamps in tqdm(city_interpolated.items(), desc='Validating cities'):
        city_path = os.path.join(data_dir, city_name)
        
        if not os.path.isdir(city_path):
            print(f"City directory not found: {city_path}")
            continue
        
        # Get all scenes in the city
        all_scenes = []
        for scene_folder in os.listdir(city_path):
            scene_path = os.path.join(city_path, scene_folder)
            if os.path.isdir(scene_path) and scene_folder != "interpolated.txt":
                try:
                    date_obj = datetime.fromisoformat(scene_folder.replace('Z', '+00:00'))
                    all_scenes.append((date_obj, scene_folder, scene_path))
                except ValueError:
                    continue
        
        # Sort scenes by date
        all_scenes.sort(key=lambda x: x[0])
        
        city_results = {
            'total_interpolated': len(interpolated_timestamps),
            'validated_scenes': 0,
            'failed_scenes': 0,
            'validation_details': []
        }
        
        # Validate each interpolated scene
        for interp_timestamp in tqdm(interpolated_timestamps, desc=f'Validating {city_name}', leave=False):
            interp_date = datetime.fromisoformat(interp_timestamp.replace('Z', '+00:00'))
            
            # Find the interpolated scene
            interp_scene_path = None
            for date_obj, scene_folder, scene_path in all_scenes:
                if scene_folder == interp_timestamp:
                    interp_scene_path = scene_path
                    break
            
            if interp_scene_path is None:
                print(f"Could not find interpolated scene {interp_timestamp} for {city_name}")
                continue
            
            # Find the nearest original scenes before and after
            before_scene = None
            after_scene = None
            
            # Find closest original scene before
            for date_obj, scene_folder, scene_path in reversed(all_scenes):
                if (date_obj < interp_date and 
                    scene_folder not in interpolated_timestamps):
                    before_scene = (date_obj, scene_folder, scene_path)
                    break
            
            # Find closest original scene after
            for date_obj, scene_folder, scene_path in all_scenes:
                if (date_obj > interp_date and 
                    scene_folder not in interpolated_timestamps):
                    after_scene = (date_obj, scene_folder, scene_path)
                    break
            
            if before_scene is None or after_scene is None:
                print(f"Could not find before/after scenes for {interp_timestamp} in {city_name}")
                city_results['failed_scenes'] += 1
                continue
            
            # Calculate expected interpolation weights
            before_date = before_scene[0]
            after_date = after_scene[0]
            total_duration = (after_date - before_date).total_seconds()
            missing_duration = (interp_date - before_date).total_seconds()
            weight_after = missing_duration / total_duration
            weight_before = 1.0 - weight_after
            
            scene_validation = {
                'timestamp': interp_timestamp,
                'before_scene': before_scene[1],
                'after_scene': after_scene[1],
                'weight_before': weight_before,
                'weight_after': weight_after,
                'file_results': {}
            }
            
            scene_valid = True
            
            # Test each file type
            for file_type in file_types:
                before_file = os.path.join(before_scene[2], file_type)
                after_file = os.path.join(after_scene[2], file_type)
                interp_file = os.path.join(interp_scene_path, file_type)
                
                if not all(os.path.exists(f) for f in [before_file, after_file, interp_file]):
                    scene_validation['file_results'][file_type] = {'status': 'missing_files'}
                    continue
                
                try:
                    # Load the data
                    before_data = rxr.open_rasterio(before_file, chunks=True)
                    after_data = rxr.open_rasterio(after_file, chunks=True)
                    interp_data = rxr.open_rasterio(interp_file, chunks=True)
                    
                    # Get data shape
                    height, width = before_data.shape[1], before_data.shape[2]
                    
                    # Sample random pixels to test
                    np.random.seed(42)  # For reproducibility
                    test_pixels = min(sample_pixels, height * width)
                    
                    if height * width > 0:
                        # Generate random pixel coordinates
                        pixel_indices = np.random.choice(height * width, test_pixels, replace=False)
                        row_indices = pixel_indices // width
                        col_indices = pixel_indices % width
                        
                        # Extract pixel values
                        before_values = before_data.values[0, row_indices, col_indices]
                        after_values = after_data.values[0, row_indices, col_indices]
                        interp_values = interp_data.values[0, row_indices, col_indices]
                        
                        # Test linear interpolation
                        valid_mask = (before_values != 0) & (after_values != 0) & (interp_values != 0)
                        
                        if np.sum(valid_mask) > 0:
                            expected_values = (before_values[valid_mask] * weight_before + 
                                             after_values[valid_mask] * weight_after)
                            actual_values = interp_values[valid_mask]
                            
                            # Check if values match within tolerance
                            differences = np.abs(expected_values - actual_values)
                            max_diff = np.max(differences)
                            mean_diff = np.mean(differences)
                            
                            is_valid = max_diff <= tolerance
                            
                            scene_validation['file_results'][file_type] = {
                                'status': 'valid' if is_valid else 'invalid',
                                'tested_pixels': np.sum(valid_mask),
                                'max_difference': float(max_diff),
                                'mean_difference': float(mean_diff),
                                'tolerance': tolerance
                            }
                            
                            if not is_valid:
                                scene_valid = False
                                print(f"Validation failed for {city_name}/{interp_timestamp}/{file_type}")
                                print(f"  Max difference: {max_diff}, Mean difference: {mean_diff}")
                        else:
                            scene_validation['file_results'][file_type] = {
                                'status': 'no_valid_pixels'
                            }
                    
                    # Clean up
                    before_data.close()
                    after_data.close()
                    interp_data.close()
                    
                except Exception as e:
                    scene_validation['file_results'][file_type] = {
                        'status': 'error',
                        'error': str(e)
                    }
                    scene_valid = False
            
            if scene_valid:
                city_results['validated_scenes'] += 1
            else:
                city_results['failed_scenes'] += 1
            
            city_results['validation_details'].append(scene_validation)
        
        validation_results[city_name] = city_results
        
        # Print summary for this city
        total = city_results['validated_scenes'] + city_results['failed_scenes']
        if total > 0:
            success_rate = city_results['validated_scenes'] / total * 100
            print(f"{city_name}: {city_results['validated_scenes']}/{total} scenes validated successfully ({success_rate:.1f}%)")
    
    return validation_results

def print_validation_summary(validation_results):
    """Print a summary of validation results"""
    print("\n" + "="*50)
    print("VALIDATION SUMMARY")
    print("="*50)
    
    total_cities = len(validation_results)
    total_interpolated = sum(r['total_interpolated'] for r in validation_results.values())
    total_validated = sum(r['validated_scenes'] for r in validation_results.values())
    total_failed = sum(r['failed_scenes'] for r in validation_results.values())
    
    print(f"Cities processed: {total_cities}")
    print(f"Total interpolated scenes: {total_interpolated}")
    print(f"Successfully validated: {total_validated}")
    print(f"Failed validation: {total_failed}")
    
    if total_interpolated > 0:
        success_rate = total_validated / (total_validated + total_failed) * 100
        print(f"Overall success rate: {success_rate:.1f}%")
    
    print("\nPer-city results:")
    for city, results in validation_results.items():
        total = results['validated_scenes'] + results['failed_scenes']
        if total > 0:
            rate = results['validated_scenes'] / total * 100
            print(f"  {city}: {results['validated_scenes']}/{total} ({rate:.1f}%)")

# Example usage
if __name__ == "__main__":
    # Run validation
    results = validate_linear_interpolation(
        data_dir="./Data/ML/Cities_Processed",
        interpolated_file_path="./Data/ML/interpolated.txt",
        sample_pixels=100,
        tolerance=1.0  # Allow 1 unit difference due to integer rounding
    )
    
    # Print summary
    print_validation_summary(results)
    
    # Optionally save detailed results
    import json
    with open('validation_results.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)

In [None]:
import rasterio

def check_tif_divisibility(tif_path, divisor=128):
    """
    Check if a TIF file's dimensions are divisible by a given value.
    
    Args:
        tif_path (str): Path to the TIF file
        divisor (int): Value to check divisibility against
        
    Returns:
        dict: Results of the check
    """
    with rasterio.open(tif_path) as src:
        width = src.width
        height = src.height
    
    height_remainder = height % divisor
    width_remainder = width % divisor
    
    return {
        'is_divisible': (height_remainder == 0 and width_remainder == 0),
        'height': height,
        'width': width,
        'height_remainder': height_remainder,
        'width_remainder': width_remainder
    }

result = check_tif_divisibility("/root/projects/STAC/Data/Dataset/DEM_2014_Preprocessed/Abilene_TX/DEM.tif")
if result['is_divisible']:
    print(f"TIF is divisible by 128: {result['height']}x{result['width']}")
else:
    print(f"TIF not divisible by 128. Dimensions: {result['height']}x{result['width']}")
    print(f"Remainders: {result['height_remainder']}x{result['width_remainder']}")

In [None]:
from tqdm import tqdm
import os
def list_files_in_folder(folder_path):
    files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if
             os.path.isfile(os.path.join(folder_path, f))]
    return files

def get_file_paths(folder_path):
    file_paths = []
    for root, _, files in os.walk(folder_path):
        for file in files:
            full_path = os.path.abspath(os.path.join(root, file))
            file_paths.append(full_path)
    return file_paths
file_list = []
allAlbedoPixelFiles = []
allDEMPixelFiles = []
for filePath in get_file_paths('./Data/Dataset'):
    if 'albedo' in filePath:
        allAlbedoPixelFiles.append(filePath)
    if 'DEM' in filePath:
        allDEMPixelFiles.append(filePath)
for xPath in tqdm(allAlbedoPixelFiles, desc="Packing to dictionary..."):
    sceneFiles = list_files_in_folder(os.path.dirname(os.path.abspath(xPath)))
    rasterDict = {}
    for rasterPath in sceneFiles:
        rasterName = rasterPath.split('/')[-1]
        rasterDict[rasterName] = rasterPath
    file_list.append(rasterDict)
for xPath in tqdm(allDEMPixelFiles, desc="Packing to dictionary..."):
    sceneFiles = list_files_in_folder(os.path.dirname(os.path.abspath(xPath)))
    rasterDict = {}
    for rasterPath in sceneFiles:
        rasterName = rasterPath.split('/')[-1]
        rasterDict[rasterName] = rasterPath
    file_list.append(rasterDict)

In [None]:
file_list

In [None]:
import rasterio
import json

def get_tif_ranges(data_list):
    ranges = {
        'red': {'min': float('inf'), 'max': float('-inf')},
        'ndwi': {'min': float('inf'), 'max': float('-inf')},
        'ndvi': {'min': float('inf'), 'max': float('-inf')},
        'ndbi': {'min': float('inf'), 'max': float('-inf')},
        'LST': {'min': float('inf'), 'max': float('-inf')},
        'green': {'min': float('inf'), 'max': float('-inf')},
        'blue': {'min': float('inf'), 'max': float('-inf')},
        'DEM': {'min': float('inf'), 'max': float('-inf')},
        'albedo': {'min': float('inf'), 'max': float('-inf')}
    }

    for item in tqdm(data_list, desc="Getting Ranges"):
        for tif_type, path in item.items():
            with rasterio.open(path) as src:
                data = src.read(1)
                tif_name = tif_type.split('.')[0]
                minimum = min(ranges[tif_name]['min'], float(data[data != 0].min()))
                maximum = max(ranges[tif_name]['max'], float(data.max()))
                ranges[tif_name]['min'] = minimum
                ranges[tif_name]['max'] = maximum
    return ranges

ranges = get_tif_ranges(file_list)
print(json.dumps(ranges, indent=4))

In [None]:
import os
import rasterio
from rasterio.windows import Window
from rasterio.transform import from_bounds
import numpy as np
from tqdm import tqdm
from pathlib import Path
import math

def convert_to_tiles(data_dir: str, tile_size: int = 128, overlap: int = 0):
    """
    Divide preprocessed rasters into 128x128 tiles while maintaining coordinate systems.
    
    Args:
        data_dir: Path to Dataset directory containing Cities_Preprocessed and DEM_2014_Preprocessed
        tile_size: Size of each tile (default 128x128)
        overlap: Overlap between tiles in pixels (default 0)
    
    Output structure:
        Dataset/
        ├── Cities_Tiles/
        │   ├── <city>/
        │   │   ├── <timestamp>/
        │   │   │   ├── LST_row_0_col_0.tif
        │   │   │   ├── red_row_0_col_1.tif
        │   │   │   └── ...
        └── DEM_2014_Tiles/
            ├── <city>/
            │   ├── DEM_row_0_col_0.tif
            │   └── ...
    """
    data_path = Path(data_dir)
    
    # Calculate stride (step size between tiles)
    stride = tile_size - overlap
    
    print(f"Converting to {tile_size}x{tile_size} tiles with {overlap}px overlap...")
    
    # Process Cities data
    cities_preprocessed = data_path / "Cities_Processed"
    cities_tiles = data_path / "Cities_Tiles"
    
    if cities_preprocessed.exists():
        print("Processing Cities data...")
        _process_cities_tiles(cities_preprocessed, cities_tiles, tile_size, stride)
    
    # Process DEM data
    dem_preprocessed = data_path / "DEM_2014_Preprocessed"
    dem_tiles = data_path / "DEM_2014_Tiles"
    
    if dem_preprocessed.exists():
        print("Processing DEM data...")
        _process_dem_tiles(dem_preprocessed, dem_tiles, tile_size, stride)
    
    print("Tiling complete!")

def _process_cities_tiles(cities_preprocessed: Path, cities_tiles: Path, tile_size: int, stride: int):
    """Process Cities_Preprocessed into tiles"""
    
    for city_dir in tqdm(list(cities_preprocessed.iterdir()), desc="Processing cities"):
        if not city_dir.is_dir():
            continue
            
        city_name = city_dir.name
        output_city_dir = cities_tiles / city_name
        
        # Process each timestamp
        for timestamp_dir in city_dir.iterdir():
            if not timestamp_dir.is_dir():
                continue
                
            timestamp = timestamp_dir.name
            output_timestamp_dir = output_city_dir / timestamp
            output_timestamp_dir.mkdir(parents=True, exist_ok=True)
            
            # Get all .tif files in this timestamp
            tif_files = list(timestamp_dir.glob("*.tif"))
            
            if not tif_files:
                continue
            
            # Use first file to determine grid dimensions and spatial reference
            reference_file = tif_files[0]
            
            with rasterio.open(reference_file) as src:
                height, width = src.height, src.width
                transform = src.transform
                crs = src.crs
                
                # Calculate number of tiles
                n_rows = math.ceil((height - tile_size) / stride) + 1 if height > tile_size else 1
                n_cols = math.ceil((width - tile_size) / stride) + 1 if width > tile_size else 1
                
                # Ensure we don't go beyond image boundaries
                n_rows = min(n_rows, math.ceil(height / stride))
                n_cols = min(n_cols, math.ceil(width / stride))
            
            # Process each .tif file in this timestamp
            for tif_file in tif_files:
                band_name = tif_file.stem  # e.g., 'LST', 'red', 'ndvi', etc.
                
                with rasterio.open(tif_file) as src:
                    # Create tiles for this band
                    for row in range(n_rows):
                        for col in range(n_cols):
                            # Calculate tile boundaries
                            start_row = row * stride
                            start_col = col * stride
                            
                            # Ensure we don't exceed image boundaries
                            end_row = min(start_row + tile_size, height)
                            end_col = min(start_col + tile_size, width)
                            
                            # Adjust start positions if tile would be too small
                            if end_row - start_row < tile_size:
                                start_row = max(0, end_row - tile_size)
                            if end_col - start_col < tile_size:
                                start_col = max(0, end_col - tile_size)
                            
                            # Final tile dimensions
                            tile_height = end_row - start_row
                            tile_width = end_col - start_col
                            
                            # Skip if tile is too small
                            if tile_height < tile_size // 2 or tile_width < tile_size // 2:
                                continue
                            
                            # Create window for reading
                            window = Window(start_col, start_row, tile_width, tile_height)
                            
                            # Read tile data
                            tile_data = src.read(window=window)
                            
                            # Calculate new transform for this tile
                            tile_transform = rasterio.windows.transform(window, src.transform)
                            
                            # Pad tile to exact tile_size if necessary
                            if tile_height != tile_size or tile_width != tile_size:
                                padded_data = np.full(
                                    (src.count, tile_size, tile_size), 
                                    src.nodata if src.nodata is not None else 0, 
                                    dtype=tile_data.dtype
                                )
                                padded_data[:, :tile_height, :tile_width] = tile_data
                                tile_data = padded_data
                                
                                # Adjust transform for padding (keeping upper-left corner the same)
                                # No transform adjustment needed since we're padding bottom/right
                            
                            # Create output filename
                            tile_filename = f"{band_name}_row_{row:03d}_col_{col:03d}.tif"
                            tile_path = output_timestamp_dir / tile_filename
                            
                            # Skip if tile already exists
                            if tile_path.exists():
                                continue
                            
                            # Update profile for tile
                            tile_profile = src.profile.copy()
                            tile_profile.update({
                                'height': tile_size,
                                'width': tile_size,
                                'transform': tile_transform,
                                'crs': crs
                            })
                            
                            # Write tile
                            with rasterio.open(tile_path, 'w', **tile_profile) as dst:
                                dst.write(tile_data)

def _process_dem_tiles(dem_preprocessed: Path, dem_tiles: Path, tile_size: int, stride: int):
    """Process DEM_2014_Preprocessed into tiles"""
    
    for city_dir in tqdm(list(dem_preprocessed.iterdir()), desc="Processing DEM"):
        if not city_dir.is_dir():
            continue
            
        city_name = city_dir.name
        output_city_dir = dem_tiles / city_name
        output_city_dir.mkdir(parents=True, exist_ok=True)
        
        # Get DEM file
        dem_file = city_dir / "DEM.tif"
        if not dem_file.exists():
            continue
        
        with rasterio.open(dem_file) as src:
            height, width = src.height, src.width
            transform = src.transform
            crs = src.crs
            
            # Calculate number of tiles
            n_rows = math.ceil((height - tile_size) / stride) + 1 if height > tile_size else 1
            n_cols = math.ceil((width - tile_size) / stride) + 1 if width > tile_size else 1
            
            # Ensure we don't go beyond image boundaries
            n_rows = min(n_rows, math.ceil(height / stride))
            n_cols = min(n_cols, math.ceil(width / stride))
            
            # Create tiles
            for row in range(n_rows):
                for col in range(n_cols):
                    # Calculate tile boundaries
                    start_row = row * stride
                    start_col = col * stride
                    
                    # Ensure we don't exceed image boundaries
                    end_row = min(start_row + tile_size, height)
                    end_col = min(start_col + tile_size, width)
                    
                    # Adjust start positions if tile would be too small
                    if end_row - start_row < tile_size:
                        start_row = max(0, end_row - tile_size)
                    if end_col - start_col < tile_size:
                        start_col = max(0, end_col - tile_size)
                    
                    # Final tile dimensions
                    tile_height = end_row - start_row
                    tile_width = end_col - start_col
                    
                    # Skip if tile is too small
                    if tile_height < tile_size // 2 or tile_width < tile_size // 2:
                        continue
                    
                    # Create window for reading
                    window = Window(start_col, start_row, tile_width, tile_height)
                    
                    # Read tile data
                    tile_data = src.read(window=window)
                    
                    # Calculate new transform for this tile
                    tile_transform = rasterio.windows.transform(window, src.transform)
                    
                    # Pad tile to exact tile_size if necessary
                    if tile_height != tile_size or tile_width != tile_size:
                        padded_data = np.full(
                            (src.count, tile_size, tile_size), 
                            src.nodata if src.nodata is not None else 0, 
                            dtype=tile_data.dtype
                        )
                        padded_data[:, :tile_height, :tile_width] = tile_data
                        tile_data = padded_data
                    
                    # Create output filename
                    tile_filename = f"DEM_row_{row:03d}_col_{col:03d}.tif"
                    tile_path = output_city_dir / tile_filename
                    
                    # Skip if tile already exists
                    if tile_path.exists():
                        continue
                    
                    # Update profile for tile
                    tile_profile = src.profile.copy()
                    tile_profile.update({
                        'height': tile_size,
                        'width': tile_size,
                        'transform': tile_transform,
                        'crs': crs
                    })
                    
                    # Write tile
                    with rasterio.open(tile_path, 'w', **tile_profile) as dst:
                        dst.write(tile_data)

def get_tile_info(data_dir: str):
    """
    Get information about the tiled dataset
    """
    data_path = Path(data_dir)
    cities_tiles = data_path / "Cities_Tiles"
    dem_tiles = data_path / "DEM_2014_Tiles"
    
    print("=== TILE DATASET INFORMATION ===")
    
    if cities_tiles.exists():
        total_scenes = 0
        total_tiles = 0
        cities = list(cities_tiles.iterdir())
        
        print(f"Cities with tiles: {len(cities)}")
        
        for city_dir in cities[:3]:  # Show info for first 3 cities
            if not city_dir.is_dir():
                continue
                
            city_name = city_dir.name
            timestamps = list(city_dir.iterdir())
            city_scenes = len(timestamps)
            total_scenes += city_scenes
            
            # Count tiles in first timestamp
            if timestamps:
                first_timestamp = timestamps[0]
                tiles = list(first_timestamp.glob("*.tif"))
                tiles_per_scene = len(tiles)
                city_total_tiles = city_scenes * tiles_per_scene
                total_tiles += city_total_tiles
                
                print(f"  {city_name}: {city_scenes} scenes, ~{tiles_per_scene} tiles/scene, ~{city_total_tiles} total tiles")
        
        print(f"Total scenes: {total_scenes}")
        print(f"Estimated total tiles: {total_tiles}")
    
    if dem_tiles.exists():
        dem_cities = list(dem_tiles.iterdir())
        total_dem_tiles = 0
        
        for city_dir in dem_cities:
            if city_dir.is_dir():
                dem_tiles_count = len(list(city_dir.glob("*.tif")))
                total_dem_tiles += dem_tiles_count
        
        print(f"DEM cities: {len(dem_cities)}")
        print(f"Total DEM tiles: {total_dem_tiles}")

# Example usage
if __name__ == "__main__":
    # Convert to tiles
    convert_to_tiles("./Data/ML")
    
    # Show information about the tiled dataset
    get_tile_info("./Data/ML")

In [None]:
import os
import shutil
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.windows import from_bounds
import numpy as np
from pathlib import Path
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

def get_majority_class_from_classification(dem_tile_path, class_src):
    try:
        with rasterio.open(dem_tile_path) as dem_src:
            dem_bounds = dem_src.bounds
            dem_crs = dem_src.crs
            
            # Check if DEM and classification have same CRS
            if dem_crs != class_src.crs:
                from rasterio.warp import transform_bounds
                dem_bounds_in_class_crs = transform_bounds(
                    dem_crs, class_src.crs, *dem_bounds
                )
            else:
                dem_bounds_in_class_crs = dem_bounds
            
            # Get the window of class_src that overlaps with dem tile
            class_window = rasterio.windows.from_bounds(
                *dem_bounds_in_class_crs, 
                class_src.transform
            )
            
            # Read the classification data for the overlapping area
            class_data = class_src.read(1, window=class_window)
            
            # Get the actual bounds of the read window
            window_transform = rasterio.windows.transform(class_window, class_src.transform)    
            
            # Calculate intersection area for each pixel with DEM bounds
            class_counts_weighted = {}            
            for i in range(class_data.shape[0]):
                for j in range(class_data.shape[1]):
                    class_val = class_data[i, j]
                    
                    # Skip invalid classes
                    if class_val < 0 or class_val > 17:
                        continue
                    
                    # Calculate pixel bounds
                    class_pixel_left = window_transform[2] + j * window_transform[0] #0
                    class_pixel_right = class_pixel_left + window_transform[0] # 100
                    class_pixel_top = window_transform[5] + i * window_transform[4] # 100
                    class_pixel_bottom = class_pixel_top + window_transform[4] # 0
                    
                    # Calculate intersection with DEM bounds
                    intersect_left = max(class_pixel_left, dem_bounds_in_class_crs[0]) # 70
                    intersect_right = min(class_pixel_right, dem_bounds_in_class_crs[2]) # 130 -> 100
                    intersect_top = min(class_pixel_top, dem_bounds_in_class_crs[3]) # 30 -> 30
                    intersect_bottom = max(class_pixel_bottom, dem_bounds_in_class_crs[1]) # 0
                    
                    # Calculate intersection area
                    if intersect_right > intersect_left and intersect_top > intersect_bottom:
                        intersection_area = (intersect_right - intersect_left) * (intersect_top - intersect_bottom)
                        
                        if class_val not in class_counts_weighted:
                            class_counts_weighted[class_val] = 0
                        class_counts_weighted[class_val] += intersection_area
            
            if not class_counts_weighted:
                print("No valid classes found")
                return None
            
            # Get majority class by area
            majority_class = max(class_counts_weighted.keys(), 
                               key=lambda x: class_counts_weighted[x])
            return int(majority_class)
            
    except Exception as e:
        print(f"Error processing {dem_tile_path}: {e}")
        import traceback
        traceback.print_exc()
        return None

def copy_tile_with_structure(src_path, base_src_dir, base_dst_dir):
    """
    Copy a tile while preserving folder structure.
    
    Args:
        src_path: Source file path
        base_src_dir: Base source directory (e.g., './Dataset/Cities_Tiles')
        base_dst_dir: Base destination directory (e.g., './Dataset/Clustered/1')
    
    Returns:
        str: Destination path where file was copied
    """
    # Get relative path from base source directory
    rel_path = os.path.relpath(src_path, base_src_dir)
    
    # Create destination path
    dst_path = os.path.join(base_dst_dir, rel_path)
    
    # Create directories if they don't exist
    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
    
    # Copy the file
    shutil.copy2(src_path, dst_path)
    
    return dst_path

def find_corresponding_dem_tile(cities_tile_path, cities_tiles_dir, dem_tiles_dir):
    """
    Find the corresponding DEM tile for a given cities tile.
    
    Args:
        cities_tile_path: Path to cities tile
        cities_tiles_dir: Base cities tiles directory
        dem_tiles_dir: Base DEM tiles directory
    
    Returns:
        str: Path to corresponding DEM tile or None if not found
    """
    try:
        # Parse the cities tile path to extract city and tile coordinates
        rel_path = os.path.relpath(cities_tile_path, cities_tiles_dir)
        path_parts = rel_path.split(os.sep)
        
        if len(path_parts) < 3:
            return None
            
        city_name = path_parts[0]
        timestamp = path_parts[1]
        tile_filename = path_parts[2]
        
        # Extract row and col from filename (e.g., "red_row_004_col_005.tif" -> "row_004_col_005")
        # Split by underscore and find row/col parts
        filename_parts = tile_filename.split('_')
        row_col_part = None
        
        for i, part in enumerate(filename_parts):
            if part == 'row' and i + 2 < len(filename_parts) and filename_parts[i + 2] == 'col':
                row_col_part = f"row_{filename_parts[i + 1]}_col_{filename_parts[i + 3]}"
                break
        
        if not row_col_part:
            return None
        
        # Construct DEM tile path
        dem_tile_filename = f"DEM_{row_col_part}.tif"
        dem_tile_path = os.path.join(dem_tiles_dir, city_name, dem_tile_filename)
        
        if os.path.exists(dem_tile_path):
            return dem_tile_path
        else:
            return None
            
    except Exception as e:
        print(f"Error finding corresponding DEM tile for {cities_tile_path}: {e}")
        return None

def cluster_tiles_by_classification(dataset_dir, classification_raster_path):
    """
    Main function to cluster tiles by classification.
    
    Args:
        dataset_dir: Path to Dataset directory
        classification_raster_path: Path to classification raster (100m, values 1-5)
    """
    dataset_path = Path(dataset_dir)
    cities_tiles_dir = dataset_path / "Cities_Tiles"
    dem_tiles_dir = dataset_path / "DEM_2014_Tiles"
    clustered_dir = dataset_path / "Clustered"
    
    # Create cluster directories
    for class_num in range(1, 6):
        cluster_dir = clustered_dir / str(class_num)
        cluster_dir.mkdir(parents=True, exist_ok=True)
        
        # Create subdirectories for cities and DEM tiles
        (cluster_dir / "Cities_Tiles").mkdir(exist_ok=True)
        (cluster_dir / "DEM_2014_Tiles").mkdir(exist_ok=True)
    
    # Statistics tracking
    stats = {i: 0 for i in range(1, 6)}
    total_processed = 0
    failed_processing = 0
    
    print("Starting tile clustering by classification...")
    print(f"Classification raster: {classification_raster_path}")
    print(f"Cities tiles directory: {cities_tiles_dir}")
    print(f"DEM tiles directory: {dem_tiles_dir}")
    
    # Process all cities tiles
    cities_tiles_list = []
    for city_dir in cities_tiles_dir.iterdir():
        if city_dir.is_dir():
            for timestamp_dir in city_dir.iterdir():
                if timestamp_dir.is_dir():
                    for tile_file in timestamp_dir.glob("*.tif"):
                        cities_tiles_list.append(tile_file)
    
    print(f"Found {len(cities_tiles_list)} cities tiles to process")
    
    # Process each cities tile
    for cities_tile_path in tqdm(cities_tiles_list, desc="Processing tiles"):
        total_processed += 1
        
        # Get majority classification for this tile
        majority_class = get_majority_class_from_classification(
            str(cities_tile_path), 
            classification_raster_path
        )
        
        if majority_class is None:
            failed_processing += 1
            continue
        
        # Update statistics
        stats[majority_class] += 1
        
        # Copy cities tile to appropriate cluster directory
        cluster_cities_dir = clustered_dir / str(majority_class) / "Cities_Tiles"
        copy_tile_with_structure(
            str(cities_tile_path),
            str(cities_tiles_dir),
            str(cluster_cities_dir)
        )
        
        # Find and copy corresponding DEM tile
        dem_tile_path = find_corresponding_dem_tile(
            str(cities_tile_path),
            str(cities_tiles_dir),
            str(dem_tiles_dir)
        )
        
        if dem_tile_path and os.path.exists(dem_tile_path):
            cluster_dem_dir = clustered_dir / str(majority_class) / "DEM_2014_Tiles"
            copy_tile_with_structure(
                dem_tile_path,
                str(dem_tiles_dir),
                str(cluster_dem_dir)
            )
    
    # Print statistics
    print("\n=== CLUSTERING RESULTS ===")
    print(f"Total tiles processed: {total_processed}")
    print(f"Failed to process: {failed_processing}")
    print(f"Successfully clustered: {sum(stats.values())}")
    print("\nTiles per class:")
    for class_num in range(1, 6):
        print(f"  Class {class_num}: {stats[class_num]} tiles")
    
    print(f"\nClustered tiles saved to: {clustered_dir}")
    
    # Verify folder structure
    print("\n=== FOLDER STRUCTURE VERIFICATION ===")
    for class_num in range(1, 6):
        class_dir = clustered_dir / str(class_num)
        cities_count = len(list((class_dir / "Cities_Tiles").rglob("*.tif")))
        dem_count = len(list((class_dir / "DEM_2014_Tiles").rglob("*.tif")))
        print(f"Class {class_num}: {cities_count} cities tiles, {dem_count} DEM tiles")

# Example usage
if __name__ == "__main__":
    # Set your paths here
    dataset_directory = "./Data/Dataset"
    classification_raster = "./cluster5.tif"  # Update this path
    
    # Verify paths exist
    if not os.path.exists(dataset_directory):
        print(f"Error: Dataset directory not found: {dataset_directory}")
        exit(1)
    
    if not os.path.exists(classification_raster):
        print(f"Error: Classification raster not found: {classification_raster}")
        print("Please update the 'classification_raster' variable with the correct path")
        exit(1)
    
    # Run clustering
    cluster_tiles_by_classification(dataset_directory, classification_raster)
    
    print("\nClustering complete!")

In [None]:
import os
import shutil
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.windows import from_bounds
import numpy as np
from pathlib import Path
from tqdm import tqdm
from collections import Counter
import warnings
# warnings.filterwarnings('ignore')
def get_majority_class_from_classification(dem_tile_path, class_src):
    try:
        with rasterio.open(dem_tile_path) as dem_src:
            dem_bounds = dem_src.bounds
            dem_crs = dem_src.crs
            
            # Check if DEM and classification have same CRS
            if dem_crs != class_src.crs:
                from rasterio.warp import transform_bounds
                dem_bounds_in_class_crs = transform_bounds(
                    dem_crs, class_src.crs, *dem_bounds
                )
            else:
                dem_bounds_in_class_crs = dem_bounds
            
            # Get the window of class_src that overlaps with dem tile
            class_window = rasterio.windows.from_bounds(
                *dem_bounds_in_class_crs, 
                class_src.transform
            )
            
            # Read the classification data for the overlapping area
            class_data = class_src.read(1, window=class_window)
            
            # Get the actual bounds of the read window
            window_transform = rasterio.windows.transform(class_window, class_src.transform)    
            
            # Calculate intersection area for each pixel with DEM bounds
            class_counts_weighted = {}            
            for i in range(class_data.shape[0]):
                for j in range(class_data.shape[1]):
                    class_val = class_data[i, j]
                    
                    # Skip invalid classes
                    if class_val < 0 or class_val > 17:
                        continue
                    
                    # Calculate pixel bounds
                    class_pixel_left = window_transform[2] + j * window_transform[0] #0
                    class_pixel_right = class_pixel_left + window_transform[0] # 100
                    class_pixel_top = window_transform[5] + i * window_transform[4] # 100
                    class_pixel_bottom = class_pixel_top + window_transform[4] # 0
                    
                    # Calculate intersection with DEM bounds
                    intersect_left = max(class_pixel_left, dem_bounds_in_class_crs[0]) # 70
                    intersect_right = min(class_pixel_right, dem_bounds_in_class_crs[2]) # 130 -> 100
                    intersect_top = min(class_pixel_top, dem_bounds_in_class_crs[3]) # 30 -> 30
                    intersect_bottom = max(class_pixel_bottom, dem_bounds_in_class_crs[1]) # 0
                    
                    # Calculate intersection area
                    if intersect_right > intersect_left and intersect_top > intersect_bottom:
                        intersection_area = (intersect_right - intersect_left) * (intersect_top - intersect_bottom)
                        
                        if class_val not in class_counts_weighted:
                            class_counts_weighted[class_val] = 0
                        class_counts_weighted[class_val] += intersection_area
            
            if not class_counts_weighted:
                print("No valid classes found")
                return None
            
            # Get majority class by area
            majority_class = max(class_counts_weighted.keys(), 
                               key=lambda x: class_counts_weighted[x])
            return int(majority_class)
            
    except Exception as e:
        print(f"Error processing {dem_tile_path}: {e}")
        import traceback
        traceback.print_exc()
        return None

def count_tiles_by_classification(dataset_dir, classification_raster_path):
    """
    Count tiles by classification without copying them.
    
    Args:
        dataset_dir: Path to Dataset directory
        classification_raster_path: Path to classification raster (100m, values 1-17)
    """
    dataset_path = Path(dataset_dir)
    dem_tiles_dir = dataset_path / "DEM_2014_Tiles"
    
    # Statistics tracking
    stats = {i: 0 for i in range(0, 18)}
    total_processed = 0
    failed_processing = 0
    
    print("Starting tile counting by classification...")
    print(f"Classification raster: {classification_raster_path}")
    print(f"DEM tiles directory: {dem_tiles_dir}")
    
    # Process all cities tiles (only red tiles)
    dem_tiles_list = []
    for city_dir in dem_tiles_dir.iterdir():
        if city_dir.is_dir():            
            if city_dir.is_dir():
                for tile_file in city_dir.glob("DEM_*.tif"):
                    dem_tiles_list.append(tile_file)
    
    print(f"Found {len(dem_tiles_list)} dem tiles to process")
    
    # Open classification raster once and keep it open for all processing
    with rasterio.open(classification_raster_path) as class_src:
        print("Classification raster opened, starting processing...")
        
        # Process each cities tile
        for dem_tile_path in tqdm(dem_tiles_list, desc="Processing tiles"):
            total_processed += 1
            
            # Get majority classification for this tile
            majority_class = get_majority_class_from_classification(
                str(dem_tile_path), 
                class_src  # Pass the open rasterio dataset
            )
            
            if majority_class is None:
                failed_processing += 1
                continue
            
            # Update statistics 
            stats[majority_class] += 1
            
    
    # Print statistics
    print("\n=== COUNTING RESULTS ===")
    print(f"Total tiles processed: {total_processed}")
    print(f"Failed to process: {failed_processing}")
    print(f"Successfully classified: {sum(stats.values())}")
    print("\nTiles per class:")
    for class_num in range(0, 18):
        if class_num == 7 or class_num == 9:
            continue
        percentage = (stats[class_num] / sum(stats.values()) * 100) if sum(stats.values()) > 0 else 0
        print(f"  Class {class_num:2d}: {stats[class_num]:6d} tiles ({percentage:5.1f}%)")

# Example usage
if __name__ == "__main__":
    # Set your paths here
    dataset_directory = "./Data/Dataset"
    classification_raster = "./original_LCZ.tif"

    
    # Verify paths exist
    if not os.path.exists(dataset_directory):
        print(f"Error: Dataset directory not found: {dataset_directory}")
        exit(1)
    
    if not os.path.exists(classification_raster):
        print(f"Error: Classification raster not found: {classification_raster}")
        print("Please update the 'classification_raster' variable with the correct path")
        exit(1)
    
    # Run counting
    count_tiles_by_classification(dataset_directory, classification_raster)
    
    print("\nCounting complete! *Multiply the tile counts by 12.")

In [None]:
with rasterio.open("./original_LCZ.tif") as src:
    print(f"NoData value: {src.nodata}")
    data = src.read(1)
    unique_vals = np.unique(data[data != src.nodata])  # Exclude NoData
    print(f"Valid class values: {sorted(unique_vals)}")

In [None]:
import os
import shutil
import rasterio
from rasterio.windows import from_bounds
import numpy as np
from pathlib import Path
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

def map_class_to_group(class_val):
    """
    Map original classification values to groups 1-4.
    
    Args:
        class_val: Original classification value (0-17)
    
    Returns:
        int: Group number (1-4) or None if invalid
    """
    if class_val in [1, 2, 3]:
        return 1
    elif class_val in [4, 5, 6]:
        return 2
    elif class_val in [7, 8, 9, 10]:
        return 3
    elif class_val in [11, 12, 13, 14, 15, 16, 17, 0]:
        return 4
    else:
        return None

def get_majority_class_from_classification(dem_tile_path, class_src):
    try:
        with rasterio.open(dem_tile_path) as dem_src:
            dem_bounds = dem_src.bounds
            dem_crs = dem_src.crs
            
            # Check if DEM and classification have same CRS
            if dem_crs != class_src.crs:
                from rasterio.warp import transform_bounds
                dem_bounds_in_class_crs = transform_bounds(
                    dem_crs, class_src.crs, *dem_bounds
                )
            else:
                dem_bounds_in_class_crs = dem_bounds
            
            # Get the window of class_src that overlaps with dem tile
            class_window = rasterio.windows.from_bounds(
                *dem_bounds_in_class_crs, 
                class_src.transform
            )
            
            # Read the classification data for the overlapping area
            class_data = class_src.read(1, window=class_window)
            
            # Get the actual bounds of the read window
            window_transform = rasterio.windows.transform(class_window, class_src.transform)    
            
            # Calculate intersection area for each pixel with DEM bounds
            # Now tracking by group (1-4) instead of original class
            group_counts_weighted = {}            
            for i in range(class_data.shape[0]):
                for j in range(class_data.shape[1]):
                    class_val = class_data[i, j]
                    
                    # Map class to group
                    group_val = map_class_to_group(class_val)
                    if group_val is None:
                        continue
                    
                    # Calculate pixel bounds
                    class_pixel_left = window_transform[2] + j * window_transform[0]
                    class_pixel_right = class_pixel_left + window_transform[0]
                    class_pixel_top = window_transform[5] + i * window_transform[4]
                    class_pixel_bottom = class_pixel_top + window_transform[4]
                    
                    # Calculate intersection with DEM bounds
                    intersect_left = max(class_pixel_left, dem_bounds_in_class_crs[0])
                    intersect_right = min(class_pixel_right, dem_bounds_in_class_crs[2])
                    intersect_top = min(class_pixel_top, dem_bounds_in_class_crs[3])
                    intersect_bottom = max(class_pixel_bottom, dem_bounds_in_class_crs[1])
                    
                    # Calculate intersection area
                    if intersect_right > intersect_left and intersect_top > intersect_bottom:
                        intersection_area = (intersect_right - intersect_left) * (intersect_top - intersect_bottom)
                        
                        if group_val not in group_counts_weighted:
                            group_counts_weighted[group_val] = 0
                        group_counts_weighted[group_val] += intersection_area
            
            if not group_counts_weighted:
                print("No valid classes found")
                return None
            
            # Get majority group by area
            majority_group = max(group_counts_weighted.keys(), 
                               key=lambda x: group_counts_weighted[x])
            return int(majority_group)
            
    except Exception as e:
        print(f"Error processing {dem_tile_path}: {e}")
        import traceback
        traceback.print_exc()
        return None

def copy_tile_with_structure(src_path, base_src_dir, base_dst_dir):
    """
    Copy a tile while preserving folder structure.
    
    Args:
        src_path: Source file path
        base_src_dir: Base source directory (e.g., './Dataset/Cities_Tiles')
        base_dst_dir: Base destination directory (e.g., './Dataset/Clustered/1')
    
    Returns:
        str: Destination path where file was copied
    """
    # Get relative path from base source directory
    rel_path = os.path.relpath(src_path, base_src_dir)
    
    # Create destination path
    dst_path = os.path.join(base_dst_dir, rel_path)
    
    # Create directories if they don't exist
    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
    
    # Copy the file
    shutil.copy2(src_path, dst_path)
    
    return dst_path

def find_corresponding_dem_tile(cities_tile_path, cities_tiles_dir, dem_tiles_dir):
    """
    Find the corresponding DEM tile for a given cities tile.
    
    Args:
        cities_tile_path: Path to cities tile
        cities_tiles_dir: Base cities tiles directory
        dem_tiles_dir: Base DEM tiles directory
    
    Returns:
        str: Path to corresponding DEM tile or None if not found
    """
    try:
        # Parse the cities tile path to extract city and tile coordinates
        rel_path = os.path.relpath(cities_tile_path, cities_tiles_dir)
        path_parts = rel_path.split(os.sep)
        
        if len(path_parts) < 3:
            return None
            
        city_name = path_parts[0]
        timestamp = path_parts[1]
        tile_filename = path_parts[2]
        
        # Extract row and col from filename (e.g., "red_row_004_col_005.tif" -> "row_004_col_005")
        # Split by underscore and find row/col parts
        filename_parts = tile_filename.split('_')
        row_col_part = None
        
        for i, part in enumerate(filename_parts):
            if part == 'row' and i + 2 < len(filename_parts) and filename_parts[i + 2] == 'col':
                row_col_part = f"row_{filename_parts[i + 1]}_col_{filename_parts[i + 3]}"
                break
        
        if not row_col_part:
            return None
        
        # Construct DEM tile path
        dem_tile_filename = f"DEM_{row_col_part}.tif"
        dem_tile_path = os.path.join(dem_tiles_dir, city_name, dem_tile_filename)
        
        if os.path.exists(dem_tile_path):
            return dem_tile_path
        else:
            return None
            
    except Exception as e:
        print(f"Error finding corresponding DEM tile for {cities_tile_path}: {e}")
        return None

def cluster_tiles_by_classification(dataset_dir, classification_raster_path):
    """
    Main function to cluster tiles by classification.
    
    Args:
        dataset_dir: Path to Dataset directory
        classification_raster_path: Path to classification raster (original classes 0-17, grouped into 1-4)
    """
    dataset_path = Path(dataset_dir)
    cities_tiles_dir = dataset_path / "Cities_Tiles"
    dem_tiles_dir = dataset_path / "DEM_2014_Tiles"
    clustered_dir = dataset_path / "Clustered"
    
    # Create cluster directories for groups 1-4
    for group_num in range(1, 5):  # Changed to groups 1-4
        cluster_dir = clustered_dir / str(group_num)
        cluster_dir.mkdir(parents=True, exist_ok=True)
        
        # Create subdirectories for cities and DEM tiles
        (cluster_dir / "Cities_Tiles").mkdir(exist_ok=True)
        (cluster_dir / "DEM_2014_Tiles").mkdir(exist_ok=True)
    
    # Statistics tracking - Updated for groups 1-4
    stats = {i: 0 for i in range(1, 5)}  # Changed to groups 1-4
    total_processed = 0
    failed_processing = 0
    
    print("Starting tile clustering by classification...")
    print("Classification mapping:")
    print("  Group 1: Classes 1-3")
    print("  Group 2: Classes 4-6") 
    print("  Group 3: Classes 7-10")
    print("  Group 4: Classes 11-17, 0")
    print(f"Classification raster: {classification_raster_path}")
    print(f"Cities tiles directory: {cities_tiles_dir}")
    print(f"DEM tiles directory: {dem_tiles_dir}")
    
    # Open the classification raster once for all processing
    with rasterio.open(classification_raster_path) as class_src:
        # Process all cities tiles
        cities_tiles_list = []
        for city_dir in cities_tiles_dir.iterdir():
            if city_dir.is_dir():
                for timestamp_dir in city_dir.iterdir():
                    if timestamp_dir.is_dir():
                        for tile_file in timestamp_dir.glob("*.tif"):
                            cities_tiles_list.append(tile_file)
        
        print(f"Found {len(cities_tiles_list)} cities tiles to process")
        
        # Process each cities tile
        for cities_tile_path in tqdm(cities_tiles_list, desc="Processing tiles"):
            total_processed += 1
            
            # Get majority group for this tile
            majority_group = get_majority_class_from_classification(
                str(cities_tile_path), 
                class_src
            )
            
            if majority_group is None:
                failed_processing += 1
                continue
            
            # Validate group is in expected range
            if majority_group < 1 or majority_group > 4:
                print(f"Warning: Invalid group {majority_group} for tile {cities_tile_path}")
                failed_processing += 1
                continue
            
            # Update statistics
            stats[majority_group] += 1
            
            # Copy cities tile to appropriate cluster directory
            cluster_cities_dir = clustered_dir / str(majority_group) / "Cities_Tiles"
            copy_tile_with_structure(
                str(cities_tile_path),
                str(cities_tiles_dir),
                str(cluster_cities_dir)
            )        
    
    # Print statistics
    print("\n=== CLUSTERING RESULTS ===")
    print(f"Total tiles processed: {total_processed}")
    print(f"Failed to process: {failed_processing}")
    print(f"Successfully clustered: {sum(stats.values())}")
    print("\nTiles per group:")
    group_descriptions = {
        1: "Classes 1-3",
        2: "Classes 4-6", 
        3: "Classes 7-10",
        4: "Classes 11-17, 0"
    }
    for group_num in range(1, 5):  # Changed to groups 1-4
        print(f"  Group {group_num} ({group_descriptions[group_num]}): {stats[group_num]} tiles")
    
    print(f"\nClustered tiles saved to: {clustered_dir}")
    
    # Verify folder structure
    print("\n=== FOLDER STRUCTURE VERIFICATION ===")
    for group_num in range(1, 5):  # Changed to groups 1-4
        class_dir = clustered_dir / str(group_num)
        cities_count = len(list((class_dir / "Cities_Tiles").rglob("*.tif")))
        dem_count = len(list((class_dir / "DEM_2014_Tiles").rglob("*.tif")))
        print(f"Group {group_num} ({group_descriptions[group_num]}): {cities_count} cities tiles, {dem_count} DEM tiles")

# Example usage
if __name__ == "__main__":
    # Set your paths here
    dataset_directory = "./Data/Dataset"
    classification_raster = "./original_LCZ.tif"  # Your classification raster with values 0-17
    
    # Verify paths exist
    if not os.path.exists(dataset_directory):
        print(f"Error: Dataset directory not found: {dataset_directory}")
        exit(1)
    
    if not os.path.exists(classification_raster):
        print(f"Error: Classification raster not found: {classification_raster}")
        print("Please update the 'classification_raster' variable with the correct path")
        exit(1)
    
    # Run clustering
    cluster_tiles_by_classification(dataset_directory, classification_raster)
    
    print("\nClustering complete!")