In [None]:
!pip install segmentation-models-pytorch
!pip install torchmetrics

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.optim as optim
import torchvision.models as models

Завантаження датасету

In [None]:
import kagglehub

path = kagglehub.dataset_download("balraj98/deepglobe-road-extraction-dataset") #завантаження датасету deepglobe-road-extraction-dataset з kaggle
print(path)

Організація даних

In [None]:
from pathlib import Path
import pandas as pd
import numpy as np

DATA_DIR = Path(path)

metadata_df = pd.read_csv(DATA_DIR / 'metadata.csv')
metadata_df = metadata_df[metadata_df['split'] == 'train'][['image_id', 'sat_image_path', 'mask_path']]
metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda x: DATA_DIR / x)
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda x: DATA_DIR / x)


metadata_df = metadata_df.sample(frac=1, random_state=42).reset_index(drop=True)
valid_df = metadata_df.sample(frac=0.075, random_state=42)
train_df = metadata_df.drop(valid_df.index)

class_dict = pd.read_csv(DATA_DIR / 'class_dict.csv')
class_names = class_dict['name'].tolist()


select_classes = ['background', 'road']
select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]


Реалізація класу завантаження та попередньої обробки даних

In [None]:
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import tv_tensors
import numpy as np
import torch

class RoadExtractionDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = cv2.imread(str(row['sat_image_path']))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(str(row['mask_path']))
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
        output_mask = np.where(mask_gray > 128, 1, 0).astype(np.uint8)
        output_mask = np.expand_dims(output_mask, axis=-1)
        image = torch.from_numpy(image).permute(2, 0, 1)
        output_mask = torch.from_numpy(output_mask).permute(2, 0, 1)
        image = tv_tensors.Image(image)
        output_mask = tv_tensors.Mask(output_mask)
        if self.transform:
            image, output_mask = self.transform((image, output_mask))

        return image, output_mask

Пайплайни препроцесингу

In [None]:
from torchvision.transforms import v2
train_transform_cpu = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomRotation(degrees=30),
    v2.RandomCrop(size=(512, 512), pad_if_needed=True),
    v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.0)
])

train_transform_gpu = v2.Compose([
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

Ініціалізація та конфігурація завантажувачів даних

In [None]:
import os

cpu_num_cores = os.cpu_count()
train_dataset = RoadExtractionDataset(train_df transform=train_transform_cpu)
valid_dataset = RoadExtractionDataset(valid_df)

BATCH_SIZE = 8
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cpu_num_cores, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=cpu_num_cores, pin_memory=True)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(valid_loader)}")

Програмна реалізація циклу навчання (Training Loop) та валідації моделі зі збереденням проміжних ваг та значень метрик

In [None]:
from google.colab import files
import torchmetrics
import segmentation_models_pytorch as smp
from torch.cuda.amp import GradScaler, autocast
import torch.compiler
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = smp.DeepLabV3Plus(
    encoder_name="timm-mobilenetv3_large_100",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
).to(device)

print("Compiling the model... (Це може зайняти хвилину-дві)")
print("Model compiled successfully.")

criterion = nn.BCEWithLogitsLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
NUM_EPOCHS = 100
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr= 5e-4,
    steps_per_epoch=len(train_loader),
    epochs=NUM_EPOCHS,
    pct_start=0.3
)
scaler = GradScaler()

train_iou_metric = torchmetrics.JaccardIndex(task="binary").to(device)
valid_iou_metric = torchmetrics.JaccardIndex(task="binary").to(device)


train_losses = []
valid_losses = []
train_ious = []
valid_ious = []
old_valid_iou = float("inf")
num_confirmed_errors = 0
num_constraint_errors = 5

print("Starting training...")
for epoch in range(NUM_EPOCHS):

    model.train()
    run_error = 0.0
    train_iou_metric.reset()

    for i, (images, masks) in enumerate(train_loader):
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True).float()
        images_normalized = train_transform_gpu(images)
        optimizer.zero_grad()
        with autocast():
            outputs = model(images_normalized)
            loss = criterion(outputs, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        run_error += loss.item()
        train_iou_metric.update(outputs.sigmoid(), masks.int())
        print(f"\rEpoch {epoch + 1}/{NUM_EPOCHS}, Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}", end="")


    epoch_train_loss = run_error / len(train_loader)
    train_losses.append(epoch_train_loss)

    epoch_train_iou = train_iou_metric.compute()
    train_ious.append(epoch_train_iou.item())


    print(f"\rEpoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train IoU: {epoch_train_iou:.4f}")

    torch.compiler.cudagraph_mark_step_begin()

    model.eval()
    valid_error = 0.0
    valid_iou_metric.reset()

    with torch.no_grad():
        for images, masks in valid_loader:
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True).float()

            images_normalized = train_transform_gpu(images)

            with autocast():
              outputs = model(images_normalized)
              loss = criterion(outputs, masks)
            valid_error += loss.item()
            valid_iou_metric.update(outputs.sigmoid(), masks.int())
            print(f"\rEpoch {epoch + 1}/{NUM_EPOCHS}, Batch {i+1}/{len(valid_loader)}, Loss: {loss.item():.4f}", end="")


    epoch_valid_loss = valid_error / len(valid_loader)
    valid_losses.append(epoch_valid_loss)

    epoch_valid_iou = valid_iou_metric.compute()
    valid_ious.append(epoch_valid_iou.item())

    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Valid Loss: {epoch_valid_loss:.4f}, Valid IoU: {epoch_valid_iou:.4f}")


    if abs(old_valid_iou - epoch_valid_iou) > 1e-5:
        old_valid_iou = epoch_valid_iou
        num_confirmed_errors = 0

        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_valid_loss,
            'iou': epoch_valid_iou.item()
        }
        torch.save(checkpoint, "my_checkpoint.pth")
        print(f"Checkpoint saved! (Improved Val Loss: {epoch_valid_loss:.4f})")

    elif num_confirmed_errors < num_constraint_errors:
        num_confirmed_errors += 1
        print(f"No improvement. Early stopping counter: {num_confirmed_errors}/{num_constraint_errors}")
    else:
        print(f"Early stopping triggered at epoch {epoch + 1}.")
        break

print("Training finished.")

Збереження значень метрик

In [None]:
import pandas as pd
data = pd.DataFrame({
    'train_loss': train_losses,
    'valid_loss': valid_losses,
    'train_iou': train_ious,
    'valid_iou': valid_ious
})
data.to_csv("train_loss.csv")
files.download("train_loss.csv")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Завантаження моделі

In [None]:
files.download("my_checkpoint.pth")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>