In [None]:
# Imports
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchmetrics import Accuracy, Precision, Recall, F1Score, JaccardIndex
from torchmetrics.collections import MetricCollection
import torch.nn.functional as F
import random
from terratorch.tasks import SemanticSegmentationTask
import rasterio
import glob
import re
import traceback
import sys


# Configuration
BASE_DATA_DIR_ACTUAL = "CerealDataset"
CHECKPOINT_SAVE_DIR_ACTUAL = "checkpoints_baseline"
LOGS_DIR_ACTUAL = "logs_baseline"

PATCHES_DIR = f"{BASE_DATA_DIR_ACTUAL}/GEE_WeeklyPatches"
TARGETS_DIR = f"{BASE_DATA_DIR_ACTUAL}/GEE_WeeklyMasks"

os.makedirs(CHECKPOINT_SAVE_DIR_ACTUAL, exist_ok=True)
os.makedirs(LOGS_DIR_ACTUAL, exist_ok=True)

print(f"Data will be loaded from: {BASE_DATA_DIR_ACTUAL}")
print(f"Checkpoints will be saved to: {CHECKPOINT_SAVE_DIR_ACTUAL}")
print(f"CSV Logs will be saved to: {LOGS_DIR_ACTUAL}")

if not os.path.isdir(PATCHES_DIR):
    print(f"WARNING: Patches directory not found at {PATCHES_DIR}")
if not os.path.isdir(TARGETS_DIR):
    print(f"WARNING: Masks directory not found at {TARGETS_DIR}")

EXPECTED_BANDS_ORDER = ["B02", "B03", "B04", "B8A", "B11", "B12"]
NUM_CLASSES = 2
IMG_SIZE = 256

BATCH_SIZE = 4
EPOCHS = 100
LR = 1e-5
WEIGHT_DECAY = 0.01
HEAD_DROPOUT = 0.2
FREEZE_BACKBONE = False

EARLY_STOPPING_PATIENCE = 10
EARLY_STOPPING_MONITOR = "val_loss"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"PyTorch available: {torch.cuda.is_available()} (CUDA)")
if DEVICE == "cuda":
    print(f"CUDA available: {torch.cuda.device_count()} device(s).")


# Dataset Definition
class CerealSensingDataset(Dataset):
    def __init__(self, file_identifiers, patches_dir, targets_dir, target_img_size):
        self.file_identifiers = file_identifiers
        self.patches_dir = patches_dir
        self.targets_dir = targets_dir
        self.target_img_size = target_img_size

    def __len__(self):
        return len(self.file_identifiers)

    def __getitem__(self, idx):
        identifier = self.file_identifiers[idx]
        image_filename = f"{identifier}_S2.tif"
        mask_filename = f"{identifier}_Mask.tif"
        image_path = os.path.join(self.patches_dir, image_filename)
        label_path = os.path.join(self.targets_dir, mask_filename)

        try:
            with rasterio.open(image_path) as src:
                image = src.read()
                if image.shape[0] != 6:
                    raise ValueError(f"Image {identifier} has {image.shape[0]} bands, expected 6.")

            with rasterio.open(label_path) as src:
                label = src.read(1)

            image_tensor = torch.tensor(image, dtype=torch.float32)
            label_tensor = torch.tensor(label, dtype=torch.float32)

            image_tensor = F.interpolate(
                image_tensor.unsqueeze(0),
                size=(self.target_img_size, self.target_img_size),
                mode='bilinear', align_corners=False
            ).squeeze(0)

            label_tensor = F.interpolate(
                label_tensor.unsqueeze(0).unsqueeze(0),
                size=(self.target_img_size, self.target_img_size),
                mode='nearest'
            ).squeeze(0).squeeze(0)

            label_tensor = label_tensor.long()

            if image_tensor.shape != (6, self.target_img_size, self.target_img_size) or \
               label_tensor.shape != (self.target_img_size, self.target_img_size):
                raise ValueError(f"Shape mismatch for {identifier}")

            return image_tensor, label_tensor

        except Exception as e:
            print(f"Error loading/processing {identifier}: {e}")
            raise e


# Data loading, integrity check, splitting
print("\n=========== Data Loading & Prep ============")
image_files_pattern = os.path.join(PATCHES_DIR, "Patch*_S2.tif")
all_image_filepaths = sorted(glob.glob(image_files_pattern))

initial_file_identifiers = []
if not all_image_filepaths:
    print(f"ERROR: No image files found in {PATCHES_DIR}")
else:
    print(f"Found {len(all_image_filepaths)} potential images. Checking masks...")

for img_path in all_image_filepaths:
    base_name = os.path.basename(img_path)
    identifier_match = re.match(r"(Patch\d+_W\d+_\d{8})_S2\.tif", base_name)

    if identifier_match:
        identifier = identifier_match.group(1)
        mask_path = os.path.join(TARGETS_DIR, f"{identifier}_Mask.tif")
        if os.path.exists(mask_path):
            initial_file_identifiers.append(identifier)
        else:
            print(f"WARNING: Missing mask for {identifier}")
    else:
        print(f"WARNING: Could not parse {base_name}")

print(f"Found {len(initial_file_identifiers)} initial pairs.")

print(f"\nPerforming data integrity check on {len(initial_file_identifiers)} samples...")
corrupted_identifiers = []
valid_identifiers = []

for identifier in initial_file_identifiers:
    image_path = os.path.join(PATCHES_DIR, f"{identifier}_S2.tif")
    try:
        with rasterio.open(image_path) as src:
            image_data = src.read()

        problem = False
        msgs = []

        if np.isnan(image_data).any():
            msgs.append("NaNs")
            problem = True

        if np.any(image_data > 30000) or np.any(image_data < 0):
            msgs.append("Extreme values")
            problem = True

        if image_data.size > 0:
            is_uniform = np.allclose(image_data, image_data[0, 0, 0], atol=1e-5)
            if is_uniform:
                msgs.append("Uniform")
                problem = True

        if problem:
            print(f"WARNING: {identifier}: {msgs}")
            corrupted_identifiers.append(identifier)
        else:
            valid_identifiers.append(identifier)

    except Exception as e:
        print(f"ERROR checking {identifier}: {e}")
        corrupted_identifiers.append(identifier)

print(f"Integrity check complete. Corrupted: {len(corrupted_identifiers)}")
clean_identifiers = valid_identifiers
num_clean_patches = len(clean_identifiers)
print(f"Clean patches: {num_clean_patches}")

train_ds, val_ds, train_loader, val_loader = None, None, None, None

if num_clean_patches > 0:
    random.shuffle(clean_identifiers)

    if num_clean_patches >= 2:
        num_train = max(1, int(num_clean_patches * 0.8))
        train_ids = clean_identifiers[:num_train]
        val_ids = clean_identifiers[num_train:]
        if len(val_ids) == 0:
            val_ids = [train_ids.pop()]
    elif num_clean_patches == 1:
        train_ids = clean_identifiers
        val_ids = []
    else:
        train_ids = []
        val_ids = []

    print(f"Training IDs: {len(train_ids)}, Validation IDs: {len(val_ids)}")

    if train_ids:
        train_ds = CerealSensingDataset(train_ids, PATCHES_DIR, TARGETS_DIR, IMG_SIZE)
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

    if val_ids:
        val_ds = CerealSensingDataset(val_ids, PATCHES_DIR, TARGETS_DIR, IMG_SIZE)
        val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
else:
    print("ERROR: No clean patches available.")

# Model Setup with TerraTorch
print("\n=========== Model Setup ============")
model_args = {
    "backbone": "prithvi_eo_v2_300",
    "backbone_pretrained": True,
    "backbone_bands": EXPECTED_BANDS_ORDER,
    "decoder": "FCNDecoder",
    "decoder_channels": 256,
    "head_dropout": HEAD_DROPOUT,
    "num_classes": NUM_CLASSES,
    "rescale": True
}

task_config_obj = SemanticSegmentationTask(
    model_args=model_args,
    model_factory="EncoderDecoderFactory",
    loss="ce",
    freeze_backbone=FREEZE_BACKBONE,
    lr=LR,
    optimizer="adamw",
    optimizer_hparams={"weight_decay": WEIGHT_DECAY}
)

print(f"Backbone freezing: {FREEZE_BACKBONE}")


# Lightning Module
class LitSegmentationModel(LightningModule):
    def __init__(self, model_arch, learning_rate, weight_decay, num_classes, class_names=None):
        super().__init__()
        self.save_hyperparameters(ignore=['model_arch'])
        self.model = model_arch
        self.class_weights_tensor = None
        self.class_names = class_names if class_names else [f"class{i}" for i in range(num_classes)]

        metrics = MetricCollection({
            'accuracy_overall': Accuracy(task="multiclass", num_classes=self.hparams.num_classes, average='micro'),
            'precision_macro': Precision(task="multiclass", num_classes=self.hparams.num_classes, average='macro'),
            'recall_macro': Recall(task="multiclass", num_classes=self.hparams.num_classes, average='macro'),
            'f1_score_macro': F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro'),
            'jaccard_index_per_class': JaccardIndex(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'accuracy_per_class': Accuracy(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'precision_per_class': Precision(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'recall_per_class': Recall(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
            'f1_score_per_class': F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='none'),
        })

        self.val_metrics = metrics.clone(prefix='val_')

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        out_struct = self.model(x)
        logits = out_struct.output
        weights = self.class_weights_tensor.to(logits.device) if self.class_weights_tensor is not None else None
        loss = F.cross_entropy(logits, y, weight=weights)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out_struct = self.model(x)
        logits = out_struct.output
        weights = self.class_weights_tensor.to(logits.device) if self.class_weights_tensor is not None else None
        loss = F.cross_entropy(logits, y, weight=weights)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

        preds_classes = torch.argmax(logits, dim=1)
        self.val_metrics.update(preds_classes, y)
        return loss

    def on_validation_epoch_end(self):
        metrics_output = self.val_metrics.compute()

        for name, value in metrics_output.items():
            log_to_prog_bar = name in ["val_accuracy_overall", "val_f1_score_macro"]

            if isinstance(value, torch.Tensor) and value.numel() == self.hparams.num_classes:
                for i in range(self.hparams.num_classes):
                    class_metric_name = f"{name}_{self.class_names[i]}"
                    self.log(class_metric_name, value[i], prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)

                    if self.class_names[i] == "cereal":
                        if name == "val_jaccard_index_per_class":
                            self.log("val_iou_cereal", value[i], prog_bar=True, logger=False, on_epoch=True, sync_dist=True)
                        elif name == "val_f1_score_per_class":
                            self.log("val_f1_cereal", value[i], prog_bar=True, logger=False, on_epoch=True, sync_dist=True)
            else:
                self.log(name, value, prog_bar=log_to_prog_bar, logger=True, on_epoch=True, sync_dist=True)

        self.val_metrics.reset()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)


CLASS_NAMES_FOR_LOGGING = ["non_cereal", "cereal"]
lightning_model = None

if 'task_config_obj' in locals() or 'task_config_obj' in globals():
    lightning_model = LitSegmentationModel(
        model_arch=task_config_obj.model,
        learning_rate=LR,
        weight_decay=WEIGHT_DECAY,
        num_classes=NUM_CLASSES,
        class_names=CLASS_NAMES_FOR_LOGGING
    )
    print("LitSegmentationModel instantiated.")
else:
    print("ERROR: task_config_obj not defined.")
