# 1. MMDetection 설치

In [1]:
import torch
print(torch.__version__)

1.10.0+cu111


In [2]:
# MMCV 설치
# torch 버전이 달라지면 달라진 버전을 입력해주어야 함
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/index.html

Looking in links: https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/index.html
Collecting mmcv-full
  Downloading https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/mmcv_full-1.4.6-cp37-cp37m-manylinux1_x86_64.whl (46.0 MB)
[K     |████████████████████████████████| 46.0 MB 244 kB/s 
Collecting addict
  Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)
Collecting yapf
  Downloading yapf-0.32.0-py2.py3-none-any.whl (190 kB)
[K     |████████████████████████████████| 190 kB 5.2 MB/s 
Installing collected packages: yapf, addict, mmcv-full
Successfully installed addict-2.4.0 mmcv-full-1.4.6 yapf-0.32.0


In [3]:
# MMDetection git clolne
!git clone https://github.com/open-mmlab/mmdetection.git

Cloning into 'mmdetection'...
remote: Enumerating objects: 23588, done.[K
remote: Total 23588 (delta 0), reused 0 (delta 0), pack-reused 23588[K
Receiving objects: 100% (23588/23588), 35.32 MiB | 24.92 MiB/s, done.
Resolving deltas: 100% (16486/16486), done.


In [4]:
# MMDetection 설치
!cd mmdetection; python setup.py install

running install
running bdist_egg
running egg_info
creating mmdet.egg-info
writing mmdet.egg-info/PKG-INFO
writing dependency_links to mmdet.egg-info/dependency_links.txt
writing requirements to mmdet.egg-info/requires.txt
writing top-level names to mmdet.egg-info/top_level.txt
writing manifest file 'mmdet.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'mmdet.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib
creating build/lib/mmdet
copying mmdet/__init__.py -> build/lib/mmdet
copying mmdet/version.py -> build/lib/mmdet
creating build/lib/mmdet/utils
copying mmdet/utils/misc.py -> build/lib/mmdet/utils
copying mmdet/utils/logger.py -> build/lib/mmdet/utils
copying mmdet/utils/contextmanagers.py -> build/lib/mmdet/utils
copying mmdet/utils/profiling.py -> build/lib/mmdet/utils
copying mmdet/utils/util_random.py -> build/lib

In [1]:
# 임포트하기 전에 '런타임 다시 시작'을 해야 함
from mmdet.apis import init_detector, inference_detector
import mmcv

# 2. 사전 훈련된 Faster R-CNN으로 Inference 수행 (동영상 파일)

## 사전 훈련된 Faster R-CNN 다운로드

In [2]:
# 사전 훈련 모델을 다운로드 받기 위해 mmdetection/checkpoints 디렉터리 생성 
!cd mmdetection; mkdir checkpoints

In [3]:
# 사전 훈련된 Faster R-CNN 다운로드
!wget -O /content/mmdetection/checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth

--2022-03-17 12:56:10--  https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
Resolving download.openmmlab.com (download.openmmlab.com)... 47.252.96.28
Connecting to download.openmmlab.com (download.openmmlab.com)|47.252.96.28|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 167287506 (160M) [application/octet-stream]
Saving to: ‘/content/mmdetection/checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth’


2022-03-17 12:56:32 (7.91 MB/s) - ‘/content/mmdetection/checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth’ saved [167287506/167287506]



## 사전 훈련 Faster R-CNN 모델 생성

In [4]:
# config 파일 설정
config_file = '/content/mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
# 다운로드 받은 사전 훈련 모델을 checkpoint로 설정
checkpoint_file = '/content/mmdetection/checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

In [5]:
# config 파일과 사전 훈련 모델을 기반으로 객체 탐지 모델 생성
model = init_detector(config_file, checkpoint_file, device='cuda:0')

load checkpoint from local path: /content/mmdetection/checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth


## 객체 탐지 함수 정의

In [6]:
import numpy as np
import cv2

# 0부터 차례로 클래스 매핑 
labels_to_names = {0:'person',1:'bicycle',2:'car',3:'motorbike',4:'aeroplane',5:'bus',6:'train',7:'truck',8:'boat',9:'traffic light',10:'fire hydrant',
                   11:'stop sign',12:'parking meter',13:'bench',14:'bird',15:'cat',16:'dog',17:'horse',18:'sheep',19:'cow',20:'elephant',
                   21:'bear',22:'zebra',23:'giraffe',24:'backpack',25:'umbrella',26:'handbag',27:'tie',28:'suitcase',29:'frisbee',30:'skis',
                   31:'snowboard',32:'sports ball',33:'kite',34:'baseball bat',35:'baseball glove',36:'skateboard',37:'surfboard',38:'tennis racket',39:'bottle',40:'wine glass',
                   41:'cup',42:'fork',43:'knife',44:'spoon',45:'bowl',46:'banana',47:'apple',48:'sandwich',49:'orange',50:'broccoli',
                   51:'carrot',52:'hot dog',53:'pizza',54:'donut',55:'cake',56:'chair',57:'sofa',58:'pottedplant',59:'bed',60:'diningtable',
                   61:'toilet',62:'tvmonitor',63:'laptop',64:'mouse',65:'remote',66:'keyboard',67:'cell phone',68:'microwave',69:'oven',70:'toaster',
                   71:'sink',72:'refrigerator',73:'book',74:'clock',75:'vase',76:'scissors',77:'teddy bear',78:'hair drier',79:'toothbrush' }

def get_detected_img(model, img_arr, score_threshold=0.3):
  '''모델과 원본 이미지, 필터링할 클래스 신뢰도 점수 기준을 입력받아 객체 탐지 inference 결과 이미지 반환'''
  img_arr_copy = img_arr.copy() # img_arr 복사

  bbox_color=(0, 255, 0) # 초록색
  text_color=(0, 0, 255) # 빨간색

  results = inference_detector(model, img_arr) # 객체 탐지 ineference 수행 

  # 80개의 array를 갖는 results를 순회하며, 
  # 개별 2차원 array를 추출하고 이를 바탕으로 이미지에 경계 박스와 신뢰도 점수 표시 
  for class_id, result in enumerate(results):
    if len(result) == 0: # 개별 array가 비었다면 해당 class id로 탐지된 값이 없음 
      continue
    
    # 신뢰도 점수가 score_threshold보다 큰 값만 추출 
    result_filtered = result[np.where(result[:, 4] > score_threshold)]
    
    for i in range(len(result_filtered)): # result_filtered를 순회 
      # 객체 좌표값(경계 박스 좌표값) 추출 
      x_min = int(result_filtered[i, 0]) # 좌상단 x좌표
      y_min = int(result_filtered[i, 1]) # 좌상단 y좌표
      x_max = int(result_filtered[i, 2]) # 우상단 x좌표
      y_max = int(result_filtered[i, 3]) # 우상단 y좌표
      # 경계 박스 그리기 
      cv2.rectangle(img_arr_copy, (x_min, y_min), (x_max, y_max), color=bbox_color, thickness=2)

      class_name = labels_to_names[class_id] # 클래스명
      confidence_score = result_filtered[i, 4] # 신뢰도 점수
      caption = f'{class_name}: {confidence_score:.4f}' # 캡션
      # 캡션 달기 
      cv2.putText(img_arr_copy, caption, (int(x_min), int(y_min - 7)), cv2.FONT_HERSHEY_SIMPLEX, 0.37, text_color, 1)

  return img_arr_copy # 경계 박스와 캡션이 달린 이미지 반환

## 동영상 파일을 활용한 Inference 수행

#### 동영상 파일 다운로드

In [9]:
# data 디렉터리 생성
!mkdir data 
# 동영상 파일 다운로드
!wget -O /content/data/John_Wick_small.mp4 https://github.com/chulminkw/DLCV/blob/master/data/video/John_Wick_small.mp4?raw=true

--2022-03-17 12:57:54--  https://github.com/chulminkw/DLCV/blob/master/data/video/John_Wick_small.mp4?raw=true
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/chulminkw/DLCV/raw/master/data/video/John_Wick_small.mp4 [following]
--2022-03-17 12:57:54--  https://github.com/chulminkw/DLCV/raw/master/data/video/John_Wick_small.mp4
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/chulminkw/DLCV/master/data/video/John_Wick_small.mp4 [following]
--2022-03-17 12:57:54--  https://raw.githubusercontent.com/chulminkw/DLCV/master/data/video/John_Wick_small.mp4
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.

In [10]:
def write_detected_video(model, input_path, output_path, score_threshold):
    '''동영상 파일의 프레임마다 객체 탐지를 수행해 경계 박스와 신뢰도 점수를 표시해 동영상 저장'''
    cap = cv2.VideoCapture(input_path) # 비디오 캡쳐 객체 생성

    codec = cv2.VideoWriter_fourcc(*'XVID') # Codec은 *'XVID'로 설정
    video_fps = cap.get(cv2.CAP_PROP_FPS) # FPS(Frames Per Second)
    video_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # 비디오 캡쳐 객체 프레임의 너비
    video_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # 비디오 캡쳐 객체 프레임의 높이
    video_size = (round(video_width), round(video_height)) # 비디오 크기 
    # VideoWriter 객체 생성
    video_writer = cv2.VideoWriter(output_path, codec, video_fps, video_size) 

    frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # # Frame 갯수
    print('총 프레임 개수:', frame_cnt)

    while True:
        hasFrame, img_frame = cap.read() # 비디오 캡쳐 객체에서 Frame 하나 읽기 
        if not hasFrame: # 처리할 Frame이 없으면 중단
            break
        # 해당 프레임 이미지를 활용해 객체 탐지 inference 수행 ---①
        img_frame = get_detected_img(model, img_frame,  score_threshold=score_threshold)
        video_writer.write(img_frame) # 경계 박스를 그린 프레임을 저장 

    video_writer.release() # video_writer 닫기
    cap.release() # cap 닫기

In [13]:
# 객체 탐지를 수행한 뒤 동영상 저장
write_detected_video(model=model, 
                     input_path='/content/data/John_Wick_small.mp4', 
                     output_path='/content/data/John_Wick_small_out2.mp4', 
                     score_threshold=0.4)

총 프레임 개수: 58
