In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import copy
import cv2
import time
from tqdm import tqdm

import matplotlib.pyplot as plt

import IPython
from IPython.display import Image as img
from IPython.display import display

from PIL import Image

import torch
import torch.nn as nn
import torchvision

In [None]:
from dass_det.models.yolox import YOLOX
from dass_det.models.yolo_head import YOLOXHead
from dass_det.models.yolo_head_stem import YOLOXHeadStem
from dass_det.models.yolo_pafpn import YOLOPAFPN

from dass_det.data.data_augment import ValTransform

from dass_det.utils import (
    postprocess,
    vis
)

## Set the Parameters Below

In [None]:
model_path  = "weights/sample_model.pth" # None # "weights/..."
model_size  = "xs" # "xl"
model_mode  = 0    # 1 for only face, 2 for only body

nms_thold   = 0.4
conf_thold  = 0.65
resize_size = (1024, 1024)

image_path  = "/datasets/COMICS/raw_page_images/3665/12.jpg"

In [None]:
transform = ValTransform()

def predict_and_draw(model, imgs, path, scale, sizes, conf_thold, nms_thold):
    img_cu = torch.Tensor(imgs).unsqueeze(0).cuda()
    # print("Predicting:", path)
    
    with torch.no_grad():
        face_preds, body_preds = model(img_cu, mode=0)
        face_preds = postprocess(face_preds, 1, conf_thold, nms_thold)[0]
        body_preds = postprocess(body_preds, 1, conf_thold, nms_thold)[0]

    del img_cu
    
    if face_preds is not None: 
        len_faces = face_preds.shape[0]
    else:
        len_faces = 0
    
    if body_preds is not None:
        len_bodies = body_preds.shape[0]
    else:
        len_bodies = 0
    
    if face_preds is not None and body_preds is not None:
        preds = torch.cat([face_preds, body_preds], dim=0)
    elif face_preds is not None:
        preds = face_preds
    elif body_preds is not None:
        preds = body_preds
    else:
        print("No faces or bodies are found!")
        if type(path) == str:
            p_img = cv2.imread(path)[:,:,::-1]
        else:
            p_img = cv2.imread(os.path.join(path[1], path[2] + ".jpg"))[:,:,::-1]
        plt.imshow(p_img)
        return

    classes = torch.cat([torch.zeros(len_faces), torch.ones(len_bodies)])

    preds[:,:4] /= scale
    preds[:,0]  = torch.max(preds[:,0], torch.zeros(preds.shape[0]).cuda())
    preds[:,1]  = torch.max(preds[:,1], torch.zeros(preds.shape[0]).cuda())
    preds[:,2]  = torch.min(preds[:,2], torch.zeros(preds.shape[0]).fill_(sizes[1]).cuda())
    preds[:,3]  = torch.min(preds[:,3], torch.zeros(preds.shape[0]).fill_(sizes[0]).cuda())
    scores      = preds[:,4]

    if type(path) == str:
        p_img = cv2.imread(path)[:,:,::-1]
    else:
        p_img = cv2.imread(os.path.join(path[1], path[2] + ".jpg"))[:,:,::-1]
    
    display(Image.fromarray(vis(copy.deepcopy(p_img), preds[:,:4], scores, classes, conf=0.0, class_names=["Face", "Body"])))
    
    del face_preds, body_preds, preds

## Load Model

In [None]:
assert model_path is not None
assert model_size in ["xs", "xl"]
assert model_mode in [0, 1, 2]

if model_size == "xs":
    depth, width = 0.33, 0.375
elif model_size == "xl":
    depth, width = 1.33, 1.25

model = YOLOX(backbone=YOLOPAFPN(depth=depth, width=width),
              head_stem=YOLOXHeadStem(width=width),
              face_head=YOLOXHead(1, width=width),
              body_head=YOLOXHead(1, width=width))

d = torch.load(model_path, map_location=torch.device('cpu'))

if "teacher_model" in d.keys():
    model.load_state_dict(d["teacher_model"])
else:
    model.load_state_dict(d["model"])
model = model.eval().cuda()

del d

## Predict Image

In [None]:
imgs = cv2.imread(image_path)
h, w, c = imgs.shape

imgs, labels = transform(imgs, None, resize_size)
scale = min(resize_size[0] / h, resize_size[1] / w)

predict_and_draw(model, copy.deepcopy(imgs), image_path, scale, [h, w], conf_thold, nms_thold)