In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchmetrics import Accuracy, Precision, Recall, F1Score, JaccardIndex
from torchmetrics.collections import MetricCollection
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import random
from terratorch.tasks import SemanticSegmentationTask
import rasterio
import glob
import re
import traceback

BASE_DATA_DIR_ACTUAL = "CerealDataBetter"
CHECKPOINT_SAVE_DIR_ACTUAL = os.path.join(os.getcwd(), "checkpoints_auto_augment")
LOGS_DIR_ACTUAL = os.path.join(os.getcwd(), "metrics_auto_augment")
PATCHES_DIR = os.path.join(BASE_DATA_DIR_ACTUAL, "GEE_WeeklyPatches")
TARGETS_DIR = os.path.join(BASE_DATA_DIR_ACTUAL, "GEE_WeeklyMasks")

os.makedirs(CHECKPOINT_SAVE_DIR_ACTUAL, exist_ok=True)
os.makedirs(LOGS_DIR_ACTUAL, exist_ok=True)

EXPECTED_BANDS_ORDER = ["B02", "B03", "B04", "B8A", "B11", "B12"]
NUM_CLASSES = 2
IMG_SIZE = 256
BATCH_SIZE = 4
EPOCHS = 100
LR = 1e-5
WEIGHT_DECAY = 0.01
HEAD_DROPOUT = 0.2
FREEZE_BACKBONE = False
RAND_AUGMENT_N = 2
ROTATION_DEGREES = 180
BRIGHTNESS_RANGE = (0.9, 1.1)
CONTRAST_RANGE = (0.9, 1.1)
GAMMA_RANGE = (0.9, 1.1)
GAUSSIAN_NOISE_STD = 0.01
EARLY_STOPPING_PATIENCE = 15
EARLY_STOPPING_MONITOR = "val_loss"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class CerealSensingDataset(Dataset):
    def __init__(self, file_identifiers, patches_dir, targets_dir, target_img_size,
                 augmentations=False,
                 rand_augment_n=0,
                 rotation_degrees=0,
                 brightness_range=(1.0, 1.0),
                 contrast_range=(1.0, 1.0),
                 gamma_range=(1.0, 1.0),
                 gaussian_noise_std=0.0):
        self.file_identifiers = file_identifiers
        self.patches_dir = patches_dir
        self.targets_dir = targets_dir
        self.target_img_size = target_img_size
        self.augmentations = augmentations
        self.rand_augment_n = rand_augment_n
        self.rotation_degrees = rotation_degrees
        self.brightness_range = brightness_range
        self.contrast_range = contrast_range
        self.gamma_range = gamma_range
        self.gaussian_noise_std = gaussian_noise_std
        self.augmentation_pool = [
            self._aug_hflip,
            self._aug_vflip,
            self._aug_rotate,
            self._aug_brightness_single_channel,
            self._aug_contrast_single_channel,
            self._aug_gamma_single_channel
        ]
        if self.gaussian_noise_std > 0:
            self.augmentation_pool.append(self._aug_gaussian_noise_all_channels)

    def _aug_hflip(self, image_tensor, mask_tensor_float):
        return TF.hflip(image_tensor), TF.hflip(mask_tensor_float)

    def _aug_vflip(self, image_tensor, mask_tensor_float):
        return TF.vflip(image_tensor), TF.vflip(mask_tensor_float)

    def _aug_rotate(self, image_tensor, mask_tensor_float):
        if self.rotation_degrees > 0:
            angle = random.uniform(-self.rotation_degrees, self.rotation_degrees)
            image_tensor = TF.rotate(image_tensor, angle, interpolation=TF.InterpolationMode.BILINEAR)
            mask_tensor_float = TF.rotate(mask_tensor_float.unsqueeze(0), angle, interpolation=TF.InterpolationMode.NEAREST).squeeze(0)
        return image_tensor, mask_tensor_float

    def _aug_brightness_single_channel(self, image_tensor, mask_tensor_float):
        num_channels = image_tensor.shape[0]
        channel_to_aug = random.randint(0, num_channels - 1)
        factor = random.uniform(self.brightness_range[0], self.brightness_range[1])
        image_tensor[channel_to_aug:channel_to_aug+1] = TF.adjust_brightness(
            image_tensor[channel_to_aug:channel_to_aug+1], factor
        )
        return image_tensor, mask_tensor_float

    def _aug_contrast_single_channel(self, image_tensor, mask_tensor_float):
        num_channels = image_tensor.shape[0]
        channel_to_aug = random.randint(0, num_channels - 1)
        factor = random.uniform(self.contrast_range[0], self.contrast_range[1])
        image_tensor[channel_to_aug:channel_to_aug+1] = TF.adjust_contrast(
            image_tensor[channel_to_aug:channel_to_aug+1], factor
        )
        return image_tensor, mask_tensor_float

    def _aug_gamma_single_channel(self, image_tensor, mask_tensor_float):
        num_channels = image_tensor.shape[0]
        channel_to_aug = random.randint(0, num_channels - 1)
        gamma = random.uniform(self.gamma_range[0], self.gamma_range[1])
        gamma = max(0.1, gamma)
        channel_non_neg = torch.clamp(image_tensor[channel_to_aug:channel_to_aug+1], min=0)
        image_tensor[channel_to_aug:channel_to_aug+1] = TF.adjust_gamma(channel_non_neg, gamma)
        return image_tensor, mask_tensor_float

    def _aug_gaussian_noise_all_channels(self, image_tensor, mask_tensor_float):
        if self.gaussian_noise_std > 0:
            noise = torch.randn_like(image_tensor) * self.gaussian_noise_std
            image_tensor = image_tensor + noise
        return image_tensor, mask_tensor_float

    def __len__(self):
        return len(self.file_identifiers)

    def __getitem__(self, idx):
        identifier = self.file_identifiers[idx]
        image_filename = f"{identifier}_S2.tif"
        mask_filename = f"{identifier}_Mask.tif"
        image_path = os.path.join(self.patches_dir, image_filename)
        label_path = os.path.join(self.targets_dir, mask_filename)

        try:
            with rasterio.open(image_path) as src:
                image = src.read()
                if image.shape[0] != 6:
                    raise ValueError(f"Image {identifier} has {image.shape[0]} bands, expected 6.")
            with rasterio.open(label_path) as src:
                label = src.read(1)

            image_tensor = torch.tensor(image, dtype=torch.float32)
            label_tensor_float = torch.tensor(label, dtype=torch.float32)

            image_tensor = F.interpolate(image_tensor.unsqueeze(0),
                                         size=(self.target_img_size, self.target_img_size),
                                         mode='bilinear', align_corners=False).squeeze(0)
            label_tensor_float = F.interpolate(label_tensor_float.unsqueeze(0).unsqueeze(0),
                                         size=(self.target_img_size, self.target_img_size),
                                         mode='nearest').squeeze(0).squeeze(0)

            if self.augmentations and self.rand_augment_n > 0 and self.augmentation_pool:
                num_ops_to_apply = min(self.rand_augment_n, len(self.augmentation_pool))
                ops_to_apply = random.sample(self.augmentation_pool, num_ops_to_apply)
                for op_func in ops_to_apply:
                    image_tensor, label_tensor_float = op_func(image_tensor, label_tensor_float)

            label_tensor = label_tensor_float.long()
            image_tensor = image_tensor.float()

            if image_tensor.shape != (6, self.target_img_size, self.target_img_size) or \
               label_tensor.shape != (self.target_img_size, self.target_img_size):
                raise ValueError(f"Shape mismatch for {identifier} post-processing. Img: {image_tensor.shape}, Lbl: {label_tensor.shape}")
            return image_tensor, label_tensor
        except Exception as e:
            print(f"Error loading/processing {identifier}: {e}")
            traceback.print_exc()
            raise e

image_files_pattern = os.path.join(PATCHES_DIR, "Patch*_S2.tif")
all_image_filepaths = sorted(glob.glob(image_files_pattern))
initial_file_identifiers = []
for img_path in all_image_filepaths:
    base_name = os.path.basename(img_path)
    identifier_match = re.match(r"(Patch\d+_W\d+_\d{8})_S2\.tif", base_name)
    if identifier_match:
        identifier = identifier_match.group(1)
        mask_path = os.path.join(TARGETS_DIR, f"{identifier}_Mask.tif")
        if os.path.exists(mask_path):
            initial_file_identifiers.append(identifier)

corrupted_identifiers, valid_identifiers = [], []
for identifier in initial_file_identifiers:
    image_path = os.path.join(PATCHES_DIR, f"{identifier}_S2.tif")
    try:
        with rasterio.open(image_path) as src:
            image_data = src.read()
        problem = False
        if np.isnan(image_data).any():
            problem = True
        if image_data.max() > 65535 or image_data.min() < -32768:
            problem = True
        if image_data.size > 0:
            is_uniform_check = True
            for band_idx_check in range(image_data.shape[0]):
                first_val_band_check = image_data[band_idx_check,0,0] if image_data.shape[1] > 0 and image_data.shape[2] > 0 else None
                if first_val_band_check is not None:
                    if np.issubdtype(image_data[band_idx_check].dtype, np.floating):
                        if not np.allclose(image_data[band_idx_check], first_val_band_check, atol=1e-5):
                            is_uniform_check = False
                            break
                    else:
                        if not np.all(image_data[band_idx_check] == first_val_band_check):
                            is_uniform_check = False
                            break
            if is_uniform_check and image_data.shape[1] > 0 and image_data.shape[2] > 0:
                problem = True
        else:
            problem = True
        if problem:
            corrupted_identifiers.append(identifier)
        else:
            valid_identifiers.append(identifier)
    except Exception as e:
        corrupted_identifiers.append(identifier)

clean_identifiers = valid_identifiers
num_clean_patches = len(clean_identifiers)

train_ds, val_ds, train_loader, val_loader = None, None, None, None
if num_clean_patches > 0:
    random.shuffle(clean_identifiers)
    if num_clean_patches >= 2:
        num_train = max(1, int(num_clean_patches * 0.8))
        train_ids, val_ids = clean_identifiers[:num_train], clean_identifiers[num_train:]
        if not val_ids and len(train_ids) > 1:
            val_ids = [train_ids.pop()]
    elif num_clean_patches == 1:
        train_ids, val_ids = clean_identifiers, []
    else:
        train_ids, val_ids = [], []

    if train_ids:
        train_ds = CerealSensingDataset(
            train_ids, PATCHES_DIR, TARGETS_DIR, IMG_SIZE,
            augmentations=True,
            rand_augment_n=RAND_AUGMENT_N,
            rotation_degrees=ROTATION_DEGREES,
            brightness_range=BRIGHTNESS_RANGE,
            contrast_range=CONTRAST_RANGE,
            gamma_range=GAMMA_RANGE,
            gaussian_noise_std=GAUSSIAN_NOISE_STD
        )
        num_data_workers = 4
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_data_workers, pin_memory=True if DEVICE=="cuda" else False, persistent_workers=True if num_data_workers > 0 else False)
    if val_ids:
        val_ds = CerealSensingDataset(val_ids, PATCHES_DIR, TARGETS_DIR, IMG_SIZE, augmentations=False)
        val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_data_workers, pin_memory=True if DEVICE=="cuda" else False, persistent_workers=True if num_data_workers > 0 else False)

model_args = {
    "backbone": "prithvi_eo_v2_300", "backbone_pretrained": True,
    "backbone_bands": EXPECTED_BANDS_ORDER, "decoder": "FCNDecoder",
    "decoder_channels": 256, "head_dropout": HEAD_DROPOUT,
    "num_classes": NUM_CLASSES, "rescale": True
}
task_config_obj = SemanticSegmentationTask(
    model_args=model_args, model_factory="EncoderDecoderFactory", loss="ce",
    freeze_backbone=FREEZE_BACKBONE, lr=LR, optimizer="adamw",
    optimizer_hparams={"weight_decay": WEIGHT_DECAY}
)

class LitSegmentationModel(LightningModule):
    def __init__(self, model_arch, learning_rate, weight_decay, num_classes, class_names=None):
        super().__init__()
        self.save_hyperparameters(ignore=['model_arch'])
        self.model = model_arch
        self.class_weights_tensor = None
        self.class_names = class_names if class_names else [f"class{i}" for i in range(num_classes)]
        metrics = MetricCollection({
            'accuracy_overall': Accuracy(task="multiclass", num_classes=self.hparams.num_classes, average='micro'),
            'precision_macro': Precision(task="multiclass", num_classes=self.hparams.num_classes, average='macro'),
            'recall_macro': Recall(task="multiclass", num_classes=self.hparams.num_classes, average='macro'),
            'f1_score_macro': F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro'),
            'jaccard_index_per_class': JaccardIndex(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'accuracy_per_class': Accuracy(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'precision_per_class': Precision(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'recall_per_class': Recall(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'f1_score_per_class': F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
        })
        self.val_metrics = metrics.clone(prefix='val_')

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        out_struct = self.model(x)
        logits = out_struct.output
        weights = self.class_weights_tensor.to(logits.device) if self.class_weights_tensor is not None else None
        loss = F.cross_entropy(logits, y, weight=weights)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out_struct = self.model(x)
        logits = out_struct.output
        weights = self.class_weights_tensor.to(logits.device) if self.class_weights_tensor is not None else None
        loss = F.cross_entropy(logits, y, weight=weights)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        preds_classes = torch.argmax(logits, dim=1)
        self.val_metrics.update(preds_classes, y)
        return loss

    def on_validation_epoch_end(self):
        metrics_output = self.val_metrics.compute()
        for name, value in metrics_output.items():
            log_to_prog_bar = False
            prog_bar_name_override = None
            if name in ["val_accuracy_overall", "val_f1_score_macro"]:
                log_to_prog_bar = True
            if isinstance(value, torch.Tensor) and value.numel() == self.hparams.num_classes:
                for i in range(self.hparams.num_classes):
                    class_metric_name_for_csv = f"{name}_{self.class_names[i]}"
                    current_metric_prog_bar_status = False
                    if self.class_names[i] == "cereal":
                        if name == "val_jaccard_index_per_class":
                            prog_bar_name_override = "val_iou_cereal"
                            current_metric_prog_bar_status = True
                        elif name == "val_f1_score_per_class":
                            prog_bar_name_override = "val_f1_cereal"
                            current_metric_prog_bar_status = True
                    if prog_bar_name_override and current_metric_prog_bar_status:
                        self.log(prog_bar_name_override, value[i], prog_bar=True, logger=False, on_step=False, on_epoch=True, sync_dist=True)
                        self.log(class_metric_name_for_csv, value[i], prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
                        prog_bar_name_override = None
                    else:
                        self.log(class_metric_name_for_csv, value[i], prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
            else:
                if not (prog_bar_name_override and name.endswith(prog_bar_name_override.split('_')[-1])):
                    self.log(name, value, prog_bar=log_to_prog_bar, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        self.val_metrics.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
        return optimizer

CLASS_NAMES_FOR_LOGGING = ["non_cereal", "cereal"]
lightning_model = None
if 'task_config_obj' in locals() or 'task_config_obj' in globals():
    lightning_model = LitSegmentationModel(
        model_arch=task_config_obj.model, learning_rate=LR, weight_decay=WEIGHT_DECAY,
        num_classes=NUM_CLASSES, class_names=CLASS_NAMES_FOR_LOGGING
    )

csv_logger = CSVLogger(save_dir=LOGS_DIR_ACTUAL, name="cereal_segmentation_rand_aug_run")

loss_checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINT_SAVE_DIR_ACTUAL, filename='model-rand-aug-best-loss-{epoch:02d}-{val_loss:.3f}',
    save_top_k=1, verbose=True, monitor='val_loss', mode='min'
)
iou_cereal_checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINT_SAVE_DIR_ACTUAL, filename='model-rand-aug-best-iou-cereal-{epoch:02d}-{val_jaccard_index_per_class_cereal:.3f}',
    save_top_k=1, verbose=True, monitor='val_jaccard_index_per_class_cereal', mode='max'
)
early_stop_callback = EarlyStopping(
    monitor=EARLY_STOPPING_MONITOR, patience=EARLY_STOPPING_PATIENCE, verbose=True,
    mode="min" if "loss" in EARLY_STOPPING_MONITOR.lower() else "max"
)

if lightning_model and train_loader:
    is_interactive = False
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            is_interactive = True
    except NameError:
        is_interactive = False

    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        strategy_to_use = "ddp_notebook" if is_interactive else "ddp_find_unused_parameters_true"
    else:
        strategy_to_use = "auto"

    trainer = Trainer(
        accelerator="gpu" if DEVICE == "cuda" else "cpu",
        devices="auto" if DEVICE == "cuda" else 1,
        strategy=strategy_to_use,
        num_nodes=1, max_epochs=EPOCHS, log_every_n_steps=10, logger=csv_logger,
        callbacks=[early_stop_callback, loss_checkpoint_callback, iou_cereal_checkpoint_callback],
        gradient_clip_val=0.5,
    )
    try:
        val_dataloaders_arg = val_loader if val_loader and len(val_loader.dataset) > 0 else None
        if not val_dataloaders_arg and EARLY_STOPPING_MONITOR.startswith("val_"):
            callbacks_list = [c for c in [loss_checkpoint_callback, iou_cereal_checkpoint_callback] if c.monitor is None or not c.monitor.startswith("val_")]
            if not EARLY_STOPPING_MONITOR.startswith("val_"):
                callbacks_list.append(early_stop_callback)
            trainer.callbacks = callbacks_list

        trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_dataloaders_arg)
    except Exception as e:
        print(f"An error occurred during training: {e}")
        traceback.print_exc()