In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/proj_image_segmentation_valid/

/content/drive/MyDrive/proj_image_segmentation_valid


In [None]:
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch albumentations


Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to /tmp/pip-req-build-0hw3x9zg
  Running command git clone --filter=blob:none --quiet https://github.com/qubvel/segmentation_models.pytorch /tmp/pip-req-build-0hw3x9zg
  Resolved https://github.com/qubvel/segmentation_models.pytorch to commit 3d6da1d74636873372c265f300862a6a6d01777d
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting albumentations
  Downloading albumentations-1.4.8-py3-none-any.whl (156 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m156.8/156.8 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Collecting pretrainedmodels==0.7.4 (from segmentation_models_pytorch==0.3.4.dev0)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/

In [None]:
!pip install tqdm
!pip install numpy



In [None]:
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import segmentation_models_pytorch as smp

LEARNING_RATE = 1e-4
BATCH_SIZE = 8
NUM_EPOCHS = 400
NUM_WORKERS = 0
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PIN_MEMORY = True
LOAD_MODEL = False
SAVE_MODEL = True

ROOT_PATH = "/content/drive/MyDrive/proj_image_segmentation_valid/"
TRAIN_IMG_DIR = os.path.join(ROOT_PATH, "dataset/Data set I/Images/TRAIN_DATA")
TRAIN_MASK_DIR = os.path.join(ROOT_PATH, "dataset/Data set I/Masks/TRAIN_DATA")
VAL_IMG_DIR = os.path.join(ROOT_PATH, "dataset/Data set I/Images/VALIDATION_DATA")
VAL_MASK_DIR = os.path.join(ROOT_PATH, "dataset/Data set I/Masks/VALIDATION_DATA")

class CarbonDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name.replace(".tif", ".png"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask, img_name

def save_checkpoint(state, filename="mycheckpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_dataloaders(train_img_dir, train_mask_dir, val_img_dir, val_mask_dir, batch_size, transform):
    train_ds = CarbonDataset(image_dir=train_img_dir, mask_dir=train_mask_dir, transform=transform)
    val_ds = CarbonDataset(image_dir=val_img_dir, mask_dir=val_mask_dir, transform=transform)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
        shuffle=True,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
        shuffle=False,
    )

    return train_loader, val_loader

def get_model():
    model = smp.Unet(
        encoder_name="resnet18",
        encoder_weights="imagenet",
        in_channels=3,
        classes=3
    )
    return model.to(DEVICE)

def dice_score(preds, targets, num_classes=3, smooth=1e-6):
    preds = preds.contiguous()
    targets = targets.contiguous()

    dice_scores = []
    for class_index in range(num_classes):
        pred = (preds == class_index).float()
        target = (targets == class_index).float()
        intersection = (pred * target).sum(dim=[1, 2])
        union = pred.sum(dim=[1, 2]) + target.sum(dim=[1, 2])
        dice = (2. * intersection + smooth) / (union + smooth)
        dice_scores.append(dice.mean())

    return torch.stack(dice_scores).mean().item()

def save_images(preds, targets, img_names, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    for idx, (pred, target) in enumerate(zip(preds, targets)):
        torchvision.utils.save_image(
            pred.float() / 2 * 255,
            os.path.join(save_dir, f"pred_{img_names[idx]}")
        )
        torchvision.utils.save_image(
            target.float() / 2 * 255,
            os.path.join(save_dir, f"mask_{img_names[idx]}")
        )

def train_fn(loader, model, optimizer, loss_fn, scaler, save_dir="saved_images/train"):
    model.train()
    loop = tqdm(loader)
    for batch_idx, (data, targets, img_names) in enumerate(loop):
        data = data.to(DEVICE)
        targets = targets.long().to(DEVICE)

        with torch.cuda.amp.autocast(enabled=DEVICE == "cuda"):
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        preds = torch.softmax(predictions, dim=1)
        preds = predictions.argmax(dim=1)
        dice = dice_score(preds, targets, num_classes=3)

        loop.set_postfix(loss=loss.item(), dice=dice)

        if batch_idx % 10 == 0:
            save_images(preds, targets, img_names, save_dir)

def validate_fn(loader, model, loss_fn, save_dir="saved_images/val"):
    model.eval()
    val_loss = 0
    dice_scores = []
    with torch.no_grad():
        loop = tqdm(loader)
        for batch_idx, (data, targets, img_names) in enumerate(loop):
            data = data.to(DEVICE)
            targets = targets.long().to(DEVICE)

            predictions = model(data)
            loss = loss_fn(predictions, targets)
            val_loss += loss.item()

            preds = torch.softmax(predictions, dim=1)
            preds = predictions.argmax(dim=1)
            dice = dice_score(preds, targets, num_classes=3)
            dice_scores.append(dice)

            loop.set_postfix(loss=loss.item(), dice=dice)

            if batch_idx % 10 == 0:
                save_images(preds, targets, img_names, save_dir)

    return val_loss / len(loader), np.mean(dice_scores)

def get_transforms():
    return A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
        ToTensorV2(),
    ])

def main():
    train_transform = get_transforms()
    model = get_model()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.cuda.amp.GradScaler()

    if LOAD_MODEL:
        load_checkpoint(torch.load("mycheckpoint.pth.tar"), model)

    train_loader, val_loader = get_dataloaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform
    )

    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        train_fn(train_loader, model, optimizer, loss_fn, scaler, save_dir="saved_images/train")

        val_loss, val_dice = validate_fn(val_loader, model, loss_fn, save_dir="saved_images/val")
        print(f"Validation Loss: {val_loss}, Validation Dice Score: {val_dice}")

        if SAVE_MODEL:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_checkpoint(checkpoint)

if __name__ == "__main__":
    main()
