### Model Training

In [None]:
import os
import numpy as np
import torch
import cv2
import h5py
import rasterio
import albumentations
import matplotlib.pyplot as plt
import torch.optim as optim
from albumentations import Compose, Resize, HorizontalFlip, Flip
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
import terratorch
from terratorch.datamodules import Landslide4SenseNonGeoDataModule
from terratorch.datasets import Landslide4SenseNonGeo
from terratorch.tasks import SemanticSegmentationTask
from huggingface_hub import hf_hub_download

# Download config.json to the specified folder
hf_hub_download(
    repo_id='ibm-nasa-geospatial/Prithvi-EO-2.0-300M',
    filename="config.json",
    cache_dir='/home/skaushik/flood_prithvi/'
)


# Custom FloodDataModule
class FloodDataModule(pl.LightningDataModule):
    def __init__(self, data_root, batch_size=2, train_transform=None, val_transform=None, test_transform=None):
        super().__init__()
        self.data_root = data_root
        self.train_dir = os.path.join(data_root, "images/train")
        self.val_dir = os.path.join(data_root, "images/validation")
        self.test_dir = os.path.join(data_root, "images/test")
        self.mask_train_dir = os.path.join(data_root, "annotations/train")
        self.mask_val_dir = os.path.join(data_root, "annotations/validation")
        self.mask_test_dir = os.path.join(data_root, "annotations/test")
        self.batch_size = batch_size

        self.train_transform = train_transform
        self.val_transform = val_transform
        self.test_transform = test_transform

    def setup(self, stage=None):
        if stage in ("fit", None):
            self.train_dataset = self.create_dataset(self.train_dir, self.mask_train_dir, self.train_transform)
            self.val_dataset = self.create_dataset(self.val_dir, self.mask_val_dir, self.val_transform)

        if stage in ("test", None):
            self.test_dataset = self.create_dataset(self.test_dir, self.mask_test_dir, self.test_transform)

    def create_dataset(self, image_dir, mask_dir, transform):
        return CustomFloodDataset(image_dir, mask_dir, transform)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size)


# Custom Dataset
class CustomFloodDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        # Read the image and mask using rasterio
        with rasterio.open(image_path) as src:
            image = src.read()  # (bands, height, width)
        with rasterio.open(mask_path) as src:
            mask = src.read(1)  # Single channel for mask

        mask = (mask > 0.5).astype(np.uint8)  # Binarize mask
        image = np.moveaxis(image, 0, -1)  # Convert to (height, width, bands)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Include the filename in the output
        filename = os.path.basename(image_path)

        return {"image": image, "mask": mask, "filename": filename}

#    def plot(self, sample):
#        """Plot a sample from the dataset."""
#        image = sample["image"].permute(1, 2, 0).numpy()  # Convert to HWC
#        mask = sample["mask"].numpy()
#
#        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
#        ax[0].imshow(image)
#        ax[0].set_title("Image")
#        ax[0].axis("off")
#
#        ax[1].imshow(mask, cmap="gray")
#        ax[1].set_title("Mask")
#        ax[1].axis("off")
#
#        plt.tight_layout()
#        plt.show()
#
#        return fig

# Transforms
train_transform = Compose([
    HorizontalFlip(),
    Resize(896, 896),  # Ensure size is divisible by 14 and closer to 1024
    ToTensorV2(),
])

val_transform = Compose([
    Resize(896, 896),
    ToTensorV2(),
])

test_transform = Compose([
    Resize(896, 896),
    ToTensorV2(),
])

# Instantiate FloodDataModule
data_module = FloodDataModule(
    data_root='/home/skaushik/flood_prithvi/cambodia',
    batch_size=2,
    train_transform=train_transform,
    val_transform=val_transform,
    test_transform=test_transform,
)

# Logger
logger = TensorBoardLogger(
    save_dir="flood_logs",
    name="flood_segmentation"
)

# Define ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",  # Save all best models in a single directory
    filename="epoch-{epoch:02d}-val_f1-{val/Multiclass_F1_Score:.4f}",  # Include epoch and F1 in filename
    monitor="val/Multiclass_F1_Score",  # Track validation F1-score
    mode="max",  # Save models with highest F1-score
    save_top_k=2,  # Keep only the best 3 models
    every_n_epochs=1,
    save_on_train_epoch_end=False,  # Save based on validation performance
    auto_insert_metric_name=False  # Prevents creating subfolders for each epoch
)

# Trainer
trainer = pl.Trainer(
    max_epochs=100,
    logger=logger,
    callbacks=[checkpoint_callback]
)

# Model
model = SemanticSegmentationTask(
    model_args={
        "decoder": "UperNetDecoder",
        "backbone_pretrained": True,
        "backbone": "prithvi_eo_v2_600",
        "backbone_in_channels": 4,  # Match your dataset's number of channels
        "rescale": True,
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_BROAD"],
        "backbone_num_frames": 1,
        "num_classes": 2,
        "head_dropout": 0.1,
        "decoder_channels": 256,
        "decoder_scale_modules": True,
        "head_channel_list": [128, 64],
        "necks": [
            {
                "name": "SelectIndices",
                "indices": [7, 15, 23, 31]
            },
            {
                "name": "ReshapeTokensToImage"
            }
        ]
    },
    plot_on_val=False,  # Enable plotting during validation (set False to skip)
    loss="focal",
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.1},
    scheduler="StepLR",
    scheduler_hparams={"step_size": 10, "gamma": 0.9},
    model_factory="EncoderDecoderFactory",
)

# Train
#trainer.fit(model, datamodule=data_module)


### Testing the trained model

In [None]:

trainer.test(model, datamodule=data_module, ckpt_path="/home/skaushik/Prithvi/Prithvi/Prithvi-EO-2.0-main/checkpoints/checkpoints/cambodia/epoch-27-val_f1-0.8856.ckpt")

## inference and saving predicitons


In [None]:
# Custom Dataset
class CustomFloodDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        # Read the image and mask using rasterio
        with rasterio.open(image_path) as src:
            image = src.read()  # (bands, height, width)
        with rasterio.open(mask_path) as src:
            mask = src.read(1)  # Single channel for mask

        mask = (mask > 0.5).astype(np.uint8)  # Binarize mask
        image = np.moveaxis(image, 0, -1)  # Convert to (height, width, bands)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        filename = os.path.basename(image_path)
        return {"image": image, "mask": mask, "filename": filename}

# Data Module
class FloodDataModule:
    def __init__(self, data_root, batch_size=2, test_transform=None):
        self.data_root = data_root
        self.test_dir = os.path.join(data_root, "images/test")
        self.mask_test_dir = os.path.join(data_root, "annotations/test")
        self.batch_size = batch_size
        self.test_transform = test_transform

    def setup(self):
        self.test_dataset = CustomFloodDataset(self.test_dir, self.mask_test_dir, self.test_transform)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

# Function to plot and save the image, ground truth mask, and predicted mask

def plot_and_save_sample(image_path, mask, prediction, filename, output_dir):
    """
    Plots the input image (using bands 4, 3, 2), ground truth mask, and predicted mask side by side,
    and saves the predicted mask to the specified output directory.

    Parameters:
    - image_path (str): The file path to the input image.
    - mask (torch.Tensor): The ground truth mask tensor of shape (H, W).
    - prediction (np.ndarray): The predicted mask of shape (H, W).
    - filename (str): The filename of the image being processed.
    - output_dir (str): The directory where the predicted mask will be saved.
    """
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Read the specific bands (4, 3, 2) using rasterio
    with rasterio.open(image_path) as src:
        # Read bands 4, 3, 2
        band4 = src.read(4)
        band3 = src.read(3)
        band2 = src.read(2)

        # Stack bands to create an RGB image
        image_rgb = np.stack((band4, band3, band2), axis=-1)

        # Normalize the image for display
        image_rgb = image_rgb.astype(np.float32)
        image_rgb /= image_rgb.max()

    # Plot the image, ground truth mask, and predicted mask
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f"Sample: {filename}", fontsize=16)

    # Display the RGB image
    ax[0].imshow(image_rgb)
    ax[0].set_title("Image (Bands 4, 3, 2)")
    ax[0].axis("off")

    # Display the ground truth mask
    ax[1].imshow(mask.cpu().numpy(), cmap="gray")
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    # Display the predicted mask
    ax[2].imshow(prediction, cmap="gray")
    ax[2].set_title("Predicted Mask")
    ax[2].axis("off")

    plt.tight_layout()
    plt.show()

    # Save the predicted mask
    # Normalize the prediction to the range [0, 255]
    prediction_uint8 = (prediction * 255).astype(np.uint8)
    # Create a PIL image from the NumPy array
    prediction_img = Image.fromarray(prediction_uint8)
    # Define the output path
    output_path = os.path.join(output_dir, filename)
    # Save the image
    prediction_img.save(output_path)
    print(f"Saved prediction to {output_path}")

# Ensure the Data Module is set up
data_module.setup()
test_loader = data_module.test_dataloader()  # Create test_loader

# Define the output directory for predictions
output_dir = '/home/skaushik/flood_prithvi/predictions_smoalia_ps'

# Ensure model is on the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Run inference and plot/save results
with torch.no_grad():
    for batch in test_loader:
        images = batch["image"].to(device)
        masks = batch["mask"].to(device)
        filenames = batch["filename"]

        # Forward pass
        outputs = model(images)
        preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

        # Plot and save each sample in the batch
        for i in range(images.size(0)):
            image_path = os.path.join(data_module.test_dir, filenames[i])
            plot_and_save_sample(image_path, masks[i], preds[i], filenames[i], output_dir)
