In [None]:
import os


os.chdir("..")
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import numpy as np
from PIL import Image
import torch
from matplotlib import pyplot as plt
import segmentation_models_pytorch as smp
from tqdm import tqdm
import safitty

In [3]:
from src import prepare_for_inference, imread, GlobalDice, GlobalIoU
from src.dataset import SegmentationDataset

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

setup = prepare_for_inference(
    "./params/pan_eff_net_b0.yml",
    "./logs/lightning_logs/version_1/checkpoints/epoch=39.ckpt")

model = setup["model"]
model = model.to(device)

preproc = setup["preprocessing_function"]
params = safitty.load("params/pan_eff_net_b0.yml")

In [5]:
dataset_params = params["dataset"]["val"]
dataset_params["images"] = os.listdir(
    os.path.join(dataset_params["root_path"], dataset_params["images_folder"]))
dataset_params["image_transforms"] = smp.encoders.get_preprocessing_fn(
    params["model"]["params"]["encoder_name"],
    pretrained="imagenet",
)
dataset = SegmentationDataset(**dataset_params)
dataloader = torch.utils.data.DataLoader(dataset, **params["dataloader"]["val"]["params"])

In [6]:
dice_meter = GlobalDice()
iou_meter = GlobalIoU()
for x in tqdm(dataloader):
    with torch.no_grad():
        image = x["image"].to(device)
        mask = x["mask"].to(device)
        out = torch.sigmoid(model(image))
        dice_meter.update(out, mask)
        iou_meter.update(out, mask)

100%|██████████| 13/13 [00:02<00:00,  4.92it/s]


In [7]:
print(f"Dice on Validation: {dice_meter.get_metric()}")
print(f"IoU on Validation: {iou_meter.get_metric()}")

Dice on Validation: 0.9563417769809763
IoU on Validation: 0.916336168189953
