In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/proj_image_segmentation_valid/

/content/drive/MyDrive/proj_image_segmentation_valid


In [None]:
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch albumentations

In [4]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import numpy as np
import segmentation_models_pytorch as smp
import torchvision

IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEST_IMG_DIR = "/content/drive/MyDrive/proj_image_segmentation_valid/dataset/Data set II/Images/TEST_DATA"
CHECKPOINT_PATH = "mycheckpoint.pth.tar"
SAVE_DIR = "test_predictions_no_dice/"

class TestDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)
        image = np.array(Image.open(img_path).convert("RGB"))

        if self.transform is not None:
            augmentations = self.transform(image=image)
            image = augmentations["image"]

        return image, img_name

def get_transforms():
    test_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0
            ),
            ToTensorV2(),
        ],
    )
    return test_transform

def get_model(in_channels=3, out_channels=3):
    model = smp.Unet(
        encoder_name="resnet18",
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=out_channels,
    )
    return model

def load_checkpoint(checkpoint_path, model, device):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])

def make_predictions(loader, model, save_dir):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    loop = tqdm(loader)

    for data, img_names in loop:
        data = data.to(device=DEVICE)
        with torch.no_grad():
            preds = torch.argmax(torch.softmax(model(data), dim=1), dim=1).float()

        for pred, img_name in zip(preds, img_names):
            base_name = os.path.basename(img_name).replace(".tif", ".png")
            pred_path = os.path.join(save_dir, f"pred_{base_name}")
            torchvision.utils.save_image(pred, pred_path)

    model.train()

def test():
    test_transform = get_transforms()
    model = get_model(in_channels=3, out_channels=3).to(DEVICE)

    load_checkpoint(CHECKPOINT_PATH, model, DEVICE)

    test_ds = TestDataset(image_dir=TEST_IMG_DIR, transform=test_transform)
    test_loader = DataLoader(test_ds, batch_size=1, num_workers=0, shuffle=False)

    make_predictions(test_loader, model, SAVE_DIR)

if __name__ == "__main__":
    test()


Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 235MB/s]


=> Loading checkpoint


100%|██████████| 3/3 [00:02<00:00,  1.46it/s]
