# Faster R-CNN Inference

In [21]:
import os

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from PIL import Image

DIR_INPUT = '/kaggle/input/global-wheat-detection'
DIR_TRAIN = f'{DIR_INPUT}/train'
DIR_TEST = f'{DIR_INPUT}/test'

DIR_WEIGHTS = '/kaggle/input/torchvision-faster-r-cnn-finetuning'
WEIGHTS_FILE = f'{DIR_WEIGHTS}/fasterrcnn_resnet50_fpn.pth'

### Submission File

The submission format requires a space delimited set of bounding boxes. For example:

`ce4833752,0.5 0 0 100 100`

indicates that image `ce4833752` has a bounding box with a confidence of `0.5`, at `x` == 0 and `y` == 0, with a `width` and `height` of 100.

In [22]:
test_df = pd.read_csv(f'{DIR_INPUT}/sample_submission.csv')
test_df.tail()

Unnamed: 0,image_id,PredictionString
5,348a992bb,1.0 0 0 50 50
6,cc3532ff6,1.0 0 0 50 50
7,2fd875eaa,1.0 0 0 50 50
8,cb8d261a3,1.0 0 0 50 50
9,53f253011,1.0 0 0 50 50


### Create Dataset

In [23]:
import cv2 as cv
import numpy as np

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
from torch.utils.data import Dataset


class WheatTest(Dataset):

  def __init__(self, image_ids, image_dir, transforms=None):
    super().__init__()
    self.image_ids = image_ids
    self.image_dir = image_dir
    self.transforms = transforms

  def __getitem__(self, idx: int):
    image_id = self.image_ids[idx]

    image = cv.imread(f'{self.image_dir}/{image_id}.jpg', cv.IMREAD_COLOR)
    image = cv.cvtColor(image, cv.COLOR_BGR2RGB).astype(np.float32)
    image /= 255.0

    if self.transforms:
      sample = {
        'image': image,
      }
      sample = self.transforms(**sample)
      image = sample['image']

    return image, image_id

  def __len__(self) -> int:
    return len(self.image_ids)

  @staticmethod
  def get_test_transform():
    return A.Compose([
      ToTensorV2(p=1.0)
    ])

### Create DataLoader

In [24]:
def get_image_ids(p):
  import glob
  image_ids = []
  for p in glob.glob(f'{p}/*.jpg'):
    n, _ = os.path.splitext(os.path.basename(p))
    image_ids.append(n)
  return image_ids

# try more images for submission
#test_dataset = WheatTest(get_image_ids(DIR_TRAIN), DIR_TRAIN, WheatTest.get_test_transform())

test_dataset = WheatTest(test_df["image_id"].unique(), DIR_TEST, WheatTest.get_test_transform())

In [25]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  return tuple(zip(*batch))

test_data_loader = DataLoader(
  test_dataset,
  batch_size=8, # GPU not enough if inference 16 images
  shuffle=False,
  num_workers=4,
  drop_last=False,
  collate_fn=collate_fn
)

## Load Model

In [26]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# create a Faster R-CNN model without pre-trained
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)

num_classes = 2 # wheat or not(background)

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained model's head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# load the trained weights
model.load_state_dict(torch.load(WEIGHTS_FILE, map_location=device))
model.eval()

# move model to the right device
_ = model.to(device)

## Produce Test Image Outputs

In [27]:
score_threshold = 0.7
image_outputs = []

for images, image_ids in test_data_loader:
  images = list(image.to(device) for image in images)
  outputs = model(images)
    
  for image_id, output in zip(image_ids, outputs):
    boxes = output['boxes'].data.cpu().numpy()
    scores = output['scores'].data.cpu().numpy()
    
    mask = scores >= score_threshold
    boxes = boxes[mask].astype(np.int32)
    scores = scores[mask]

    image_outputs.append((image_id, boxes, scores))

### Save Results

In [28]:
def format_prediction_string(boxes, scores):
  pred_strings = []
  for score, box in zip(scores, boxes):
    pred_strings.append(round(score, 4))
    pred_strings.extend(box)
  return ' '.join(map(str, pred_strings))

results = []

for image_id, boxes, scores in image_outputs:
  #boxes = boxes_.copy()
  boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
  boxes[:, 3] = boxes[:, 3] - boxes[:, 1]

  result = {
    'image_id': image_id,
    'PredictionString': format_prediction_string(boxes, scores)
  }
  results.append(result)

results

[{'image_id': 'aac893a91',
  'PredictionString': '0.9938 557 531 124 193 0.9923 70 2 101 159 0.9863 613 923 83 99 0.9843 740 772 81 115 0.9834 330 666 115 150 0.9828 696 394 117 173 0.982 820 708 105 198 0.9815 28 454 102 154 0.977 175 566 112 186 0.9766 590 779 94 119 0.9732 357 534 95 81 0.9713 458 863 81 92 0.9599 306 0 73 67 0.9402 236 843 155 106 0.927 550 56 142 196 0.9057 93 620 115 90 0.9043 249 97 126 134 0.8824 67 856 112 70 0.8533 828 629 81 120 0.8187 165 6 80 143 0.7453 825 910 123 112 0.7126 359 263 99 140'},
 {'image_id': '51f1be19e',
  'PredictionString': '0.9768 607 85 162 176 0.9728 842 260 129 207 0.9668 277 479 134 114 0.9656 65 692 127 213 0.9645 815 95 115 76 0.9641 803 761 113 100 0.9614 768 882 148 105 0.9608 186 922 117 101 0.9606 508 473 192 108 0.9565 697 920 85 88 0.9509 24 0 88 72 0.9424 0 381 57 100 0.9281 653 795 104 79 0.9169 353 143 96 175 0.8925 774 26 116 68 0.8735 245 120 113 123 0.8703 565 597 116 122 0.8592 654 583 115 90 0.8252 913 561 97 111 0.80

In [29]:
test_df = pd.DataFrame(results, columns=['image_id', 'PredictionString'])
test_df

Unnamed: 0,image_id,PredictionString
0,aac893a91,0.9938 557 531 124 193 0.9923 70 2 101 159 0.9...
1,51f1be19e,0.9768 607 85 162 176 0.9728 842 260 129 207 0...
2,f5a1f0358,0.9883 539 271 115 116 0.9868 139 747 162 125 ...
3,796707dd7,0.9821 894 332 114 96 0.9737 673 459 106 119 0...
4,51b3e36ab,0.9965 1 438 98 321 0.994 233 649 96 154 0.992...
5,348a992bb,0.9902 3 316 118 99 0.9891 734 223 138 89 0.98...
6,cc3532ff6,0.9932 768 835 171 162 0.9925 488 583 101 130 ...
7,2fd875eaa,0.993 458 505 82 128 0.9925 940 652 82 95 0.99...
8,cb8d261a3,0.9848 307 162 115 204 0.9826 21 866 82 142 0....
9,53f253011,0.9959 468 470 158 195 0.9956 931 205 93 131 0...


In [30]:
test_df.to_csv('submission.csv', index=False)