In [76]:
from ultralytics import YOLO
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
from torch import nn
import cv2
import time
import matplotlib.pyplot as plt
from crnn_dataset import get_split, TRDataset
from crnn_model import CRNN
from crnn_decoder import ctc_decode
from crnn_evaluate import evaluate

In [4]:
def predict(crnn, data_loader, label2char=None):
    crnn.eval()
    with torch.no_grad():
        for data in data_loader:
            device = 'cuda' if next(crnn.parameters()).is_cuda else 'cpu'
            images = data.to(device)
            
            logits = crnn(images)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)
            
            preds = ctc_decode(log_probs, label2char=label2char)
            texts = []
            for pred in preds:
                text = ''.join(pred)
                texts.append(text)
    
    return texts

In [79]:
def STR(img_path, yolo_weight, crnn_config):
    start = time.time()
    model = YOLO(yolo_weight)
    
    if type(img_path) is str:
        org_img = Image.open(img_path)
    else :
        org_img = img_path

    results = model(org_img, conf=0.4, verbose=False)
    
    cropped_texts = []
    for r in results:
        for text in r.boxes.xywh:
            x,y,w,h = text.cpu().numpy()
            cropped_img = org_img.crop((x-w/2, y-w/2, x + w/2, y + h/2))
            cropped_texts.append(cropped_img)
            
    config = torch.load(crnn_config)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    if cropped_texts:
        detect_dataset = TRDataset(images=cropped_texts, img_height=config['img_height'], img_width=config['img_width'])
        detect_loader = torch.utils.data.DataLoader(detect_dataset, batch_size=64, shuffle=False)

        num_class = len(TRDataset.LABEL2CHAR) + 1

        crnn = CRNN(1, config['img_height'], config['img_width'], num_class,
                        map_to_seq=config['map_to_seq'],
                        rnn_hidden=config['rnn_hidden'])

        if config['state_dict']:
            crnn.load_state_dict(config['state_dict'])

        crnn.to(device)

        texts = predict(crnn, detect_loader, label2char=TRDataset.LABEL2CHAR)
    else :
        print('No text detected')
        return org_img

    image_draw = org_img.copy()
    draw = ImageDraw.Draw(image_draw)

    font = ImageFont.truetype("arialbd.ttf", size=16)

    for r in results:
        for (x,y,w,h),text in zip(r.boxes.xywh.cpu().numpy(),texts):
            draw.rectangle([int(x-w/2), int(y-h/2), int(x+w/2), int(y+h/2)], outline="red", width=2)

            draw.text((int(x-w/2), int(y-h/2-18)), text, fill="red", font=font)
            
    end = time.time()
    print(f"Total time : {end-start}, Detected {len(cropped_texts)} texts")
    return image_draw

In [74]:
img_path = '../demo/herta.jpeg'
yolo_weight = '../runs/detect/train_5k_2/weights/best.pt'
crnn_config = '../checkpoints/crnn_synth_100k_config.pt'

In [80]:
cap = cv2.VideoCapture('../demo/street.mp4')

# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'XVID')  # You can choose the codec based on your needs
output_video = cv2.VideoWriter('output_video.avi', fourcc, fps, (width, height))

while cap.isOpened():
    ret, frame = cap.read()

    if ret:
        start = time.perf_counter()
        
        frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        
        results_pil = STR(frame_pil, yolo_weight, crnn_config)
        
        results_cv2 = cv2.cvtColor(np.array(results_pil), cv2.COLOR_RGB2BGR)
        end = time.perf_counter()
        
        total_time = end - start
        fps = 1 / total_time
        
        cv2.putText(results_cv2, f"FPS: {int(fps)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        cv2.imshow("YOLOv8 Inference", results_cv2)

        # Write the frame to the output video
        output_video.write(results_cv2)
    
    if cv2.waitKey(1) & 0xFF == ord('q') or not ret:
        break

# Release the VideoWriter and VideoCapture objects
cap.release()
output_video.release()
cv2.destroyAllWindows()

Total time : 0.5273473262786865, Detected 4 texts
Total time : 0.41216564178466797, Detected 4 texts
Total time : 0.3135855197906494, Detected 2 texts
Total time : 0.3175499439239502, Detected 2 texts
Total time : 0.3347146511077881, Detected 4 texts
Total time : 0.2955448627471924, Detected 2 texts
Total time : 0.28964757919311523, Detected 1 texts
Total time : 0.30158519744873047, Detected 1 texts
Total time : 0.29828381538391113, Detected 2 texts
Total time : 0.3056056499481201, Detected 2 texts
Total time : 0.3111131191253662, Detected 1 texts
Total time : 0.3045468330383301, Detected 2 texts
Total time : 0.30056023597717285, Detected 2 texts
Total time : 0.3275735378265381, Detected 3 texts
Total time : 0.31212687492370605, Detected 2 texts
Total time : 0.28459763526916504, Detected 1 texts
Total time : 0.29463887214660645, Detected 1 texts
Total time : 0.2841532230377197, Detected 1 texts
Total time : 0.3085660934448242, Detected 1 texts
Total time : 0.3272838592529297, Detected 