In [None]:
!pip install rasterio pytorch_tabular

Collecting pytorch_tabular
  Downloading pytorch_tabular-1.1.1-py2.py3-none-any.whl.metadata (24 kB)
Collecting numpy>=1.24 (from rasterio)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-lightning<2.5.0,>=2.0.0 (from pytorch_tabular)
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics<1.6.0,>=0.10.0 (from pytorch_tabular)
  Downloading torchmetrics-1.5.2-py3-none-any.whl.metadata (20 kB)
Collecting protobuf<5.29.0,>=3.20.0 (from pytorch_tabular)
  Downloading protobuf-5.28.3-cp38-abi3-manylinux2014_x86_64.whl.metadata (592 bytes)
Collecting pytorch-tabnet==4.1 (from pytorch_tabular)
  Downloading pytorch_tabnet-4.1.0-py3-none-any.whl.metadata (15 kB)
Collecting einops<0.8.0,>=0.6.0 (from pytorch_tabular)
  Downloading einops-0.7.0-py3-none-any.whl.met

In [None]:
import os
import gc
import json
import glob
import joblib
import pickle
import rasterio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from rasterio.transform import from_bounds
from sklearn.preprocessing import StandardScaler

# Install pyproj if needed for coordinate transformation
try:
    from pyproj import Transformer
except:
    print("Installing pyproj...")
    import subprocess
    subprocess.check_call(['pip', 'install', 'pyproj'])
    from pyproj import Transformer


import torch
import torch.nn as nn
import torch.nn.functional as F

import tensorflow as tf

from pytorch_tabular import TabularModel
from pytorch_tabular.models import FTTransformerConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig

In [None]:
# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [None]:
# VARIABLES
MODEL_PATH = '/content/drive/MyDrive/AGRI/Planting_Method/model/cnn-lstm/CNN-LSTM_dry_model.pth'
TFRECORD_DIR = '/content/drive/MyDrive/AGRI/Planting_Method/tfrecord'
OUTPUT_DIR = '/content/drive/MyDrive/AGRI/Planting_Method/results'
os.makedirs(OUTPUT_DIR, exist_ok=True)

PATCH_SIZE = 256
BATCH_SIZE = 512
SEQUENCE_LENGTH = 18  # Number of timesteps
EXPECTED_BANDS = 1    # Single band per timestep: VH polarization

In [None]:
class SimplifiedCNNLSTM(nn.Module):
    """
    Simplified CNN-LSTM architecture for better stability and convergence.
    Good for imbalanced SAR time-series data.
    """
    def __init__(self, input_dim, hidden_dim=128, num_layers=2, num_classes=2, dropout_rate=0.3):
        super().__init__()

        # Simpler CNN with fewer parameters
        self.cnn = nn.Sequential(
            nn.Conv1d(input_dim, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),

            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        # Bidirectional LSTM
        self.lstm = nn.LSTM(
            input_size=128,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout_rate if num_layers > 1 else 0
        )

        # Layer normalization
        self.ln = nn.LayerNorm(hidden_dim * 2)

        # Simple attention
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim * 2, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, num_classes)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # CNN
        x = x.permute(0, 2, 1)
        cnn_out = self.cnn(x)
        cnn_out = cnn_out.permute(0, 2, 1)

        # LSTM
        lstm_out, _ = self.lstm(cnn_out)
        lstm_out = self.ln(lstm_out)

        # Attention
        attn_scores = self.attention(lstm_out)
        attn_weights = F.softmax(attn_scores, dim=1)
        context = torch.sum(attn_weights * lstm_out, dim=1)

        # Classification
        return self.classifier(context)


class ResidualBlock(nn.Module):
    """Residual CNN block with batch normalization"""
    def __init__(self, in_channels, out_channels, kernel_size=3, dropout=0.3):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.dropout = nn.Dropout(dropout)

        # Skip connection
        self.skip = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.bn_skip = nn.BatchNorm1d(out_channels) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.bn_skip(self.skip(x))

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.leaky_relu(out, 0.1)
        out = self.dropout(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity  # Residual connection
        out = F.leaky_relu(out, 0.1)

        return out


class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel attention"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool1d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        batch, channels, _ = x.size()
        y = self.squeeze(x).view(batch, channels)
        y = self.excitation(y).view(batch, channels, 1)
        return x * y.expand_as(x)


class PositionalEncoding(nn.Module):
    """Positional encoding for temporal sequences"""
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

def load_trained_model(model_path, model, device, load_optimizer=False, load_scheduler=False):
    """
    Load a trained model from checkpoint.

    Parameters:
    -----------
    model_path : str
        Path to the saved model checkpoint
    model : nn.Module
        Model instance to load weights into
    device : torch.device
        Device to load the model on
    load_optimizer : bool
        Whether to load optimizer state
    load_scheduler : bool
        Whether to load scheduler state

    Returns:
    --------
    model : nn.Module
        Model with loaded weights
    checkpoint : dict
        Full checkpoint dictionary (contains history, metrics, etc.)
    """
    checkpoint = torch.load(model_path, map_location=device)

    # Handle different checkpoint formats
    if 'model_state_dict' in checkpoint:
        # New format (from train_model_full)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"✓ Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
        print(f"✓ Best validation accuracy: {checkpoint.get('best_accuracy', 'unknown'):.4f}")
    else:
        # Old format (direct state_dict)
        model.load_state_dict(checkpoint)
        print(f"✓ Loaded model state_dict")

    model.to(device)
    model.eval()

    return model, checkpoint

In [None]:
# # Use simplified model (more stable)
model = SimplifiedCNNLSTM(
         input_dim=18,
         hidden_dim=128,
         num_layers=2,
         num_classes=2,
         dropout_rate=0.4
).to(device)

model_new, checkpoint = load_trained_model(
         model_path='/content/drive/MyDrive/AGRI/Planting_Method/model/CNN-LSTM_dry_model.pth',
         model=model,
         device=device
)

model_new.to(device)
model_new.eval()

✓ Loaded model from epoch 224
✓ Best validation accuracy: 0.7981


SimplifiedCNNLSTM(
  (cnn): Sequential(
    (0): Conv1d(18, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.4, inplace=False)
    (4): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.4, inplace=False)
  )
  (lstm): LSTM(128, 128, num_layers=2, batch_first=True, dropout=0.4, bidirectional=True)
  (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (attention): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): Tanh()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): Linear(in_features=128, out_features=2, bias=True)
  )


In [None]:
# =============================
# PROCESS TFRECORD - CLASSIFICATION VERSION
# =============================

def process_tfrecord_streaming_timeseries(tfrecord_base_pattern, model, scaler, tile_bounds,
                                          expected_bands=1, sequence_length=18):
    """
    Process TFRecord for time series classification (Transplanted vs Direct-Seeded).

    Args:
        tfrecord_base_pattern: Base path pattern for TFRecord files
        model: PyTorch CNN-LSTM model
        scaler: Fitted scaler for normalization (or None)
        tile_bounds: [min_lon, min_lat, max_lon, max_lat]
        expected_bands: Number of bands per timestep (1 for VH only)
        sequence_length: Number of timesteps in the series (18 for your data)

    Returns:
        classification_map: Array of shape (H, W) with class predictions (0 or 1)
        confidence_map: Array of shape (H, W) with prediction confidence
        actual_bounds: Adjusted spatial bounds
    """
    print(f"\nStreaming processing: {os.path.basename(tfrecord_base_pattern)}")

    # Find all TFRecord files
    tfrecord_files = sorted(glob.glob(f"{tfrecord_base_pattern}-*.tfrecord.gz"))

    if not tfrecord_files:
        print("  ⚠️ No TFRecord files found!")
        return None, None, None

    print(f"  Found {len(tfrecord_files)} file(s)")

    # Load mixer.json for grid information
    mixer_path = f"{tfrecord_base_pattern}-mixer.json"
    mixer = None
    if os.path.exists(mixer_path):
        with open(mixer_path, 'r') as f:
            mixer = json.load(f)
        print(f"  ✓ Loaded mixer.json")

    if not mixer or 'patchesPerRow' not in mixer:
        print("  ⚠️ No valid mixer.json found")
        return None, None, None

    # Get grid dimensions
    patches_per_row = mixer['patchesPerRow']
    total_patches = mixer['totalPatches']
    num_rows = int(np.ceil(total_patches / patches_per_row))

    print(f"  Grid: {num_rows} rows × {patches_per_row} cols ({total_patches} patches)")

    # Calculate output dimensions
    output_h = num_rows * PATCH_SIZE
    output_w = patches_per_row * PATCH_SIZE

    print(f"  Output size: {output_h} × {output_w} pixels")

    # Calculate adjusted bounds
    min_lon, min_lat, max_lon, max_lat = tile_bounds

    actual_pixel_size_lon = (max_lon - min_lon) / output_w
    actual_pixel_size_lat = (max_lat - min_lat) / output_h

    actual_max_lon = min_lon + (output_w * actual_pixel_size_lon)
    actual_max_lat = min_lat + (output_h * actual_pixel_size_lat)

    actual_bounds = [min_lon, min_lat, actual_max_lon, actual_max_lat]

    print(f"  Adjusted bounds: [{actual_bounds[0]:.6f}, {actual_bounds[1]:.6f}, "
          f"{actual_bounds[2]:.6f}, {actual_bounds[3]:.6f}]")

    # Initialize output arrays - use 255 as nodata for uint8
    classification_map = np.full((output_h, output_w), 255, dtype=np.uint8)
    confidence_map = np.zeros((output_h, output_w), dtype=np.float32)

    # Band names for your S1 VH time series
    band_names_ordered = [f"{t}_VH" for t in range(sequence_length)]

    # Track missing bands
    missing_bands_count = 0

    # Process patches
    patch_idx = 0
    total_valid_pixels = 0

    for file_idx, tfrecord_file in enumerate(tfrecord_files):
        print(f"  Processing file {file_idx+1}/{len(tfrecord_files)}")

        dataset = tf.data.TFRecordDataset(tfrecord_file, compression_type='GZIP')

        for raw_record in dataset:
            # Parse TFRecord
            example = tf.train.Example()
            example.ParseFromString(raw_record.numpy())
            features = example.features.feature

            # Extract time series data with interpolation for missing bands
            patch_timeseries = []
            available_bands = {}

            # First pass: collect all available bands
            for band_name in band_names_ordered:
                if band_name in features:
                    values = np.array(features[band_name].float_list.value)
                    timestep_patch = values.reshape(PATCH_SIZE, PATCH_SIZE, 1)
                    available_bands[band_name] = timestep_patch

            # If we're missing bands, interpolate
            if len(available_bands) < sequence_length:
                missing_bands_count += 1

                # Second pass: interpolate missing bands
                for i, band_name in enumerate(band_names_ordered):
                    if band_name in available_bands:
                        patch_timeseries.append(available_bands[band_name])
                    else:
                        # Find nearest neighbors for interpolation
                        prev_idx = i - 1
                        next_idx = i + 1

                        # Search backward for valid band
                        while prev_idx >= 0 and band_names_ordered[prev_idx] not in available_bands:
                            prev_idx -= 1

                        # Search forward for valid band
                        while next_idx < sequence_length and band_names_ordered[next_idx] not in available_bands:
                            next_idx += 1

                        # Interpolate
                        if prev_idx >= 0 and next_idx < sequence_length:
                            prev_band = available_bands[band_names_ordered[prev_idx]]
                            next_band = available_bands[band_names_ordered[next_idx]]
                            weight = (i - prev_idx) / (next_idx - prev_idx)
                            interpolated = prev_band * (1 - weight) + next_band * weight
                            patch_timeseries.append(interpolated)
                        elif prev_idx >= 0:
                            patch_timeseries.append(available_bands[band_names_ordered[prev_idx]])
                        elif next_idx < sequence_length:
                            patch_timeseries.append(available_bands[band_names_ordered[next_idx]])
                        else:
                            patch_timeseries = None
                            break
            else:
                # All bands available
                patch_timeseries = [available_bands[bn] for bn in band_names_ordered]

            if patch_timeseries is None or len(patch_timeseries) != sequence_length:
                patch_idx += 1
                continue

            # Stack timesteps: (seq_len, PATCH_SIZE, PATCH_SIZE, 1)
            patch = np.stack(patch_timeseries, axis=0)

            # Calculate patch position in output grid
            row_idx = patch_idx // patches_per_row
            col_idx = patch_idx % patches_per_row

            start_h = row_idx * PATCH_SIZE
            start_w = col_idx * PATCH_SIZE

            # Reshape to pixels: (seq_len, n_pixels, num_bands)
            pixels_per_patch = PATCH_SIZE * PATCH_SIZE
            pixels = patch.reshape(sequence_length, pixels_per_patch, expected_bands)

            # Transpose to: (n_pixels, seq_len, num_bands)
            pixels = np.transpose(pixels, (1, 0, 2))

            # Find valid pixels (no NaN or 0 across all timesteps)
            valid_mask = ~np.any(np.isnan(pixels) | (pixels == 0), axis=(1, 2))
            valid_indices = np.where(valid_mask)[0]
            n_valid = len(valid_indices)

            if n_valid > 0:
                total_valid_pixels += n_valid

                # Normalize if scaler is provided
                if scaler is not None:
                    valid_pixels_flat = pixels[valid_indices].reshape(-1, expected_bands)

                    import warnings
                    with warnings.catch_warnings():
                        warnings.filterwarnings('ignore')
                        scaled_flat = scaler.transform(valid_pixels_flat)

                    valid_features = scaled_flat.reshape(n_valid, sequence_length, expected_bands)
                else:
                    valid_features = pixels[valid_indices]

                # Initialize prediction arrays
                patch_classes = np.full(pixels_per_patch, 255, dtype=np.uint8)  # 255 = nodata
                patch_conf = np.zeros(pixels_per_patch, dtype=np.float32)

                # Process in batches
                for start_idx in range(0, n_valid, BATCH_SIZE):
                    end_idx = min(start_idx + BATCH_SIZE, n_valid)
                    batch_indices = valid_indices[start_idx:end_idx]

                    batch_features = valid_features[start_idx:end_idx]

                    # Convert to PyTorch tensor: (batch_size, seq_len, num_bands)
                    # batch_tensor = torch.from_numpy(batch_features).float().to(device)

                    batch_features_reshaped = batch_features.squeeze(-1)  # Remove last dim: (n_valid, 18, 1) -> (n_valid, 18)
                    batch_tensor = torch.from_numpy(batch_features_reshaped).float().to(device)
                    # Shape: (batch_size, 18)

                    # Add feature dimension for model
                    batch_tensor = batch_tensor.unsqueeze(1)  # (batch_size, 18) -> (batch_size, 1, 18)
                    # Now shape is (batch_size, num_features=1, seq_len=18)

                    # Predict with PyTorch model
                    with torch.no_grad():
                        outputs = model(batch_tensor)  # (batch_size, num_classes)
                        probs = torch.softmax(outputs, dim=1)

                        # Get predicted class (0 or 1)
                        predicted_classes = torch.argmax(probs, dim=1).cpu().numpy()

                        # Confidence: max probability
                        confidences = torch.max(probs, dim=1)[0].cpu().numpy()

                    patch_classes[batch_indices] = predicted_classes
                    patch_conf[batch_indices] = confidences

                # Reshape and store results
                patch_class_map = patch_classes.reshape(PATCH_SIZE, PATCH_SIZE)
                patch_conf_map = patch_conf.reshape(PATCH_SIZE, PATCH_SIZE)

                classification_map[start_h:start_h+PATCH_SIZE, start_w:start_w+PATCH_SIZE] = patch_class_map
                confidence_map[start_h:start_h+PATCH_SIZE, start_w:start_w+PATCH_SIZE] = patch_conf_map

            # Clean up
            del patch, pixels

            patch_idx += 1

            # Progress update
            if patch_idx % 50 == 0:
                progress = (patch_idx / total_patches) * 100
                print(f"    Progress: {progress:.1f}% ({patch_idx}/{total_patches})")

            # Memory cleanup
            if patch_idx % 20 == 0:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

    print(f"  ✓ Processed all {patch_idx} patches")

    if missing_bands_count > 0:
        print(f"  ⚠️ Interpolated missing bands in {missing_bands_count} patches")

    # Statistics
    valid_mask = classification_map != 255
    valid_classes = classification_map[valid_mask]

    if len(valid_classes) > 0:
        # Basic statistics
        n_direct = (valid_classes == 0).sum()
        n_transplanted = (valid_classes == 1).sum()

        # Confidence-based statistics
        valid_confidence = confidence_map[valid_mask]
        high_conf_mask = valid_confidence >= 0.6  # Filter by confidence threshold

        high_conf_classes = valid_classes[high_conf_mask]
        n_high_conf = len(high_conf_classes)

        if n_high_conf > 0:
            n_direct_highconf = (high_conf_classes == 0).sum()
            n_transplanted_highconf = (high_conf_classes == 1).sum()
        else:
            n_direct_highconf = 0
            n_transplanted_highconf = 0

        print(f"\n  Classification Statistics:")
        print(f"    Total valid pixels: {len(valid_classes):,}")
        print(f"    ")
        print(f"    All Predictions:")
        print(f"      Class 0 (Direct-Seeded): {n_direct:,} ({n_direct/len(valid_classes)*100:.1f}%)")
        print(f"      Class 1 (Transplanted):  {n_transplanted:,} ({n_transplanted/len(valid_classes)*100:.1f}%)")
        print(f"    ")
        print(f"    High Confidence (≥0.6) Predictions:")
        print(f"      Total: {n_high_conf:,} ({n_high_conf/len(valid_classes)*100:.1f}%)")
        if n_high_conf > 0:
            print(f"      Class 0 (Direct-Seeded): {n_direct_highconf:,} ({n_direct_highconf/n_high_conf*100:.1f}%)")
            print(f"      Class 1 (Transplanted):  {n_transplanted_highconf:,} ({n_transplanted_highconf/n_high_conf*100:.1f}%)")
        print(f"    ")
        print(f"    Confidence Statistics:")
        print(f"      Mean confidence:  {valid_confidence.mean():.3f}")
        print(f"      Median confidence: {np.median(valid_confidence):.3f}")
        print(f"      Min confidence:   {valid_confidence.min():.3f}")
        print(f"      Max confidence:   {valid_confidence.max():.3f}")

    return classification_map, confidence_map, actual_bounds

In [None]:
# =============================
# SAVE GEOTIFF (SAME AS BEFORE)
# =============================

def save_geotiff_aligned(array, output_path, bounds, crs):
    """Save array as GeoTIFF with proper georeferencing"""
    h, w = array.shape

    min_x, min_y, max_x, max_y = bounds

    pixel_width = (max_x - min_x) / w
    pixel_height = (max_y - min_y) / h

    transform = rasterio.transform.from_bounds(
        min_x, min_y, max_x, max_y, w, h
    )

    print(f"    Saving: {os.path.basename(output_path)}")
    print(f"      Size: {w} x {h}")
    print(f"      Bounds: {bounds}")



    with rasterio.open(
        output_path, 'w',
        driver='GTiff',
        height=h,
        width=w,
        count=1,
        dtype=np.float32,
        crs=crs,
        transform=transform,
        compress='lzw',
        nodata=-9999
    ) as dst:
        dst.write(array, 1)

    print(f"    ✓ Saved")


# =============================
# READ TILE MIXER (SAME AS BEFORE)
# =============================

def read_tile_mixer(tile_base_path):
    """Read mixer.json and extract georeferencing info"""
    mixer_path = f"{tile_base_path}-mixer.json"

    if not os.path.exists(mixer_path):
        raise FileNotFoundError(f"mixer.json not found: {mixer_path}")

    with open(mixer_path, 'r') as f:
        mixer = json.load(f)

    # Extract info
    crs = mixer['projection']['crs']
    patch_dims = mixer.get('patchDimensions', [256, 256])
    patches_per_row = mixer.get('patchesPerRow', 0)
    total_patches = mixer.get('totalPatches', 0)

    patches_per_col = total_patches // patches_per_row if patches_per_row > 0 else 0

    # Extract affine transform
    affine_matrix = mixer['projection']['affine']['doubleMatrix']

    scale_x = affine_matrix[0]
    translate_x = affine_matrix[2]
    scale_y = affine_matrix[4]
    translate_y = affine_matrix[5]

    # Calculate bounds
    patch_width_pixels = patch_dims[0]
    patch_height_pixels = patch_dims[1]

    total_width_pixels = patches_per_row * patch_width_pixels
    total_height_pixels = patches_per_col * patch_height_pixels

    min_x = translate_x
    max_y = translate_y
    max_x = min_x + (total_width_pixels * scale_x)
    min_y = max_y + (total_height_pixels * scale_y)

    bounds = [min_x, min_y, max_x, max_y]

    return {
        'crs': crs,
        'mixer': mixer,
        'patch_dims': patch_dims,
        'bounds': bounds,
        'grid_size': (patches_per_row, patches_per_col),
        'pixel_size': (scale_x, abs(scale_y))
    }

In [None]:
# =============================
# MAIN PROCESSING
# =============================

print("\n" + "="*70)
print("DISCOVERING TILES")
print("="*70)

# Find tiles
all_files = glob.glob(f"{TFRECORD_DIR}/*.tfrecord.gz")
tile_bases = set()

for file in all_files:
    basename = os.path.basename(file)
    base = basename.rsplit('-', 1)[0]
    tile_bases.add(os.path.join(TFRECORD_DIR, base))

tile_bases = sorted(tile_bases)
print(f"\nFound {len(tile_bases)} unique tiles")

# Read mixer.json for each
tile_info = []

for tile_base in tile_bases:
    basename = os.path.basename(tile_base)

    # Extract tile number
    import re
    match = re.search(r'tile[_-](\d+)', basename)
    tile_num = int(match.group(1)) if match else None

    try:
        mixer_data = read_tile_mixer(tile_base)

        tile_info.append({
            'base': tile_base,
            'number': tile_num,
            'bounds': mixer_data['bounds'],
            'crs': mixer_data['crs'],
            'mixer': mixer_data['mixer'],
            'grid_size': mixer_data['grid_size'],
            'pixel_size': mixer_data['pixel_size']
        })

        print(f"\n  Tile {tile_num}: {basename}")
        print(f"    CRS: {mixer_data['crs']}")
        print(f"    Grid: {mixer_data['grid_size'][0]} x {mixer_data['grid_size'][1]}")
        print(f"    Bounds: {mixer_data['bounds']}")

    except Exception as e:
        print(f"\n  ⚠ Error: {e}")
        continue

print(f"\n✓ Loaded {len(tile_info)} tiles with georeferencing")

if len(tile_info) == 0:
    raise ValueError("No valid tiles found!")


DISCOVERING TILES

Found 4 unique tiles

  Tile 1: S1_composite_dry2025_tile_001
    CRS: EPSG:4326
    Grid: 23 x 23
    Bounds: [120.58759795199427, 15.193026737256643, 121.11652599128384, 15.721954776546218]

  Tile 2: S1_composite_dry2025_tile_002
    CRS: EPSG:4326
    Grid: 23 x 20
    Bounds: [120.58759795199427, 15.67569153941406, 121.11652599128384, 16.135628964883256]

  Tile 3: S1_composite_dry2025_tile_003
    CRS: EPSG:4326
    Grid: 12 x 23
    Bounds: [121.08760023913518, 15.19293690572823, 121.3635626944167, 15.721864945017805]

  Tile 4: S1_composite_dry2025_tile_004
    CRS: EPSG:4326
    Grid: 12 x 20
    Bounds: [121.08760023913518, 15.675601707885647, 121.3635626944167, 16.135539133354843]

✓ Loaded 4 tiles with georeferencing


In [None]:
# =============================
# MAIN PROCESSING LOOP
# =============================
print("\n" + "="*70)
print("PROCESSING TILES FOR CLASSIFICATION")
print("="*70)

results = []

for idx, tile_data in enumerate(tile_info):
    tile_base = tile_data['base']
    bounds = tile_data['bounds']
    tile_num = tile_data['number']
    crs = tile_data['crs']

    print(f"\n{'='*70}")
    print(f"TILE {tile_num} ({idx+1}/{len(tile_info)})")
    print(f"Base: {os.path.basename(tile_base)}")
    print(f"CRS: {crs}")
    print(f"Bounds: {bounds}")
    print('='*70)

    tile_name = f"tile_{tile_num:03d}"

    try:
        # Process tile
        class_map, confidence_map, actual_bounds = process_tfrecord_streaming_timeseries(
            tile_base, model_new, None, bounds, EXPECTED_BANDS, SEQUENCE_LENGTH
        )

        if class_map is None:
            print("  ⚠️ Failed to process")
            continue

        # Use actual bounds
        if actual_bounds:
            bounds = actual_bounds

        # Statistics
        valid_mask = class_map != 255
        valid_classes = class_map[valid_mask]

        if len(valid_classes) > 0:
            n_direct = (valid_classes == 0).sum()
            n_transplanted = (valid_classes == 1).sum()
            pct_direct = (n_direct / len(valid_classes)) * 100
            pct_transplanted = (n_transplanted / len(valid_classes)) * 100
            mean_conf = confidence_map[valid_mask].mean()

            print(f"\n  Tile Statistics:")
            print(f"    Valid pixels:        {len(valid_classes):,}")
            print(f"    Direct-Seeded (0):   {n_direct:,} ({pct_direct:.1f}%)")
            print(f"    Transplanted (1):    {n_transplanted:,} ({pct_transplanted:.1f}%)")
            print(f"    Mean confidence:     {mean_conf:.3f}")

            results.append({
                'tile': tile_num,
                'valid_pixels': len(valid_classes),
                'direct_seeded_count': n_direct,
                'transplanted_count': n_transplanted,
                'direct_seeded_pct': pct_direct,
                'transplanted_pct': pct_transplanted,
                'mean_confidence': mean_conf
            })

        # Save outputs
        print(f"\n  Saving outputs...")
        save_geotiff_aligned(
            class_map,
            f"{OUTPUT_DIR}/{tile_name}_classification.tif",
            bounds,
            crs
        )
        save_geotiff_aligned(
            confidence_map,
            f"{OUTPUT_DIR}/{tile_name}_confidence.tif",
            bounds,
            crs
        )

        # Create visualization
        print(f"  Creating visualization...")

        fig, axes = plt.subplots(1, 2, figsize=(16, 7))

        # Classification map with custom colors
        class_vis = np.ma.masked_where(class_map == 255, class_map)

        from matplotlib.colors import ListedColormap
        colors = ['#2ecc71', '#e67e22']  # Green for Direct-Seeded, Orange for Transplanted
        cmap_class = ListedColormap(colors)

        im1 = axes[0].imshow(class_vis, cmap=cmap_class, vmin=0, vmax=1, interpolation='nearest')
        axes[0].set_title(f'Rice Planting Method Classification (Tile {tile_num})', fontsize=14, fontweight='bold')
        axes[0].axis('off')

        # Custom legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor='#2ecc71', label=f'Direct-Seeded ({n_direct:,} pixels, {pct_direct:.1f}%)'),
            Patch(facecolor='#e67e22', label=f'Transplanted ({n_transplanted:,} pixels, {pct_transplanted:.1f}%)'),
            Patch(facecolor='white', edgecolor='black', label=f'No Data')
        ]
        axes[0].legend(handles=legend_elements, loc='upper right', fontsize=10)

        # Confidence map
        conf_vis = np.ma.masked_where(class_map == 255, confidence_map)
        im2 = axes[1].imshow(conf_vis, cmap='RdYlGn', vmin=0, vmax=1)
        axes[1].set_title('Prediction Confidence', fontsize=14, fontweight='bold')
        axes[1].axis('off')
        cbar2 = plt.colorbar(im2, ax=axes[1], fraction=0.046)
        cbar2.set_label('Confidence', rotation=270, labelpad=20, fontsize=11)

        if len(valid_classes) > 0:
            plt.suptitle(
                f'Classification Summary | Valid Pixels: {len(valid_classes):,} | Mean Confidence: {mean_conf:.3f}',
                fontsize=12, y=0.98
            )

        plt.tight_layout()
        plt.savefig(f"{OUTPUT_DIR}/{tile_name}_classification_result.png", dpi=200, bbox_inches='tight')
        plt.close()

        print(f"  ✓ Saved all outputs")

        # Clean up
        del class_map, confidence_map
        gc.collect()

    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()
        gc.collect()


PROCESSING TILES FOR CLASSIFICATION

TILE 1 (1/4)
Base: S1_composite_dry2025_tile_001
CRS: EPSG:4326
Bounds: [120.58759795199427, 15.193026737256643, 121.11652599128384, 15.721954776546218]

Streaming processing: S1_composite_dry2025_tile_001
  Found 25 file(s)
  ✓ Loaded mixer.json
  Grid: 23 rows × 23 cols (529 patches)
  Output size: 5888 × 5888 pixels
  Adjusted bounds: [120.587598, 15.193027, 121.116526, 15.721955]
  Processing file 1/25
  Processing file 2/25
  Processing file 3/25
    Progress: 9.5% (50/529)
  Processing file 4/25
  Processing file 5/25
    Progress: 18.9% (100/529)
  Processing file 6/25
  Processing file 7/25
    Progress: 28.4% (150/529)
  Processing file 8/25
  Processing file 9/25
  Processing file 10/25
    Progress: 37.8% (200/529)
  Processing file 11/25
  Processing file 12/25
    Progress: 47.3% (250/529)
  Processing file 13/25
  Processing file 14/25
    Progress: 56.7% (300/529)
  Processing file 15/25
  Processing file 16/25
    Progress: 66.2% (3

Traceback (most recent call last):
  File "/tmp/ipython-input-799641944.py", line 68, in <cell line: 0>
    save_geotiff_aligned(
  File "/tmp/ipython-input-3635084374.py", line 22, in save_geotiff_aligned
    with rasterio.open(
         ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/rasterio/env.py", line 463, in wrapper
    return f(*args, **kwds)
           ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/rasterio/__init__.py", line 366, in open
    dataset = writer(
              ^^^^^^^
  File "rasterio/_io.pyx", line 1553, in rasterio._io.DatasetWriterBase.__init__
ValueError: Given nodata value, nan, is beyond the valid range of its data type, uint8.



TILE 2 (2/4)
Base: S1_composite_dry2025_tile_002
CRS: EPSG:4326
Bounds: [120.58759795199427, 15.67569153941406, 121.11652599128384, 16.135628964883256]

Streaming processing: S1_composite_dry2025_tile_002
  Found 21 file(s)
  ✓ Loaded mixer.json
  Grid: 20 rows × 23 cols (460 patches)
  Output size: 5120 × 5888 pixels
  Adjusted bounds: [120.587598, 15.675692, 121.116526, 16.135629]
  Processing file 1/21
  Processing file 2/21


In [None]:
# =============================
# PROCESS ALL TILES
# =============================

print("\n" + "="*70)
print("PROCESSING TILES FOR BATHYMETRY")
print("="*70)

results = []

for idx, tile_data in enumerate(tile_info):
    tile_base = tile_data['base']
    bounds = tile_data['bounds']
    tile_num = tile_data['number']
    crs = tile_data['crs']

    print(f"\n{'='*70}")
    print(f"TILE {tile_num} ({idx+1}/{len(tile_info)})")
    print(f"Base: {os.path.basename(tile_base)}")
    print(f"CRS: {crs}")
    print(f"Bounds: {bounds}")
    print('='*70)

    tile_name = f"tile_{tile_num}"

    try:
        # Process tile
        # In your notebook, replace the old process function call with:
        class_map, confidence_map, actual_bounds = process_tfrecord_streaming_timeseries(
                    tile_base,
                    model,
                    None,  # scaler - set to your scaler if you used normalization during training
                    bounds,
                    1,     # expected_bands
                    18     # sequence_length
        )

        if class_map is None:
            print("  ⚠️ Failed to process")
            continue

        # Use actual bounds
        if actual_bounds:
            bounds = actual_bounds

        # Statistics
        valid_class = class_map[~np.isnan(class_map)]
        n_valid = len(valid_class)

        # Save outputs
        print(f"\n  Saving outputs...")
        save_geotiff_aligned(
            class_map,
            f"{OUTPUT_DIR}/{tile_name}_class.tif",
            bounds,
            crs
        )
        save_geotiff_aligned(
            confidence_map,
            f"{OUTPUT_DIR}/{tile_name}_confidence.tif",
            bounds,
            crs
        )

        print(f"  ✓ Saved all outputs")

        # Clean up
        del class_map, confidence_map
        gc.collect()

    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()
        gc.collect()