# PATH MANAGEMENT

In [None]:
import os

print(os.getcwd())
if not os.getcwd().endswith("app"):
    os.chdir("../app")
    print(os.getcwd())

%load_ext autoreload
%autoreload 2

# TRAIN SEGMENTATION MODELS

## Configuration

In [None]:
from src.config import Configuration

CONFIG = Configuration()

## Load images

In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset
from maikol_utils.file_utils import list_dir_files
import matplotlib.pyplot as plt

from src.data import SampleImage

class RoadSegmentationDataset(Dataset):
    def __init__(self, root_dir: str, max_samples: int = None, recursive: bool = True):
        original_files, n = list_dir_files(
            root_dir,
            nat_sorting=True,  # for Numbers to be properly sorted
            absolute_path=True,
            recursive=recursive,
        )
        path_images_X = [img for img in original_files if "_sat" in img][:max_samples]
        path_images_Y = [img for img in original_files if "_mask" in img][:max_samples]

        self.sample_points: list[SampleImage] = [
            SampleImage(path_img_x, path_img_y)
            for path_img_x, path_img_y in zip(path_images_X, path_images_Y)
        ]
        self.N = n // 2  # because we have image and groundtruth
        print(f"Dataset initialized with {self.N} samples from {root_dir}")

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        x, y = self.sample_points[idx].get_images(keep_in_memory=False)

        # Convert to tensors and normalize to [0, 1]
        # Make sure the channels are aligned with the models library (C, H, W)
        x = torch.tensor(np.array(x)).permute(2, 0, 1).float() / 255.0
        y = torch.tensor(np.array(y)).unsqueeze(0).float() / 255.0

        return x, y
    
    def plot_sample(self, idx: int):
        """Plot the image and mask for the given index."""
        x, y = self[idx]  # uses __getitem__

        # Convert tensors back to numpy for plotting
        img_np = x.permute(1, 2, 0).numpy()
        mask_np = y.squeeze(0).numpy()  # remove channel dim

        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
        axes[0].imshow(img_np)
        axes[0].set_title("Image")
        axes[0].axis("off")

        axes[1].imshow(mask_np, cmap="gray")
        axes[1].set_title("Mask")
        axes[1].axis("off")

        plt.show()

In [None]:
from torch.utils.data import DataLoader

train_dataset = RoadSegmentationDataset(CONFIG.train_folder, max_samples=100)
valid_dataset = RoadSegmentationDataset(CONFIG.val_folder, max_samples=100)
test_dataset  = RoadSegmentationDataset(CONFIG.test_folder, max_samples=100)

n_cpu = os.cpu_count()
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=n_cpu)
valid_dataloader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=n_cpu)
test_dataloader  = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=n_cpu)

In [None]:
i = 0
i

In [None]:
train_dataset.plot_sample(i)
i += 1

## Define model

In [None]:
# import segmentation_models_pytorch as smp
# import pytorch_lightning as pl
# from torch.optim import lr_scheduler

# class RoadSegmentationModel(pl.LightningModule):
#     def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
#         super().__init__()
#         self.model = smp.create_model(
#             arch,
#             encoder_name=encoder_name,
#             in_channels=in_channels,
#             classes=out_classes,
#             **kwargs,
#         )
#         # for image segmentation dice loss could be the best first choice
#         self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

#         # initialize step metics
#         self.training_step_outputs = []
#         self.validation_step_outputs = []
#         self.test_step_outputs = []

#     def forward(self, image):
#         # normalize image here
#         image = (image - self.mean) / self.std
#         mask = self.model(image)
#         return mask

#     def shared_step(self, batch, stage):
#         image = batch["image"]

#         # Shape of the image should be (batch_size, num_channels, height, width)
#         # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
#         assert image.ndim == 4

#         # Check that image dimensions are divisible by 32,
#         # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of
#         # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have
#         # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
#         # and we will get an error trying to concat these features
#         h, w = image.shape[2:]
#         assert h % 32 == 0 and w % 32 == 0

#         mask = batch["mask"]
#         assert mask.ndim == 4

#         # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
#         assert mask.max() <= 1.0 and mask.min() >= 0

#         logits_mask = self.forward(image)

#         # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
#         loss = self.loss_fn(logits_mask, mask)

#         # Lets compute metrics for some threshold
#         # first convert mask values to probabilities, then
#         # apply thresholding
#         prob_mask = logits_mask.sigmoid()
#         pred_mask = (prob_mask > 0.5).float()

#         # We will compute IoU metric by two ways
#         #   1. dataset-wise
#         #   2. image-wise
#         # but for now we just compute true positive, false positive, false negative and
#         # true negative 'pixels' for each image and class
#         # these values will be aggregated in the end of an epoch
#         tp, fp, fn, tn = smp.metrics.get_stats(
#             pred_mask.long(), mask.long(), mode="binary"
#         )
#         return {
#             "loss": loss,
#             "tp": tp,
#             "fp": fp,
#             "fn": fn,
#             "tn": tn,
#         }

#     def shared_epoch_end(self, outputs, stage):
#         # aggregate step metics
#         tp = torch.cat([x["tp"] for x in outputs])
#         fp = torch.cat([x["fp"] for x in outputs])
#         fn = torch.cat([x["fn"] for x in outputs])
#         tn = torch.cat([x["tn"] for x in outputs])

#         # per image IoU means that we first calculate IoU score for each image
#         # and then compute mean over these scores
#         per_image_iou = smp.metrics.iou_score(
#             tp, fp, fn, tn, reduction="micro-imagewise"
#         )

#         # dataset IoU means that we aggregate intersection and union over whole dataset
#         # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
#         # in this particular case will not be much, however for dataset
#         # with "empty" images (images without target class) a large gap could be observed.
#         # Empty images influence a lot on per_image_iou and much less on dataset_iou.
#         dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
#         metrics = {
#             f"{stage}_per_image_iou": per_image_iou,
#             f"{stage}_dataset_iou": dataset_iou,
#         }

#         self.log_dict(metrics, prog_bar=True)

#     def training_step(self, batch, batch_idx):
#         train_loss_info = self.shared_step(batch, "train")
#         # append the metics of each step to the
#         self.training_step_outputs.append(train_loss_info)
#         return train_loss_info

#     def on_train_epoch_end(self):
#         self.shared_epoch_end(self.training_step_outputs, "train")
#         # empty set output list
#         self.training_step_outputs.clear()
#         return

#     def validation_step(self, batch, batch_idx):
#         valid_loss_info = self.shared_step(batch, "valid")
#         self.validation_step_outputs.append(valid_loss_info)
#         return valid_loss_info

#     def on_validation_epoch_end(self):
#         self.shared_epoch_end(self.validation_step_outputs, "valid")
#         self.validation_step_outputs.clear()
#         return

#     def test_step(self, batch, batch_idx):
#         test_loss_info = self.shared_step(batch, "test")
#         self.test_step_outputs.append(test_loss_info)
#         return test_loss_info

#     def on_test_epoch_end(self):
#         self.shared_epoch_end(self.test_step_outputs, "test")
#         # empty set output list
#         self.test_step_outputs.clear()
#         return

#     def configure_optimizers(self):
#         optimizer = torch.optim.Adam(self.parameters(), lr=2e-4)
#         scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_MAX, eta_min=1e-5)
#         return {
#             "optimizer": optimizer,
#             "lr_scheduler": {
#                 "scheduler": scheduler,
#                 "interval": "step",
#                 "frequency": 1,
#             },
#         }

In [None]:
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from torch.optim import lr_scheduler
from src.config import ModelConfiguration

class RoadSegmentationModel(pl.LightningModule):
    def __init__(self, arch, M_CONFIG: ModelConfiguration, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch,
            encoder_name=M_CONFIG.encoder_name,
            in_channels=M_CONFIG.in_channels,
            classes=M_CONFIG.out_classes,
            **kwargs,
        )
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
        self.lr = M_CONFIG.learning_rate
        self.t_max = M_CONFIG.max_steps
        self.eta_min = M_CONFIG.learning_rate / 1000


    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.t_max, eta_min=self.eta_min
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

## Train model

In [None]:
import pytorch_lightning as pl
from src.config import ModelConfiguration
from src.utils import get_device

M_CONFIG = ModelConfiguration()

trainer = pl.Trainer(max_epochs=M_CONFIG.epochs, log_every_n_steps=1)
model = RoadSegmentationModel("FPN", "resnet34", in_channels=3, out_classes=1)
model = model.to(get_device())

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader,
)

## Use model

In [None]:
# (C, H, W), (1, H, W)
test_dataset[0][0].shape, test_dataset[0][1].shape  

In [None]:
import torch
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from src.utils import to_device, get_device

# =============================================================================
#                                   DATA
# =============================================================================
batch = next(iter(train_dataloader))  # or valid_dataloader/test_dataloader
images, masks = batch  # images: [B, 3, H, W], masks: [B, 1, H, W]

images = to_device(images)
masks  = to_device(masks)


# =============================================================================
#                                   MODEL   
# =============================================================================
model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1  # binary output
)
model = model.to(get_device())
model.eval()

# =============================================================================
#                               PREDICION
# =============================================================================
# Forward pass (no gradient)
with torch.inference_mode():
    logits = model(images)         # [B, 1, H, W]
    probs  = torch.sigmoid(logits) # convert logits to [0,1] probabilities
    preds  = (probs > 0.5).float()  # binary mask predictions


# =============================================================================
#                                PLOT
# =============================================================================
img_np   = images[0].permute(1, 2, 0).cpu().numpy()
mask_np  = masks[0].squeeze(0).cpu().numpy()
pred_np  = preds[0].squeeze(0).cpu().numpy()

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_np)
axes[0].set_title("Image")
axes[0].axis("off")

axes[1].imshow(mask_np, cmap="gray")
axes[1].set_title("Ground Truth")
axes[1].axis("off")

axes[2].imshow(pred_np, cmap="gray")
axes[2].set_title("Prediction")
axes[2].axis("off")

plt.show()
