In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np

In [2]:
STREAMURL = "rtsp://michael:pigeonaway@192.168.0.164:554/stream1"

# Load model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [4]:
from torchvision.transforms import ToTensor, Compose, Normalize, Resize

In [20]:
image_transform = Compose([ToTensor(), Resize((16, 25), antialias=True), Normalize((0.1307,), (0.3081,))])

In [44]:
def transform_images(images):
    return torch.stack([image_transform(image) for image in images])

In [47]:
def get_predictions(images, model):
    model.eval()
    with torch.no_grad():
        logits = model(transform_images(images))
        return torch.argmax(logits, dim=1)

In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(60, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(x.size(0), 60)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return nn.LogSoftmax(dim=1)(x)

In [8]:
network = Net()

In [9]:
network.load_state_dict(torch.load('../weights/ocr_v3.pt'))

<All keys matched successfully>

# Get a few rames

In [10]:
from time import sleep

In [11]:
def get_time_digits(frame, start_x=350, end_x=605, digit_height=50):
    digit_x = list(np.linspace(start_x , end_x, 9))
    digits = [
        frame[:digit_height, int(x1):int(x2), :] for x1, x2  in zip(digit_x, digit_x[1:])
    ]
    return digits[0:2] + digits[3:5] + digits[6:8]

In [16]:
digits = []

In [63]:
cap = cv2.VideoCapture(STREAMURL)
index = 0
while True:
    ret, frame = cap.read()
    digits = get_time_digits(frame)
    predictions = get_predictions(digits, network).tolist()
    cv2.putText(frame, str(predictions), (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.imshow("frame", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cv2.destroyWindow("frame")

In [41]:
result = transform_images(digits)

In [42]:
result

[tensor([[[-0.1187, -0.1187, -0.1187,  ...,  0.0395,  0.0619,  0.0722],
          [-0.1187, -0.1187, -0.1187,  ...,  0.0409,  0.0620,  0.0722],
          [-0.1187, -0.1187, -0.1187,  ...,  0.0533,  0.0656,  0.0722],
          ...,
          [-0.1696, -0.1387, -0.0933,  ...,  0.0145, -0.0243, -0.0367],
          [-0.1696, -0.1372, -0.0895,  ..., -0.0085, -0.0276, -0.0349],
          [-0.1678, -0.1328, -0.0898,  ..., -0.0040, -0.0003, -0.0024]],
 
         [[ 0.0340,  0.0340,  0.0340,  ...,  0.1922,  0.2146,  0.2249],
          [ 0.0340,  0.0340,  0.0340,  ...,  0.1936,  0.2148,  0.2249],
          [ 0.0340,  0.0340,  0.0340,  ...,  0.2061,  0.2184,  0.2249],
          ...,
          [ 0.0722,  0.0670,  0.0595,  ...,  0.1594,  0.1163,  0.1039],
          [ 0.0722,  0.0685,  0.0632,  ...,  0.1355,  0.1049,  0.0976],
          [ 0.0740,  0.0790,  0.0780,  ...,  0.1175,  0.0998,  0.0976]],
 
         [[-0.0678, -0.0678, -0.0678,  ...,  0.0904,  0.1128,  0.1231],
          [-0.0678, -0.0678,

In [45]:
result = network(transform_images(digits))

In [46]:
torch.argmax(result, dim=1)

tensor([1, 6, 3, 0, 3, 7])

In [50]:
get_predictions(digits, network).tolist()

[1, 6, 3, 0, 3, 7]

# Test implementation in spoc

In [3]:
from birdhub.video import Stream
from time import sleep

In [4]:
stream = Stream(STREAMURL)
while True:
    frame = stream.get_frame()
    cv2.imshow("frame", frame.image)
    sleep(0.2)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cv2.destroyWindow("frame")
del stream

In [14]:
cv2.destroyWindow("frame")
del stream