In [None]:
from pathlib import Path
from course_ocr_t1.data import MidvPackage
from tqdm import tqdm
from matplotlib import pyplot as plt
import cv2
import numpy as np
import torch
from typing import Dict, Tuple
from skimage.morphology import area_closing
import os

from src.segmentation.lightning import LightningModel

from course_ocr_t1.metrics import dump_results_dict, measure_crop_accuracy
from course_ocr_t1.metrics import iou_relative_quads

In [2]:
ckpt_path = "deeplabv3plus/segmentation/1652827910/segmentation-14.ckpt"  # specify your own path

model = LightningModel.load_from_checkpoint(ckpt_path, map_location="cpu")
model.eval();
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"

model.to(device);

In [3]:
DATASET_PATH = Path() / '..' / 'midv500_compressed'
assert DATASET_PATH.exists(), DATASET_PATH.absolute()

In [4]:
class Dataset(torch.utils.data.Dataset):
    BASE: int = 320
    ORIGINAL_SHAPE: Tuple[int, int] = (800, 450)

    def __init__(self, path: str, is_test: bool = False):
        super().__init__()
        self.path = Path(path)
        self.is_test = is_test
        self.images = None
        self.masks = None
        self.keys = None
        self._init()

    def _init(self):
        data_packs = MidvPackage.read_midv500_dataset(self.path)
        self.images = []
        self.masks = []
        self.keys = []
        for pack in data_packs:
            for di in pack:
                if self.is_test == di.is_test_split() and di.is_correct():
                    self.images.append(di.img_path)
                    self.masks.append(np.array(di.gt_data["quad"]))
                    self.keys.append(di.unique_key)
        self.images = np.array(self.images)
        self.masks = np.array(self.masks)
        self.keys = np.array(self.keys)
        self.h_pad = (self.BASE * 3 - self.ORIGINAL_SHAPE[0]) // 2
        self.w_pad = (self.BASE * 2 - self.ORIGINAL_SHAPE[1]) // 2

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item: int) -> Dict[str, np.ndarray]:
        tmp = cv2.imread(str(self.path / self.images[item]))[:, :, ::-1]

        image = np.zeros((self.BASE * 3, self.BASE * 2, 3), dtype=np.uint8)
        image[self.h_pad:tmp.shape[0] + self.h_pad, self.w_pad: tmp.shape[1] + self.w_pad] = tmp
        
        image = torch.permute(torch.from_numpy(image).float(), (2, 0, 1)) / 255.
        return {
            "image": image,
            "mask": self.masks[item],
            "key": self.keys[item]
        }

In [5]:
batch_size = 16
dataset = Dataset(DATASET_PATH, True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [6]:
ious = []
result = {}
idx = 0

In [7]:
global_batch = idx

with torch.no_grad():
    for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        if global_batch > idx:
            continue
        image = batch["image"].to(device)
        res = model(image)
        for in_batch_idx in range(batch_size):
            if in_batch_idx >= len(res):
                break
            pred = (res[in_batch_idx] > 0.5).float()[0]
            pred = pred[
                dataset.h_pad:dataset.ORIGINAL_SHAPE[0] + dataset.h_pad,
                dataset.w_pad:dataset.w_pad + dataset.ORIGINAL_SHAPE[1]
            ].cpu().numpy()
            pred = area_closing(pred)
            cnts = cv2.findContours(
                pred.astype(np.uint8).copy(),
                cv2.RETR_TREE,
                cv2.CHAIN_APPROX_SIMPLE
            )[0]
            cnt_idx = 0
            if len(cnts) == 0:
                result[batch["key"][in_batch_idx]] = [
                    [0.0, 0.0],
                    [0.0, 1.0],
                    [1.0, 1.0],
                    [1.0, 0.0]
                ]
                continue
            if len(cnts) > 1:
                max_perimeter = -1
                for cnt_idx_ in range(len(cnts)):
                    perimeter = cv2.arcLength(cnts[cnt_idx_], True)
                    if perimeter > max_perimeter:
                        cnt_idx = cnt_idx_
                        max_perimeter = perimeter
            perimeter = cv2.arcLength(cnts[cnt_idx], True)
            poly_curve = cv2.approxPolyDP(cnts[cnt_idx], 0.01 * perimeter, True)
            poly = poly_curve[:, 0].astype(float)
            if len(poly) < 3:
                result[batch["key"][in_batch_idx]] = [
                    [0.0, 0.0],
                    [0.0, 1.0],
                    [1.0, 1.0],
                    [1.0, 0.0]
                ]
                continue
            gt = batch["mask"][in_batch_idx].numpy().astype(float)
            gt[:, 0] /= dataset.ORIGINAL_SHAPE[1]
            gt[:, 1] /= dataset.ORIGINAL_SHAPE[0]
            p = poly.copy()
            p[:, 0] /= dataset.ORIGINAL_SHAPE[1]
            p[:, 1] /= dataset.ORIGINAL_SHAPE[0]
            ious.append(iou_relative_quads(gt, p))
            result[batch["key"][in_batch_idx]] = p.tolist()
        if (idx + 1) % 25 == 0:
            print(np.array(ious).mean())
        global_batch += 1

  9%|▉         | 25/266 [02:23<22:05,  5.50s/it]

0.9771624118328396


 19%|█▉        | 50/266 [04:39<19:37,  5.45s/it]

0.9775590479454631


 28%|██▊       | 75/266 [06:58<17:48,  5.59s/it]

0.9777399498438721


 38%|███▊      | 100/266 [09:16<16:11,  5.85s/it]

0.9785266707786286


 47%|████▋     | 125/266 [11:37<12:52,  5.48s/it]

0.9776938350261406


 56%|█████▋    | 150/266 [13:59<10:24,  5.39s/it]

0.9771373022349827


 66%|██████▌   | 175/266 [16:25<09:01,  5.95s/it]

0.9776889793926895


 75%|███████▌  | 200/266 [18:46<06:18,  5.74s/it]

0.9779303939735113


 85%|████████▍ | 225/266 [21:10<04:01,  5.90s/it]

0.9739944424153301


 94%|█████████▍| 250/266 [23:31<01:29,  5.60s/it]

0.9713333940186761


100%|██████████| 266/266 [25:01<00:00,  5.65s/it]


In [9]:
dump_results_dict(result, Path() / "pred.json")

---