In [1]:
import os
import sys

sys.path.append("../../../")

import yaml
import torch
from typing import Dict
from architectures.build_architecture import build_architecture
from dataloaders.build_dataset import build_dataset
from torchvision.utils import save_image

In [2]:
def load_config(config_path: str) -> Dict:
    """loads the yaml config file

    Args:
        config_path (str): _description_

    Returns:
        Dict: _description_
    """
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config

In [3]:
config = load_config("config.yaml")

Set Up

In [4]:
model = build_architecture(config=config)
checkpoint = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(checkpoint)
model = model.to("cpu")
model = model.eval()

In [5]:
# build validation dataset & validataion data loader
valset = build_dataset(
    dataset_type=config["dataset_parameters"]["dataset_type"],
    dataset_args=config["dataset_parameters"]["val_dataset_args"],
    augmentation_args=config["test_augmentation_args"],
)
testloader = torch.utils.data.DataLoader(
    valset, batch_size=1, shuffle=False, num_workers=1
)

Inference

In [6]:
if not os.path.isdir("predictions"):
    os.makedirs("predictions/rgb")
    os.makedirs("predictions/gt")
    os.makedirs("predictions/pred")

In [7]:
model.eval()
counter = 4000
for idx, data in enumerate(testloader):
    image = data["image"].cuda()
    mask = data["mask"].cuda()
    out = model.forward(image)
    out = torch.sigmoid(out)
    out[out < 0.5] = 0
    out[out >= 0.5] = 1

    # rgb input
    img = data["image"]
    img = img.detach()
    print(img.shape)
    img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128
    img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000
    img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532
    save_image(img, f"predictions/rgb/image_{idx}.png")

    # prediction
    pred = out.detach()
    pred = pred * 255.0
    save_image(pred, f"predictions/pred/pred_{idx}.png")

    # ground truth
    gt = data["mask"]
    gt = gt.detach()
    gt = gt * 255.0
    save_image(gt, f"predictions/gt/gt_{idx}.png")

    if idx == counter:
        break

torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1, 3, 512, 512])
torch.Size([1,

Save RGB, GT and Pred

In [14]:
# # RGB INPUT

# img = data["image"]
# img = img.detach()
# print(img.shape)
# img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128
# img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000
# img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532
# save_image(img, f"predictions/rgb/image_{idx}.png")

In [15]:
# # PREDICTION

# pred = out.detach()
# pred = pred * 255.0
# save_image(pred, f"predictions/pred/pred_{idx}.png")

In [16]:
# # GROUND TRUTH

# gt = data[1]
# gt = gt.detach()
# gt = gt * 255.0
# save_image(gt, f"predictions/gt/gt_{idx}.png")