**About** : This notebook is used to train detection models.

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "-1"

In [None]:
import os
import cv2
import sys
import ast
import glob
import json
import yaml
import shutil
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm

pd.set_option('display.width', 500)
pd.set_option('max_colwidth', 100)

In [None]:
from params import *
from util.plots import plot_annotated_image, plot_sample
from util.torch import seed_everything
from util.yolo import *

### Load data

In [None]:
df = pd.read_csv('../input/df_train.csv')
df_text = pd.read_csv('../input/texts.csv')
df_target = pd.read_csv('../input/y_train.csv')
df_elt = pd.read_csv('../input/elements.csv')

### Split

In [None]:
# SEED = 42
# seed_everything(SEED)

# split = {}
# for i in range(len(df)):
#     split[df['id'][i]] = "train"

#     if df['source'][i] == "extracted":
#         split[df['id'][i]] = "val"
        
#         if df['chart-type'][i] == "horizontal_bar":
#             if np.random.random() > 0.3:
#                 split[df['id'][i]] = "train"
#     else:
#         if df['chart-type'][i] == "dot":
#             if np.random.random() < 0.2:
#                 split[df['id'][i]] = "val"
                
# df_split = pd.DataFrame.from_dict(split, orient="index").reset_index()
# df_split.columns = ['id', 'split']
# df_split.to_csv('../input/df_split.csv', index=False)

## EDA

In [None]:
df_split = pd.read_csv('../input/df_split.csv')
df = df.merge(df_split)

sns.countplot(x="chart-type", hue="split", data=df)
plt.yscale('log')
plt.show()

In [None]:
sns.countplot(x="chart-type", hue="source", data=df)
plt.yscale('log')
plt.show()

In [None]:
# df[df['source'] != "generated"].head()
# df.head()

In [None]:
ANOMALIES = [
    # DUPLICATED STUFF
    'ae686738e744', 'c76f6d0d5239', '760c3fa4e3d9', 'c0c1f4046222', '3e568d136b85', '913447978a74', '2ff071a45cce', 'a9a07d74ee31',
    # MISSING or MISLABELED TICKS ANNOTS
    "36079df3b5b2", "3968efe9cbfc", "6ce4bc728dd5", "733b9b19e09a", "aa9df520a5f2", "d0cf883b1e13",
    # WEIRD
    "9f6b7c57e6cd", "e1034ff92655", "e796b10718bd", "f8bdbaf0b97d", "3ef41bbc82c3", "73cfbba65962", "872d1be39bae", "3ef41bbc82c3"
]

In [None]:
df = df[~df['id'].isin(ANOMALIES)].reset_index(drop=True)

In [None]:
file = '90c504a8b320'

In [None]:
fig = plot_annotated_image(file)

In [None]:
CLASSES = [
    "dot",
    "line",
    "scatter",
]

df = df[df['chart-type'].isin(CLASSES)].reset_index(drop=True)

## Yolo preparation

### Folders

In [None]:
YOLO_PATH = '../yolov7/'
DATA_PATH = '../input/'

In [None]:
labels = ["chart", "text", "tick", "point"]
label_dict = {k: i for i, k in enumerate(labels)}
label_dict

#(1) image file path
yolo_train_img_dir = f'{DATA_PATH}/v2/images/train/'
yolo_valid_img_dir = f'{DATA_PATH}/v2/images/valid/'

#(2) label file path
yolo_train_label_dir = f'{DATA_PATH}/v2/labels/train/'
yolo_valid_label_dir = f'{DATA_PATH}/v2/labels/valid/'

#(3) config file path
yaml_file = f'{YOLO_PATH}/data_2.yaml'

os.makedirs(yolo_train_img_dir, exist_ok=True)
os.makedirs(yolo_valid_img_dir, exist_ok=True)
os.makedirs(yolo_train_label_dir, exist_ok=True)
os.makedirs(yolo_valid_label_dir, exist_ok=True)

shutil.rmtree(yolo_train_img_dir)
shutil.rmtree(yolo_valid_img_dir)
shutil.rmtree(yolo_train_label_dir)
shutil.rmtree(yolo_valid_label_dir)

os.makedirs(yolo_train_img_dir, exist_ok=True)
os.makedirs(yolo_valid_img_dir, exist_ok=True)
os.makedirs(yolo_train_label_dir, exist_ok=True)
os.makedirs(yolo_valid_label_dir, exist_ok=True)

### Loop

In [None]:
df_text = df_text[~df_text['axis'].isna()].reset_index(drop=True)  # ignore titles

In [None]:
dfts = {}
for id_, dfg in tqdm(df_text.groupby('chart_id')):
    dfts[id_] = dfg.reset_index(drop=True)

In [None]:
dfes = {}
for id_, dfg in tqdm(df_elt.groupby('chart_id')):
    dfes[id_] = dfg.reset_index(drop=True)

In [None]:
PLOT = True
SAVE = False

In [None]:
for i, (id_, dfg) in tqdm(enumerate(df.groupby('id')), total=len(df)):    
#     id_ = 'e93bed1228d6'
#     dfg = df[df['id'] == id_]

    img_file = f'../input/train/images/{id_}.jpg'
    src = dfg['source'].values[0]
    split = dfg['split'].values[0]

    if split == 'train':
        yolo_img_dir = yolo_train_img_dir
        yolo_label_dir = yolo_train_label_dir
    else:
        yolo_img_dir = yolo_valid_img_dir
        yolo_label_dir = yolo_valid_label_dir

#     # Extract boxes
    try:
        dft = dfts[id_]
        dfe = dfes[id_]
    except KeyError:
#         print("Error")
        continue
    
    if not SAVE:
        boxes = extract_bboxes_2(dfg, dft, dfe, dfg['img_h'].values[0], dfg['img_w'].values[0])
        
#         bs = np.concatenate([b for b in boxes if len(b)])
#         szs = bs[:, 2] * bs[:, 3] * dfg['img_h'].values[0] * dfg['img_w'].values[0]
#         minsz = np.min(szs)
#         if minsz > 10:
#             continue
#         print(id_)

    else:
        # Copy image
        dst_file = f'{yolo_img_dir}/{id_}.jpg'
        shutil.copyfile(img_file, dst_file)

        # Save boxes
        file_name = f'{yolo_label_dir}/{id_}.txt'
        if not os.path.exists(file_name):
            boxes = extract_bboxes_2(dfg, dft, dfe, dfg['img_h'].values[0], dfg['img_w'].values[0])

            written = []
            with open(file_name, 'w') as f:
                for c, boxes_c in enumerate(boxes):
                    for box in boxes_c:
                        str_bbox = ' '.join([str(c)] + [f"{b:.4g}" for b in box])
                        if str_bbox not in written:
                            f.write(str_bbox)
                            f.write('\n')
                            written.append(str_bbox)
        else:
            continue
                
    if PLOT or not (i % 10000):
        img = cv2.imread(img_file)
        plot_sample(img, boxes)
        plt.title(f"{id_} - {src} {dfg['chart-type'].values[0]}")
        plt.show()
    
#     if i >= 10:
    break

In [None]:
# for i, (id_, dfg) in tqdm(enumerate(df.groupby('id')), total=len(df)):    
# #     id_ = "3ef41bbc82c3"
# #     dfg = df[df['id'] == id_]

#     img_file = f'../input/train/images/{id_}.jpg'
#     src = dfg['source'].values[0]
#     split = dfg['split'].values[0]

#     if split == 'train':
#         yolo_img_dir = yolo_train_img_dir
#         yolo_label_dir = yolo_train_label_dir
#     else:
#         yolo_img_dir = yolo_valid_img_dir
#         yolo_label_dir = yolo_valid_label_dir

#     # Save boxes
#     file_name = f'{yolo_label_dir}/{id_}.txt'
#     file_name = "../input/v2/labels/train/4913c6a99055.txt"

# #     try:
#     with open(file_name, 'r') as f:
#         cs = [l[:-1] for l in f.readlines()]

# #     print(np.array(cs))
#     a, b = np.unique(cs, return_counts=True)
#     print(a[b > 1])
# #     except:
# #         continue
        
#     break

### Model
- WARNING: Extremely small objects found. 11577 of 2816497 labels are < 3 pixels in size.

In [None]:
# Dump config file
data_yaml = dict(
    train=yolo_train_img_dir,
    val=yolo_valid_img_dir,
    nc=len(labels),
    names=labels
)

print(data_yaml)

with open(yaml_file, 'w') as outfile:
    yaml.dump(data_yaml, outfile, default_flow_style=True)

yaml_file

Done ! 