**About** : This notebook is used to train detection models.

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [None]:
import os
import cv2
import sys
import ast
import glob
import json
import yaml
import shutil
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm

warnings.filterwarnings("ignore", category=UserWarning)
pd.set_option('display.width', 500)
pd.set_option('max_colwidth', 100)

In [None]:
from params import *

from util.plots import *
from util.metrics import *
from util.torch import seed_everything
from util.boxes import Boxes

from model_zoo.centernet import CenterNet
from data.dataset import CenterNetDataset
from data.transforms import get_transfos_centernet
from data.preparation import prepare_centernet_data

from training.losses import CenterLoss

from util.torch import init_distributed
from util.logger import (
    prepare_log_folder,
    save_config,
    create_logger,
    init_neptune
)

### Load data

In [None]:
df, df_val = prepare_centernet_data(use_extra=False)

In [None]:
transfos = get_transfos_centernet(resize=(512, 512), strength=1)

In [None]:
dataset = CenterNetDataset(df_val, transfos)

In [None]:
for idx in np.random.choice(np.arange(len(dataset)), 1):
    img, tgt, _ = dataset[idx]

    plt.figure(figsize=(20, 5))
    plt.subplot(1, 4, 1)
    plt.imshow(img[0].cpu().numpy(), cmap="gray")
    plt.subplot(1, 4, 2)
    plt.imshow(img[1].cpu().numpy(), cmap="gray")
    plt.subplot(1, 4, 3)
    plt.imshow(tgt[:, :, 0].cpu().numpy(), interpolation=None)
    plt.subplot(1, 4, 4)
    plt.imshow(tgt[:, :, 1].cpu().numpy(), interpolation=None)

### Model

In [None]:
model = CenterNet(num_classes=3)

In [None]:
y = model(img.unsqueeze(0))

In [None]:
img.size()

In [None]:
y.size()

### Loss

In [None]:
loss = CenterLoss()

In [None]:
loss(y, tgt.unsqueeze(0))

### Main

In [None]:
class Config:
    """
    Parameters used for training
    """
    # General
    seed = 42
    verbose = 1
    device = "cuda"https://smp.readthedocs.io/en/latest/index.html
    save_weights = True

    # Images
    img_folder = "v13_sim/"
    data_path = "../input/"
    aug_strength = 1
    resize = (512, 512)
    use_extra = False

    # k-fold
    k = 4
    folds_file = None
    selected_folds = [0]

    # Model
    name = "resnet18"  # "eca_nfnet_l2"  # "tf_efficientnetv2_s" "eca_nfnet_l1"
    pretrained_weights = None
    num_classes = 3
    n_channels = 3
    drop_rate = 0.
    drop_path_rate = 0.
    syncbn = False

    # Training
    loss_config = {
        "name": "centerloss",  # bce ?
        "smoothing": 0.0,
        "activation": "",
        "aux_loss_weight": 0.,
    }

    data_config = {
        "batch_size": 16,
        "val_bs": 32,
        "mix": "cutmix",
        "mix_proba": 0,
        "mix_alpha": 4.0,
        "num_classes": num_classes,
        "additive_mix": False,
    }

    optimizer_config = {
        "name": "Ranger",
        "lr": 1e-3,
        "warmup_prop": 0.1,
        "betas": (0.9, 0.999),
        "max_grad_norm": 10.0,
        "weight_decay": 0,  # 1e-2,
    }

    epochs = 10
    use_fp16 = True

    verbose = 1
    verbose_eval = 200

    fullfit = False
    n_fullfit = 1


In [None]:
DEBUG = True
log_folder = None
run = None

In [None]:
from training.main_centernet import k_fold

In [None]:
if not DEBUG:
    log_folder = prepare_log_folder(LOG_PATH)
    print(f"Logging results to {log_folder}")
    config_df = save_config(Config, log_folder + "config.json")
    create_logger(directory=log_folder, name="logs.txt")
#     run = init_neptune(Config, log_folder)

config = Config
init_distributed(config)

preds = k_fold(config, log_folder=log_folder, run=run)

### Eval

In [None]:
import torch
from util.centernet import process_and_score, pred2box

In [None]:
preds = np.load('../logs/2023-06-06/11/pred_val.npy')

In [None]:
f1s = process_and_score(preds.astype(np.float32), df_val, th=0.4, pool_size=3)

In [None]:
print(f'Avg F1: {np.mean(f1s):.3f}  \t Avg F1==1: {np.mean(np.array(f1s) == 1):.3f}')

In [None]:
# print(f'Avg F1: {np.mean(f1s):.3f}  \t Avg F1==1: {np.mean(np.array(f1s) == 1):.3f}')

### Viz

In [None]:
preds = np.load('../logs/2023-06-06/11/pred_val.npy')

In [None]:
import re 
df_val['path'] = df_val['path'].apply(lambda x: re.sub('v13_sim/', "v13/", x))

In [None]:
pool_size = 3
pool = torch.nn.MaxPool2d(pool_size, stride=1, padding=pool_size // 2)

shape = (128, 128)
th = 0.4

dataset = CenterNetDataset(df_val, None)

In [None]:
PLOT = False
DEBUG = False

In [None]:
f1s = []
for i in range(len(df_val)):
#     i = 0
#     DEBUG = True
    
    gt_path = df_val['gt_path'][i]
    coords = open(gt_path, 'r').readlines()
    coords = np.array([c[2:-1].split(' ') for c in coords]).astype(float)
    
    heatmap = torch.from_numpy(preds[i][0]).float()
    sz = heatmap.size(-1)
    
#     if DEBUG:
#         plt.imshow(heatmap)
#         plt.show()
    
    reg = torch.ones_like(heatmap) * 0.005
    reg = torch.stack([reg, reg])

    heatmap = heatmap.unsqueeze(0).unsqueeze(0)
    heatmap = torch.where(heatmap == pool(heatmap), heatmap, 0)
    heatmap = heatmap[0, 0]
    
#     if DEBUG:
#         plt.imshow(heatmap)
#         plt.show()
    
    boxes, confs = pred2box(heatmap, reg, th)

    if len(boxes):
        boxes[:, 2] = coords[:, 2].max() * sz
        boxes[:, 3] = coords[:, 3].max() * sz

    pred_boxes = Boxes(boxes / sz, shape, bbox_format="yolo")
    gt_boxes = Boxes(coords, shape, bbox_format="yolo")

    metrics = compute_metrics([pred_boxes], [gt_boxes])
    f1s.append(metrics['f1_score'])
    
#     print(i, df_val['id'][i], "\t f1==1", metrics['f1_score'] == 1)
    
    if PLOT:
        img, _, _ = dataset[i]
        
        pred_boxes.update_shape(img.shape)
        gt_boxes.update_shape(img.shape)

        plot_results(img, [[], [], [], pred_boxes['pascal_voc']])
#         plot_results(img, [[], [], [], gt_boxes['pascal_voc']])

    if DEBUG:
        break

In [None]:
from util.plots import plot_results

In [None]:
print(f'Avg F1: {np.mean(f1s):.3f}  \t Avg F1==1: {np.mean(np.array(f1s) == 1):.3f}')

In [None]:
plt.imshow(heatmap > th)
plt.scatter(boxes[:, 0], boxes[:, 1], s=1, c="r")
plt.scatter(coords[:, 0] * 128, coords[:, 1] * 128, s=1, c="r")

In [None]:
th = 0.5

boxes, confs = pred2box(heatmap, reg, th)

plt.imshow(heatmap > th)
plt.scatter(boxes[:, 0], boxes[:, 1], s=1, c="r")


### Inf

In [None]:
from pathlib import Path

df_test = pd.DataFrame({"path": glob.glob('../input/dots/*')})
df_test['id'] = df_test['path'].apply(lambda x: Path(x).stem)
df_test['source'] = "extracted"
df_test['chart-type'] = "dot"
df_test['gt_path'] = ""


In [None]:
from inference.main_centernet import kfold_inference

In [None]:
pred_test = kfold_inference(df_test, '../logs/2023-06-06/12/')
# pred_test = kfold_inference(df_test, '../logs/2023-06-06/14/')

In [None]:
pool_size = 3
pool = torch.nn.MaxPool2d(pool_size, stride=1, padding=pool_size // 2)

shape = (128, 128)
th = 0.5

dataset = CenterNetDataset(df_test, None)

In [None]:
PLOT = True

In [None]:
f1s = []
for i in range(len(dataset)):
    
    heatmap = torch.from_numpy(pred_test[i][0]).float()
    heatmap = heatmap / heatmap.max()

    sz = heatmap.size(-1)
#     if DEBUG:
#         plt.imshow(heatmap)
#         plt.show()
    
    reg = torch.ones_like(heatmap) * 0.005
    reg = torch.stack([reg, reg])

    heatmap = heatmap.unsqueeze(0).unsqueeze(0)
    heatmap = torch.where(heatmap == pool(heatmap), heatmap, 0)
    heatmap = heatmap[0, 0]
    
#     if DEBUG:
#         plt.imshow(heatmap)
#         plt.show()
    
    boxes, confs = pred2box(heatmap, reg, th)

    pred_boxes = Boxes(boxes / sz, shape, bbox_format="yolo")
    
    print(i, df_val['id'][i])
    
    if PLOT:
        img, _, _ = dataset[i]
        pred_boxes.update_shape(img.shape)
        plot_results(img, [[], [], [], pred_boxes['pascal_voc']])
#         plot_results(img, [[], [], [], gt_boxes['pascal_voc']])

    if DEBUG:
        break

Done !