In [None]:
# Imports
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchmetrics import Accuracy, Precision, Recall, F1Score, JaccardIndex
from torchmetrics.collections import MetricCollection
import torchvision.transforms.functional as TF
import random
from terratorch.tasks import SemanticSegmentationTask
import satlaspretrain_models
import rasterio
import glob
import re
import traceback

print("--- MoE System Configuration ---")

# --- User-Specified / Cloud-Safe Paths (placeholders) ---
MOE_CHECKPOINT_SAVE_DIR = "checkpoints_moe"
MOE_LOGS_DIR = "metrics_moe"

PRITHVI_CHECKPOINT_PATH = "prithvi_checkpoint.ckpt"
SATLAS_CHECKPOINT_PATH = "satlas_checkpoint.ckpt"

os.makedirs(MOE_CHECKPOINT_SAVE_DIR, exist_ok=True)
os.makedirs(MOE_LOGS_DIR, exist_ok=True)

print(f"MoE Checkpoints will be saved to: {MOE_CHECKPOINT_SAVE_DIR}")
print(f"MoE CSV Logs will be saved to: {MOE_LOGS_DIR}")
print(f"Prithvi Expert Checkpoint: {PRITHVI_CHECKPOINT_PATH}")
print(f"Satlas Expert Checkpoint: {SATLAS_CHECKPOINT_PATH}")

# --- Dataset Directories (Clean placeholders) ---
BASE_DATA_DIR_ACTUAL = "CerealDataset"   # no absolute path
PATCHES_DIR = f"{BASE_DATA_DIR_ACTUAL}/GEE_WeeklyPatches"
TARGETS_DIR = f"{BASE_DATA_DIR_ACTUAL}/GEE_WeeklyMasks"

NUM_CLASSES = 2
IMG_SIZE = 256

# --- MoE Hyperparameters ---
MOE_BATCH_SIZE = 1
MOE_EPOCHS = 50
MOE_LR = 1e-4
MOE_WEIGHT_DECAY = 0.01
MOE_RAND_AUGMENT_N = 2

MOE_ROTATION_DEGREES = 180
MOE_BRIGHTNESS_RANGE = (0.9, 1.1)
MOE_CONTRAST_RANGE = (0.9, 1.1)
MOE_GAMMA_RANGE = (0.9, 1.1)
MOE_GAUSSIAN_NOISE_STD = 0.01

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"MoE Training Device: {DEVICE}")

CLASS_NAMES_FOR_LOGGING = ["non_cereal", "cereal"]

# ==============================================================================
# SECTION 1: Configurations and Definitions from Prithvi Training Script
# ==============================================================================
print("\n--- Prithvi Configuration for Expert Loading ---")
PRITHVI_EXPECTED_BANDS_ORDER = ["B02", "B03", "B04", "B8A", "B11", "B12"]
PRITHVI_LR = 1e-5
PRITHVI_WEIGHT_DECAY = 0.01
PRITHVI_HEAD_DROPOUT = 0.2
PRITHVI_FREEZE_BACKBONE = False

prithvi_model_args_for_reconstruction = {
    "backbone": "prithvi_eo_v2_300", "backbone_pretrained": False,
    "backbone_bands": PRITHVI_EXPECTED_BANDS_ORDER, "decoder": "FCNDecoder",
    "decoder_channels": 256, "head_dropout": PRITHVI_HEAD_DROPOUT,
    "num_classes": NUM_CLASSES, "rescale": True
}

# ==============================================================================
# SECTION 2: Configurations and Definitions from Satlas Training Script
# ==============================================================================
print("\n--- Satlas Configuration for Expert Loading & Input Mapping ---")
SATLAS_MODEL_IDENTIFIER = "Sentinel2_SwinB_SI_MS"
MOE_INPUT_RAW_BANDS_ORDER = ['B2','B3','B4','B8','B11','B12'] 
MOE_INPUT_RAW_BAND_TO_IDX_MAP = {band_name: i for i, band_name in enumerate(MOE_INPUT_RAW_BANDS_ORDER)}

SATLAS_EXPECTED_INPUT_CHANNEL_NAMES = [
    "TCI_R", "TCI_G", "TCI_B", "B05_duplicate", "B06_duplicate", "B07_duplicate",
    "B08_from_B8", "B11", "B12"
]
SATLAS_CHANNEL_TO_SOURCE_RAW_BAND_MAP = {
    "TCI_R": "B4", "TCI_G": "B3", "TCI_B": "B2",
    "B05_duplicate": "B4", "B06_duplicate": "B3", "B07_duplicate": "B2",
    "B08_from_B8": "B8", "B11": "B11", "B12": "B12"
}
SATLAS_NORMALIZATION_DIVISOR = 8160.0
SATLAS_LR = 1e-5
SATLAS_WEIGHT_DECAY = 0.01
SATLAS_FREEZE_BACKBONE_FPN = False

# ==============================================================================
# SECTION 3: CerealSensingDataset_MoE
# This dataset will output 6-channel images (MOE_INPUT_RAW_BANDS_ORDER)
# ==============================================================================
class CerealSensingDataset_MoE(Dataset):
    def __init__(self, file_identifiers, patches_dir, targets_dir, target_img_size,
                 raw_band_to_idx_map,  # For loading the 6 raw bands from GeoTIFF
                 augmentations=False,
                 rand_augment_n=0,
                 rotation_degrees=0,
                 brightness_range=(1.0, 1.0),
                 contrast_range=(1.0, 1.0),
                 gamma_range=(1.0, 1.0),
                 gaussian_noise_std=0.0):
        self.file_identifiers = file_identifiers
        self.patches_dir = patches_dir
        self.targets_dir = targets_dir
        self.target_img_size = target_img_size
        self.raw_band_to_idx_map = raw_band_to_idx_map
        self.num_raw_bands = len(raw_band_to_idx_map) # Should be 6 for MoE input

        self.augmentations = augmentations
        self.rand_augment_n = rand_augment_n

        # Store parameters for individual operations
        self.rotation_degrees = rotation_degrees
        self.brightness_range = brightness_range
        self.contrast_range = contrast_range
        self.gamma_range = gamma_range
        self.gaussian_noise_std = gaussian_noise_std

        self.augmentation_pool = []
        if self.augmentations and self.rand_augment_n > 0:
            self._build_augmentation_pool()

    # --- Helper methods for individual augmentations ---
    def _aug_hflip(self, image_tensor, mask_tensor_float):
        return TF.hflip(image_tensor), TF.hflip(mask_tensor_float)

    def _aug_vflip(self, image_tensor, mask_tensor_float):
        return TF.vflip(image_tensor), TF.vflip(mask_tensor_float)

    def _aug_rotate(self, image_tensor, mask_tensor_float):
        if self.rotation_degrees > 0:
            angle = random.uniform(-self.rotation_degrees, self.rotation_degrees)
            image_tensor = TF.rotate(image_tensor, angle, interpolation=TF.InterpolationMode.BILINEAR)
            # Mask needs to be [C, H, W] for rotate, typically [1, H, W]
            if mask_tensor_float.ndim == 2: mask_tensor_float = mask_tensor_float.unsqueeze(0)
            mask_tensor_float = TF.rotate(mask_tensor_float, angle, interpolation=TF.InterpolationMode.NEAREST)
            if mask_tensor_float.ndim == 3 and mask_tensor_float.shape[0] == 1: mask_tensor_float = mask_tensor_float.squeeze(0)
        return image_tensor, mask_tensor_float

    def _aug_brightness_single_channel(self, image_tensor, mask_tensor_float):
        num_channels = image_tensor.shape[0] # Should be 6 for MoE input
        if num_channels > 0:
            channel_to_aug = random.randint(0, num_channels - 1)
            factor = random.uniform(self.brightness_range[0], self.brightness_range[1])
            image_tensor[channel_to_aug:channel_to_aug+1] = TF.adjust_brightness(
                image_tensor[channel_to_aug:channel_to_aug+1], factor
            )
        return image_tensor, mask_tensor_float

    def _aug_contrast_single_channel(self, image_tensor, mask_tensor_float):
        num_channels = image_tensor.shape[0]
        if num_channels > 0:
            channel_to_aug = random.randint(0, num_channels - 1)
            factor = random.uniform(self.contrast_range[0], self.contrast_range[1])
            image_tensor[channel_to_aug:channel_to_aug+1] = TF.adjust_contrast(
                image_tensor[channel_to_aug:channel_to_aug+1], factor
            )
        return image_tensor, mask_tensor_float

    def _aug_gamma_single_channel(self, image_tensor, mask_tensor_float):
        num_channels = image_tensor.shape[0]
        if num_channels > 0:
            channel_to_aug = random.randint(0, num_channels - 1)
            gamma = random.uniform(self.gamma_range[0], self.gamma_range[1])
            gamma = max(0.1, gamma)
            # Ensure channel is non-negative for gamma correction. Input is raw, might be negative.
            channel_slice = image_tensor[channel_to_aug:channel_to_aug+1]
            channel_non_neg = torch.clamp(channel_slice, min=0)
            image_tensor[channel_to_aug:channel_to_aug+1] = TF.adjust_gamma(channel_non_neg, gamma)
        return image_tensor, mask_tensor_float

    def _aug_gaussian_noise_all_channels(self, image_tensor, mask_tensor_float):
        if self.gaussian_noise_std > 0:
            noise = torch.randn_like(image_tensor) * self.gaussian_noise_std
            image_tensor = image_tensor + noise
        return image_tensor, mask_tensor_float

    def _build_augmentation_pool(self):
        self.augmentation_pool = []
        # Geometric
        self.augmentation_pool.append(self._aug_hflip)
        self.augmentation_pool.append(self._aug_vflip)
        if self.rotation_degrees > 0: self.augmentation_pool.append(self._aug_rotate)
        # Photometric (applied to one of the 6 raw bands)
        self.augmentation_pool.append(self._aug_brightness_single_channel)
        self.augmentation_pool.append(self._aug_contrast_single_channel)
        self.augmentation_pool.append(self._aug_gamma_single_channel)
        if self.gaussian_noise_std > 0:
            self.augmentation_pool.append(self._aug_gaussian_noise_all_channels)
        if not self.augmentation_pool:
             print("Note: Augmentation pool is empty based on current settings.")


    def __len__(self):
        return len(self.file_identifiers)

    def __getitem__(self, idx):
        identifier = self.file_identifiers[idx]
        image_filename = f"{identifier}_S2.tif"
        mask_filename = f"{identifier}_Mask.tif"
        image_path = os.path.join(self.patches_dir, image_filename)
        label_path = os.path.join(self.targets_dir, mask_filename)

        try:
            with rasterio.open(image_path) as src:
                all_bands_data = src.read()
                if all_bands_data.shape[0] < self.num_raw_bands:
                    raise ValueError(f"Image {identifier} has {all_bands_data.shape[0]} bands, expected at least {self.num_raw_bands}.")
                # Select the 6 channels for MoE based on MOE_INPUT_RAW_BANDS_ORDER and the map
                image_6_channel_data = all_bands_data[[self.raw_band_to_idx_map[bname] for bname in MOE_INPUT_RAW_BANDS_ORDER], :, :]

            with rasterio.open(label_path) as src:
                label = src.read(1)

            image_tensor = torch.tensor(image_6_channel_data, dtype=torch.float32)
            label_tensor_float = torch.tensor(label, dtype=torch.float32) # Keep as float for interpolation and aug

            # Resize
            image_tensor = F.interpolate(image_tensor.unsqueeze(0),
                                         size=(self.target_img_size, self.target_img_size),
                                         mode='bilinear', align_corners=False).squeeze(0)
            
            # Ensure label_tensor_float is [H, W] before potential unsqueeze in aug
            label_tensor_float = F.interpolate(label_tensor_float.unsqueeze(0).unsqueeze(0),
                                             size=(self.target_img_size, self.target_img_size),
                                             mode='nearest').squeeze(0).squeeze(0)


            # Apply RandAugment-like strategy
            if self.augmentations and self.rand_augment_n > 0 and self.augmentation_pool:
                num_ops_to_apply = min(self.rand_augment_n, len(self.augmentation_pool))
                ops_to_apply = random.sample(self.augmentation_pool, num_ops_to_apply)
                
                # Ensure mask is float and potentially [1, H, W] for aug functions that might expect channel dim
                current_mask_for_aug = label_tensor_float
                if current_mask_for_aug.ndim == 2: 
                    current_mask_for_aug = current_mask_for_aug.unsqueeze(0) 

                for op_func in ops_to_apply:
                    image_tensor, current_mask_for_aug = op_func(image_tensor, current_mask_for_aug)
                
                # Ensure mask is back to [H,W] if it was unsqueezed
                if current_mask_for_aug.ndim == 3 and current_mask_for_aug.shape[0] == 1:
                    label_tensor_float = current_mask_for_aug.squeeze(0)
                else: # Should already be [H,W] if ops returned it that way
                    label_tensor_float = current_mask_for_aug


            label_tensor = label_tensor_float.long()

            # This dataset outputs the 6-channel image directly.
            # Normalization specific to Satlas will happen later in the MoESegmentationSystem.
            return image_tensor, label_tensor
        except Exception as e:
            print(f"Error in CerealSensingDataset_MoE for {identifier}: {e}")
            traceback.print_exc()
            raise e

# ==============================================================================
# SECTION 4: LitSegmentationModel Definition 
# ==============================================================================
class LitSegmentationModel(LightningModule):
    def __init__(self, model_arch=None, full_satlas_model=None,
                 learning_rate=1e-5, weight_decay=0.01, 
                 num_classes=2, class_names=None, freeze_backbone_fpn=False): 
        super().__init__()
        self.save_hyperparameters(ignore=['model_arch', 'full_satlas_model'])
        
        if full_satlas_model is not None:
            self.model = full_satlas_model
            if freeze_backbone_fpn: 
                frozen_count = 0
                if hasattr(self.model, 'backbone'):
                    for param in self.model.backbone.parameters(): param.requires_grad = False
                    print("Satlas backbone parameters frozen during LitModel init.")
                    frozen_count +=1
                if hasattr(self.model, 'fpn'):
                    for param in self.model.fpn.parameters(): param.requires_grad = False
                    print("Satlas FPN parameters frozen during LitModel init.")
                    frozen_count +=1
                if frozen_count == 0: print("WARNING: Could not find 'backbone' or 'fpn' in full_satlas_model to freeze.")
        elif model_arch is not None:
            self.model = model_arch
        else:
            raise ValueError("Either model_arch (for Prithvi-like) or full_satlas_model must be provided.")

        self.class_weights_tensor = None
        self.class_names = class_names if class_names else [f"class_{i}" for i in range(num_classes)]
        
        metrics_to_clone = MetricCollection({
            'accuracy_overall': Accuracy(task="multiclass", num_classes=self.hparams.num_classes, average='micro'),
            'f1_score_macro': F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro'),
            'jaccard_index_macro': JaccardIndex(task="multiclass", num_classes=self.hparams.num_classes, average='macro'), # For overall IoU
            'jaccard_index_per_class': JaccardIndex(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'accuracy_per_class': Accuracy(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'f1_score_per_class': F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='none')
        })
        self.val_metrics = metrics_to_clone.clone(prefix='val_')


    def forward(self, x):
        model_output = self.model(x)
        if isinstance(model_output, tuple): 
            logits = model_output[0]
        elif hasattr(model_output, 'output'): 
            logits = model_output.output
        else: 
            logits = model_output
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch; logits = self.forward(x)
        loss = F.cross_entropy(logits, y, weight=(self.class_weights_tensor.to(logits.device) if self.class_weights_tensor is not None else None))
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True); return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch; logits = self.forward(x)
        loss = F.cross_entropy(logits, y, weight=(self.class_weights_tensor.to(logits.device) if self.class_weights_tensor is not None else None))
        self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True); 
        self.val_metrics.update(torch.argmax(logits, dim=1), y); 
        return loss
    
    def on_validation_epoch_end(self):
        metrics_output = self.val_metrics.compute()
        for name, value in metrics_output.items():
            is_prog_bar = name in ["val_f1_score_macro", "val_jaccard_index_macro", "val_accuracy_overall"]
            if isinstance(value, torch.Tensor) and value.numel() == self.hparams.num_classes:
                for i in range(self.hparams.num_classes):
                    class_metric_name = f"{name}_{self.class_names[i]}"
                    prog_bar_metric_name, log_to_prog_bar_specific = None, False
                    if self.class_names[i] == "cereal":
                        if name == "val_jaccard_index_per_class": prog_bar_metric_name, log_to_prog_bar_specific = "val_iou_cereal", True
                        elif name == "val_f1_score_per_class": prog_bar_metric_name, log_to_prog_bar_specific = "val_f1_cereal", True
                    
                    if prog_bar_metric_name and log_to_prog_bar_specific:
                        self.log(prog_bar_metric_name, value[i], prog_bar=True, logger=False, sync_dist=True)
                        self.log(class_metric_name, value[i], prog_bar=False, logger=True, sync_dist=True)
                    else:
                        self.log(class_metric_name, value[i], prog_bar=False, logger=True, sync_dist=True)
            else:
                self.log(name, value, prog_bar=is_prog_bar, logger=True, sync_dist=True)
        self.val_metrics.reset()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)


# ==============================================================================
# SECTION 5: Helper Functions to Load Expert Models 
# ==============================================================================
def load_prithvi_expert(checkpoint_path, model_args_for_recon, device):
    print(f"Loading Prithvi expert from: {checkpoint_path}")
    prithvi_task_for_arch = SemanticSegmentationTask(
        model_args=model_args_for_recon, model_factory="EncoderDecoderFactory",
        loss="ce", lr=1e-5, optimizer="adamw"
    )
    prithvi_architecture_nn_module = prithvi_task_for_arch.model
    loaded_lit_model = LitSegmentationModel.load_from_checkpoint(
        checkpoint_path, map_location=torch.device('cpu'),
        model_arch=prithvi_architecture_nn_module, learning_rate=PRITHVI_LR, 
        weight_decay=PRITHVI_WEIGHT_DECAY, num_classes=NUM_CLASSES,
        class_names=CLASS_NAMES_FOR_LOGGING
    )
    expert_model = loaded_lit_model.model
    expert_model.eval()
    for param in expert_model.parameters(): param.requires_grad = False
    print("Prithvi expert model loaded, set to eval, and frozen.")
    return expert_model.to(device)

def load_satlas_expert(checkpoint_path, satlas_model_id, num_seg_classes, device):
    print(f"Loading Satlas expert from: {checkpoint_path}")
    try:
        weights_manager = satlaspretrain_models.Weights()
        satlas_architecture_nn_module = weights_manager.get_pretrained_model(
            model_identifier=satlas_model_id, fpn=True,
            head=satlaspretrain_models.Head.SEGMENT, num_categories=num_seg_classes
        )
        print(f"Base Satlas architecture '{satlas_model_id}' with segmentation head loaded.")
    except Exception as e:
        print(f"Error instantiating base Satlas architecture: {e}"); traceback.print_exc(); return None
    
    loaded_lit_model = LitSegmentationModel.load_from_checkpoint(
        checkpoint_path, map_location=torch.device('cpu'),
        full_satlas_model=satlas_architecture_nn_module, learning_rate=SATLAS_LR,
        weight_decay=SATLAS_WEIGHT_DECAY, num_classes=NUM_CLASSES,
        class_names=CLASS_NAMES_FOR_LOGGING, freeze_backbone_fpn=False 
    )
    expert_model = loaded_lit_model.model
    expert_model.eval()
    for param in expert_model.parameters(): param.requires_grad = False
    print("Satlas expert model loaded, set to eval, and frozen.")
    return expert_model.to(device)

# ==============================================================================
# SECTION 6: Gating Network Definition 
# ==============================================================================
class GatingNetwork(nn.Module):
    def __init__(self, input_channels, num_experts):
        super().__init__()
        self.num_experts = num_experts
        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, num_experts)

    def forward(self, x): # x is [B, 6, H, W]
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        expert_logits = self.fc(x)
        expert_weights = F.softmax(expert_logits, dim=1) # [B, num_experts]
        return expert_weights

# ==============================================================================
# SECTION 7: MoE LightningModule Definition 
# ==============================================================================
class MoESegmentationSystem(LightningModule):
    def __init__(self, prithvi_expert_nn_module, satlas_expert_nn_module,
                 moe_input_raw_band_to_idx_map, 
                 satlas_expected_input_channel_names, 
                 satlas_channel_to_source_raw_band_map,
                 satlas_normalization_divisor,
                 num_moe_input_channels, 
                 num_output_classes, 
                 learning_rate, weight_decay, class_names=None):
        super().__init__()
        self.save_hyperparameters(ignore=['prithvi_expert_nn_module', 'satlas_expert_nn_module'])

        self.prithvi_expert = prithvi_expert_nn_module
        self.satlas_expert = satlas_expert_nn_module
        
        self.gating_network = GatingNetwork(input_channels=num_moe_input_channels, num_experts=2)

        self.moe_input_raw_band_to_idx_map = moe_input_raw_band_to_idx_map
        self.satlas_expected_input_channel_names = satlas_expected_input_channel_names
        self.satlas_channel_to_source_raw_band_map = satlas_channel_to_source_raw_band_map
        self.satlas_normalization_divisor = satlas_normalization_divisor
        
        self.class_names = class_names if class_names else [f"class{i}" for i in range(num_output_classes)]
        self.class_weights_tensor = None

        moe_metrics = MetricCollection({
            'accuracy_overall': Accuracy(task="multiclass", num_classes=num_output_classes, average='micro'),
            'precision_macro': Precision(task="multiclass", num_classes=num_output_classes, average='macro'),
            'recall_macro': Recall(task="multiclass", num_classes=num_output_classes, average='macro'),
            'f1_score_macro': F1Score(task="multiclass", num_classes=num_output_classes, average='macro'),
            'jaccard_index_macro': JaccardIndex(task="multiclass", num_classes=num_output_classes, average='macro'),
            'accuracy_per_class': Accuracy(task="multiclass", num_classes=num_output_classes, average='none'),
            'precision_per_class': Precision(task="multiclass", num_classes=num_output_classes, average='none'),
            'recall_per_class': Recall(task="multiclass", num_classes=num_output_classes, average='none'),
            'f1_score_per_class': F1Score(task="multiclass", num_classes=num_output_classes, average='none'),
            'jaccard_index_per_class': JaccardIndex(task="multiclass", num_classes=num_output_classes, average='none')
        })
        self.val_moe_metrics = moe_metrics.clone(prefix='val_moe_')

    def _map_input_to_satlas_format(self, x_6_channel_raw): 
        batch_size, _, height, width = x_6_channel_raw.shape
        device = x_6_channel_raw.device
        
        satlas_input_bands_list = []
        for target_satlas_channel_name in self.satlas_expected_input_channel_names:
            source_raw_band_name = self.satlas_channel_to_source_raw_band_map[target_satlas_channel_name]
            raw_band_idx_in_6_channel_input = self.moe_input_raw_band_to_idx_map[source_raw_band_name]
            selected_band = x_6_channel_raw[:, raw_band_idx_in_6_channel_input, :, :]
            satlas_input_bands_list.append(selected_band)
            
        x_9_channel_for_satlas = torch.stack(satlas_input_bands_list, dim=1)
        x_9_channel_for_satlas_normalized = x_9_channel_for_satlas / self.satlas_normalization_divisor
        x_9_channel_for_satlas_normalized = torch.clip(x_9_channel_for_satlas_normalized, 0.0, 1.0)
        return x_9_channel_for_satlas_normalized

    def forward(self, x_raw_6_channel):
        gate_weights = self.gating_network(x_raw_6_channel)
        prithvi_struct = self.prithvi_expert(x_raw_6_channel)
        logits_prithvi = prithvi_struct.output
        x_for_satlas = self._map_input_to_satlas_format(x_raw_6_channel)
        satlas_output = self.satlas_expert(x_for_satlas)
        logits_satlas = satlas_output[0] if isinstance(satlas_output, tuple) else satlas_output
        
        w_prithvi = gate_weights[:, 0].view(-1, 1, 1, 1) 
        w_satlas = gate_weights[:, 1].view(-1, 1, 1, 1)  
        final_logits = w_prithvi * logits_prithvi + w_satlas * logits_satlas
        return final_logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        final_logits = self.forward(x)
        weights = self.class_weights_tensor.to(final_logits.device) if self.class_weights_tensor is not None else None
        loss = F.cross_entropy(final_logits, y, weight=weights)
        self.log("train_moe_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        final_logits = self.forward(x)
        weights = self.class_weights_tensor.to(final_logits.device) if self.class_weights_tensor is not None else None
        loss = F.cross_entropy(final_logits, y, weight=weights)
        self.log("val_moe_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        preds_classes = torch.argmax(final_logits, dim=1)
        self.val_moe_metrics.update(preds_classes, y)
        return loss

    def on_validation_epoch_end(self):
        metrics_output = self.val_moe_metrics.compute()
        for name, value in metrics_output.items():
            if isinstance(value, torch.Tensor) and value.numel() == self.hparams.num_output_classes:
                for i in range(self.hparams.num_output_classes):
                    class_metric_name = f"{name}_{self.class_names[i]}"
                    current_metric_prog_bar_flag = False
                    if self.class_names[i] == "cereal":
                        if "jaccard_index_per_class" in name: current_metric_prog_bar_flag = True
                    self.log(class_metric_name, value[i], 
                             on_step=False, on_epoch=True, 
                             prog_bar=current_metric_prog_bar_flag, 
                             logger=True, sync_dist=True)
            else:
                current_metric_prog_bar_flag = name in ["val_moe_accuracy_overall", "val_moe_f1_score_macro", "val_moe_jaccard_index_macro"]
                self.log(name, value, 
                         on_step=False, on_epoch=True, 
                         prog_bar=current_metric_prog_bar_flag, 
                         logger=True, sync_dist=True)
        self.val_moe_metrics.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.gating_network.parameters(), 
                                      lr=self.hparams.learning_rate, 
                                      weight_decay=self.hparams.weight_decay)
        return optimizer

# ==============================================================================
# SECTION 8: Main Execution Block for MoE Training 
# ==============================================================================
if __name__ == "__main__":
    print("\n--- Initializing MoE System ---")

    if DEVICE == "cuda":
        torch.cuda.empty_cache()

    # --- 1. Load Prithvi Expert ---
    prithvi_expert = load_prithvi_expert(PRITHVI_CHECKPOINT_PATH, prithvi_model_args_for_reconstruction, DEVICE)
    if prithvi_expert is None: print("CRITICAL ERROR: Failed to load Prithvi expert. Exiting."); exit()

    # --- 2. Load Satlas Expert ---
    satlas_expert = load_satlas_expert(SATLAS_CHECKPOINT_PATH, SATLAS_MODEL_IDENTIFIER, NUM_CLASSES, DEVICE)
    if satlas_expert is None: print("CRITICAL ERROR: Failed to load Satlas expert. Exiting."); exit()

    # --- 3. Instantiate MoE System ---
    moe_lightning_system = MoESegmentationSystem(
        prithvi_expert_nn_module=prithvi_expert, satlas_expert_nn_module=satlas_expert,
        moe_input_raw_band_to_idx_map=MOE_INPUT_RAW_BAND_TO_IDX_MAP,
        satlas_expected_input_channel_names=SATLAS_EXPECTED_INPUT_CHANNEL_NAMES,
        satlas_channel_to_source_raw_band_map=SATLAS_CHANNEL_TO_SOURCE_RAW_BAND_MAP,
        satlas_normalization_divisor=SATLAS_NORMALIZATION_DIVISOR,
        num_moe_input_channels=len(MOE_INPUT_RAW_BANDS_ORDER),
        num_output_classes=NUM_CLASSES, learning_rate=MOE_LR,
        weight_decay=MOE_WEIGHT_DECAY, class_names=CLASS_NAMES_FOR_LOGGING
    )
    print("MoESegmentationSystem instantiated.")

    # --- 4. Setup DataLoaders for MoE training ---
    print("\n--- MoE Data Preparation ---")
    
    # --- Initial Data Discovery ---
    moe_image_files_pattern = os.path.join(PATCHES_DIR, "Patch*_S2.tif")
    moe_all_image_filepaths = sorted(glob.glob(moe_image_files_pattern))
    initial_moe_identifiers = []
    
    if not moe_all_image_filepaths: print(f"ERROR: No image files found with pattern {moe_image_files_pattern}"); exit()
    else: print(f"Found {len(moe_all_image_filepaths)} potential images. Checking masks...")

    for img_path in moe_all_image_filepaths:
        base_name = os.path.basename(img_path)
        identifier_match = re.match(r"(Patch\d+_W\d+_\d{8})_S2\.tif", base_name)
        if identifier_match:
            identifier = identifier_match.group(1)
            mask_path = os.path.join(TARGETS_DIR, f"{identifier}_Mask.tif")
            if os.path.exists(mask_path): initial_moe_identifiers.append(identifier)
            else: print(f"WARNING: Mask for {identifier} not found. Skipping.")
        else: print(f"WARNING: Could not parse {base_name}. Skipping.")
    
    if not initial_moe_identifiers: print("ERROR: No image-mask pairs found. Exiting."); exit()
    print(f"Found {len(initial_moe_identifiers)} initial image-mask pairs.")

    # --- Data Integrity Check ---
    print(f"\nPerforming data integrity check on {len(initial_moe_identifiers)} pairs...")
    corrupted_identifiers, clean_moe_identifiers = [], []
    num_bands_expected_in_raw_file = len(MOE_INPUT_RAW_BANDS_ORDER)

    for identifier in initial_moe_identifiers:
        image_path = os.path.join(PATCHES_DIR, f"{identifier}_S2.tif")
        problem, msgs = False, []
        try:
            with rasterio.open(image_path) as src:
                if src.count < num_bands_expected_in_raw_file:
                    msgs.append(f"File has {src.count} bands, expected {num_bands_expected_in_raw_file}."); problem = True
                else:
                    indices_to_read = [MOE_INPUT_RAW_BAND_TO_IDX_MAP[b_name] for b_name in MOE_INPUT_RAW_BANDS_ORDER]
                    if any(idx >= src.count for idx in indices_to_read):
                        msgs.append(f"Band indices out of range for file with {src.count} bands."); problem = True
                    else:
                        bands_to_read_1_indexed = [idx + 1 for idx in indices_to_read]
                        data_to_check = src.read(bands_to_read_1_indexed).astype(np.float32)
                        if data_to_check.shape[0] != num_bands_expected_in_raw_file:
                            msgs.append(f"Read {data_to_check.shape[0]} bands, expected {num_bands_expected_in_raw_file}."); problem = True
                        if np.isnan(data_to_check).any(): msgs.append("NaNs found."); problem = True
                        if data_to_check.min() < -2000: msgs.append(f"Low values: {data_to_check.min()}."); problem = True
                        if data_to_check.max() > 20000: msgs.append(f"High values: {data_to_check.max()}."); problem = True
                        if data_to_check.size > 0 and not problem:
                            is_uniform_issue = False
                            for i in range(data_to_check.shape[0]):
                                band_data = data_to_check[i]
                                if band_data.size > 0 and np.allclose(band_data, band_data.flat[0], atol=1e-1):
                                    msgs.append(f"Band {MOE_INPUT_RAW_BANDS_ORDER[i]} uniform."); is_uniform_issue = True; break
                            if is_uniform_issue: problem = True
                        elif data_to_check.size == 0: msgs.append("Empty data read."); problem = True
            if problem:
                print(f"WARNING: {identifier} issues: {'; '.join(msgs)}. Skipping.")
                corrupted_identifiers.append(identifier)
            else:
                clean_moe_identifiers.append(identifier)
        except Exception as e:
            print(f"ERROR checking {identifier}: {e}. Skipping."); corrupted_identifiers.append(identifier)
            traceback.print_exc()

    print(f"Integrity check done. Corrupted: {len(corrupted_identifiers)}. Valid: {len(clean_moe_identifiers)}.")
    if not clean_moe_identifiers: print("ERROR: No clean image-mask pairs after integrity check. Exiting."); exit()

    random.shuffle(clean_moe_identifiers)
    moe_split_idx = int(len(clean_moe_identifiers) * 0.8)
    moe_train_ids = clean_moe_identifiers[:moe_split_idx]
    moe_val_ids = clean_moe_identifiers[moe_split_idx:]

    if not moe_train_ids: print("ERROR: MoE training set empty. Exiting."); exit()

    moe_train_ds = CerealSensingDataset_MoE(
        file_identifiers=moe_train_ids,
        patches_dir=PATCHES_DIR, targets_dir=TARGETS_DIR, target_img_size=IMG_SIZE,
        raw_band_to_idx_map=MOE_INPUT_RAW_BAND_TO_IDX_MAP,
        augmentations=True, # Master switch for augmentations
        rand_augment_n=MOE_RAND_AUGMENT_N,
        rotation_degrees=MOE_ROTATION_DEGREES,
        brightness_range=MOE_BRIGHTNESS_RANGE,
        contrast_range=MOE_CONTRAST_RANGE,
        gamma_range=MOE_GAMMA_RANGE,
        gaussian_noise_std=MOE_GAUSSIAN_NOISE_STD
    )
    moe_train_loader = DataLoader(moe_train_ds, batch_size=MOE_BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=(DEVICE=="cuda"))
    print(f"MoE Train DataLoader: {len(moe_train_loader.dataset)} samples. Augmentations: RandAugment-like (N={MOE_RAND_AUGMENT_N})")

    moe_val_loader = None
    if moe_val_ids:
        moe_val_ds = CerealSensingDataset_MoE(
            file_identifiers=moe_val_ids,
            patches_dir=PATCHES_DIR, targets_dir=TARGETS_DIR, target_img_size=IMG_SIZE,
            raw_band_to_idx_map=MOE_INPUT_RAW_BAND_TO_IDX_MAP,
            augmentations=False # No augmentations for validation set
        )
        moe_val_loader = DataLoader(moe_val_ds, batch_size=MOE_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=(DEVICE=="cuda"))
        print(f"MoE Val DataLoader: {len(moe_val_loader.dataset)} samples.")
    else:
        print("WARNING: MoE Validation dataset is empty.")

    # --- 5. Setup Trainer for MoE ---
    moe_csv_logger = CSVLogger(save_dir=MOE_LOGS_DIR, name="moe_rand_augment_run") # Updated run name
    
    moe_loss_ckpt_cb = ModelCheckpoint(
        dirpath=MOE_CHECKPOINT_SAVE_DIR, 
        filename='moe-best-loss-{epoch:02d}-{val_moe_loss:.3f}',
        save_top_k=1, verbose=True, monitor='val_moe_loss', mode='min'
    )
    moe_iou_ckpt_cb = ModelCheckpoint(
        dirpath=MOE_CHECKPOINT_SAVE_DIR, 
        filename='moe-best-iou-cereal-{epoch:02d}-{val_moe_jaccard_index_per_class_cereal:.3f}',
        save_top_k=1, verbose=True, monitor='val_moe_jaccard_index_per_class_cereal', mode='max'
    )
    moe_early_stop_cb = EarlyStopping(
        monitor="val_moe_loss", patience=10, verbose=True, mode="min"
    )

    is_interactive_moe = False
    try: shell = get_ipython().__class__.__name__
    except NameError: is_interactive_moe = False
    else: 
        if shell == 'ZMQInteractiveShell': is_interactive_moe = True
    
    moe_strategy = "auto"
    if DEVICE == "cuda" and torch.cuda.is_available() and torch.cuda.device_count() > 1:
        moe_strategy = "ddp_notebook" if is_interactive_moe else "ddp_find_unused_parameters_true"
        print(f"INFO (MoE): Multiple GPUs detected. Using '{moe_strategy}' strategy.")
    else:
        print(f"INFO (MoE): Single GPU or CPU. Using '{moe_strategy}' strategy.")

    moe_trainer = Trainer(
        accelerator=DEVICE, devices="auto" if DEVICE == "cuda" else 1,
        strategy=moe_strategy, max_epochs=MOE_EPOCHS,
        logger=moe_csv_logger,
        callbacks=[moe_early_stop_cb, moe_loss_ckpt_cb, moe_iou_ckpt_cb],
        gradient_clip_val=0.5,
    )

    print("\n--- Starting MoE Training ---")
    try:
        val_dl_arg_moe = moe_val_loader if moe_val_loader and hasattr(moe_val_loader, 'dataset') and len(moe_val_loader.dataset) > 0 else None
        if not val_dl_arg_moe:
            print("WARNING (MoE): Validation data not available. Adjusting callbacks that monitor validation metrics.")
            active_cbs = []
            for cb_instance in moe_trainer.callbacks:
                monitor_attr = getattr(cb_instance, 'monitor', None)
                if monitor_attr and monitor_attr.startswith("val_moe_"):
                    print(f"  INFO (MoE): Removing callback: {cb_instance.__class__.__name__} monitoring {monitor_attr}")
                else:
                    active_cbs.append(cb_instance)
            moe_trainer.callbacks = active_cbs if active_cbs else None

        moe_trainer.fit(moe_lightning_system, 
                        train_dataloaders=moe_train_loader, 
                        val_dataloaders=val_dl_arg_moe)
        print("MoE Training finished.")
        if hasattr(moe_loss_ckpt_cb, 'best_model_path') and moe_loss_ckpt_cb.best_model_path:
            print(f"Best MoE model (by loss) saved at: {moe_loss_ckpt_cb.best_model_path}")
        if hasattr(moe_iou_ckpt_cb, 'best_model_path') and moe_iou_ckpt_cb.best_model_path:
            print(f"Best MoE model (by IoU for cereal) saved at: {moe_iou_ckpt_cb.best_model_path}")
    except Exception as e:
        print(f"An error occurred during MoE training: {e}")
        traceback.print_exc()