In [1]:
from src.data.evaluation import ocr_metrics
from src.data.generator import DataGenerator
from src.network.model import HTRModel
from rest.clients import groq_client 

In [2]:
from dotenv import load_dotenv
import os
load_dotenv()

GROQ_API_KEYS=os.getenv("GROQ_API_KEYS", "[]")
GROQ_API_KEYS = [key.removeprefix('"').removesuffix('"') for key in GROQ_API_KEYS.removeprefix('[').removesuffix(']').split(",")]
GROQ_MODEL=os.getenv("GROQ_MODEL", "llama3-8b-8192")
FUEL_SHOT_SIZE=int(os.getenv("FUEL_SHOT_SIZE", "50"))

In [3]:
source_path="./data/mine_logs.hdf5"
batch_size=16
charset_base="""0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ ČčĆćĐđŽžŠš"""
max_text_length=256
predict=True
architecture="flor"
input_size=(1024, 128, 1)
vocab_size=110
beam_width=30
reduce_tolerance=20
stop_tolerance=30
target_path="./ml_models/text_detection_model.hdf5"

In [4]:
dtgen = DataGenerator(source=source_path,
                              batch_size=batch_size,
                              charset=charset_base,
                              max_text_length=max_text_length,
                              predict=True)

model = HTRModel(architecture=architecture,
                         input_size=input_size,
                         vocab_size=dtgen.tokenizer.vocab_size,
                         beam_width=30,
                         top_paths=10,
                         )

client = groq_client.GroqClient(GROQ_MODEL, GROQ_API_KEYS, FUEL_SHOT_SIZE)


model.compile(learning_rate=0.001)
model.load_checkpoint(target=target_path)

In [5]:
import datetime
import asyncio

async def test(ds="test", use_llm=False):
    """
    ds: vals "test", "valid", "train"
    """
    start_time = datetime.datetime.now()

    predicts, _ = model.predict(x=dtgen.next_test_batch(),
                                steps=dtgen.steps[ds],
                                ctc_decode=True,
                                verbose=1)
    predicts = [dtgen.tokenizer.decode(x[0]) for x in predicts]
    ground_truth = [x.decode() for x in dtgen.dataset[ds]['gt']]
    
    if use_llm:
        predicts = [client.correct_extraction(txt) for txt in predicts]
        predicts = await asyncio.gather(*predicts)

    
    total_time = datetime.datetime.now() - start_time
    evaluate = ocr_metrics(predicts=predicts, ground_truth=ground_truth)

    e_corpus = "\n".join([
        f"Total test images:    {dtgen.size[ds]}",
        f"Total time:           {total_time}",
        f"Time per item:        {total_time / dtgen.size[ds]}\n",
        "Metrics:",
        f"Character Error Rate: {evaluate[0]:.8f}",
        f"Word Error Rate:      {evaluate[1]:.8f}",
        f"Sequence Error Rate:  {evaluate[2]:.8f}"
    ])
    print(e_corpus)

In [6]:
await test()

Model Predict
CTC Decode
Total test images:    111
Total time:           0:00:08.858345
Time per item:        0:00:00.079805

Metrics:
Character Error Rate: 0.14501020
Word Error Rate:      0.34817858
Sequence Error Rate:  0.66666667


In [8]:
await test(use_llm=True)

Model Predict
CTC Decode
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many requests made, returning unaltered data
To many request