In [None]:
# ## Install the library (I make it private repo so I think because of that I can't call git clone)
# !rm -rf ./boat_downloader
# !gdown 1sRm-gUKWYxo7B8UQOjttgPQUmav1NgJ3
# !unzip boat_downloader.zip
# !rm -rf __MACOSX/

# ## Run the code
# %cd ./boat_downloader
# !python boat.py --download 1

In [None]:
from ssl import get_server_certificate
import cv2 as cv
import os
from PIL import Image
from matplotlib import pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from keras_ocr import detection,recognition

import numpy as np
from pprint import pprint



train_images = "./boat_downloader/datasets/boats/train_images"
train_gts = "./boat_downloader/datasets/boats/train_gts"

test_images = "./boat_downloader/datasets/boats/test_images"
test_gts = "./boat_downloader/datasets/boats/test_gts"


class BoatTrain(Dataset):
    def __init__(self, images_path, gts_path):
        self.images = [cv.imread(img) for img in sorted([f"{images_path}/{name}" for name in os.listdir(images_path)])]
        gts_raw = [open(img,'r').read().split('\n')[:-1] for img in sorted([f"{gts_path}/{name}" for name in os.listdir(gts_path)])]
        self.gts = []

        for gt in gts_raw:
            entity = []
            for row in gt:
                cors = [int(cor) for cor in row.split(',')[:-1]]
                label = row.split(',')[-1]
                entity.append([cors,label])
            self.gts.append(entity)

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
            gt = self.gts[idx]
        return image,gt

class BoatTest(Dataset):
    def __init__(self, images_path):
        self.images = [cv.imread(img) for img in sorted([f"{images_path}/{name}" for name in os.listdir(images_path)])]
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        return image


torch_detector = detection.Detector()
recognizer = recognition.Recognizer()

test_data = BoatTest(test_images)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

# train_data = BoatTrain(train_images, train_gts)
# train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)

## Loop for recognition pipeline
## ---------------------------------------------------------------------------------------------------------------------
detections = []
predictions = []
for samples in test_dataloader:
    images = [(np.moveaxis(sample.cpu().numpy(), [0,1,2], [-1,0,1])*255).astype(np.uint8) for sample in samples]
    # print(type(x[0]))
    # print(x[0].shape)
    # plt.imshow(cv.cvtColor(x[0], cv.COLOR_BGR2RGB))
    # plt.show()
    # break

    detection = torch_detector.detect(samples)
    detections.append(detection)
    if len(detection) >0: 
        prediction = recognizer.recognize_from_boxes(images, detection)
        predictions.append(prediction)
    else:
        predictions.append([])
## ---------------------------------------------------------------------------------------------------------------------

    

In [10]:
type(detections[0][0])

numpy.ndarray

In [3]:
predictions

[[['orda', 'polisi', 'deniz']],
 [['alova', 'marina']],
 [['setur', 'guanawd', 'alova', 'marina']],
 [['nicersmen']],
 [['oyamans', 'setur', 'alova', 'marina']],
 [['wn']],
 [['ke']],
 [['sillaa', 'setur', 'yalova', 'marina']],
 [['crman',
   'tari',
   've',
   'caaig',
   'mooorlogo',
   'yalov',
   'l',
   'kontrol',
   'cronler',
   'taona']],
 [['ss', 'ss', 's']],
 [['nesgon']],
 [['tosunl', 'ylova']],
 [['herca']],
 [['buyukde', 'duzgit']],
 [['yalona', 'esa']],
 [[]],
 [['hion', 'oa', 'sahil', 'guvenlik', 'tcsg', '85']],
 [['tarfattaxi', 'vixa', 'boat']],
 [['baane']],
 [['celon']],
 [['melisz']]]