In [None]:
# Cell 1: Imports
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchmetrics import Accuracy, Precision, Recall, F1Score, JaccardIndex
from torchmetrics.collections import MetricCollection
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import random
from terratorch.tasks import SemanticSegmentationTask
import rasterio
import glob
import re
import traceback

# Cell 2: Configuration
BASE_DATA_DIR_ACTUAL = "<DATASET_DIR>"
CHECKPOINT_SAVE_DIR_ACTUAL = "<CHECKPOINT_DIR>"
LOGS_DIR_ACTUAL = "<LOGS_DIR>"

PATCHES_DIR = os.path.join(BASE_DATA_DIR_ACTUAL, "GEE_WeeklyPatches")
TARGETS_DIR = os.path.join(BASE_DATA_DIR_ACTUAL, "GEE_WeeklyMasks")

os.makedirs(CHECKPOINT_SAVE_DIR_ACTUAL, exist_ok=True)
os.makedirs(LOGS_DIR_ACTUAL, exist_ok=True)

EXPECTED_BANDS_ORDER = ["B02", "B03", "B04", "B8A", "B11", "B12"]
NUM_CLASSES = 2
IMG_SIZE = 256

BATCH_SIZE = 4
EPOCHS = 100
LR = 1e-5
WEIGHT_DECAY = 0.01
HEAD_DROPOUT = 0.2
FREEZE_BACKBONE = False

ROTATION_DEGREES = 180
EARLY_STOPPING_PATIENCE = 15
EARLY_STOPPING_MONITOR = "val_loss"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Cell 3: Custom Dataset Definition (Geometric Augmentations)
class CerealSensingDataset(Dataset):
    def __init__(self, file_identifiers, patches_dir, targets_dir, target_img_size, augmentations=False, rotation_degrees=0):
        self.file_identifiers = file_identifiers
        self.patches_dir = patches_dir
        self.targets_dir = targets_dir
        self.target_img_size = target_img_size
        self.augmentations = augmentations
        self.rotation_degrees = rotation_degrees

    def __len__(self):
        return len(self.file_identifiers)

    def __getitem__(self, idx):
        identifier = self.file_identifiers[idx]
        image_path = os.path.join(self.patches_dir, f"{identifier}_S2.tif")
        label_path = os.path.join(self.targets_dir, f"{identifier}_Mask.tif")

        try:
            with rasterio.open(image_path) as src:
                image = src.read()
                if image.shape[0] != 6:
                    raise ValueError(f"Image {identifier} has {image.shape[0]} bands, expected 6.")

            with rasterio.open(label_path) as src:
                label = src.read(1)

            image_tensor = torch.tensor(image, dtype=torch.float32)
            label_tensor_float = torch.tensor(label, dtype=torch.float32)

            image_tensor = F.interpolate(image_tensor.unsqueeze(0),
                                         size=(self.target_img_size, self.target_img_size),
                                         mode='bilinear', align_corners=False).squeeze(0)

            label_tensor_float = F.interpolate(
                label_tensor_float.unsqueeze(0).unsqueeze(0),
                size=(self.target_img_size, self.target_img_size),
                mode='nearest'
            ).squeeze(0).squeeze(0)

            if self.augmentations:
                if random.random() > 0.5:
                    image_tensor = TF.hflip(image_tensor)
                    label_tensor_float = TF.hflip(label_tensor_float)
                if random.random() > 0.5:
                    image_tensor = TF.vflip(image_tensor)
                    label_tensor_float = TF.vflip(label_tensor_float)
                if self.rotation_degrees > 0 and random.random() > 0.5:
                    angle = random.uniform(-self.rotation_degrees, self.rotation_degrees)
                    image_tensor = TF.rotate(image_tensor, angle, interpolation=TF.InterpolationMode.BILINEAR)
                    label_tensor_float = TF.rotate(
                        label_tensor_float.unsqueeze(0),
                        angle,
                        interpolation=TF.InterpolationMode.NEAREST
                    ).squeeze(0)

            label_tensor = label_tensor_float.long()

            return image_tensor, label_tensor

        except Exception as e:
            print(f"Error loading {identifier}: {e}")
            traceback.print_exc()
            raise e

# Cell 4: Data Discovery and Loaders
print("\n=========== Data Loading & Prep ============")
image_files_pattern = os.path.join(PATCHES_DIR, "Patch*_S2.tif")
all_image_filepaths = sorted(glob.glob(image_files_pattern))

initial_file_identifiers = []
for img_path in all_image_filepaths:
    base_name = os.path.basename(img_path)
    match = re.match(r"(Patch\d+_W\d+_\d{8})_S2\.tif", base_name)
    if not match:
        continue
    identifier = match.group(1)
    mask_path = os.path.join(TARGETS_DIR, f"{identifier}_Mask.tif")
    if os.path.exists(mask_path):
        initial_file_identifiers.append(identifier)

corrupted_identifiers, valid_identifiers = [], []

for identifier in initial_file_identifiers:
    image_path = os.path.join(PATCHES_DIR, f"{identifier}_S2.tif")
    try:
        with rasterio.open(image_path) as src:
            image_data = src.read()

        problem = False
        if np.isnan(image_data).any():
            problem = True
        if np.any(image_data > 30000) or np.any(image_data < -1000):
            problem = True

        is_uniform = True
        for b in range(image_data.shape[0]):
            band = image_data[b]
            if not np.all(band == band[0, 0]):
                is_uniform = False
                break
        if is_uniform:
            problem = True

        if problem:
            corrupted_identifiers.append(identifier)
        else:
            valid_identifiers.append(identifier)

    except:
        corrupted_identifiers.append(identifier)

clean_identifiers = valid_identifiers
num_clean = len(clean_identifiers)

train_ds = val_ds = None
train_loader = val_loader = None

if num_clean > 0:
    random.shuffle(clean_identifiers)
    num_train = max(1, int(num_clean * 0.8))
    train_ids = clean_identifiers[:num_train]
    val_ids = clean_identifiers[num_train:]
    if not val_ids and len(train_ids) > 1:
        val_ids = [train_ids.pop()]

    train_ds = CerealSensingDataset(train_ids, PATCHES_DIR, TARGETS_DIR, IMG_SIZE, augmentations=True, rotation_degrees=ROTATION_DEGREES)
    val_ds = CerealSensingDataset(val_ids, PATCHES_DIR, TARGETS_DIR, IMG_SIZE, augmentations=False)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Cell 5: Visualization
if train_ds and len(train_ds) > 0:
    try:
        sample_img_aug, sample_mask_aug = train_ds[0]
        img_display_aug = sample_img_aug[:3].numpy()
        img_min_aug, img_max_aug = img_display_aug.min(), img_display_aug.max()
        img_norm_aug = (img_display_aug - img_min_aug) / max(1e-6, img_max_aug - img_min_aug)
        img_norm_aug = img_norm_aug.transpose(1, 2, 0)

        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].imshow(img_norm_aug)
        axs[1].imshow(sample_mask_aug.numpy(), cmap='gray')
        axs[0].axis('off'); axs[1].axis('off')
        plt.show()
    except:
        pass

# Cell 6: Model Setup
model_args = {
    "backbone": "prithvi_eo_v2_300",
    "backbone_pretrained": True,
    "backbone_bands": EXPECTED_BANDS_ORDER,
    "decoder": "FCNDecoder",
    "decoder_channels": 256,
    "head_dropout": HEAD_DROPOUT,
    "num_classes": NUM_CLASSES,
    "rescale": True
}

task_config_obj = SemanticSegmentationTask(
    model_args=model_args,
    model_factory="EncoderDecoderFactory",
    loss="ce",
    freeze_backbone=FREEZE_BACKBONE,
    lr=LR,
    optimizer="adamw",
    optimizer_hparams={"weight_decay": WEIGHT_DECAY}
)

# Cell 7: Lightning Module
class LitSegmentationModel(LightningModule):
    def __init__(self, model_arch, learning_rate, weight_decay, num_classes, class_names=None):
        super().__init__()
        self.save_hyperparameters(ignore=['model_arch'])
        self.model = model_arch
        self.class_weights_tensor = None
        self.class_names = class_names if class_names else [f"class{i}" for i in range(num_classes)]

        metrics = MetricCollection({
            'accuracy_overall': Accuracy(task="multiclass", num_classes=num_classes, average='micro'),
            'precision_macro': Precision(task="multiclass", num_classes=num_classes, average='macro'),
            'recall_macro': Recall(task="multiclass", num_classes=num_classes, average='macro'),
            'f1_score_macro': F1Score(task="multiclass", num_classes=num_classes, average='macro'),
            'jaccard_index_per_class': JaccardIndex(task="multiclass", num_classes=num_classes, average='none'),
        })

        self.val_metrics = metrics.clone(prefix='val_')

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x).output
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x).output
        loss = F.cross_entropy(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        preds = torch.argmax(logits, dim=1)
        self.val_metrics.update(preds, y)
        return loss

    def on_validation_epoch_end(self):
        metrics = self.val_metrics.compute()
        for name, value in metrics.items():
            if isinstance(value, torch.Tensor) and value.numel() == self.hparams.num_classes:
                for i in range(self.hparams.num_classes):
                    self.log(f"{name}_{self.class_names[i]}", value[i])
            else:
                self.log(name, value)
        self.val_metrics.reset()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)

CLASS_NAMES_FOR_LOGGING = ["non_cereal", "cereal"]
lightning_model = LitSegmentationModel(
    model_arch=task_config_obj.model,
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,
    num_classes=NUM_CLASSES,
    class_names=CLASS_NAMES_FOR_LOGGING
)

# Cell 8: Training Setup
csv_logger = CSVLogger(save_dir=LOGS_DIR_ACTUAL, name="cereal_segmentation_geometric_run")

loss_checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINT_SAVE_DIR_ACTUAL,
    filename="best-loss",
    monitor="val_loss",
    mode="min",
    save_top_k=1
)

early_stop_callback = EarlyStopping(
    monitor=EARLY_STOPPING_MONITOR,
    patience=EARLY_STOPPING_PATIENCE,
    mode="min"
)

accelerator_type = "gpu" if DEVICE == "cuda" else "cpu"
devices_setting = "auto"
strategy_to_use = "auto"

trainer = Trainer(
    accelerator=accelerator_type,
    devices=devices_setting,
    strategy=strategy_to_use,
    max_epochs=EPOCHS,
    logger=csv_logger,
    callbacks=[
        early_stop_callback,
        loss_checkpoint_callback
    ],
    gradient_clip_val=0.5
)

if __name__ == "__main__":
    try:
        val_arg = val_loader if val_loader and len(val_loader.dataset) > 0 else None
        trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_arg)
    except Exception as e:
        print(f"Training error: {e}")
        traceback.print_exc()
