In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/proj_image_segmentation_valid/

/content/drive/MyDrive/proj_image_segmentation_valid


In [None]:
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch albumentations

Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to /tmp/pip-req-build-c3mqphsm
  Running command git clone --filter=blob:none --quiet https://github.com/qubvel/segmentation_models.pytorch /tmp/pip-req-build-c3mqphsm
  Resolved https://github.com/qubvel/segmentation_models.pytorch to commit 3d6da1d74636873372c265f300862a6a6d01777d
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting albumentations
  Downloading albumentations-1.4.8-py3-none-any.whl (156 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m156.8/156.8 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Collecting pretrainedmodels==0.7.4 (from segmentation_models_pytorch==0.3.4.dev0)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import numpy as np
import segmentation_models_pytorch as smp
import torchvision

IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEST_IMG_DIR = "/content/drive/MyDrive/proj_image_segmentation_valid/dataset/Data set I/Images/TEST_DATA"
TEST_MASK_DIR = "/content/drive/MyDrive/proj_image_segmentation_valid/dataset/Data set I/Masks/TEST_DATA"
CHECKPOINT_PATH = "mycheckpoint.pth.tar"
SAVE_DIR = "test_predictions/"

class TestDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name.replace(".tif", ".png"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 2.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask, img_name

def get_transforms():
    test_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0
            ),
            ToTensorV2(),
        ],
    )
    return test_transform

def get_model(in_channels=3, out_channels=3):
    model = smp.Unet(
        encoder_name="resnet18",
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=out_channels,
    )
    return model

def load_checkpoint(checkpoint_path, model, device):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])

def dice_score(preds, targets, num_classes=3, smooth=1e-6):
    dice_scores = []

    for class_index in range(num_classes):
        pred = (preds == class_index).float()
        target = (targets == class_index).float()

        intersection = (pred * target).sum()
        union = pred.sum() + target.sum()

        dice = (2. * intersection + smooth) / (union + smooth)
        dice_scores.append(dice.item())

    return dice_scores

def make_predictions(loader, model, save_dir):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    loop = tqdm(loader)
    dice_scores = []

    for data, masks, img_names in loop:
        data = data.to(device=DEVICE)
        masks = masks.to(device=DEVICE)
        with torch.no_grad():
            preds = torch.argmax(torch.softmax(model(data), dim=1), dim=1).float()

        for pred, mask, img_name in zip(preds, masks, img_names):
            base_name = os.path.basename(img_name).replace(".tif", ".png")
            pred_path = os.path.join(save_dir, f"pred_{base_name}")
            torchvision.utils.save_image(pred, pred_path)

            dice = dice_score(pred, mask, num_classes=3)
            dice_scores.extend(dice)

    model.train()

    avg_dice_score = np.mean(dice_scores)
    print(f"Average Dice Score: {avg_dice_score}")

    return avg_dice_score

def test():
    test_transform = get_transforms()
    model = get_model(in_channels=3, out_channels=3).to(DEVICE)

    load_checkpoint(CHECKPOINT_PATH, model, DEVICE)

    test_ds = TestDataset(image_dir=TEST_IMG_DIR, mask_dir=TEST_MASK_DIR, transform=test_transform)
    test_loader = DataLoader(test_ds, batch_size=1, num_workers=0, shuffle=False)

    make_predictions(test_loader, model, SAVE_DIR)

if __name__ == "__main__":
    test()


Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 168MB/s]


=> Loading checkpoint


100%|██████████| 50/50 [00:50<00:00,  1.00s/it]

Average Dice Score: 0.8992987055579821



