# 🐠 Reef - Pytorch Starter - FasterRCNN Train

## A self-contained, simple, pure pytorch 🔥 Faster R-CNN implementation with `LB=0.416`

![](https://storage.googleapis.com/kaggle-competitions/kaggle/31703/logos/header.png)

#### FasterR-CNN is one of the SOTA models for Object detection.

### In this notebook we will cover a full pipeline from zero to a submission using a simple, pure pytorch implementation of a Faster R-CNN with pretrained weights, and fine-tuning it to this specific dataset. 
### The inference notebook is separated for organization purposes, as is it a common practice in Kaggle nowadays.

I hope it helps to get started in this amazing competition!

## You can find the [inference notebook here](https://www.kaggle.com/julian3833/coral-reef-pytorch-fasterrcnn-infer-0-xxx).

## Details: 

- It is an adapted version of [this notebook](https://www.kaggle.com/pestipeti/pytorch-starter-fasterrcnn-train) mentioned in [this comment](https://www.kaggle.com/c/tensorflow-great-barrier-reef/discussion/290016).
- FasterRCNN from torchvision
- Use Resnet50 backbone


# Please, _DO_ upvote if you find this useful!!


&nbsp;
&nbsp;
&nbsp;

#### Changelog

| Version | Description| Dataset| Best LB |
| --- | ----| --- | --- |
| [**V8**](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-train-lb-0-293?scriptVersionId=80517118)  | 2 epochs - Save last epoch | [coral-reef-pytorch-starter-fasterrcnn-weights](https://www.kaggle.com/julian3833/coral-reef-pytorch-starter-fasterrcnn-weights)| `0.293`|
| [**V16**](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-train-lb-0-293?scriptVersionId=80601095) | 4 epochs - Save all epochs | [reef-starter-torch-fasterrcnn-4e](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-4e)| `0.361` |
| [**V17**](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-train-lb-0-369?scriptVersionId=80604402) | **Add validation (on subsequences)**. 95-5 split. 8 epochs, keeping track of validation loss. | [reef-starter-torch-fasterrcnn-8e](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-8e)| `0.369` |
| [**V19**](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-train-lb-0-369?scriptVersionId=80610403) | 12 epochs, lower LR | [reef-starter-torch-fasterrcnn-12e](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-12e)| `0.413` |
| [**V24**](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-train-lb-0-416?scriptVersionId=806809369) | V19 with 90-10 train-validation split. Tidy up code. Add Flip. Correct problem with augmentations. | [reef-starter-torch-fasterrcnn-12e-v2](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-12e-v2) | `0.416` |
| [**V30**](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-train-lb-0-416) | Split in 90-10 subsequences. Train 16 epochs. Rollback validation to standard metrics. Log bbox validation score. Remove VerticalFlip (Only Horizontal). Add clip of masks as suggested [here](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-train-lb-0-416/comments#1597371)|  | `??` |

---

### References:

* [TorchVision Object Detection Finetuning Tutorial - PyTorch](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
* [Finetuning Torchvision Models - PyTorch](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html)


Ok, enough chit chat, show me the code!!

# Imports

In [None]:
# Very few imports. This is a pure torch solution!
import os
import cv2
import time
import random

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN


# Fix randomness

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_all_seeds(42)

# Configuration

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

BASE_DIR = "../input/tensorflow-great-barrier-reef/train_images/"


# Configuration for the Optimizer
LEARNING_RATE = 0.0025
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005

# Number of epochs
NUM_EPOCHS = 12

BATCH_SIZE = 8

# Load `df`

In [None]:
df = pd.read_csv("../input/reef-cv-strategy-subsequences-dataframes/train-validation-split/train-0.1.csv")

# Turn annotations from strings into lists of dictionaries
df['annotations'] = df['annotations'].apply(eval)

# Create the image path for the row
df['image_path'] = "video_" + df['video_id'].astype(str) + "/" + df['video_frame'].astype(str) + ".jpg"

df.head()

There are a lot of images with no annotations, about `80%` or `18k`.

In [None]:
(df['annotations'].str.len() > 0).value_counts()

In [None]:
(df['annotations'].str.len() > 0).value_counts(normalize=True).round(2)

 We drop them to make it easier and faster to train.

In [None]:
# Drop images with no annotations. The background works as negative samples anyway
df = df[df['annotations'].str.len() > 0].reset_index(drop=True)

# Train-validation split

We are splitting using subsequences. I have tried other strategies and this is the one that works the best for now. The dataset has just 3 videos, each of them split into sequences, but in total there are only 20 sequences. A **subsequences**, as we defined them, are parts of a sequences where objects are continually present or are continually not present. 

&nbsp;

Let's see an **example**. Consider the sequence `A` with the following frames:
* `1-20` - No annotations present
* `21-30` - Annotations present
* `31-60` - No annotations
* `61-80` - Annotations present

In this case, we say that the sequence `A` has `4` subsequences (`1-20`, `21-30`, `31-60`, `61-80`).


See: [🐠 Reef - CV strategy: subsequences!](https://www.kaggle.com/julian3833/reef-cv-strategy-subsequences) for more details about this

In [None]:
df_train = df[df['is_train']].reset_index(drop=True)
df_val = df[~df['is_train']].reset_index(drop=True)

df_train.shape[0], df_val.shape[0]

# Dataset class

In [None]:
class ReefDataset:

    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms

    def get_boxes(self, row):
        """Returns the bboxes for a given row as a 3D matrix with format [x_min, y_min, x_max, y_max]"""
        
        boxes = pd.DataFrame(row['annotations'], columns=['x', 'y', 'width', 'height']).astype(float).values
        
        # Change from [x_min, y_min, w, h] to [x_min, y_min, x_max, y_max]
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        
        # Clip boxes as suggested by Lukazs in the comment below:
        # https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-train-lb-0-416/comments#1597371
        boxes[:, 2] = np.clip(boxes[:, 2], 0, 1280)
        boxes[:, 3] = np.clip(boxes[:, 3], 0, 720)
        
        return boxes
    
    def get_image(self, row):
        """Gets the image for a given row"""
        
        image = cv2.imread(f'{BASE_DIR}/{row["image_path"]}', cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0
        return image
    
    def __getitem__(self, i):

        row = self.df.iloc[i]
        image = self.get_image(row)
        boxes = self.get_boxes(row)
        
        n_boxes = boxes.shape[0]
        
        # Calculate the area
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        
        
        target = {
            'boxes': torch.as_tensor(boxes, dtype=torch.float32),
            'area': torch.as_tensor(area, dtype=torch.float32),
            
            'image_id': torch.tensor([i]),
            
            # There is only one class
            'labels': torch.ones((n_boxes,), dtype=torch.int64),
            
            # Suppose all instances are not crowd
            'iscrowd': torch.zeros((n_boxes,), dtype=torch.int64)            
        }

        
        sample = {
            'image': image,
            'bboxes': target['boxes'],
            'labels': target['labels']
        }
        sample = self.transforms(**sample)
        image = sample['image']

        if n_boxes > 0:
            target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)

        return image, target

    def __len__(self):
        return len(self.df)

## Augmentations

A very simple set of augmentations. There should be a lot of low-hanging fruits here to explore!

In [None]:
def get_train_transform():
    return A.Compose([
        A.Flip(0.5),
        ToTensorV2(p=1.0)
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})


def get_valid_transform():
    return A.Compose([
        ToTensorV2(p=1.0)
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

In [None]:
# Define datasets
ds_train = ReefDataset(df_train, get_train_transform())
ds_val = ReefDataset(df_val, get_valid_transform())

## Check one sample

In [None]:
try:
    # Let's get an interesting one ;)
    idx = df_train[df_train.annotations.str.len() > 12].iloc[0].name
except:
    idx = 0
    

image, targets = ds_train[idx]


boxes = targets['boxes'].cpu().numpy().astype(np.int32)
img = image.permute(1,2,0).cpu().numpy()
fig, ax = plt.subplots(1, 1, figsize=(16, 8))

for box in boxes:
    cv2.rectangle(img,
                  (box[0], box[1]),
                  (box[2], box[3]),
                  (220, 0, 0), 3)
    
ax.set_axis_off()
ax.imshow(img);

## DataLoaders

In [None]:
# Create dataloaders

def collate_fn(batch):
    return tuple(zip(*batch))

dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=collate_fn)
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=collate_fn)

# Create the model

In [None]:
def get_model():
    # load a model; pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    num_classes = 2  # 1 class (starfish) + background

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features

    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    model.to(DEVICE)
    return model

model = get_model()

# Training  loop!!

In [None]:
# Create the optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

n_batches, n_batches_val = len(dl_train), len(dl_val)
val_losses = []
val_box_losses = []

for epoch in range(NUM_EPOCHS):
    
    model.train()
    
    time_start = time.time()
    loss_accum = 0
    loss_box_accum = 0
    
    # Go over training batches
    for batch_idx, (images, targets) in enumerate(dl_train, 1):
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        # Predict!
        # This dict has the following keys:
        #    loss_classifier, loss_box_reg, loss_objectness, loss_rpn_box_reg
        loss_dict = model(images, targets)
        
        # We optimize the full set of losses
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()

        loss_accum += loss_value
        loss_box_accum += loss_dict['loss_box_reg'].item()
        

        # Back-prop
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        
    # Validation step!
    val_loss_accum = 0
    val_loss_box_accum = 0
    
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(dl_val, 1):
            images = list(image.to(DEVICE) for image in images)
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
            
            val_loss_dict = model(images, targets)
            val_batch_loss = sum(loss for loss in val_loss_dict.values())
            
            val_loss_accum += val_batch_loss.item()
            val_loss_box_accum += val_loss_dict['loss_box_reg'].item()

    
    # Calculate epoch losses
    val_loss = val_loss_accum / n_batches_val
    val_loss_box = val_loss_box_accum / n_batches_val
    
    train_loss = loss_accum / n_batches
    train_loss_box = loss_box_accum / n_batches
    
    val_losses.append(val_loss)
    val_box_losses.append(val_loss_box)
    
    # Save model
    chk_name = f'pytorch_model-e{epoch}.bin'
    torch.save(model.state_dict(), chk_name)
    
    
    # Logging
    elapsed = time.time() - time_start
    
    prefix = f"[Epoch {epoch+1:2d} / {NUM_EPOCHS:2d}]"
    print()
    print(f"{prefix} Train loss: {train_loss:.3f}.  Train loss (bbox only): {train_loss_box:.3f}.  Val loss (bbox only): {val_loss_box:.3f}")   
    print(prefix)
    print(f"{prefix} Saved to  : {chk_name}  [{elapsed:.0f} secs]")
    print(f"{prefix} Val loss  : {val_loss:.3f}")
    

In [None]:
# Best model based on lowest validation loss
np.argmin(val_losses)

In [None]:
# Best model based on lowest bbox loss
np.argmin(val_box_losses)

# Check result

In [None]:
idx = 0

images, targets = next(iter(dl_val))
images = list(img.to(DEVICE) for img in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

boxes = targets[idx]['boxes'].cpu().numpy().astype(np.int32)
sample = images[idx].permute(1,2,0).cpu().numpy()

model.eval()

outputs = model(images)
outputs = [{k: v.detach().cpu().numpy() for k, v in t.items()} for t in outputs]

fig, ax = plt.subplots(1, 1, figsize=(16, 8))

# Red for ground truth
for box in boxes:
    cv2.rectangle(sample,
                  (box[0], box[1]),
                  (box[2], box[3]),
                  (220, 0, 0), 3)

    
# Green for predictions
# Print the first 5
for box in outputs[idx]['boxes'][:5]:
    cv2.rectangle(sample,
                  (box[0], box[1]),
                  (box[2], box[3]),
                  (0, 220, 0), 3)

ax.set_axis_off()
ax.imshow(sample);

# Please, _DO_ upvote if you found it useful!