**About** : This notebook is used to train detection models.

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [None]:
import os
import cv2
import sys
import ast
import glob
import json
import yaml
import shutil
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm

warnings.filterwarnings("ignore", category=UserWarning)
pd.set_option('display.width', 500)
pd.set_option('max_colwidth', 100)

In [None]:
from params import *

from inference.yolox import retrieve_yolox_model, predict, YoloXWrapper
from inference.utils import get_transfos, InferenceDataset, nms

from util.plots import *
from util.metrics import *
from util.torch import seed_everything
from util.boxes import Boxes

from post_process.retrieve import retrieve_missing_boxes
from post_process.remove import remove_outlier_boxes
from post_process.reg import rounding, linear_regression
from post_process.ticks import restrict_on_line, assign
from post_process.in_graph import post_process_preds
from post_process.tick_point import post_process_arrow, post_process_point_as_tick

### Load data

In [None]:
df = pd.read_csv('../input/df_train.csv')
df_target = pd.read_csv('../input/y_train.csv')

In [None]:
df = df[~df['id'].isin(ANOMALIES)].reset_index(drop=True)

In [None]:
df_split = pd.read_csv('../input/df_split.csv')
df = df.merge(df_split)

In [None]:
CLASSES = ["scatter",]
df = df[df['chart-type'].isin(CLASSES)].reset_index(drop=True)

### Model

In [None]:
class ConfigMarker:
    selected_model = "yolo"
    bbox_format = "yolo"
    pred_format = "pascal_voc"

    name = "benetech_1_m_1"
    cfg = f"../yolox/exps/{name}.py"
    ckpt = f"../yolox/YOLOX_outputs/{name}/best_ckpt.pth"
    
    version = "v13"
#     version = "v13_sim"
    labels = ["point"]

    size = (1024, 1024)

    # NMS
    conf_thresh = 0.6
    iou_thresh = 0.4
    max_per_img = 500
    min_per_img = 1
    
    val_bs = 1  # if size[0] > 1024 else 16
    device = "cuda"
    
config_marker = ConfigMarker

In [None]:
model_marker = retrieve_yolox_model(config_marker.cfg, config_marker.ckpt)
model_marker = YoloXWrapper(model_marker, config_marker)

In [None]:
class ConfigMarker2:
    selected_model = "yolo"
    bbox_format = "yolo"
    pred_format = "pascal_voc"

    name = "benetech_1_l_1"
    cfg = f"../yolox/exps/{name}.py"
    ckpt = f"../yolox/YOLOX_outputs/{name}/best_ckpt.pth"
    
    version = "v13"
#     version = "v13_sim"
    labels = ["point"]

    size = (1024, 1024)

    # NMS
    conf_thresh = 0.55
    iou_thresh = 0.4
    max_per_img = 500
    min_per_img = 1
    
    val_bs = 1  # if size[0] > 1024 else 16
    device = "cuda"
    
config_marker_2 = ConfigMarker2

In [None]:
model_marker_2 = retrieve_yolox_model(config_marker_2.cfg, config_marker_2.ckpt)
model_marker_2 = YoloXWrapper(model_marker_2, config_marker_2)

### Evaluate

In [None]:
chart_types = ["scatter"]

In [None]:
PAD = True
PAD_ADV = True

In [None]:
df_val = df[df['split'] == "val"].reset_index(drop=True)
df_val['path'] = f'../input/{config_marker.version}/val2017/' + df_val['id'] + '.jpg'
df_val['gt_path'] = f'../input/{config_marker.version}/labels/valid/' + df_val['id'] + '.txt'
df_val_ = df_val.copy()

merged_boxes_list, confs_list = [], []
for t in chart_types:
    print(f'\n-> Chart type : {t}\n')
    df_val = df_val_[df_val_['chart-type'] == t].reset_index(drop=True)  # .head(10)

    print('- Predict 1')
    transforms = get_transfos(size=config_marker.size)
    dataset = InferenceDataset(df_val, transforms, pad=PAD, pad_advanced=PAD_ADV)
    meter, _ = predict(model_marker, dataset, config_marker, extract_fts=False)
    
    print('- Predict 2')
    transforms = get_transfos(size=config_marker_2.size)
    dataset = InferenceDataset(df_val, transforms, pad=PAD, pad_advanced=PAD_ADV)
    meter_2, _ = predict(model_marker_2, dataset, config_marker, extract_fts=False)
        
    print('- Update shapes')
    dataset = InferenceDataset(df_val, None, pad=PAD, pad_advanced=PAD_ADV)
    for i in range(len(dataset)):
        shape = dataset[i][2]
        meter.preds[i].update_shape(shape)
        meter_2.preds[i].update_shape(shape)

    f1s = {c: [] for c in config_marker.labels}
    recalls = {c: [] for c in config_marker.labels}
    
    dataset = InferenceDataset(df_val, None, pad=False)
    
    print('- Evaluate')
    for idx in range(len(dataset)):
        img, gt, shape = dataset[idx] 

        gt = Boxes(gt, (shape[0], shape[1]), bbox_format="yolo")['pascal_voc']
        gt = [gt[dataset.classes[idx] == i] for i in range(len(config_marker.labels))]
        
        assert len(gt[-1])
        
        preds = [meter.preds[idx]['pascal_voc'][meter.labels[idx] == i] for i in range(len(config_marker.labels))]
        preds_2 = [meter_2.preds[idx]['pascal_voc'][meter_2.labels[idx] == i] for i in range(len(config_marker.labels))]

        scores = [meter.confidences[idx][meter.labels[idx] == i] for i in range(len(config_marker.labels))]
        scores_2 = [meter_2.confidences[idx][meter_2.labels[idx] == i] for i in range(len(config_marker.labels))]
        
        boxes = np.concatenate([preds[0], preds_2[0]], 0)
        confs = np.concatenate([scores[0], scores_2[0]], 0)

        merged_boxes, merged_confs = nms(boxes, confs, threshold=0.4)
        merged_boxes = [merged_boxes]
        merged_boxes_list.append(merged_boxes)
        confs_list.append(merged_confs)

        for i, (t, p1, p2, pm) in enumerate(zip(gt, preds, preds_2, merged_boxes)):
#             metrics = compute_metrics(p1, t)
#             print(f"Preds 1 : {metrics['f1_score'] :.3f}")
#             metrics = compute_metrics(p2, t)
#             print(f"Preds 2 : {metrics['f1_score'] :.3f}")
            metrics = compute_metrics(pm, t)
#             print(f"Preds nms : {metrics['f1_score'] :.3f}")
            
            f1s[config_marker.labels[i]].append(metrics['f1_score'])
            recalls[config_marker.labels[i]].append(metrics['recall'])

    for k, v in f1s.items():
        print(f'{k} \t Avg F1: {np.mean(v):.3f}  \t Avg F1==1: {np.mean(np.array(v) == 1):.3f}', end="\t")
        print(f'Avg Recall==1: {np.mean(np.array(recalls[k]) == 1):.3f}')
#         break
#     break

Ens best fixed :
- point 	 Avg F1: 0.933  	 Avg F1==1: 0.612	Avg Recall==1: 0.745

### Predict

In [None]:
# df_val = df[df['split'] == "val"].reset_index(drop=True)
# # df_val['path'] = f'../input/{config_marker.version}/images/valid/' + df_val['id'] + '.jpg'
# df_val['path'] = f'../input/{config_marker.version}/val2017/' + df_val['id'] + '.jpg'
# df_val['gt_path'] = f'../input/{config_marker.version}/labels/valid/' + df_val['id'] + '.txt'

# TYPES = ["scatter"]
# df_val = df_val[df_val['chart-type'].isin(TYPES)].reset_index(drop=True)

# transforms = get_transfos(size=config_marker.size)
# dataset = InferenceDataset(df_val, transforms)

In [None]:
# %%time
# meter_marker, _ = predict(model_marker, dataset, config_marker)

# for i, p in enumerate(meter_marker.preds):
#     p.update_shape((df_val['img_h'][i], df_val['img_w'][i]))

### Chart model

In [None]:
from mmdet.apis import init_detector, inference_detector  # depend heavily on mmcv

In [None]:
CACHED_CLASSES = [
    'x_title', 'y_title', 'plot_area', 'other', 'xlabel', 'ylabel',
    'chart_title', 'x_tick', 'y_tick', 'legend_patch', 'legend_label',
    'legend_title', 'legend_area', 'mark_label', 'value_label',
    'y_axis_area', 'x_axis_area', 'tick_grouping'
]

In [None]:
wdir = '../input/cached/work_dirs'
config_file = wdir + '/custom.py'
checkpoint_file = wdir + '/cascade_rcnn_swin-t_fpn_LGF_VCE_PCE_coco_focalsmoothloss/checkpoint.pth'

cached_model = init_detector(config_file, checkpoint_file, device='cuda')

### OCR

In [None]:
import torch
import transformers
transformers.utils.logging.set_verbosity_error()

from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel

from util.ocr import ocr, post_process_texts

In [None]:
# name = "microsoft/trocr-base-stage1"
# processor = TrOCRProcessor.from_pretrained(name)
# ocr_model = VisionEncoderDecoderModel.from_pretrained(name).cuda()

In [None]:
# from transformers import AutoConfig
# config = AutoConfig.from_pretrained(name)
# torch.save(config, "../output/weights/ocr/config.pth")

# processor.save_pretrained("../output/weights/ocr/")
# ocr_model.save_pretrained("../output/weights/ocr/")

In [None]:
processor = TrOCRProcessor.from_pretrained("../output/weights/ocr/")

config = torch.load("../output/weights/ocr/config.pth")

ocr_model = VisionEncoderDecoderModel(config)
ocr_model.load_state_dict(torch.load('../output/weights/ocr/pytorch_model.bin'))

ocr_model = ocr_model.cuda().eval()

### Main
- Enforce sim between dets
- conv sim not robust to col  (#26)
- Make sure

In [None]:
%matplotlib inline

In [None]:
PLOT = False
DEBUG = False

In [None]:
dataset = InferenceDataset(df_val, None)

In [None]:
TO_REMOVE = ["513147edc8a1", "a7a81c55df4c", "039d3e82ebaf", "82c3706f2698", "6d4d21bdc9a8", "ca30ad3528c4", "1ab7f626447d"]

In [None]:
FIXES = {
    "17000b60f53e": [-2.36, -1.6, -1.3, -0.8, -0.5, 0.009, 0.416, 0.768, 1.296, 1.539, 2.027],
    "6d4d21bdc9a8": [7400.0, 8100.0, 9300.0, 6300.0, 10000.0, 11800.0, 9800.0, 11700.0, 16128.0, 18823.0, 21519.0],
    "e93bed1228d6": [ 5., 10., 15., 20., 22., 25., 26., 30., 32.],
}

In [None]:
# np.abs(gt.y.values)

In [None]:
# plt.imshow(img)

In [None]:
seed_everything(0)

scores = []
for idx in range(len(dataset)):
#     idx = 2  # 30
#     DEBUG = True
    
    img, gt, _ = dataset[idx]
    id_ = df_val.id[idx]
    
#     if id_ not in [
# #         "3f65b43556ac",
# #         "464599bcd4f1",
# #         "a311643a2dae",
# #         "1ab7f626447d",
# #         "22decba7318f",
# #         "611e9fffebc3",
# #         "9b62dbdded7b",
# #         "bb1e2880dfea",
# #         "2cfc2230ea25",
# #         "5931d6d316b0",
# #         "855b6cf38c6a",
# #         "fbad72f7acbd",
# #         "5c6466354e49",
#     ]:
#         continue
    
    padding_bottom, padding_right = 0, 0
    if img.shape[1] > img.shape[0] * 1.4:
        padding_bottom = int(img.shape[1] * 0.9) - img.shape[0]
    if img.shape[1] < img.shape[0] * 0.9:
        padding_right = int(img.shape[0] * 1) - img.shape[1]

#         continue
#     meter.preds[idx].update_shape((img.shape[0] + padding, img.shape[1]))
    
#     if f1s['point'][idx] == 1:
#         continue

    if id_ in TO_REMOVE:
        continue

    print(idx, id_, end="\t")
    title = f"{id_} - {df_val.source[idx]} {df_val['chart-type'][idx]}"
    
    preds = [[], [], [], []]
#     preds_marker = [
#         meter_marker.preds[idx]['pascal_voc'][meter_marker.labels[idx] == i]
#         for i in range(len(config_marker.labels))
#     ]
    
    preds_marker = [p.copy() for p in merged_boxes_list[idx]]
    
#     if padding_bottom and PAD:
#         pass
# #         preds_marker[-1][:, 1] = (
# #             preds_marker[-1][:, 1].astype(float) * (padding_bottom + img.shape[0]) / img.shape[0]
# #         ).astype(int)
# #         preds_marker[-1][:, 3] = (
# #             preds_marker[-1][:, 3].astype(float) * (padding_bottom + img.shape[0]) / img.shape[0]
# #         ).astype(int)

#     if padding_right and PAD_ADV and PAD:
#         pass
# #         preds_marker[-1][:, 0] = (
# #             preds_marker[-1][:, 0].astype(float) * (padding_right + img.shape[1]) / img.shape[1]
# #         ).astype(int)
# #         preds_marker[-1][:, 2] = (
# #             preds_marker[-1][:, 2].astype(float) * (padding_right + img.shape[1]) / img.shape[1]
# #         ).astype(int)
#     else:
#         continue

    # Cached
    cached_result = inference_detector(cached_model, dataset.paths[idx])  # list[array]

    if DEBUG:
        for i, (r, c) in enumerate(zip(cached_result, CACHED_CLASSES)):
            if c == "plot_area":
                cached_result[i] = r[:1]
            elif c not in ['plot_area', "xlabel", "ylabel", "x_tick", "y_tick", "legend_area"]:
                cached_result[i] = np.empty((0, 5))

        cached_model.show_result(
            dataset.paths[idx],
            cached_result,
            out_file='../output/sample_result.jpg',
            score_thr=0.1,
            thickness=1,
            font_size=5,
        )
        plt.figure(figsize=(15, 10))
        plt.imshow(cv2.imread('../output/sample_result.jpg'))
        plt.axis(False)
        plt.show()
            
    # Override with cached
    score_th = min(0.1, cached_result[4][2, 4])
    x_labels = cached_result[4][cached_result[4][:, -1] > score_th][:, :4].astype(int)

    score_th = min(0.1, cached_result[5][2, 4])
    y_labels = cached_result[5][cached_result[5][:, -1] > score_th][:, :4].astype(int)

    score_th = min(0.1, cached_result[7][2, 4])
    x_ticks = cached_result[7][cached_result[7][:, -1] > score_th][:, :4].astype(int)

    score_th = min(0.1, cached_result[8][2, 4])
    y_ticks = cached_result[8][cached_result[8][:, -1] > score_th][:, :4].astype(int)

    preds[0] = cached_result[2][:1, :4].astype(int)
    preds[1] = np.concatenate([x_labels, y_labels])
    preds[2] = np.concatenate([x_ticks, y_ticks])
    preds[3] = preds_marker[-1]

    if DEBUG:
        plot_results(img, preds, figsize=(12, 7), title=title)
        
    preds = post_process_point_as_tick(
        preds, marker_conf=confs_list[idx], th=0.5, max_dist=4, max_dist_o=10, verbose=DEBUG
    )
    preds = post_process_preds(preds)
    preds = post_process_arrow(preds, verbose=1)
    
    if DEBUG:
        plot_results(img, preds, figsize=(12, 7), title=title)

    margin = (img.shape[0] + img.shape[1]) / (2 * 20)
    preds = restrict_on_line(preds, margin=margin)

    # Visual similarity
    img_r = img.copy()
    try:
        legend_area = cached_result[12][0]
        if legend_area[-1] > 0.5:
            legend_area = legend_area.astype(int)
            if DEBUG:
                print('Clear legend :', legend_area)
            img_r[legend_area[1]:legend_area[3], legend_area[0]: legend_area[2]] = 255
    except:
        pass
    
    retrieved_boxes = retrieve_missing_boxes(preds, img_r, verbose=DEBUG, min_sim=0.75, seed=0)
    if len(retrieved_boxes):
#         print('RETRIEVED', len(retrieved_boxes), end="\t")
        preds[-1] = np.concatenate([preds[-1], retrieved_boxes])
        
    if PLOT or DEBUG:
        plot_results(img, preds, figsize=(12, 7), title=title)

    # OCR
    x_texts = ocr(ocr_model, processor, img, preds[1], margin=1, plot=DEBUG)
    x_values, x_errors = post_process_texts(x_texts)

    if DEBUG:
        print("x labels :", x_values, " - errors:", x_errors)
#     print(x_values)
#     print(preds[3])
    
    if len(preds[-1]):
        reg_x = linear_regression(preds[3], x_values, x_errors, preds[-1], mode="x", verbose=DEBUG)

        y_texts = ocr(ocr_model, processor, img, preds[2], margin=3, plot=DEBUG)
        y_values, y_errors = post_process_texts(y_texts)

        if DEBUG:
             print("y labels :", y_values, " - errors:", y_errors)

        reg_y = linear_regression(preds[4], y_values, y_errors, preds[-1], mode="y", verbose=DEBUG)

        gt = df_target[df_target['id'] == id_].reset_index(drop=True)
        gt[["x", "y"]] = gt[["x", "y"]].astype(float)
        gt = gt.sort_values(['x', 'y'], ignore_index=True)
        
        if id_ in FIXES:
            gt["y"] = FIXES[id_]

        reg_x = np.round(reg_x, rounding(np.max(reg_x)))
        pred = pd.DataFrame({"x": reg_x, "y": reg_y})
        pred = pred.sort_values(['x', 'y'], ignore_index=True)

        score_x = score_series(gt['x'].values, pred['x'].values)
        score_y = score_series(gt['y'].values, pred['y'].values)
    else:
        score_x, score_y = 0, 0

    if len(retrieved_boxes) and DEBUG:
        print(len(pred), "preds,", len(gt), "gts")

    print(f"Scores  -  x: {score_x:.3f}  - y: {score_y:.3f}")
    
    scores += [score_x, score_y]
    
#     if score_x == 0 and score_y == 0:
#         plot_results(img, preds, figsize=(12, 7), title=title)

    if DEBUG:
        print(f'GT : {len(gt)}')
#         display(gt)
        print(f'PRED : {len(pred)}')
#         display(pred)

    if DEBUG:
        break

In [None]:
print(f'Scatter CV : {np.mean(scores) :.3f}')

Done ! 