**About** : This notebook contains the dot pipeline.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [None]:
import os
import cv2
import sys
import ast
import glob
import json
import yaml
import shutil
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm

warnings.filterwarnings("ignore", category=UserWarning)
pd.set_option('display.width', 500)
pd.set_option('max_colwidth', 100)

In [None]:
from mmdet.apis import init_detector, inference_detector

In [None]:
from params import *
from util.plots import *
from inference.yolox import retrieve_yolox_model, predict, YoloXWrapper
from inference.utils import get_transfos, InferenceDataset
from util.metrics import *
from util.boxes import Boxes

from post_process.retrieve import retrieve_missing_boxes
from post_process.in_graph import post_process_preds_dots
from post_process.dots import constraint_size, restrict_labels_x, assign_dots, cluster_on_x

In [None]:
VERSION = "v13"

### Load data

In [None]:
df = pd.read_csv('../input/df_train.csv')
df_text = pd.read_csv('../input/texts.csv')
df_target = pd.read_csv('../input/y_train.csv')
df_elt = pd.read_csv('../input/elements.csv')

In [None]:
df = df[~df['id'].isin(ANOMALIES)].reset_index(drop=True)

In [None]:
df_split = pd.read_csv('../input/df_split.csv')
df = df.merge(df_split)

In [None]:
CLASSES = ["dot"]
df = df[df['chart-type'].isin(CLASSES)].reset_index(drop=True)

In [None]:
from pathlib import Path

df_test = pd.DataFrame({"path": glob.glob('../input/dots/*')})
df_test['id'] = df_test['path'].apply(lambda x: Path(x).stem)
df_test['source'] = "extracted"
df_test['chart-type'] = "dot"
df_test['gt_path'] = ""

### Model

In [None]:
class Config:
    selected_model = "yolo"
    bbox_format = "yolo"
    pred_format = "pascal_voc"
    
    name = "benetech_1_m_1"
    cfg = f"../yolox/exps/{name}.py"
    ckpt = f"../yolox/YOLOX_outputs/{name}/best_ckpt.pth"

    version = "v13"
    labels = ["point"]

    size = (1024, 1024)

    # NMS
    conf_thresh = 0.05
    iou_thresh = 0.1
    max_per_img = 500
    min_per_img = 1
    
    val_bs = 1  # if size[0] > 1024 else 16
    device = "cuda"
    
config_marker = Config

In [None]:
model = retrieve_yolox_model(config_marker.cfg, config_marker.ckpt)
model = YoloXWrapper(model, config_marker)

In [None]:
from mmdet.apis import init_detector, inference_detector  # depend heavily on mmcv

wdir = '../input/cached/work_dirs'
config_file = wdir + '/custom.py'
checkpoint_file = wdir + '/cascade_rcnn_swin-t_fpn_LGF_VCE_PCE_coco_focalsmoothloss/checkpoint.pth'

cached_model = init_detector(config_file, checkpoint_file, device='cuda')

### Predict

In [None]:
TEST = True

In [None]:
if TEST:
    if "img_h" not in df_test.columns:
        shapes = []
        for i in range(len(df_test)):
            img = cv2.imread(df_test['path'][i])
            shapes.append(img.shape[:2])
        df_test['img_h'] = np.array(shapes)[:, 0]
        df_test['img_w'] = np.array(shapes)[:, 1]

    df_val = df_test
    
    df_target = pd.read_csv("../output/dot_labels.csv")
else:
    df_val = df[df['split'] == "val"].reset_index(drop=True).head(100)
    df_val['path'] = '../input/v2/images/valid/' + df_val['id'] + '.jpg'
    df_val['gt_path'] = '../input/v2/labels/valid/' + df_val['id'] + '.txt'

In [None]:
TYPES = ["dot"]
df_val = df_val[df_val['chart-type'].isin(TYPES)].reset_index(drop=True)

In [None]:
transforms = get_transfos(size=Config.size)
dataset = InferenceDataset(df_val, transforms, pad=True)

In [None]:
%%time
meter, _ = predict(model, dataset, Config)

### OCR

In [None]:
import transformers
transformers.utils.logging.set_verbosity_error()

from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel

from util.boxes import expand_boxes
from util.ocr import *

In [None]:
name = "microsoft/trocr-base-stage1"

processor = TrOCRProcessor.from_pretrained(name)
ocr_model = VisionEncoderDecoderModel.from_pretrained(name).cuda()
ocr_model = ocr_model.cuda()

### Loop

In [None]:
CACHED_CLASSES = [
    'x_title', 'y_title', 'plot_area', 'other', 'xlabel', 'ylabel',
    'chart_title', 'x_tick', 'y_tick', 'legend_patch', 'legend_label',
    'legend_title', 'legend_area', 'mark_label', 'value_label',
    'y_axis_area', 'x_axis_area', 'tick_grouping'
]

In [None]:
PLOT = False
DEBUG = False

In [None]:
dataset = InferenceDataset(df_val, None, pad=False)

In [None]:
%matplotlib inline

In [None]:
scores = []
df_preds = []
for idx in range(len(dataset)):
    if idx == 7:
        continue
#     DEBUG = True
    
    img, gt, shape = dataset[idx]

    if img.shape[1] > img.shape[0] * 1.4:
        padding = int(img.shape[1] * 0.9) - img.shape[0]
    else:
        padding = 0
    meter.preds[idx].update_shape((img.shape[0] + padding, img.shape[1]))

    # Cached
    cached_result = inference_detector(cached_model, dataset.paths[idx])  # list[array]
    score_th = min(0.1, cached_result[4][2, 4])

    if DEBUG:
        for i, (r, c) in enumerate(zip(cached_result, CACHED_CLASSES)):
            if c == "plot_area":
                cached_result[i] = r[:1]
            elif c not in ['plot_area', "xlabel"]:
                cached_result[i] = np.empty((0, 5))

        cached_model.show_result(
            dataset.paths[idx],
            cached_result,
            out_file='../output/sample_result.jpg',
            score_thr=score_th,
            thickness=1,
            font_size=5,
        )
        plt.figure(figsize=(15, 10))
        plt.imshow(cv2.imread('../output/sample_result.jpg'))
        plt.axis(False)
        plt.show()


    id_ = df_val.id[idx]

    print(idx, id_[:10], end="\t")
    title = f"{id_} - {df_val.source[idx]} {df_val['chart-type'][idx]}"

    preds_ = [meter.preds[idx]['pascal_voc'][meter.labels[idx] == 0] for i in range(1)][0]
    
    preds = [[], [], [], preds_]

    # Override with cached
    preds[1] = cached_result[4][cached_result[4][:, -1] > score_th][:, :4].astype(int)
    preds[0] = cached_result[2][:1, :4].astype(int)

    if DEBUG:
        plot_results(img, preds, figsize=(12, 7), title=title)
        
    preds = post_process_preds_dots(preds, margin_pt=5, margin_text=5)

    margin = (img.shape[0] + img.shape[1]) / (2 * 20)
    preds = restrict_labels_x(preds, margin=margin)

    # Visual similarity
    try:
        retrieved_boxes = retrieve_missing_boxes(
            preds, img, verbose=DEBUG, min_sim=0.8, seed=100, hw=None, max_retrieved=20, margin=-1
        )
        if len(retrieved_boxes):
            print('RETRIEVED', len(retrieved_boxes), end="\t")
            preds[-1] = np.concatenate([preds[-1], retrieved_boxes])
    except:
        pass

    if DEBUG:
        plot_results(img, preds, figsize=(12, 7), title=title)
    
    try:
        centers, clusters = cluster_on_x(preds[-1], shape[1], plot=DEBUG)
        centers = np.array([c for i, c in enumerate(centers) if clusters[i] > 0])
    except:
        centers, clusters = None, None

    if len(preds[1]):
        xlabels = preds[1]
        xlabels_loc = (xlabels[:, 0] + xlabels[:, 2]) / 2

        if centers is not None:
            mapping, retrieved_xlabels = assign_dots(preds[1], centers, retrieve_missing=True)
            if len(retrieved_xlabels):
                xlabels = np.concatenate([xlabels, retrieved_xlabels])
            xlabels_loc = (xlabels[:, 0] + xlabels[:, 2]) / 2

            preds[1] = xlabels

        if DEBUG:
            print(centers, clusters)    
            print(mapping, )

        # OCR
        x_texts = ocr(ocr_model, processor, img, preds[1], margin=1, plot=DEBUG)

        xs, ys, locs = [], [], []
        for i, txt in enumerate(x_texts):
            if clusters is not None:
                if i in mapping.keys():
                    xs.append(txt)
                    locs.append(xlabels_loc[i])
                    ys.append(clusters.get(mapping[i], 0))
                else:
                    if xlabels_loc[i] > preds[0][0][0]:
                        xs.append(txt)
                        locs.append(xlabels_loc[i])
                        ys.append(0)
            else:
                xs.append(txt)
                locs.append(xlabels_loc[i])
                ys.append(0)
    else:
        xs = [str(i) for i in range(len(centers))]
        locs = centers
        ys = list(clusters.values())

    if PLOT:
        plot_results(img, preds, figsize=(12, 7), title=title)
    
    pred = pd.DataFrame({"x": xs, "y": np.array(ys).astype(int), "loc": locs})
    pred = pred.sort_values('loc').reset_index(drop=True)

    if df_target is not None:
        gt = df_target[df_target['id'] == id_].reset_index(drop=True)
        gt['y'] = gt["y"].astype(int)

        # TODO 
        score_x = score_series(gt['x'].values, pred['x'].values)
        score_y = score_series(gt['y'].values, pred['y'].values)
        print(f"Scores  -  x: {score_x:.3f}  - y: {score_y:.3f}")

        scores += [score_x, score_y]
#         display(pred)

    if DEBUG and not TEST:
        print('GT')
        display(gt)

    pred['id'] = id_
    df_preds.append(pred)
    if DEBUG: #  or TEST:
        print('PRED')
        display(pred)

#     if idx >= 2:
    if DEBUG:
        break

In [None]:
if len(scores):
    print(f'Dots CV : {np.mean(scores) :.3f}')

Done ! 