# Imports

In [23]:
import sys
import yaml

import torch
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

sys.path.append('../')
from model import UNET
from utils import (
    load_checkpoint,
)

# Configs

In [6]:
# Read the settings
with open('../configs/test_model.yaml', 'r') as f:
    content = yaml.safe_load(f)

DEVICE = content['DEVICE']
IMAGE_HEIGHT = content['IMAGE_HEIGHT']
IMAGE_WIDTH = content['IMAGE_WIDTH']
IN_CHANNELS = content['IN_CHANNELS']
OUT_CHANNELS = content['OUT_CHANNELS']
BATCH_SIZE = content['BATCH_SIZE']
NUM_WORKERS = content['NUM_WORKERS']
PIN_MEMORY = content['PIN_MEMORY']
IMAGE_DIR = '../data/custom_images' # change the directory as you want

In [5]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class CustomDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        image = np.array(Image.open(img_path).convert("RGB"))

        if self.transform is not None:
            augmentations = self.transform(image=image)
            image = augmentations["image"]

        return image

# Dataset Creation

In [19]:
test_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

model = UNET(in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS).to(DEVICE)

# load the model: model checkpoints should be in this directory
load_checkpoint(torch.load("../model_checkpoints/checkpoint_last.pth"), model)

# get the test-dataloader
dataset = CustomDataset(image_dir=IMAGE_DIR,
                        transform=test_transforms)

=> Loading checkpoint


# Dataloaders

In [20]:
from torch.utils.data import DataLoader

data_loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    shuffle=False,
)

# Inference

In [25]:
# Do the inference
import torchvision

folder='../inferences'
os.makedirs(folder, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
for idx, x in enumerate(data_loader):
    x = x.to(device=device)
    with torch.no_grad():
        preds = torch.sigmoid(model(x))
        preds = (preds > 0.5).float()
    torchvision.utils.save_image(
        preds, f"{folder}/pred_{idx}.png"
    )
    torchvision.utils.save_image(
        x, f"{folder}/original_{idx}.png"
    )