# PATH MANAGEMENT

In [None]:
import os

print(os.getcwd())
if not os.getcwd().endswith("app"):
    os.chdir("../app")
    print(os.getcwd())

%load_ext autoreload
%autoreload 2

# TRAIN SEGMENTATION MODELS

## Configuration

In [None]:
from src.config import Configuration

CONFIG = Configuration(
    max_samples=100
)


## Datasets

In [None]:
from torch.utils.data import DataLoader
from src.data import AUG_PIPELINES

from src.models import RoadSegmentationDataset

train_dataset = RoadSegmentationDataset(CONFIG.train_folder, CONFIG, AUG_PIPELINES["single"])
valid_dataset = RoadSegmentationDataset(CONFIG.val_folder, CONFIG)
test_dataset  = RoadSegmentationDataset(CONFIG.test_folder, CONFIG)

n_cpu = os.cpu_count()
train_dataloader = DataLoader(train_dataset, batch_size=CONFIG.batch_size, shuffle=True, num_workers=n_cpu)
valid_dataloader = DataLoader(valid_dataset, batch_size=CONFIG.batch_size, shuffle=False, num_workers=n_cpu)
test_dataloader  = DataLoader(test_dataset, batch_size=CONFIG.batch_size, shuffle=False, num_workers=n_cpu)

In [None]:
i = 0
i

In [None]:
train_dataset.plot_sample(i)
i += 1

## Define model

In [None]:
from src.models import RoadSegmentationModel
from src.utils import get_device

model = RoadSegmentationModel(CONFIG)
model = model.to(get_device())

## Train model

In [None]:
import pytorch_lightning as pl

trainer = pl.Trainer(max_epochs=CONFIG.epochs, log_every_n_steps=1)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader,
)

## Metrics

In [None]:
valid_metrics = trainer.validate(model, dataloaders=valid_dataloader, verbose=False)
print(valid_metrics)

In [None]:
test_metrics = trainer.validate(model, dataloaders=test_dataloader, verbose=False)
print(test_metrics)

## Use model

In [None]:
# (C, H, W), (1, H, W)
test_dataset[0][0].shape, test_dataset[0][1].shape  

In [None]:
import matplotlib.pyplot as plt
import torch
from src.utils import get_device, to_device
from tqdm import tqdm
# Collect up to 20 samples from the dataloader
max_samples = 20
all_images, all_masks, all_preds = [], [], []

for batch in tqdm(test_dataloader, desc="Processing batches"):
    images, masks = batch
    images = to_device(images)
    masks  = to_device(masks)

    model = model.to(get_device())
    model.eval()

    with torch.inference_mode():
        logits = model(images)
        probs  = torch.sigmoid(logits)
        preds  = (probs > 0.5).float()

    all_images.append(images.cpu())
    all_masks.append(masks.cpu())
    all_preds.append(preds.cpu())

    if sum(img.shape[0] for img in all_images) >= max_samples:
        break

# Flatten the batches and limit to max_samples
all_images = torch.cat(all_images)[:max_samples]
all_masks  = torch.cat(all_masks)[:max_samples]
all_preds  = torch.cat(all_preds)[:max_samples]

# Plot all in a grid: 3 rows (Image, GT, Pred), N columns = number of samples
n_samples = all_images.shape[0]
fig, axes = plt.subplots(3, n_samples, figsize=(3 * n_samples, 9))

for i in range(n_samples):
    img_np  = all_images[i].permute(1, 2, 0).numpy()
    mask_np = all_masks[i].squeeze(0).numpy()
    pred_np = all_preds[i].squeeze(0).numpy()

    axes[0, i].imshow(img_np)
    axes[0, i].axis("off")
    axes[0, i].set_title("Image")

    axes[1, i].imshow(mask_np, cmap="gray")
    axes[1, i].axis("off")
    axes[1, i].set_title("Ground Truth")

    axes[2, i].imshow(pred_np, cmap="gray")
    axes[2, i].axis("off")
    axes[2, i].set_title("Prediction")

plt.tight_layout()
plt.show()
