# Setup

In [1]:
import os
import pandas as pd
import re
import pickle
from tqdm import tqdm

import asyncio
from tenacity import retry, stop_after_attempt
import tenacity
import time
import math

## Init Gcloud

In [2]:
bucket_name = "BUCKETNAME"
file_path = "images"

In [3]:
from google.cloud import storage
client = storage.Client("TODO")
bucket = client.bucket(bucket_name)

In [4]:
images = [blob.name for blob in bucket.list_blobs(prefix=file_path)]
images = sorted(blob_name for blob_name in images if blob_name[-4:] == ".jpg" and ("_1.jpg" in blob_name or "_" not in blob_name))

In [5]:
import vertexai
from vertexai.generative_models import GenerativeModel, Part, GenerationConfig
vertexai.init(project="TODO", location="europe-west1")
model_name = "gemini-2.5-flash"

# Clustering

In [6]:
# classes
with open("clustering/clustering.pkl", "rb") as f:
    clustering = pickle.load(f)

In [7]:
classes = set(clustering.values())
classes

{'135465157', '138196792', '4954189', '4954213', '4956212', '4957390', 'other'}

# vertexai

In [8]:
from google import genai
from google.genai.types import Content, CreateCachedContentConfig, HttpOptions, Part, GenerateContentConfig, ThinkingConfig
from google.genai import types

In [9]:
client = genai.Client(project="its-ml", location="europe-west1", vertexai=True, http_options=HttpOptions(api_version="v1"))

In [8]:
model = GenerativeModel(model_name)

In [10]:
few_shot_images = [f[:-4] for f in os.listdir("few_shot_prompt_images") if "jpg" in f]

In [12]:
pre_prompt = """
Your job is to transcribe information of persons from an image document. 
First detect if the image contains a table like structure with multiple persons in it. If it does not, return "FALSE". 
You are very accurate and you do not halucinate. The output should be provided in csv format separated by ';'. <header>first name; last name; date of birth</header>. 
Provide the date of birth in dd.mm.yyyy format. Dates are in the range from 1846 to 1945. Leave the date empty if it is not present. Fill missing day in dd.mm.yyyy with 00. 
Fill missing Month in dd.mm.yyyy with 00. Ignore degrees like dr.,  med., ing. Return only the csv, or False based on the stated condition.
""".replace("\n","")

def known_image_prompt(image_name):
    with open(f"few_shot_prompt_images/{image_name}.txt","r") as f:
        output = "\n".join([line for line in f])
    return [
        Part.from_text(text="<EXAMPLE>\nINPUT:\n"),
        Part.from_uri(f"gs://{bucket_name}/prompt/{image_name}.jpg", mime_type="image/jpeg"),
        Part.from_text(text=f"OUTPUT:\n{output}"),
        Part.from_text(text="</EXAMPLE>")
    ]

few_shots = {max(re.findall(r"\d+",image_name),key=len): known_image_prompt(image_name) for image_name in few_shot_images}

prompt_single_column = """
The following document contains information of several persons. First detect if the image contains a table like structure with multiple persons in it. If it does not, return "FALSE". 
Otherwise, fully transcribe the first name, last name and date of birth of every person in each row of each major column. Transcribe each major column from top to bottom before transcribing the next major column.
Output the transcription in csv format: The last names may be in alphabetic order. Use latin characters only for first names and last names. 
first name; last name; date of birth
Provide the date of birth in dd.mm.yyyy format. Dates are in the range from 1846 to 1945. Leave the date empty if it is not present. Fill missing day in dd.mm.yyyy with 00. Fill missing Month in dd.mm.yyyy with 00. Ignore degrees like dr.,  med., ing. Return only the csv, or False based on the stated condition.

INPUT:
""".replace("\n","")

prompt_multi_column = """
The following document contains information of several persons. First detect if the image contains a table like structure with multiple persons in it. If it does not, return "FALSE". 
Otherwise, first subdivide the image into major columns in a way that each row in a major column contains information of only one person. 
Then fully transcribe the first name, last name and date of birth of every person in each row of each major column. Transcribe each major column from top to bottom before transcribing the next major column.
Output the transcription in csv format: The last names may be in alphabetic order. Use latin characters only for first names and last names. 
first name; last name; date of birth
Provide the date of birth in dd.mm.yyyy format. Dates are in the range from 1846 to 1945. Leave the date empty if it is not present. Fill missing day in dd.mm.yyyy with 00. Fill missing Month in dd.mm.yyyy with 00. Ignore degrees like dr.,  med., ing. Return only the csv, or False based on the stated condition.

INPUT:
""".replace("\n","")

prompt_for_class = {
    '135465157': prompt_single_column, 
    '138196792': prompt_multi_column, 
    '4954189': prompt_single_column, 
    '4956212': prompt_multi_column, 
    '4957390': prompt_single_column, 
    'other': prompt_single_column,
    '4954213': prompt_single_column,
}


def build_prompt(blob_name):
    img_class = clustering[blob_name.split("/")[-1]]
    pre_prompt_parts = [
        Part.from_text(text=pre_prompt),
    ] 
    one_shot_parts = few_shots[img_class] if img_class != "other" else [] 
    prompt_parts = [
        Part.from_text(text=prompt_for_class[img_class]),
        Part.from_uri(
            uri=f"gs://{bucket_name}/{blob_name}",
            mime_type = "image/jpeg",
        ),
        Part.from_text(text="OUTPUT:\n")
    ]
    return pre_prompt_parts + one_shot_parts + prompt_parts 



generation_configs = [
    GenerationConfig(temperature=0, max_output_tokens=15000),
    GenerationConfig(temperature=0.2, max_output_tokens=15000),
    GenerationConfig(temperature=0.4, max_output_tokens=15000),
]

default_generation_config = GenerationConfig(temperature=0, max_output_tokens=15000)

default_big_generation_config = GenerationConfig(temperature=0, max_output_tokens=15000)

# Parse results

In [13]:
from aroa_etl.attribute_processing.string_utils import fix_name_uppercasing, fix_visual_character_decoding

In [14]:
defaut_header = "first name; last name; date of birth"
def parse_and_validate_response(response):
    "tries parse a response. If it cannot be parsed with this function, either retry the generation or try to repair the response."
    response_text = response.candidates[0].content.parts[0].text
    response_text = response_text[:-1] if response_text[-1] == "\n" else response_text
        
    assert re.match(r"^[a-zA-Z0-9\n\s\;\,\.]+$",response_text)
    response_text = response_text.split("\n")
    response_data = {
        "response": response
    }
    if response_text[0].lower() != "false":
        if "first" not in response_text[0]:
            response_text = [defaut_header, *response_text]
        response_df = pd.DataFrame([line.split(";") for line in response_text[1:]],columns=response_text[0].split(";"))
        assert response_df.shape[1] == 3, f"response data is malformed shape: {response_df.shape}"
        response_data["data"] = response_df
        response_data["is_table"] = True
    else:
        response_data["is_table"] = False   
    return response_data

def fix_line_field_num(cells):
    cells = [cell.strip() for cell in cells]
    
    if len(cells) > 3:
        cells = [cell for cell in cells if cell != ""]
    if len(cells) > 3:
        cells = [f"{cells[0]} {cells[1]}".strip(), *cells[2:]]
    if len(cells) > 3:
        cells = [f"{cells[0]} {cells[1]}".strip(), *cells[2:]]
    if len(cells) < 3:
        cells = [*cells, *["" for i in range(0,3-len(cells))]]
    date_idx = [idx for idx, cell in enumerate(cells) if len(cell)>=1 and len(re.findall(r"\d",cell))/len(cell)>0.3]
    if len(date_idx)>0 and date_idx[0] ==1:
        cells = [""] + cells[:2]
    elif len(date_idx)>0 and date_idx[0] ==0:
        cells = ["", ""] + cells[:1]
    return cells

def convert_date_field(date):
    date = "" if pd.isna(date) else date
    date_components = date.split(".")
    if len(date_components) != 3 or len(re.findall(r"[^\d\.]",date))>0:
        return "00000000"
    dd,mm,yyyy = date_components
    if len(yyyy) == 2:
        if int(yyyy) <= 45:
            yyyy = f"19{yyyy}"
        else:
            yyyy = f"18{yyyy}"
    elif len(yyyy) != 4 or int(yyyy) < 1845 or int(yyyy) > 1945:
        yyyy = "0000"
    if len(dd) == 1:
        dd = f"0{dd}"
    if len(mm) == 1:
        mm = f"0{mm}"
    return f"{yyyy}{mm}{dd}"
    
def remove_wrong_inserted_whitespace(name):
    tokens = list(re.finditer(r"[a-zA-Zäöüß]+\s", name))
    remove_at = [w1.span()[1] 
                 for w1,w2 in zip(tokens, tokens[1:]) 
                 if len(w1.group())<= 3 and len(w2.group())<=3]
    name = "".join([c for idx, c in enumerate(name) if idx not in remove_at])
    return name

def postprocess_year(y):
    if len(y) == 3:
        y = int(y)*10
        if 1855< y and y< 1954:
            return str(y)
        else:
            return "0000"
    elif len(y) == 4:
        y = int(y)
        if 1855< y and y< 1954:
            return str(y)
        else:
            if 1855< y + 800 and y+800< 1954:
                return str(y+800)
            elif 1855< y + 900 and y+900< 1954:
                return str(y+900)
        return "0000"
    return "0000"
def postprocess_day(d):
    if pd.notna(d) and len(d) in [1,2] and int(d) <=31:
        return d
    else:
        return "00"
def postprocess_month(m):
    if pd.notna(m) and len(m) in [1,2] and int(m) <=12:
        return m
    else:
        return "00"
        
def repair_date(date):
    date = str(date)
    if re.search(r"[a-zA-Z]",date):
        return "00000000"
    dates = date.split(".")
    if len(dates) != 3:
        return "0000000"
    dates[0] = postprocess_day(dates[0])
    dates[1] = postprocess_month(dates[1])
    dates[2] = postprocess_year(dates[2])
    return ".".join(dates)

def remove_numbers_from_string(string):
    return re.sub(r"\d","",string)
    
def parse_repair_and_validate_response(response):
    "tries to repair a response that cannot be parsed directly"
    if not isinstance(response, str):
        response_text = response.candidates[0].content.parts[0].text
    else: 
        response_text = response
    response_text = response_text[:-1] if response_text[-1] == "\n" else response_text
    response_text = response_text.split("\n")
    if "first" not in response_text[0]:
        response_text = [defaut_header, *response_text]
    if "first" in response_text[1]:
        response_text = response_text[1:]
        
    columns=response_text[0].split(";")

    data = [line.split(";") for line in response_text[1:]]
    # fix to many fields in a row
    data = [ fix_line_field_num(line) for line in data]
   
    data = pd.DataFrame(data,columns=response_text[0].split(";"))

    data.iloc[:,0] = data.iloc[:,0].apply(lambda name: fix_name_uppercasing(fix_visual_character_decoding(name)))
    data.iloc[:,0] = data.iloc[:,0].apply(lambda name: fix_name_uppercasing(remove_numbers_from_string(name)))
    data.iloc[:,1] = data.iloc[:,1].apply(lambda name: fix_name_uppercasing(fix_visual_character_decoding(name)))
    data.iloc[:,1] = data.iloc[:,1].apply(lambda name: fix_name_uppercasing(remove_numbers_from_string(name)))
    
    # remove noice 
    data.iloc[:,0] = data.iloc[:,0].apply(lambda name: "" if len(re.findall("[a-zA-Z]",name))==0 else name)
    data.iloc[:,1] = data.iloc[:,1].apply(lambda name: "" if len(re.findall("[a-zA-Z]",name))==0 else name)
    
    # multiple 2 sillable names
    data.iloc[:,0] = data.iloc[:,0].apply(lambda name: fix_name_uppercasing(remove_wrong_inserted_whitespace(name))) 
    # multiple first names and no last name
    first_names = data.iloc[:,0].apply(lambda name: re.findall(r"[a-zA-Zäöß]+",name))
    cond = (first_names.apply(len) >= 2) & (data.iloc[:,1].apply(len) == 0)
    data.iloc[cond,0] = first_names.loc[cond].apply(lambda names: " ".join(names[:-1]))
    data.iloc[cond,1] = first_names.loc[cond].apply(lambda names: names[-1])

    # restore alphabetic order if it is an ordered document
    ## checks if a lastname is lectically smaller than the following last name
    is_alphabetic = pd.Series([n1 <= n2 for n1,n2 in zip(list(data.iloc[:,1].values),list(data.iloc[1:,1].values))] + [True])
    ## is an alphabetic ordered document
    if is_alphabetic.sum() > 0.75 * is_alphabetic.shape[0] and (data.iloc[:,1] != "").sum() == data.shape[0]:
        # correct first character of last names that contradict athe alphabetic order
        next_first_char = list(data.iloc[:,1].apply(lambda name: name[0]).values)[1:]+[data.iat[-1,1][0]]
        data.iloc[:,1] = [
            name if is_alphabetic[idx] else
            f"{next_first_char[idx]}{name[1:]}"
            for idx, name in enumerate(data.iloc[:,1])
        ] 

    # some documents have a date that is mistaken to be the birth date
    most_frequent_date, occurrences = list(data.iloc[:,2].value_counts().items())[0]
    if occurrences > 5:
        data.iloc[:,2].replace(most_frequent_date, "00.00.0000")
    # convert date from dd.mm.yyyy to yyyymmdd
    data.iloc[:,2] = data.iloc[:,2].apply(lambda date: convert_date_field(date))
    data.iloc[:,2] = data.iloc[:,2].apply(lambda date: repair_date(date))

    # remove empty rows
    empty = (data.iloc[:,0] == "") & (data.iloc[:,1] == "") & (data.iloc[:,2] == "00000000")
    data = data.loc[~empty,:]
    return data  

# async prompting

In [15]:
@retry(stop=stop_after_attempt(len(generation_configs)+1))
async def async_generate(job, generation_configs):
    generation_config = None
    parse = True
    try:
        generation_config = next(generation_configs)
    except StopIteration:
        generation_config = default_generation_config
        parse = False
    response = await model.generate_content_async(
        job["content"],
        generation_config=generation_config,
        #safety_settings=safety_settings,
        stream=False,
    )
    job["response"] = response
    # do not parse after last retries
    if not parse:
        job["parsing_failure"] = True
        job["generate_failure"] = False
        return job
    response_data = parse_and_validate_response(response)
    job = {**job, **response_data}
    job["parsing_failure"] = False
    job["generate_failure"] = False
    return job

async def async_generate_attempt(job, generation_config):
    response = await model.generate_content_async(
        job["content"],
        generation_config=generation_config,
        #safety_settings=safety_settings,
        stream=False,
    )
    job["response"] = response
    return job

def postprocess_response(job):
    try:
        response_data = parse_repair_and_validate_response(response)
        job = {**job, **response_data}
        job["parsing_failure"] = False
        job["generate_failure"] = False
    except:
        job["parsing_failure"] = True
    return job
    
# careful. spawns all tasks up front. Create batches of function calls
async def async_generate_batch(jobs):
    # Create individual tasks for each prompt
    get_responses = [
        async_generate(j, iter(generation_configs)) 
        for j in jobs
    ]
    # Run all tasks concurrently
    responses = await asyncio.gather(*get_responses,return_exceptions=True)
    responses = [responses[i] if not isinstance(responses[i],tenacity.RetryError) else {**jobs[i], "generate_failure": True}
                     for i in range(0,len(jobs))]
    return responses

async def async_generate_batch_no_rerun(jobs):
    # Create individual tasks for each prompt
    get_responses = [
        async_generate_attempt(job, default_big_generation_config)
        for j in jobs
    ]
    # Run all tasks concurrently
    responses = await asyncio.gather(*get_responses,return_exceptions=True)
    responses = [responses[i] if not isinstance(responses[i],tenacity.RetryError) else {**jobs[i], "generate_failure": True}
                     for i in range(0,len(jobs))]
    return responses
#responses = await async_generate_batch([{"image":images[10],"content":build_prompt(images[10])}])

In [None]:
# cost estimation
# 15k images
# input tokens per image 1200 Eingabe * 0.3 * 1/1M = 5.40
# 1200 Ausgabe * 2.5 * 1/M = 45 Euro

In [32]:
responses = []
offset = 0
batch_size = 100
num_batches = math.ceil((len(images)-offset)/batch_size)
dump_size = 30
print(f"Start transcription of {len(images[offset:])} in {num_batches} batches")
pbar = tqdm(total=len(images)-offset)
for batch_idx in range(num_batches):
    image_idx = lambda idx_in_batch: batch_size*batch_idx + offset + idx_in_batch
    jobs = [
        {
            "image": images[image_idx(idx_in_batch)],
            "content": build_prompt(images[image_idx(idx_in_batch)])
        }
        for idx_in_batch in range(batch_size) if image_idx(idx_in_batch) < len(images)
    ]
    batch_responses = await async_generate_batch(jobs)
    responses = responses + batch_responses
    pbar.update(batch_size) 
    # dump every dump_size responses
    if batch_idx % dump_size == 0 and batch_idx !=0:
        responses_start_idx = batch_size*(batch_idx-dump_size) + offset+batch_size
        responses_end_idx = batch_size*batch_idx + offset+batch_size
        with open(f"transcriptions/transcriptions_{responses_start_idx}_{responses_end_idx}.pkl", "wb") as file:
            pickle.dump(responses, file)
        responses = []
        print(f"Dumped {responses_start_idx} {responses_end_idx}")
pbar.close()
# dump remaining responses
last_batch_start_idx = int(num_batches/dump_size)*dump_size
last_batch_end_idx = num_batches
responses_start_idx = batch_size*last_batch_start_idx + offset+batch_size
responses_end_idx = batch_size*last_batch_end_idx + offset+batch_size
with open(f"transcriptions/transcriptions_{responses_start_idx}_{responses_end_idx}.pkl", "wb") as file:
    pickle.dump(responses, file)

Start transcription of 17095 in 171 batches



  0%|                                                                                                                    | 0/17095 [20:49<?, ?it/s]

[A%|▌                                                                                                       | 100/17095 [03:07<8:51:06,  1.88s/it]
[A%|█▏                                                                                                      | 200/17095 [06:08<8:36:41,  1.83s/it]
[A%|█▊                                                                                                      | 300/17095 [09:11<8:33:59,  1.84s/it]
[A%|██▍                                                                                                     | 400/17095 [12:48<9:07:25,  1.97s/it]
[A%|███                                                                                                    | 500/17095 [18:06<11:05:11,  2.41s/it]
[A%|███▌                                                                                                   | 

Dumped 100 3100



[A%|██████████████████▉                                                                                  | 3200/17095 [1:39:17<5:37:37,  1.46s/it]
[A%|███████████████████▍                                                                                 | 3300/17095 [1:42:22<6:01:46,  1.57s/it]
[A%|████████████████████                                                                                 | 3400/17095 [1:45:24<6:16:12,  1.65s/it]
[A%|████████████████████▋                                                                                | 3500/17095 [1:48:25<6:24:29,  1.70s/it]
[A%|█████████████████████▎                                                                               | 3600/17095 [1:50:29<5:51:13,  1.56s/it]
[A%|█████████████████████▊                                                                               | 3700/17095 [1:53:21<5:58:42,  1.61s/it]
[A%|██████████████████████▍                                                                              | 380

Dumped 3100 6100



[A%|████████████████████████████████████▋                                                                | 6200/17095 [3:07:17<5:13:34,  1.73s/it]
[A%|█████████████████████████████████████▏                                                               | 6300/17095 [3:11:02<5:38:56,  1.88s/it]
[A%|█████████████████████████████████████▊                                                               | 6400/17095 [3:13:47<5:23:34,  1.82s/it]
[A%|██████████████████████████████████████▍                                                              | 6500/17095 [3:16:35<5:13:03,  1.77s/it]
[A%|██████████████████████████████████████▉                                                              | 6600/17095 [3:19:27<5:07:20,  1.76s/it]
[A%|███████████████████████████████████████▌                                                             | 6700/17095 [3:23:12<5:30:00,  1.90s/it]
[A%|████████████████████████████████████████▏                                                            | 680

Dumped 6100 9100



[A%|██████████████████████████████████████████████████████▎                                              | 9200/17095 [4:15:20<3:22:47,  1.54s/it]
[A%|██████████████████████████████████████████████████████▉                                              | 9300/17095 [4:18:03<3:23:59,  1.57s/it]
[A%|███████████████████████████████████████████████████████▌                                             | 9400/17095 [4:20:56<3:27:10,  1.62s/it]
[A%|████████████████████████████████████████████████████████▏                                            | 9500/17095 [4:24:02<3:33:54,  1.69s/it]
[A%|████████████████████████████████████████████████████████▋                                            | 9600/17095 [4:26:44<3:28:23,  1.67s/it]
[A%|█████████████████████████████████████████████████████████▎                                           | 9700/17095 [4:29:15<3:19:55,  1.62s/it]
[A%|█████████████████████████████████████████████████████████▉                                           | 980

Dumped 9100 12100



[A%|███████████████████████████████████████████████████████████████████████▎                            | 12200/17095 [5:42:02<2:17:27,  1.68s/it]
[A%|███████████████████████████████████████████████████████████████████████▉                            | 12300/17095 [5:44:41<2:12:17,  1.66s/it]
[A%|████████████████████████████████████████████████████████████████████████▌                           | 12400/17095 [5:47:27<2:09:39,  1.66s/it]
[A%|█████████████████████████████████████████████████████████████████████████                           | 12500/17095 [5:50:21<2:08:45,  1.68s/it]
[A%|█████████████████████████████████████████████████████████████████████████▋                          | 12600/17095 [5:58:37<3:19:45,  2.67s/it]
[A%|██████████████████████████████████████████████████████████████████████████▎                         | 12700/17095 [6:01:12<2:50:39,  2.33s/it]
[A%|██████████████████████████████████████████████████████████████████████████▉                         | 1280

Dumped 12100 15100



[A%|██████████████████████████████████████████████████████████████████████████████████████████▋           | 15200/17095 [7:12:17<49:25,  1.56s/it]
[A%|███████████████████████████████████████████████████████████████████████████████████████████▎          | 15300/17095 [7:14:55<46:57,  1.57s/it]
[A%|███████████████████████████████████████████████████████████████████████████████████████████▉          | 15400/17095 [7:17:30<44:11,  1.56s/it]
[A%|████████████████████████████████████████████████████████████████████████████████████████████▍         | 15500/17095 [7:20:07<41:34,  1.56s/it]
[A%|█████████████████████████████████████████████████████████████████████████████████████████████         | 15600/17095 [7:22:49<39:26,  1.58s/it]
[A%|█████████████████████████████████████████████████████████████████████████████████████████████▋        | 15700/17095 [7:25:29<36:52,  1.59s/it]
[A%|██████████████████████████████████████████████████████████████████████████████████████████████▎       | 15

In [34]:
# load responses
responses = []
for fname in os.listdir("transcriptions/"):
    if not ".pkl" in fname:
        continue
    with open(f"transcriptions/{fname}","rb") as f:
        responses += pickle.load(f)

In [34]:
# check for parsing errors
parsing_failures = []
generation_failures = []
for idx, job in enumerate(responses):
    if "generate_failure" in job and job["generate_failure"]:
        generation_failures.append((idx, job))
    if "parsing_failure" in job and job["parsing_failure"]:
        parsing_failures.append((idx, job))
print(f"{len(generation_failures)} generation failures and {len(parsing_failures)} parsing failures")

0 generation failures and 0 parsing failures


# rerun for not correctly parsed responses

In [21]:
# rerun failed
did_not_regenerate = []
for idx, job in tqdm(enumerate(responses),total=len(responses)):
    if ("generate_failure" in job and job["generate_failure"]):# or ("parsing_failure" in job and job["parsing_failure"]):
        try:
            job = await async_generate(job, iter(generation_configs))
        except:
            did_not_regenerate.append(job)
        responses[idx] = job

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 17095/17095 [28:23<00:00, 10.03it/s]


In [18]:
# fix generated responses with wrong flag
for idx, job in tqdm(enumerate(responses),total=len(responses)):
    if ("generate_failure" in job and job["generate_failure"]):
        try:
            response_data = parse_repair_and_validate_response(job["response"])
            job["data"] = response_data
            if job["data"].shape[0]> 0:
                job["parsing_failure"] = False
                job["generate_failure"] = False
                responses[idx] = job
        except:  
            responses[idx] = job

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 17095/17095 [00:00<00:00, 440692.97it/s]


In [None]:
# rerun max token reached
MAX_TOKENS_REACHED = 2
RECITATION = 4
did_not_regenerate = []
for idx, job in tqdm(enumerate(responses),total=len(responses)):
    finish_reason = job["response"].candidates[0].finish_reason
    if "response" in job and finish_reason==RECITATION:
        try:
            job = await async_generate_attempt(job, default_big_generation_config)
            job = postprocess_response(job)
            responses[idx] = job
        except:
            did_not_regenerate.append(job)
print(f"failed {len(did_not_regenerate)}")        

  0%|                                                   | 0/17095 [00:00<?, ?it/s]

In [29]:
# rerun max token reached batched
MAX_TOKENS_REACHED = 2
did_not_generate = []
for idx, job in tqdm(enumerate(responses),total=len(responses)):
    if "response" in job and job["response"].candidates[0].finish_reason == MAX_TOKENS_REACHED:
        did_not_generate.append((idx, job))
BATCH_SIZE = 100
num_batches = math.ceil(len(did_not_generate)/BATCH_SIZE)
pbar = tqdm(total=len(did_not_generate))
print(f"start regenerating {len(did_not_generate)} documents in {num_batches} batches")
for batch_idx in range(num_batches):
    image_idx = lambda idx_in_batch: BATCH_SIZE*batch_idx + idx_in_batch
    jobs = [
        did_not_generate[image_idx(idx_in_batch)][1]
        for idx_in_batch in range(BATCH_SIZE) if image_idx(idx_in_batch) < len(did_not_generate)
    ]
    response_indices = [
        did_not_generate[image_idx(idx_in_batch)][0]
        for idx_in_batch in range(BATCH_SIZE) if image_idx(idx_in_batch) < len(did_not_generate)
    ]
    batch_responses = await async_generate_batch_no_rerun(jobs)
    for idx, job in zip(response_indices, batch_responses):
        responses[idx] = job
    pbar.update(BATCH_SIZE) 


[A%|                                                   | 0/17095 [00:00<?, ?it/s]
100%|███████████████████████████████████| 17095/17095 [00:00<00:00, 139700.08it/s]

  0%|                                                    | 0/2084 [01:22<?, ?it/s]


start regenerating 2084 documents in 21 batches



[A%|██                                        | 100/2084 [00:30<10:00,  3.30it/s]
[A%|████                                      | 200/2084 [01:00<09:32,  3.29it/s]
[A%|██████                                    | 300/2084 [01:29<08:46,  3.39it/s]
[A%|████████                                  | 400/2084 [01:58<08:13,  3.41it/s]
[A%|██████████                                | 500/2084 [02:27<07:46,  3.39it/s]
[A%|████████████                              | 600/2084 [02:58<07:23,  3.35it/s]
[A%|██████████████                            | 700/2084 [03:30<07:02,  3.28it/s]
[A%|████████████████                          | 800/2084 [04:01<06:34,  3.25it/s]
[A%|██████████████████▏                       | 900/2084 [04:32<06:03,  3.26it/s]
[A%|███████████████████▋                     | 1000/2084 [05:02<05:32,  3.26it/s]
[A%|█████████████████████▋                   | 1100/2084 [05:32<05:00,  3.28it/s]
[A%|███████████████████████▌                 | 1200/2084 [06:03<04:30,  3.27it/s]
[A

In [17]:
# try the parsing with repair
i = []
for idx, job in enumerate(responses):
    if "generate_failure" not in job or job["generate_failure"] == False:
        try:
            response_data = parse_repair_and_validate_response(job["response"])
            job["data"] = response_data
            job["parsing_failure"] = False
            responses[idx] = job
        except:
            i.append(job)     
            responses[idx] = job

# save results

In [38]:
with open(f"transcriptions/transcriptions_tmp.pkl", "wb") as file:
    pickle.dump(responses,file)

In [27]:
with open(f"transcriptions/transcriptions_tmp.pkl", "rb") as file:
    responses = pickle.load(file)

In [None]:
prompt_token_count = 0
output_token_count = 0
#cost_cutoff = lambda cost: math.ceil(cost*100)/100
for response in [response]:
    usage_metadata = response.usage_metadata
    prompt_token_count += usage_metadata.prompt_token_count
    output_token_count += usage_metadata.candidates_token_count
print(f"""
Spent {prompt_token_count*0.075*pm + output_token_count*0.4*pm} €
""")