In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd drive/MyDrive/BKAI_CV/

/content/drive/MyDrive/BKAI_CV


In [None]:
# Imports
import pathlib

import numpy as np
import torch
from skimage.io import imread
from skimage.transform import resize

from UNet.inference import predict
from UNet.transformations import normalize_01, re_normalize
from UNet.unet import UNet

In [None]:
root = pathlib.Path.cwd()
print(root)

/content/drive/MyDrive/BKAI_CV


In [None]:
model_name  = "Polyp.pt"

In [None]:
def predict(img,
            model,
            preprocess,
            postprocess,
            device,
            ):
    model.eval()
    img = preprocess(img)  # preprocess image
    x = torch.from_numpy(img).to(device)  # to torch, send to device
    with torch.no_grad():
        out = model(x)  # send through model/network

    out_softmax = torch.softmax(out, dim=1)  # perform softmax on outputs
    result = postprocess(out_softmax)  # postprocess outputs

    return result

In [None]:
# device
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
print(device)

cuda


In [None]:

# model
model = UNet(in_channels=3,
             out_channels=3,
             n_blocks=4,
             start_filters=32,
             activation='relu',
             normalization='batch',
             conv_mode='same',
             dim=2).to(device)

model_weights = torch.load(pathlib.Path.cwd() / model_name)

model.load_state_dict(model_weights)

<All keys matched successfully>

In [None]:
# preprocess function
def preprocess(img: np.ndarray):
    img = resize(img, output_shape=(128, 128, 3))
    img = np.moveaxis(img, source=-1, destination=0)  # from [H, W, C] to [C, H, W]
    img = normalize_01(img)  # linear scaling to range [0-1]
    img = np.expand_dims(img, axis=0)  # add batch dimension [B, C, H, W]
    img = img.astype(np.float32)  # typecasting to float32
    return img

In [None]:
# postprocess function
def postprocess(img: torch.tensor):
    img = torch.argmax(img, dim=1)  # perform argmax to generate 1 channel
    img = img.cpu().numpy()  # send to cpu and transform to numpy.ndarray
    img = np.squeeze(img)  # remove batch dim and channel dim -> [H, W]
    img = re_normalize(img)  # scale it to the range [0-255]
    return img

In [None]:
def get_filenames_of_path(path: pathlib.Path, ext: str = '*'):
    """Returns a list of files in a directory/path. Uses pathlib."""
    filenames = [file for file in path.glob(ext) if file.is_file()]
    return filenames

In [None]:
images_names = get_filenames_of_path(root / 'data/test')
images = [imread(img_name) for img_name in images_names]
images_res = [resize(img, (128, 128, 3)) for img in images]

In [None]:
output = [predict(img, model, preprocess, postprocess, device) for img in images_res]

In [None]:
output[0].shape

(128, 128)