In [1]:
import os, sys, json
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import warnings
import glob

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import models
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision import transforms

In [2]:
DEVICE_NAME = 'cuda:0'

MODEL_NAME = f'detector_2020-06-08--23-17-22'
MODEL_NAME

'detector_2020-06-08--23-17-22'

In [3]:
SEED = 1234

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
PALETTE = sns.color_palette('pastel')

In [21]:
data_path = '../data/'
preproc_val_file = os.path.join(data_path, 'preproc_val.json')
preproc_train_file = os.path.join(data_path, 'preproc_train.json')
preproc_deduplicated_train_file = os.path.join(data_path, 'preproc_deduplicated_train.json')

test_pred_boxes_file = os.path.join(data_path, 'test_pred_boxes.json')

In [6]:
with open(preproc_val_file) as rf:
    preproc_val = json.load(rf)
len(preproc_val)

1000

In [7]:
with open(preproc_train_file) as rf:
    preproc_train = json.load(rf)
len(preproc_train)

24632

In [8]:
with open(preproc_deduplicated_train_file) as rf:
    preproc_deduplicated_train = json.load(rf)
len(preproc_deduplicated_train)

21800

# Загружаем модель детекции

In [9]:
def get_detector_model(mask_hid_size, num_classes=2):    
    model = models.detection.maskrcnn_resnet50_fpn(
        pretrained=True, 
        progress=True,
        trainable_backbone_layers=4
    )

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    box_predictor = FastRCNNPredictor(in_features, num_classes)
    model.roi_heads.box_predictor = box_predictor
    
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    mask_predictor = MaskRCNNPredictor(in_features_mask, mask_hid_size, num_classes)
    model.roi_heads.mask_predictor = mask_predictor
    
    return model

In [10]:
BATCH_SIZE = 2
NUM_WORKERS = 4
HID_DIM = 256
THRESHOLD = 0.85

In [11]:
device = torch.device(DEVICE_NAME) if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda', index=0)

In [12]:
model = get_detector_model(HID_DIM)
model.to(device);

In [13]:
with open(f'models/{MODEL_NAME}.pth', "rb") as fp:
    best_state_dict = torch.load(fp, map_location="cpu")
    model.load_state_dict(best_state_dict)

In [14]:
model.eval();

In [15]:
warnings.filterwarnings("ignore", message="The default behavior for")
warnings.filterwarnings("ignore", message="This overload of nonzero")

# Получаем предсказания детектора

In [16]:
test_img_paths = glob.glob(os.path.join(data_path, 'test/*'))
test_img_names = [path[len(data_path):] for path in test_img_paths]
test_img_names[0]

'test/0.jpg'

In [20]:
to_tensor = transforms.ToTensor()

@torch.no_grad()
def predict_boxes(model, img_names, threshold=THRESHOLD, data_path=data_path):
    results = []
    model.eval()
    for fname in tqdm(img_names, position=0, leave=True):
        fpath = os.path.join(data_path, fname)
        img = cv2.imread(fpath)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_tensor = to_tensor(img).to(device)
        
        prediction = model([img_tensor])[0]
        pred_boxes = prediction['boxes'].cpu().numpy().tolist()
        scores = prediction['scores'].cpu().numpy()
        pred_boxes = [box for (box, score) in zip(pred_boxes, scores) if score >= threshold]
        pred_boxes = sorted(pred_boxes, key=lambda x: x[0])
        results.append({'file': fname, 'boxes': pred_boxes})
    return results

In [18]:
test_box_predictions = predict_boxes(model, test_img_names)

                                                   

In [19]:
test_box_predictions[0]

{'file': 'test/0.jpg',
 'boxes': [[486.7810363769531,
   564.0574951171875,
   772.1798706054688,
   625.64111328125]]}

In [22]:
with open(test_pred_boxes_file, 'w') as f:
     json.dump(test_box_predictions, f)