# This file has Form Recognizer Model trainign and Inferencing code

#### Read configuration file and get endpoint, key of the service

In [None]:
########### Python Form Recognizer Labeled Async Train #############
import json
import time
from requests import get, post

#read form recognizer service parameters
with open('config.json','r') as config_file:
    config = json.load(config_file)

# Endpoint URL
endpoint = config['endpoint']
post_url = endpoint + r"/formrecognizer/v2.1/custom/models"
apim_key = config['apim-key']
filetype = 'application/json'


headers = {
    # Request headers
    'Content-Type': filetype,
    'Ocp-Apim-Subscription-Key': apim_key,
}

body = {
    "source": "",
    "sourceFilter": {
        "prefix": "",
        "includeSubFolders": False
    },
    "useLabelFile": False
}

# Unsupervised training 

### Function to train unsupervised Form Recognizer Model

In [None]:
n_tries = 60
n_try = 0
wait_sec = 60


def train_fr_model(sas_url, folder_path, model_file):
    
    body['source'] = sas_url
    body['sourceFilter']['prefix'] = folder_path
    
    # trigger training
    try:
        resp = post(url = post_url, json = body, headers = headers)
        #print(body)
        #print(headers)
        if resp.status_code != 201:
            print("Training model failed (%s):\n%s" % (resp.status_code, json.dumps(resp.json())))
            return
        print("Training Started:\n%s" % resp.headers)
        get_url = resp.headers["location"]
    except Exception as e:
        print("Error occurred when triggering training:\n%s" % str(e))
        quit()
        
    n_try = 0
    #wait for training to complete and save model to json file
    while n_try < n_tries:
        try:
            resp = get(url = get_url, headers = headers)
            resp_json = resp.json()
            if resp.status_code != 200:
                print("Model training failed (%s):\n%s" % (resp.status_code, json.dumps(resp_json)))
                break
            model_status = resp_json["modelInfo"]["status"]
            print("Model Status:", model_status)
            if model_status == "ready":
                #print("Training succeeded:\n%s" % json.dumps(resp_json))
                print("Training succeeded:")
                with open(model_file,"w") as f:
                    json.dump(resp_json, f)
                break
            if model_status == "invalid":
                print("Training failed. Model is invalid:\n%s" % json.dumps(resp_json))
                break
            # Training still running. Wait and retry.
            time.sleep(wait_sec)
            n_try += 1
        except Exception as e:
            msg = "Model training returned error:\n%s" % str(e)
            print(msg)
            break

    if resp.status_code != 200:
        print("Train operation did not complete within the allocated time.")    

# FR Model Inferencing

In [None]:
import requests
import glob
import os
import datetime
import tempfile
import pandas as pd
import shutil

from concurrent.futures import ThreadPoolExecutor, as_completed

In [None]:

params_infer = {
    "includeTextDetails": True
}

headers_infer = {
    # Request headers
    'Content-Type': 'application/pdf',
    'Ocp-Apim-Subscription-Key': apim_key,
}

#### Form Recognizer inferencing function

In [None]:
#######################################################
# FR Inference function multithreading
#######################################################

def fr_mt_inference(files, json_fld, model_id):
    
    post_url = endpoint + "formrecognizer/v2.1/custom/models/%s/analyze" % model_id

    
    ###################
    #send all requests in one go
    ###################
    session = requests.Session()
    url_list=[]
    for fl in files:
        
        #read file
        fname = os.path.basename(fl)
        #print("working on file: %s, %s" %(fl, datetime.datetime.now()))
        with open(fl, "rb") as f:
            data_bytes = f.read()
            
        #set variables to default values
        get_url = None
        #st_time = datetime.datetime.now()
        st_time = datetime.now()
        gap_between_requests = 1 #in seconds
        
        try:
            
            #send post request (wait and send if overlaoded)
            post_success = 0
            while post_success == 0:
                resp = session.post(url = post_url, data = data_bytes, headers = headers_infer, params = params_infer)
                if resp.status_code != 429:
                    break
                time.sleep(1)    
                
            #print(fl, resp.status_code)
                    
            if resp.status_code != 202:
                print("POST analyze failed:\n%s" % json.dumps(resp.json()))

            #print("POST analyze succeeded:\n%s" % resp.headers)
            #print("POST analyze succeeded for %s \n" % fl)
            get_url = resp.headers["operation-location"]
        except Exception as e:
            print("POST analyze failed 1:\n%s" % str(e))
        
        url_list.append((fl, fname, get_url))
        end_time = datetime.now()
        #end_time = datetime.datetime.now()
        delta = end_time - st_time
        delta = delta.total_seconds()
        if delta < gap_between_requests:
            time.sleep(gap_between_requests - delta)

    ####################################
    # get all responses in one go
    ####################################
    n_tries = 15
    wait_sec = 15

    for cnt in range(n_tries):
        
        #get results of requests sent
        completed = []
        for i in range(len(url_list)):

            fl, fname, get_url = url_list[i]
            if get_url is not None:

                try:
                    resp = session.get(url = get_url, headers = {"Ocp-Apim-Subscription-Key": apim_key})
                    resp_json = resp.json()

                    if resp.status_code != 200:
                        print("GET analyze results failed:%s \n%s" % fl, json.dumps(resp_json))
                        break

                    status = resp_json["status"]
                    if status == "succeeded":
                        print("Analysis succeeded for %s:\n" % fl)
                        with open(os.path.join(json_fld,fname.replace('.pdf','.json')), 'w') as outfile:
                            json.dump(resp_json, outfile)

                        completed.append(i)

                    if status == "failed":
                        print("Analysis failed:%s \n%s" % fl, json.dumps(resp_json))
                        break
                except Exception as e:
                    msg = "GET analyze results failed 2:\n%s" % str(e)
                    print(msg)
                    break

        # remove files where
        completed.sort(reverse=True)
        for i in completed:
            url_list.pop(i)

        print("iteration",cnt,"complete. Still",len(url_list), " to infer")
        if len(url_list) == 0:
            break
            
        time.sleep(wait_sec)
        
    ####################################
    # retun files not inferred
    ####################################
    session.close()
    
    if len(url_list) == 0:
        return("All files successfully inferred by FR")
    else:
        return(url_list)


#### Form Recognizer multi-threading inferencing function

In [None]:
# Form Recognizer inference

def fr_model_inference(src_dir, json_dir, model_file, thread_cnt):
    
    #read model details
    with open(model_file,'r') as model_file:
        model = json.load(model_file)

    if model['modelInfo']['modelId'] != None :
        model_id = model['modelInfo']['modelId']
        print("model id: %s" % model_id)
    else:
        print("Model details not present, either model training is not performed or the file is missing")
        return
    
    #Read files and divide into chunks
    fls = glob.glob(os.path.join(src_dir, "*.pdf"))
    print("inferencing ", len(fls), "files with", thread_cnt, "thread count")
    fchunk = chunkify(fls, 100)
    
    for chunk in fchunk:
        
        fr_threads = min(len(chunk),thread_cnt)

        flist = chunkify(chunk, fr_threads)

        #Call FR inference 
        threads= []
        with ThreadPoolExecutor(max_workers=thread_cnt) as executor:
            for files in flist:
                threads.append(executor.submit(fr_mt_inference, files, json_dir, model_id))

            for task in as_completed(threads):
                print(task.result()) 