In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from collections import Counter

from efficientdet.data_pipeline.utils import xyxy_to_xywh, xywh_to_xyxy
from efficientdet.data_pipeline.anchors import AnchorBox
from efficientdet.data_pipeline.target_encoder import TargetEncoder
from efficientdet.data_pipeline.input_dataset import create_combined_dataset

# Load data

In [None]:
path_train = '../data/train.csv'

df_train = pd.read_csv(path_train, dtype={'img_path': str, 'x1': 'int32', 'y1': 'int32', 'x2': 'int32', 'y2': 'int32', 'label': 'int32'})
df_train.head(2)

# Create tf dataset

In [None]:
ds = create_combined_dataset(path_train)

# Visualize samples

In [None]:
bs = 8
target_encoder = TargetEncoder()

In [None]:
for inp in ds.take(1):
    pass

print(inp[0].shape)
print(inp[1]['regression'].shape)
print(inp[1]['classification'].shape)

In [None]:
def add_box(ax, bbox, ec='r', fc='none', lw=0.5):
    x1, y1, x2, y2 = bbox
    height = y2 - y1
    width = x2 - x1
    rect = Rectangle((x1, y1), width, height, ec=ec, fc=fc, lw=lw)
    ax.add_patch(rect)


fig, axes = plt.subplots(2, 4, figsize=(20, 10))

all_anchors = target_encoder._anchor_box.get_all_anchors(512, 512).numpy()
for i, (ax, (_, g)) in enumerate(zip(np.ravel(axes), df_train.groupby('img_path'))):
    img, box_targets, cls_targets = inp[0][i].numpy(), inp[1]['regression'][i].numpy(), inp[1]['classification'][i].numpy()
    valid_anchors = all_anchors[box_targets[:,-1] == 1]
    cls_target_counter = Counter(cls_targets)
    positive_anchor_keys = [k for k in cls_target_counter.keys() if k >= 0]
    n_positive_anchors = "(label: anchors) " + ", ".join(["{}: {}".format(int(k), cls_target_counter[k]) for k in positive_anchor_keys])
    ax.set_title(n_positive_anchors, fontsize=14)
    ax.imshow(img)
    # plot all anchors
    for bbox in xywh_to_xyxy(valid_anchors):
        add_box(ax, bbox)
    # plot all ground truth boxes
    for _, row in g.iterrows():
        add_box(ax, row[['x1', 'y1', 'x2', 'y2']], ec='cyan', lw=1)