# Detectron2 for Happy Whales 🐋
![Whale](https://c.tenor.com/6_3NxH30Ud4AAAAM/whale-blue-whale.gif)

This notebook is for training the segmentation model to segment out whales & dolphins out of the images.
Note - This notebook has been written & the model has been trained only with whales.

## Install detectron2

In [None]:
!pip3 install pyyaml==5.1
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

## Common Required Libraries for the notebook 🤔

In [None]:
import torch
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import numpy as np
import os, json, cv2, random

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

## Importing the dataset and structuring it for training 😬

We were too lazy to do train-test split since we were trying out the notebook just for fun xD

The images are annotated using [LabelMe](http://https://github.com/wkentaro/labelme)

In [None]:
!cp -r ../input/happywhales-labelme-segmentation-dataset . 

In [None]:
!mkdir Final_Data
!mkdir Final_Data/train/
!mkdir Final_Data/val/

In [None]:
!cp ./happywhales-labelme-segmentation-dataset/test/*.jpg Final_Data/train 
!cp ./happywhales-labelme-segmentation-dataset/test/*.json Final_Data/train
!cp ./happywhales-labelme-segmentation-dataset/test/*.jpg Final_Data/val
!cp ./happywhales-labelme-segmentation-dataset/test/*.json Final_Data/val

In [None]:
# To convert the labelme annotations to coco for the reference of this training notebook.
!git clone https://github.com/Tony607/labelme2coco

In [None]:
!pip3 install labelme

In [None]:
!rm Final_Data/train/train.json Final_Data/train/val.json  
!python3 labelme2coco/labelme2coco.py /kaggle/working/Final_Data/train/ --output ./test_whale.json

## Loading the dataset 🧐

In [None]:
from detectron2.data.datasets import register_coco_instances

for d in ["train", "val"]:
    register_coco_instances(f"whale_{d}", {}, f"test_whale.json", f"Final_Data/{d}")

## Visualising the annotations 😇

In [None]:
import random
import matplotlib.pyplot as plt
from detectron2.data import DatasetCatalog, MetadataCatalog

dataset_dicts = DatasetCatalog.get("whale_train")
whale_metadata = MetadataCatalog.get("whale_train")

for d in random.sample(dataset_dicts, 3):
    img = cv2.imread(d["file_name"])
    v = Visualizer(img[:, :, ::-1], metadata=whale_metadata, scale=0.5)
    v = v.draw_dataset_dict(d)
    plt.figure(figsize = (14, 10))
    plt.imshow(cv2.cvtColor(v.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
    plt.show()

## Training 😭
### The model used here is Mask RCNN R50

p.s. - Takes about 30 minutes to train. Holod your horses.

In [None]:
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
import os

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("whale_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 8
cfg.SOLVER.BASE_LR = 0.00010
cfg.SOLVER.MAX_ITER = 500
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

## Inference 🚀
### Passing single image and visualising the output

In [None]:
from detectron2.utils.visualizer import ColorMode

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.4
predictor = DefaultPredictor(cfg)

#Passing random image path as input for test :)
im = cv2.imread("../input/happy-whale-and-dolphin/test_images/000e246888710c.jpg")
outputs = predictor(im)
print(outputs)
v = Visualizer(im[:, :, ::-1],
               metadata=whale_metadata, 
               scale=0.8, 
               instance_mode=ColorMode.IMAGE_BW 
)


v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
plt.figure(figsize = (14, 10))
plt.imshow(cv2.cvtColor(v.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
plt.show()

# Under Construction 🥷
## Using segments to plot the mask on a binary image and EDA.