In [None]:
# sImports
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule, Trainerss
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchmetrics import Accuracy, Precision, Recall, F1Score, JaccardIndex
from torchmetrics.collections import MetricCollection
import torchvision.transforms.functional as TF 
import random
import satlaspretrain_models # Using the allenai library
import rasterio
import glob
import re
import traceback

BASE_DATA_DIR_ACTUAL = "<BASE_DATA_DIR>"
CHECKPOINT_SAVE_DIR_ACTUAL = "<CHECKPOINT_SAVE_DIR>"
LOGS_DIR_ACTUAL = "<LOGS_DIR>"
PATCHES_DIR = os.path.join(BASE_DATA_DIR_ACTUAL, "GEE_WeeklyPatches")
TARGETS_DIR = os.path.join(BASE_DATA_DIR_ACTUAL, "GEE_WeeklyMasks")

os.makedirs(CHECKPOINT_SAVE_DIR_ACTUAL, exist_ok=True)
os.makedirs(LOGS_DIR_ACTUAL, exist_ok=True)

print(f"Data Dir: {BASE_DATA_DIR_ACTUAL}")
print(f"Checkpoints Dir: {CHECKPOINT_SAVE_DIR_ACTUAL}")
print(f"Logs Dir: {LOGS_DIR_ACTUAL}")
print(f"INFO: allenai/satlaspretrain_models will use its default cache path (usually ~/.cache/torch/hub/satlaspretrain/)")

if not os.path.isdir(PATCHES_DIR): print(f"WARNING: Patches directory not found: {PATCHES_DIR}")
if not os.path.isdir(TARGETS_DIR): print(f"WARNING: Masks directory not found: {TARGETS_DIR}")

# ============ MODEL CONFIG (allenai/satlaspretrain_models) ============ #
SATLAS_MODEL_IDENTIFIER = "Sentinel2_SwinB_SI_MS"

RAW_BANDS_IN_GEOTIFF_ORDER = ['B2','B3','B4','B8','B11','B12']
NUM_RAW_BANDS_AVAILABLE = len(RAW_BANDS_IN_GEOTIFF_ORDER)
RAW_BAND_TO_IDX_MAP = {band_name: i for i, band_name in enumerate(RAW_BANDS_IN_GEOTIFF_ORDER)}

SATLAS_INPUT_CHANNEL_NAMES = [
    "TCI_R", "TCI_G", "TCI_B",
    "B05_duplicate", "B06_duplicate", "B07_duplicate",
    "B08_from_B8", "B11", "B12"
]
NUM_SATLAS_INPUT_BANDS = len(SATLAS_INPUT_CHANNEL_NAMES)

SATLAS_CHANNEL_TO_SOURCE_RAW_BAND_MAP = {
    "TCI_R": "B4", "TCI_G": "B3", "TCI_B": "B2",
    "B05_duplicate": "B4", "B06_duplicate": "B3", "B07_duplicate": "B2",
    "B08_from_B8": "B8", "B11": "B11", "B12": "B12"
}
NORMALIZATION_DIVISOR = 8160.0

NUM_CLASSES = 2
IMG_SIZE = 256

# ============ TRAINING HYPERPARAMETERS & OTHER CONFIGS ============ #
BATCH_SIZE = 2; EPOCHS = 100; LR = 1e-5; WEIGHT_DECAY = 0.01
FREEZE_BACKBONE_FPN = False

# --- RANDOMAUGMENT STYLE CONFIG ---
APPLY_AUGMENTATIONS_TRAIN_SET = True
NUM_RAND_AUGMENT_OPS = 2  # N: Number of distinct augmentations to pick from the pool and apply

# Enable specific augmentations to be part of the pool and their parameters
RAND_AUG_H_FLIP_ENABLE = True

RAND_AUG_V_FLIP_ENABLE = True

RAND_AUG_ROTATION_ENABLE = True
RAND_AUG_ROTATION_DEGREES_OPTIONS = [0, 90, 180, 270] # Discrete angles if rotation is chosen

RAND_AUG_BRIGHTNESS_ENABLE = True
RAND_AUG_BRIGHTNESS_RANGE = (0.75, 1.25) # Factor range for brightness adjustment

RAND_AUG_CONTRAST_ENABLE = True
RAND_AUG_CONTRAST_RANGE = (0.75, 1.25)   # Factor range for contrast adjustment

RAND_AUG_GAMMA_ENABLE = True
RAND_AUG_GAMMA_RANGE = (0.8, 1.2)      # Gamma value range

RAND_AUG_GAUSSIAN_NOISE_ENABLE = True
RAND_AUG_GAUSSIAN_NOISE_STD = 0.03     # Standard deviation for Gaussian noise

EARLY_STOPPING_PATIENCE = 15; EARLY_STOPPING_MONITOR = "val_loss"

# DEVICE determination (simple check)
USE_CUDA_IF_AVAILABLE = True
if USE_CUDA_IF_AVAILABLE and torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"
print(f"Device selected: {DEVICE} (CUDA available: {torch.cuda.is_available()})")

for name in SATLAS_INPUT_CHANNEL_NAMES:
    source = SATLAS_CHANNEL_TO_SOURCE_RAW_BAND_MAP.get(name)
    if not source or source not in RAW_BAND_TO_IDX_MAP:
        print(f"CRITICAL WARNING: Source band for Satlas channel '{name}' ('{source}') is not valid. Check config.")
if NUM_SATLAS_INPUT_BANDS != 9: print(f"CRITICAL WARNING: NUM_SATLAS_INPUT_BANDS is {NUM_SATLAS_INPUT_BANDS}, Satlas MS expects 9.")

class CerealSensingDataset(Dataset):
    def __init__(self, file_identifiers, patches_dir, targets_dir, target_img_size,
                 raw_band_to_idx_map, satlas_input_channel_names,
                 satlas_channel_to_source_raw_band_map, normalization_divisor,
                 augmentations=False, # General switch for applying augmentations
                 num_rand_augment_ops=0,
                 # Flags and parameters for each potential augmentation in the pool
                 rand_aug_h_flip_enable=False,
                 rand_aug_v_flip_enable=False,
                 rand_aug_rotation_enable=False,
                 rand_aug_rotation_degrees_options=None,
                 rand_aug_brightness_enable=False,
                 rand_aug_brightness_range=(1.0, 1.0),
                 rand_aug_contrast_enable=False,
                 rand_aug_contrast_range=(1.0, 1.0),
                 rand_aug_gamma_enable=False,
                 rand_aug_gamma_range=(1.0, 1.0),
                 rand_aug_gaussian_noise_enable=False,
                 rand_aug_gaussian_noise_std=0.0
                ):

        self.file_identifiers = file_identifiers; self.patches_dir = patches_dir; self.targets_dir = targets_dir
        self.target_img_size = target_img_size; self.raw_band_to_idx_map = raw_band_to_idx_map
        self.satlas_input_channel_names = satlas_input_channel_names
        self.satlas_channel_to_source_raw_band_map = satlas_channel_to_source_raw_band_map
        self.normalization_divisor = normalization_divisor

        self.augmentations = augmentations
        self.num_rand_augment_ops = num_rand_augment_ops

        # Store all augmentation related parameters to be used in _init_augmentation_pool
        self._h_flip_enable = rand_aug_h_flip_enable
        self._v_flip_enable = rand_aug_v_flip_enable
        self._rotation_enable = rand_aug_rotation_enable
        self._rotation_degrees_options = rand_aug_rotation_degrees_options if rand_aug_rotation_degrees_options else [0]

        self._brightness_enable = rand_aug_brightness_enable
        self._brightness_range = rand_aug_brightness_range
        self._contrast_enable = rand_aug_contrast_enable
        self._contrast_range = rand_aug_contrast_range
        self._gamma_enable = rand_aug_gamma_enable
        self._gamma_range = rand_aug_gamma_range

        self._noise_enable = rand_aug_gaussian_noise_enable
        self._gaussian_noise_std = rand_aug_gaussian_noise_std

        if self.augmentations and self.num_rand_augment_ops > 0:
            self.augmentation_pool = self._init_augmentation_pool()
            if not self.augmentation_pool:
                print("WARNING: Augmentations enabled, N > 0, but augmentation pool is empty. Check enable flags in config.")
        else:
            self.augmentation_pool = []

    def _init_augmentation_pool(self):
        pool = []

        # Geometric Augmentations (affect image and mask)
        if self._h_flip_enable:
            def horizontal_flip(image, mask):
                return TF.hflip(image), TF.hflip(mask)
            pool.append(horizontal_flip)

        if self._v_flip_enable:
            def vertical_flip(image, mask):
                return TF.vflip(image), TF.vflip(mask)
            pool.append(vertical_flip)

        if self._rotation_enable:
            def rotation(image, mask):
                angle = random.choice(self._rotation_degrees_options)
                if angle != 0:
                    image = TF.rotate(image, angle, interpolation=TF.InterpolationMode.BILINEAR)
                    mask = TF.rotate(mask, angle, interpolation=TF.InterpolationMode.NEAREST)
                return image, mask
            pool.append(rotation)

        # Radiometric Augmentations (affect image only)
        # Corrected to handle multi-channel images by applying per-channel
        if self._brightness_enable:
            def adjust_brightness(image, mask):
                factor = random.uniform(self._brightness_range[0], self._brightness_range[1])
                if image.shape[0] not in [1, 3]: 
                    adjusted_channels = []
                    for i in range(image.shape[0]):
                        adjusted_channels.append(TF.adjust_brightness(image[i:i+1], factor))
                    return torch.cat(adjusted_channels, dim=0), mask
                else: # Standard 1 or 3 channel image
                    return TF.adjust_brightness(image, factor), mask
            pool.append(adjust_brightness)

        if self._contrast_enable:
            def adjust_contrast(image, mask):
                factor = random.uniform(self._contrast_range[0], self._contrast_range[1])
                if image.shape[0] not in [1, 3]: 
                    adjusted_channels = []
                    for i in range(image.shape[0]):
                        adjusted_channels.append(TF.adjust_contrast(image[i:i+1], factor))
                    return torch.cat(adjusted_channels, dim=0), mask
                else: # Standard 1 or 3 channel image
                    return TF.adjust_contrast(image, factor), mask
            pool.append(adjust_contrast)

        if self._gamma_enable:
            def adjust_gamma(image, mask):
                gamma_val = max(0.1, random.uniform(self._gamma_range[0], self._gamma_range[1]))
                if image.shape[0] not in [1, 3]: # If not grayscale or RGB, apply per channel
                    adjusted_channels = []
                    for i in range(image.shape[0]):
                        # Apply to each channel (shape [1, H, W])
                        adjusted_channels.append(TF.adjust_gamma(image[i:i+1], gamma_val, gain=1))
                    return torch.cat(adjusted_channels, dim=0), mask
                else: # Standard 1 or 3 channel image
                     return TF.adjust_gamma(image, gamma_val, gain=1), mask
            pool.append(adjust_gamma)

        # Noise Augmentation (affect image only)
        if self._noise_enable and self._gaussian_noise_std > 0:
            def add_gaussian_noise(image, mask):
                noise = torch.randn_like(image) * self._gaussian_noise_std # randn_like works for any shape
                return torch.clamp(image + noise, 0.0, 1.0), mask
            pool.append(add_gaussian_noise)

        return pool

    def __len__(self): return len(self.file_identifiers)

    def _load_process_and_normalize_image(self, image_identifier):
        # (This method remains unchanged)
        image_filename = f"{image_identifier}_S2.tif"
        image_path = os.path.join(self.patches_dir, image_filename)
        with rasterio.open(image_path) as src: raw_image_data_from_file = src.read()
        processed_bands_for_satlas = []
        for target_satlas_channel_name in self.satlas_input_channel_names:
            source_raw_band_name = self.satlas_channel_to_source_raw_band_map[target_satlas_channel_name]
            raw_band_idx_in_file = self.raw_band_to_idx_map[source_raw_band_name]
            processed_bands_for_satlas.append(raw_image_data_from_file[raw_band_idx_in_file])
        satlas_stacked_image_np = np.stack(processed_bands_for_satlas, axis=0).astype(np.float32)
        image_tensor = torch.from_numpy(satlas_stacked_image_np)
        if image_tensor.shape[1:] != (self.target_img_size, self.target_img_size):
            image_tensor = F.interpolate(image_tensor.unsqueeze(0),
                                         size=(self.target_img_size, self.target_img_size),
                                         mode='bilinear', align_corners=False).squeeze(0)
        image_tensor = image_tensor / self.normalization_divisor
        image_tensor = torch.clip(image_tensor, 0.0, 1.0)
        return image_tensor

    def __getitem__(self, idx):
        identifier = self.file_identifiers[idx]
        try:
            image_tensor = self._load_process_and_normalize_image(identifier) # Shape: [C, H, W]
            mask_filename = f"{identifier}_Mask.tif"; label_path = os.path.join(self.targets_dir, mask_filename)
            with rasterio.open(label_path) as src: label = src.read(1)
            label_tensor = F.interpolate(torch.tensor(label, dtype=torch.float32).unsqueeze(0).unsqueeze(0),
                                         size=(self.target_img_size, self.target_img_size), mode='nearest').squeeze(0).squeeze(0).long() # Shape: [H, W]

            if self.augmentations and self.augmentation_pool and self.num_rand_augment_ops > 0:
                label_tensor_aug = label_tensor.unsqueeze(0)
                num_ops_to_apply = min(self.num_rand_augment_ops, len(self.augmentation_pool))
                chosen_ops = random.sample(self.augmentation_pool, num_ops_to_apply)

                for op_func in chosen_ops:
                    image_tensor, label_tensor_aug = op_func(image_tensor, label_tensor_aug)

                label_tensor = label_tensor_aug.squeeze(0)

            if image_tensor.shape!=(NUM_SATLAS_INPUT_BANDS,self.target_img_size,self.target_img_size) or \
               label_tensor.shape!=(self.target_img_size,self.target_img_size):
                raise ValueError(f"Shape mismatch for {identifier}. Img:{image_tensor.shape}, Lbl:{label_tensor.shape}")

            return image_tensor, label_tensor
        except Exception as e:
            print(f"ERROR in __getitem__ for {identifier}: {e}"); traceback.print_exc(); raise e
# ============ DATA DISCOVERY ============ #
print("\n=========== Data Loading & Prep (allenai/satlaspretrain_models - Single GPU V16 - RandAugment) ============") # Updated print
image_files_pattern = os.path.join(PATCHES_DIR, "Patch*_S2.tif")
all_image_filepaths = sorted(glob.glob(image_files_pattern))
initial_file_identifiers = []
if not all_image_filepaths: print(f"ERROR: No image files found in {PATCHES_DIR}")
else: print(f"Found {len(all_image_filepaths)} potential image files. Checking masks...")
for img_path in all_image_filepaths:
    base_name = os.path.basename(img_path); identifier_match = re.match(r"(Patch\d+_W\d+_\d{8})_S2\.tif", base_name)
    if identifier_match:
        identifier = identifier_match.group(1); mask_path = os.path.join(TARGETS_DIR, f"{identifier}_Mask.tif")
        if os.path.exists(mask_path): initial_file_identifiers.append(identifier)
        else: print(f"WARNING: Mask for {identifier} not found. Skipping.")
    else: print(f"WARNING: Could not parse {base_name}. Skipping.")
print(f"Found {len(initial_file_identifiers)} initial image-mask pairs.")

# ============ DATA INTEGRITY CHECK ============ #
print(f"\nPerforming data integrity check on {len(initial_file_identifiers)} pairs...")
corrupted_identifiers, valid_identifiers = [], []
if not initial_file_identifiers: print("No identifiers to check.")
else:
    for identifier in initial_file_identifiers:
        image_path = os.path.join(PATCHES_DIR, f"{identifier}_S2.tif"); problem, msgs = False, []
        try:
            with rasterio.open(image_path) as src: image_data_raw_check = src.read()
            if image_data_raw_check.shape[0] < NUM_RAW_BANDS_AVAILABLE:
                msgs.append(f"File has {image_data_raw_check.shape[0]} bands, expected {NUM_RAW_BANDS_AVAILABLE}."); problem = True
            else:
                data_to_check_integrity = image_data_raw_check[:NUM_RAW_BANDS_AVAILABLE, :, :].astype(np.float32)
                if np.isnan(data_to_check_integrity).any(): msgs.append("NaNs"); problem = True
                if data_to_check_integrity.min() < -1000: msgs.append(f"Low raw values: {data_to_check_integrity.min()}"); problem = True
                if data_to_check_integrity.max() > 20000: msgs.append(f"High raw values: {data_to_check_integrity.max()}"); problem = True
                if data_to_check_integrity.size > 0 and not problem:
                    is_uniform = True
                    for i in range(data_to_check_integrity.shape[0]):
                        if data_to_check_integrity.shape[1]>0 and data_to_check_integrity.shape[2]>0:
                            if not np.allclose(data_to_check_integrity[i], data_to_check_integrity[i,0,0], atol=1e-1): is_uniform = False; break
                        elif data_to_check_integrity.shape[1] == 0 or data_to_check_integrity.shape[2] == 0: is_uniform = False; break
                    if is_uniform and data_to_check_integrity.shape[1]>0 : msgs.append("All raw bands appear uniform"); problem = True
                elif data_to_check_integrity.size == 0: msgs.append("Empty raw data"); problem = True
            if problem: print(f"WARNING: {identifier} issues: {'; '.join(msgs)}. Skipping."); corrupted_identifiers.append(identifier)
            else: valid_identifiers.append(identifier)
        except Exception as e: print(f"ERROR checking {identifier}: {e}. Skipping."); corrupted_identifiers.append(identifier)
print(f"Integrity check complete. Corrupted: {len(corrupted_identifiers)}.")
clean_identifiers = valid_identifiers; num_clean_patches = len(clean_identifiers)
print(f"Clean patches: {num_clean_patches}")


# ============ DATA SPLIT AND LOADERS (pin_memory=False, num_workers=0) ============ #
train_ds, val_ds, train_loader, val_loader = None, None, None, None
if num_clean_patches > 0:
    random.shuffle(clean_identifiers)
    split_idx = int(num_clean_patches * 0.8)
    train_ids, val_ids = (clean_identifiers[:split_idx], clean_identifiers[split_idx:]) if num_clean_patches >=2 else (clean_identifiers, [])
    if not val_ids and len(train_ids) > 1 : val_ids = [train_ids.pop()]
    print(f"Training IDs: {len(train_ids)}, Validation IDs: {len(val_ids)}")

    if train_ids:
        train_ds = CerealSensingDataset(
            file_identifiers=train_ids,
            patches_dir=PATCHES_DIR,
            targets_dir=TARGETS_DIR,
            target_img_size=IMG_SIZE,
            raw_band_to_idx_map=RAW_BAND_TO_IDX_MAP,
            satlas_input_channel_names=SATLAS_INPUT_CHANNEL_NAMES,
            satlas_channel_to_source_raw_band_map=SATLAS_CHANNEL_TO_SOURCE_RAW_BAND_MAP,
            normalization_divisor=NORMALIZATION_DIVISOR,
            augmentations=APPLY_AUGMENTATIONS_TRAIN_SET, # Master switch
            num_rand_augment_ops=NUM_RAND_AUGMENT_OPS,   # N for RandAugment
            # Pass all RandAugment config parameters
            rand_aug_h_flip_enable=RAND_AUG_H_FLIP_ENABLE,
            rand_aug_v_flip_enable=RAND_AUG_V_FLIP_ENABLE,
            rand_aug_rotation_enable=RAND_AUG_ROTATION_ENABLE,
            rand_aug_rotation_degrees_options=RAND_AUG_ROTATION_DEGREES_OPTIONS,
            rand_aug_brightness_enable=RAND_AUG_BRIGHTNESS_ENABLE,
            rand_aug_brightness_range=RAND_AUG_BRIGHTNESS_RANGE,
            rand_aug_contrast_enable=RAND_AUG_CONTRAST_ENABLE,
            rand_aug_contrast_range=RAND_AUG_CONTRAST_RANGE,
            rand_aug_gamma_enable=RAND_AUG_GAMMA_ENABLE,
            rand_aug_gamma_range=RAND_AUG_GAMMA_RANGE,
            rand_aug_gaussian_noise_enable=RAND_AUG_GAUSSIAN_NOISE_ENABLE,
            rand_aug_gaussian_noise_std=RAND_AUG_GAUSSIAN_NOISE_STD
        )
        train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)
        print(f"Train DataLoader: {len(train_loader.dataset)} samples. Augmentations: {APPLY_AUGMENTATIONS_TRAIN_SET}, N_ops={NUM_RAND_AUGMENT_OPS if APPLY_AUGMENTATIONS_TRAIN_SET else 0}")

    if val_ids:
        val_ds = CerealSensingDataset(
            file_identifiers=val_ids,
            patches_dir=PATCHES_DIR,
            targets_dir=TARGETS_DIR,
            target_img_size=IMG_SIZE,
            raw_band_to_idx_map=RAW_BAND_TO_IDX_MAP,
            satlas_input_channel_names=SATLAS_INPUT_CHANNEL_NAMES,
            satlas_channel_to_source_raw_band_map=SATLAS_CHANNEL_TO_SOURCE_RAW_BAND_MAP,
            normalization_divisor=NORMALIZATION_DIVISOR,
            augmentations=False # IMPORTANT: No augmentations for validation set
        )
        val_loader = DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)
        print(f"Val DataLoader: {len(val_loader.dataset)} samples. Augmentations: False")
    elif not val_ids and train_ids:
        print("WARNING: Validation dataset is empty (but training data exists).")
    elif not train_ids :
        print("WARNING: Training dataset is empty, so validation dataset is also empty.")
else:
    print("ERROR: No clean patches to create datasets.")
# Quick Visualisation (Sanity Check - allenai/satlaspretrain_models V16)
if train_ds and len(train_ds) > 0:
    try:
        print("\n--- Visualizing a sample from TRAIN_DS (allenai/satlaspretrain_models V16, TCI view, Geometric Aug) ---") # Updated print
        sample_idx = 0 ; sample_img_tensor, sample_mask_tensor = train_ds[sample_idx] # This will apply augmentations if enabled
        print(f"Sample Img Shape: {sample_img_tensor.shape}, Mask Shape: {sample_mask_tensor.shape}, Unique Mask: {torch.unique(sample_mask_tensor)}")

        tci_view_tensor = sample_img_tensor[:3, :, :].cpu() # Assuming first 3 channels are R,G,B for TCI
        print(f"TCI View (Ch 0,1,2 - SR data / {NORMALIZATION_DIVISOR}) - Min: {tci_view_tensor.min():.4f}, Max: {tci_view_tensor.max():.4f}, Mean: {tci_view_tensor.mean():.4f}")

        GAIN_FOR_VISUALIZATION = 3.5
        img_for_vis_gained = tci_view_tensor.numpy().transpose(1, 2, 0) * GAIN_FOR_VISUALIZATION
        img_display_np = np.clip(img_for_vis_gained, 0, 1)

        fig, axs = plt.subplots(1,2, figsize=(10,5))
        axs[0].imshow(img_display_np); axs[0].set_title(f'Input ID: {train_ds.file_identifiers[sample_idx]} (TCI, Vis Gain x{GAIN_FOR_VISUALIZATION})\nAugmented Sample'); axs[0].axis('off')
        axs[1].imshow(sample_mask_tensor.cpu().numpy(), cmap='gray'); axs[1].set_title('Mask (Augmented)'); axs[1].axis('off')
        plt.tight_layout(); plt.show()

        print("\n--- Ranges for ALL 9 Normalized Channels in Sample (all divided by 8160): ---")
        for i, ch_name_vis in enumerate(SATLAS_INPUT_CHANNEL_NAMES):
            ch_data = sample_img_tensor[i, :, :].cpu()
            source_raw_band = SATLAS_CHANNEL_TO_SOURCE_RAW_BAND_MAP[ch_name_vis]
            print(f"  Channel {i} ('{ch_name_vis}' from '{source_raw_band}'): Min={ch_data.min():.6f}, Max={ch_data.max():.6f}, Mean={ch_data.mean():.6f}")

    except Exception as e: print(f"Visualization error: {e}"); traceback.print_exc()
else: print("Training dataset empty. Skipping visualization.")

# Load Satlas Model using Official `weights_manager` (allenai/satlaspretrain_models - V16)
print(f"\n--- Loading Satlas Model via weights_manager: {SATLAS_MODEL_IDENTIFIER} ---")
satlas_full_segmentation_model = None
try:
    weights_manager = satlaspretrain_models.Weights()

    satlas_full_segmentation_model = weights_manager.get_pretrained_model(
        model_identifier=SATLAS_MODEL_IDENTIFIER,
        fpn=True,
        head=satlaspretrain_models.Head.SEGMENT,
        num_categories=NUM_CLASSES
    )
    print(f"Successfully loaded Satlas model with segmentation head: {SATLAS_MODEL_IDENTIFIER}")

except Exception as e:
    print(f"Error loading Satlas model via weights_manager: {e}")
    traceback.print_exc()
    satlas_full_segmentation_model = None

# --- Horizontal Line ---
class LitSegmentationModel(LightningModule):
    def __init__(self, full_satlas_model, learning_rate, weight_decay, num_classes, class_names=None, freeze_backbone_fpn=False):
        super().__init__()
        self.save_hyperparameters(ignore=['full_satlas_model'])
        self.model = full_satlas_model

        if freeze_backbone_fpn:
            frozen_count = 0
            if hasattr(self.model, 'backbone'):
                for param in self.model.backbone.parameters(): param.requires_grad = False
                print("Satlas backbone parameters frozen.")
                frozen_count +=1
            if hasattr(self.model, 'fpn'):
                 for param in self.model.fpn.parameters(): param.requires_grad = False
                 print("Satlas FPN parameters frozen.")
                 frozen_count +=1
            if frozen_count == 0: print("WARNING: Could not find 'backbone' or 'fpn' attributes directly on self.model to freeze.")
            if hasattr(self.model, 'head'): print("Segmentation head (self.model.head) parameters are expected to be trainable.")

        self.class_weights_tensor = None # You might want to compute and set this
        self.class_names = class_names if class_names else [f"class_{i}" for i in range(num_classes)]

        # Updated MetricCollection to include per-class accuracy
        metrics = MetricCollection({
            'accuracy_overall': Accuracy(task="multiclass", num_classes=num_classes, average='micro'),
            'f1_score_macro': F1Score(task="multiclass", num_classes=num_classes, average='macro'),
            'jaccard_index_macro': JaccardIndex(task="multiclass", num_classes=num_classes, average='macro'),
            'accuracy_per_class': Accuracy(task="multiclass", num_classes=num_classes, average='none'), # <-- ADDED
            'jaccard_index_per_class': JaccardIndex(task="multiclass", num_classes=num_classes, average='none'),
            'f1_score_per_class': F1Score(task="multiclass", num_classes=num_classes, average='none')
        })
        self.val_metrics = metrics.clone(prefix='val_')

    def forward(self, x):
        model_output = self.model(x)
        # The UNetDecoder in satlaspretrain_models might return (logits, None)
        if isinstance(model_output, tuple):
            logits = model_output[0] # Take the first element, which should be the main logits
        else:
            logits = model_output 
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch; logits = self.forward(x)
        loss = F.cross_entropy(logits, y, weight=(self.class_weights_tensor.to(logits.device) if self.class_weights_tensor is not None else None))
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True); return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch; logits = self.forward(x)
        loss = F.cross_entropy(logits, y, weight=(self.class_weights_tensor.to(logits.device) if self.class_weights_tensor is not None else None))
        self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True); self.val_metrics.update(torch.argmax(logits, dim=1), y); return loss

    def on_validation_epoch_end(self):
        metrics_output = self.val_metrics.compute()
        for name, value in metrics_output.items():
            is_prog_bar = name in ["val_f1_score_macro", "val_jaccard_index_macro"] # Overall metrics for progress bar

            # Handles per-class metrics (like Jaccard, F1, and now Accuracy)
            if isinstance(value, torch.Tensor) and value.numel() == self.hparams.num_classes:
                for i in range(self.hparams.num_classes):
                    class_metric_name = f"{name}_{self.class_names[i]}" # e.g., val_accuracy_per_class_cereal
                    prog_bar_metric_name, log_to_prog_bar = None, False

                    # Specific handling for "cereal" class IoU and F1 on progress bar
                    if self.class_names[i] == "cereal":
                        if name == "val_jaccard_index_per_class":
                            prog_bar_metric_name, log_to_prog_bar = "val_iou_cereal", True
                        elif name == "val_f1_score_per_class":
                            prog_bar_metric_name, log_to_prog_bar = "val_f1_cereal", True


                    if prog_bar_metric_name and log_to_prog_bar:
                        # Log special prog_bar name (e.g. val_iou_cereal) and the full name to CSV logger
                        self.log(prog_bar_metric_name, value[i], prog_bar=True, logger=False, sync_dist=True)
                        self.log(class_metric_name, value[i], prog_bar=False, logger=True, sync_dist=True)
                    else:
                        # Log the full class metric name to CSV logger only (e.g., val_accuracy_per_class_non_cereal)
                        self.log(class_metric_name, value[i], prog_bar=False, logger=True, sync_dist=True)
            else:
                # Handles overall metrics (e.g. val_loss, val_f1_score_macro)
                self.log(name, value, prog_bar=is_prog_bar, logger=True, sync_dist=True)
        self.val_metrics.reset()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)

# Re-instantiate the LightningModule
# This part should be run after Cell 6 successfully loads satlas_full_segmentation_model
CLASS_NAMES_FOR_LOGGING = ["non_cereal", "cereal"] # Ensure this matches your classes
lightning_model = None # Reset before re-instantiation

if 'satlas_full_segmentation_model' in locals() and satlas_full_segmentation_model is not None:
    lightning_model = LitSegmentationModel(
        full_satlas_model=satlas_full_segmentation_model,
        learning_rate=LR, # Assumes LR is defined in Cell 2
        weight_decay=WEIGHT_DECAY, # Assumes WEIGHT_DECAY is defined in Cell 2
        num_classes=NUM_CLASSES, # Assumes NUM_CLASSES is defined in Cell 2
        class_names=CLASS_NAMES_FOR_LOGGING,
        freeze_backbone_fpn=FREEZE_BACKBONE_FPN # Assumes FREEZE_BACKBONE_FPN is defined in Cell 2
    )
    print("LitSegmentationModel re-instantiated with per-class accuracy.")
else:
    print("ERROR: satlas_full_segmentation_model not loaded (from Cell 6). Cannot instantiate LitSegmentationModel.")
# Training (Simplified for Single GPU - V16)
# --- CUDA Initialization Check ---
if DEVICE == "cuda" and torch.cuda.is_available():
    if torch.cuda.is_initialized():
        print("WARNING FROM DEBUG: CUDA IS ALREADY INITIALIZED before Trainer instantiation!")
    else:
        print("INFO FROM DEBUG: CUDA is NOT yet initialized before Trainer instantiation. This is ideal.")
# --- End of debug check ---

csv_logger = CSVLogger(LOGS_DIR_ACTUAL, name="cereal_seg_satlas_singleGPU_v16_geom_aug_run") # Updated logger name
print(f"Logs: {csv_logger.log_dir}")
loss_cb = ModelCheckpoint(CHECKPOINT_SAVE_DIR_ACTUAL, 'model_singleGPU_geom_loss-{epoch:02d}-{val_loss:.3f}', monitor='val_loss', mode='min', save_top_k=1, verbose=True)
iou_cb = ModelCheckpoint(CHECKPOINT_SAVE_DIR_ACTUAL, 'model_singleGPU_geom_iou_cereal-{epoch:02d}-{val_iou_cereal:.3f}', monitor='val_iou_cereal', mode='max', save_top_k=1, verbose=True)
early_stop_cb = EarlyStopping(EARLY_STOPPING_MONITOR, patience=EARLY_STOPPING_PATIENCE, verbose=True, mode=("min" if "loss" in EARLY_STOPPING_MONITOR else "max"))

if lightning_model and train_loader:
    trainer_accelerator = DEVICE
    trainer_devices_setting = 1
    trainer_strategy = "auto"

    if DEVICE == "cuda" and torch.cuda.is_available():
        print(f"INFO: CUDA is available. Configuring for single GPU training on the first GPU.")
        trainer_devices_setting = [0] # Explicitly use the first GPU (index 0)
    elif DEVICE == "cpu":
        print(f"INFO: Configuring for CPU training.")
        trainer_accelerator = "cpu"
    else:
        print(f"WARNING: DEVICE='{DEVICE}' but CUDA not actually available. Forcing CPU training.")
        trainer_accelerator = "cpu"
        trainer_devices_setting = 1

    trainer = Trainer(accelerator=trainer_accelerator,
                      devices=trainer_devices_setting,
                      strategy=trainer_strategy,
                      max_epochs=EPOCHS,
                      logger=csv_logger,
                      callbacks=[early_stop_cb, loss_cb, iou_cb],
                      gradient_clip_val=0.5) 

    print(f"Trainer initialized: accelerator='{trainer_accelerator}', devices={trainer_devices_setting}, strategy='{trainer_strategy}'")
    print(f"Starting training: {SATLAS_MODEL_IDENTIFIER} (all bands / {NORMALIZATION_DIVISOR}, Geometric Aug)...") 
    try:
        val_dl_arg = val_loader if val_loader and hasattr(val_loader, 'dataset') and len(val_loader.dataset)>0 else None
        cbs_to_use = trainer.callbacks
        if not val_dl_arg and cbs_to_use:
            print("No val_loader, adjusting callbacks that monitor validation metrics.")
            original_callbacks = list(cbs_to_use); cbs_to_use = []
            for cb_instance in original_callbacks:
                monitor_attr = getattr(cb_instance, 'monitor', None)
                if monitor_attr and monitor_attr.startswith("val_"): print(f"Removing callback: {cb_instance.__class__.__name__} on {monitor_attr}")
                else: cbs_to_use.append(cb_instance)
            trainer.callbacks = cbs_to_use if cbs_to_use else None

        trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_dl_arg)
        print("Training finished.")
        if loss_cb and hasattr(loss_cb, 'best_model_path') and loss_cb.best_model_path: print(f"Best loss model: {loss_cb.best_model_path}")
        if iou_cb and hasattr(iou_cb, 'best_model_path') and iou_cb.best_model_path: print(f"Best IoU Cereal model: {iou_cb.best_model_path}")
    except Exception as e: print(f"Training error: {e}"); traceback.print_exc()
else: print("ERROR: Training prerequisites not met (lightning_model or train_loader is None).")