In [1]:
import cv2
import torch

from src.modelling.production import FPNMOCOEnsemble, FootPrintModel, UnetMOCO, \
    UnetPlusPlusMOCOEnsemble

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet_moco = UnetMOCO("../artifacts/weights/Unet_rn50_MOCO.pth", device)
footprint_model = FootPrintModel(device="cuda")
unetplusplus = UnetPlusPlusMOCOEnsemble(weights_paths=[
    "../artifacts/weights/unet_plus_plus_0.pth",
    "../artifacts/weights/unet_plus_plus_1.pth",
    "../artifacts/weights/unet_plus_plus_2.pth",
    "../artifacts/weights/unet_plus_plus_3.pth",
    "../artifacts/weights/unet_plus_plus_4.pth",
])
fpn = FPNMOCOEnsemble(weights_paths=[
    "../artifacts/weights/fpn_dice_0.pth",
    "../artifacts/weights/fpn_dice_1.pth",
    "../artifacts/weights/fpn_dice_2.pth",
    "../artifacts/weights/fpn_dice_3.pth",
    "../artifacts/weights/fpn_dice_4.pth",
])

In [2]:
from src.modelling.ensemble import Ensemble

ensemble = Ensemble(models={
    # "unet": unet_moco,
    "footprint": footprint_model,
    "fpn": fpn,
    "unetplusplus": unetplusplus,
},
    # weights=[0.5, 0.5]
)

In [3]:
from src.modelling.predict import ShiftedPredictor

predictor = ShiftedPredictor(ensemble)

In [4]:
import glob

images = []
paths = sorted(glob.glob("../data/digital_leaders/images/*.png"))
for filename in paths:
    images.append(cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB))

In [5]:
preds = predictor.predict_many(images)

  0%|          | 0/225 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/96 [00:00<?, ?it/s]

  0%|          | 0/104 [00:00<?, ?it/s]

  0%|          | 0/169 [00:00<?, ?it/s]

  0%|          | 0/169 [00:00<?, ?it/s]

  0%|          | 0/192 [00:00<?, ?it/s]

  0%|          | 0/192 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

  0%|          | 0/399 [00:00<?, ?it/s]

  0%|          | 0/420 [00:00<?, ?it/s]

  0%|          | 0/143 [00:00<?, ?it/s]

  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/143 [00:00<?, ?it/s]

  0%|          | 0/168 [00:00<?, ?it/s]

  0%|          | 0/208 [00:00<?, ?it/s]

  0%|          | 0/208 [00:00<?, ?it/s]

  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

  0%|          | 0/99 [00:00<?, ?it/s]

  0%|          | 0/110 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

  0%|          | 0/132 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/56 [00:00<?, ?it/s]

  0%|          | 0/110 [00:00<?, ?it/s]

  0%|          | 0/132 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

  0%|          | 0/132 [00:00<?, ?it/s]

  0%|          | 0/425 [00:00<?, ?it/s]

  0%|          | 0/450 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

  0%|          | 0/91 [00:00<?, ?it/s]

  0%|          | 0/112 [00:00<?, ?it/s]

In [6]:
binary_preds = [(pred > 0.5).astype("uint8") for pred in preds]
binary_preds[0].shape, binary_preds[0].dtype, binary_preds[0].min(), binary_preds[0].max()

((7347, 7526), dtype('uint8'), 0, 1)

In [7]:
from PIL import Image

for pred, filepath in zip(binary_preds, paths):
    image = Image.fromarray(pred)
    image.save(filepath.replace("images", "preds").replace("image", "preds"))