In [1]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torchvision as tv
from torch import nn

device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')

In [2]:
# Функция для загрузки retinaNet
def get_retina_model(num_classes):
    
    model = tv.models.detection.retinanet_resnet50_fpn_v2(num_classes=num_classes,
                                                          weights_backbone=tv.models.ResNet50_Weights.DEFAULT,
                                                          trainable_backbone_layers=5,
                                                          nms_thresh=0.5,
                                                          score_thresh=0.5
                                                        )
    return model
# Функция для загрузки maxVIT
def get_maxvit(num_classes):
    model = tv.models.maxvit_t(weights=tv.models.MaxVit_T_Weights.DEFAULT)
    model.classifier[-1] = nn.Linear(512, num_classes, bias=False)
    return model

In [3]:
# Класс хранит модели и выдает предсказания по изображениям
class Predictor:
    def __init__(self, retina, maxvit):
        self.retina = retina 
        self.maxvit = maxvit
    # Функция для предсказаний
    def predict(self, orig_img, device):
        copy_img = orig_img.copy()
        img_width = orig_img.shape[1]
        img_height = orig_img.shape[0]
        # Изображение к диапазону от 0 до 1.
        img = orig_img.astype(np.float32)/255.
        # Уменьшаем
        img = cv2.resize(img, (380, 380), interpolation=cv2.INTER_AREA)
        # Располагаем каналы в порядке принятом в pyTorch
        img = img.transpose((2, 0, 1))
        t_img = torch.from_numpy(img).to(device)
        # Три класса для положения головы
        classes = ['face', 'side_left', 'side_right']
        with torch.no_grad():
            predict = self.retina([t_img])
            boxes = predict[0]['boxes']
            # Для каждой распознанной рамки вырезаем лицо
            for box in boxes:
                xmin = int(box[0] / 380 * img_width)
                xmax = int(box[2] / 380 * img_width)
                ymin = int(box[1] / 380 * img_height)
                ymax = int(box[3] / 380 * img_height)
                
                face_img = copy_img[ymin:ymax,xmin:xmax]
                face_img = face_img.astype(np.float32)/255.
                # MaxVit работает только с размером 224x224
                face_img = cv2.resize(face_img, (224, 224), interpolation=cv2.INTER_AREA)
                
                face_img = face_img.transpose((2, 0, 1))
                t_face_img = torch.from_numpy(face_img).to(device)
                # Добавляем входному тензору + 1 измерение
                t_face_img = torch.unsqueeze(t_face_img, 0)
                pred = self.maxvit(t_face_img)
                # Визуализируем предсказания рисуя рамку и подписывая ее
                # pred[0].argmax().item() получение номера класса с самым большим значением
                cv2.rectangle(orig_img, (xmin, ymin), (xmax, ymax), [255, 0, 0], 10)
                cv2.putText(orig_img, 
                                classes[pred[0].argmax().item()], 
                                (xmin, ymin-10),
                                cv2.FONT_HERSHEY_SIMPLEX, 
                                3, 
                                [155, 255, 0], 
                                5, 
                                lineType=cv2.LINE_AA)
        return orig_img

In [4]:
# У данной модели RetinaNet два класса фон и лицо
retina = get_retina_model(2)

retina_checkpoint = torch.load(
            'best_retina2_model.pth',
            map_location=device
        )
retina.load_state_dict(retina_checkpoint["model_state_dict"])
retina.to(device).eval()


RetinaNet(
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      

In [5]:
# У MaxVIT 3 класса: прамо, влево, вправо
maxvit = get_maxvit(num_classes=3)

checkpoint = torch.load('best_maxvit2_model.pth', map_location=device)
maxvit.load_state_dict(checkpoint["model_state_dict"])
    
maxvit = maxvit.to(device).eval()
device

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


device(type='cuda')

In [6]:
predictor = Predictor(retina, maxvit)

In [7]:
# Открываем тестовое видео
cap = cv2.VideoCapture('/home/andrey/testvideo/track5108_5.mp4')

In [8]:

fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'XVID')
# Куда записываем видео
out = cv2.VideoWriter('output3-005.avi', fourcc, fps, (1920, 1080))

In [9]:
# Читаем кадр, распознаем, рисуем рамки и сохраняем кадр в новое видео
while cap.isOpened():
    
    ret, frame = cap.read()
    if not ret:
        print("no ret")
        break
    
    orig_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pred_img = predictor.predict(orig_img, device)
    pred_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR)
    out.write(pred_img)
    if cv2.waitKey(1) == ord('q'):
        break
cap.release()
print("cap closed")
out.release()
print("out closed")

[h264 @ 0x56341a9b2400] error while decoding MB 69 44, bytestream -5
[h264 @ 0x56341a87f0c0] error while decoding MB 44 44, bytestream -5
[h264 @ 0x56341cdd5940] error while decoding MB 66 49, bytestream -7
[h264 @ 0x563422edda80] error while decoding MB 70 44, bytestream -7
[h264 @ 0x5634214f8b00] error while decoding MB 66 54, bytestream -9
[h264 @ 0x563421293640] error while decoding MB 36 48, bytestream -7
[h264 @ 0x5634221d2ec0] error while decoding MB 46 56, bytestream -5
[h264 @ 0x5634240e1640] error while decoding MB 67 21, bytestream -7
[h264 @ 0x5634214f6900] error while decoding MB 1 31, bytestream -13
[h264 @ 0x56341a82ec80] error while decoding MB 26 53, bytestream -11
[h264 @ 0x5634214f8b00] error while decoding MB 22 38, bytestream -19
[h264 @ 0x5634210be6c0] error while decoding MB 117 31, bytestream -13
[h264 @ 0x56341a8398c0] error while decoding MB 55 26, bytestream -7
[h264 @ 0x563422edda80] error while decoding MB 65 44, bytestream -9
[h264 @ 0x56342173db40] error 

no ret
cap closed
out closed
