# 🐡 Starfish detection: Flash ⚡ EfficientDet

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains

## Installs

Installing the packge with additinal extras for computer vision

In [None]:
!pip install -q fiftyone effdet "icevision[all]" 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[image]'
# !pip install -q "pytorch-lightning==1.4.*"
!pip uninstall -y wandb

## Imports

In [None]:
import ast
from pathlib import Path

import torch
import flash
import fiftyone as fo
import numpy as np
import pandas as pd
from flash.image import ObjectDetectionData, ObjectDetector

## Paths

In [None]:
INPUT_DIR = Path("/kaggle/input")
DATA_DIR = INPUT_DIR / "tensorflow-great-barrier-reef"
TRAIN_CSV_PATH = DATA_DIR / "train.csv"
COCO_DATA_DIR = Path("/kaggle/working/gbr-coco")

## Connvert Dataset

In [None]:
IMAGE_DIMS = (1280, 720)

def format_bbox(annotation):
    xmin, ymin, width, height = annotation["x"], annotation["y"], annotation["width"], annotation["height"]
    width = IMAGE_DIMS[0] - xmin - 1 if xmin + width >= IMAGE_DIMS[0] else width
    height = IMAGE_DIMS[1] - ymin - 1 if ymin + height >= IMAGE_DIMS[1] else height
    return {"xmin": xmin, "ymin": ymin, "width": width, "height": height}

train_files, train_labels, train_bboxes = [], [], []

train_df = pd.read_csv(TRAIN_CSV_PATH)

for idx, row in train_df.iterrows():
    image_path = DATA_DIR / "train_images" / f"video_{row['video_id']}" / f"{row['video_frame']}.jpg"
    annotations = ast.literal_eval(row["annotations"])
    
    labels, bboxes = [], []
    
    for annotation in annotations:
        labels.append("cots")
        bboxes.append(format_bbox(annotation))
    
    # Skip images with no annotations
    if labels != []:
        train_files.append(image_path)
        train_labels.append(labels)
        train_bboxes.append(bboxes)

## ⚡ Flash in 3 Steps

configure all training args https://lightning-flash.readthedocs.io/en/stable/reference/object_detection.html

### Step 1. Load your data

In [None]:
IMAGE_SIZE = 1024

datamodule = ObjectDetectionData.from_files(
    train_files=train_files,
    train_targets=train_labels,
    train_bboxes=train_bboxes,
    val_split=0.15,
    transform_kwargs={"image_size": IMAGE_SIZE},
    batch_size=4,
    num_workers=4,
)

### Step 2: Configure your model

In [None]:
model = ObjectDetector(
    head="efficientdet",
    backbone="d3",
    num_classes=datamodule.num_classes,
    image_size=IMAGE_SIZE,
    pretrained=True,
    optimizer="AdamW",
    learning_rate=0.001,
)
# model.adapter.model.max_detection_points = 1000

### Step 3: Finetune

In [None]:
from pytorch_lightning.loggers import CSVLogger
# from pytorch_lightning.callbacks import StochasticWeightAveraging

# Trainer Args
GPUS = torch.cuda.device_count()  # Set to 1 if GPU is enabled for notebook

# swa = StochasticWeightAveraging(swa_epoch_start=0.6)
logger = CSVLogger(save_dir='logs/')

trainer = flash.Trainer(
    # fast_dev_run=False,
    # callbacks=[swa],
    gradient_clip_val=0.01,
    gpus=GPUS,
    max_epochs=15,
    precision=16 if GPUS else 32,
    logger=logger,
)

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")  # strategy=("freeze_unfreeze", 5)

trainer.save_checkpoint("object_detection_model.pt")

### Training visualizations

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sns.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(15, 5)