In [1]:
!pip install -q segmentation-models-pytorch timm albumentations

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import cv2
import numpy as np
import torch
import random
import albumentations as A
import segmentation_models_pytorch as smp

from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

In [3]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed()

In [4]:
TRAIN_IMG = "/kaggle/input/datasets/arunkumarkorra/hwi-dataset/Offroad_Segmentation_Training_Dataset/Offroad_Segmentation_Training_Dataset/train/Color_Images"
TRAIN_MASK = "/kaggle/input/datasets/arunkumarkorra/hwi-dataset/Offroad_Segmentation_Training_Dataset/Offroad_Segmentation_Training_Dataset/train/Segmentation"

VAL_IMG = "/kaggle/input/datasets/arunkumarkorra/hwi-dataset/Offroad_Segmentation_Training_Dataset/Offroad_Segmentation_Training_Dataset/val/Color_Images"
VAL_MASK = "/kaggle/input/datasets/arunkumarkorra/hwi-dataset/Offroad_Segmentation_Training_Dataset/Offroad_Segmentation_Training_Dataset/val/Segmentation"

TEST_IMG = "/kaggle/input/datasets/arunkumarkorra/hwi-dataset/Offroad_Segmentation_testImages/Offroad_Segmentation_testImages/Color_Images"

In [5]:
value_map = {
    0: 0,
    100: 1,
    200: 2,
    300: 3,
    500: 4,
    550: 5,
    700: 6,
    800: 7,
    7100: 8,
    10000: 9
}

NUM_CLASSES = 10

In [6]:
class OffroadDataset(Dataset):
    def __init__(self, img_dir, mask_dir=None, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(img_dir))
        self.transform = transform

    def convert_mask(self, mask):
        new_mask = np.zeros_like(mask)
        for raw, new in value_map.items():
            new_mask[mask == raw] = new
        return new_mask

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.img_dir, img_name)

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.mask_dir:
            mask_path = os.path.join(self.mask_dir, img_name)
            mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
            mask = self.convert_mask(mask)

            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"].long()

            return image, mask
        else:
            augmented = self.transform(image=image)
            image = augmented["image"]
            return image, img_name

In [7]:
train_transform = A.Compose([
    A.Resize(512, 512),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.4),
    A.HueSaturationValue(p=0.3),
    A.GaussianBlur(p=0.2),
    A.Normalize(),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(),
    ToTensorV2()
])

In [8]:
train_dataset = OffroadDataset(TRAIN_IMG, TRAIN_MASK, train_transform)
val_dataset = OffroadDataset(VAL_IMG, VAL_MASK, val_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.Segformer(
    encoder_name="mit_b3",
    encoder_weights="imagenet",
    in_channels=3,
    classes=NUM_CLASSES
)

model.to(device)

config.json:   0%|          | 0.00/135 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/178M [00:00<?, ?B/s]

Segformer(
  (encoder): MixVisionTransformerEncoder(
    (patch_embed1): OverlapPatchEmbed(
      (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed2): OverlapPatchEmbed(
      (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed3): OverlapPatchEmbed(
      (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed4): OverlapPatchEmbed(
      (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (block1): Sequential(
      (0): Block(
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_featur

In [10]:
dice_loss = smp.losses.DiceLoss(mode="multiclass")
ce_loss = torch.nn.CrossEntropyLoss()

def loss_fn(pred, target):
    return dice_loss(pred, target) + ce_loss(pred, target)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=40
)

In [11]:
def compute_iou(pred, mask):
    pred = torch.argmax(pred, dim=1)

    ious = []
    for cls in range(NUM_CLASSES):
        intersection = ((pred == cls) & (mask == cls)).sum().item()
        union = ((pred == cls) | (mask == cls)).sum().item()
        if union == 0:
            continue
        ious.append(intersection / union)
    return np.mean(ious)

In [12]:
best_iou = 0
EPOCHS = 40

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0

    for imgs, masks in tqdm(train_loader):
        imgs, masks = imgs.to(device), masks.to(device)

        preds = model(imgs)
        loss = loss_fn(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    scheduler.step()

    # Validation
    model.eval()
    val_iou = []

    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            val_iou.append(compute_iou(preds, masks))

    mean_iou = np.mean(val_iou)

    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {train_loss/len(train_loader):.4f} | Val IoU: {mean_iou:.4f}")

    if mean_iou > best_iou:
        best_iou = mean_iou
        torch.save(model.state_dict(), "best_model.pth")
        print("Saved best model")

100%|██████████| 715/715 [14:18<00:00,  1.20s/it]


Epoch 1/40 | Loss: 0.9547 | Val IoU: 0.5083
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 2/40 | Loss: 0.7605 | Val IoU: 0.5338
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 3/40 | Loss: 0.7115 | Val IoU: 0.5573
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 4/40 | Loss: 0.7023 | Val IoU: 0.5663
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 5/40 | Loss: 0.6608 | Val IoU: 0.5770
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 6/40 | Loss: 0.6513 | Val IoU: 0.5823
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 7/40 | Loss: 0.6426 | Val IoU: 0.5771


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 8/40 | Loss: 0.6723 | Val IoU: 0.5705


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 9/40 | Loss: 0.6325 | Val IoU: 0.5935
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 10/40 | Loss: 0.6211 | Val IoU: 0.5958
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 11/40 | Loss: 0.6100 | Val IoU: 0.5997
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 12/40 | Loss: 0.6083 | Val IoU: 0.6025
Saved best model


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 13/40 | Loss: 0.6059 | Val IoU: 0.6017


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 14/40 | Loss: 0.5988 | Val IoU: 0.6025
Saved best model


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 15/40 | Loss: 0.5950 | Val IoU: 0.6004


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 16/40 | Loss: 0.5912 | Val IoU: 0.6093
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 17/40 | Loss: 0.5814 | Val IoU: 0.6054


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 18/40 | Loss: 0.5800 | Val IoU: 0.6093
Saved best model


100%|██████████| 715/715 [14:15<00:00,  1.20s/it]


Epoch 19/40 | Loss: 0.5837 | Val IoU: 0.5845


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 20/40 | Loss: 0.5748 | Val IoU: 0.5889


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 21/40 | Loss: 0.5812 | Val IoU: 0.6030


100%|██████████| 715/715 [14:17<00:00,  1.20s/it]


Epoch 22/40 | Loss: 0.5720 | Val IoU: 0.5920


100%|██████████| 715/715 [14:17<00:00,  1.20s/it]


Epoch 23/40 | Loss: 0.5636 | Val IoU: 0.5763


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 24/40 | Loss: 0.5664 | Val IoU: 0.4926


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 25/40 | Loss: 0.5628 | Val IoU: 0.4277


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 26/40 | Loss: 0.5624 | Val IoU: 0.3070


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 27/40 | Loss: 0.5613 | Val IoU: 0.1834


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 28/40 | Loss: 0.5608 | Val IoU: 0.1994


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 29/40 | Loss: 0.5602 | Val IoU: 0.5318


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 30/40 | Loss: 0.5594 | Val IoU: 0.4638


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 31/40 | Loss: 0.5611 | Val IoU: 0.5539


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 32/40 | Loss: 0.5533 | Val IoU: 0.4839


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 33/40 | Loss: 0.5521 | Val IoU: 0.5152


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 34/40 | Loss: 0.5516 | Val IoU: 0.4525


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 35/40 | Loss: 0.5494 | Val IoU: 0.4987


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 36/40 | Loss: 0.5498 | Val IoU: 0.5581


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 37/40 | Loss: 0.5496 | Val IoU: 0.5944


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 38/40 | Loss: 0.5493 | Val IoU: 0.6146
Saved best model


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 39/40 | Loss: 0.5488 | Val IoU: 0.6257
Saved best model


100%|██████████| 715/715 [14:16<00:00,  1.20s/it]


Epoch 40/40 | Loss: 0.5533 | Val IoU: 0.6284
Saved best model
