In [None]:
# installation
!pip install transformers accelerate datasets outlines

In [None]:
# libraries
import os
import re
from google.colab import files
from pathlib import Path
import json
from dataclasses import dataclass
from tqdm import tqdm
from transformers import AutoTokenizer
from outlines import models, generate

In [None]:
# utilities
def save_json(data, file_name):
    with open(file_name, 'w') as f:
        json.dump(data, f)

def load_json(file_name):
    with open(file_name, 'r') as f:
        return json.load(f)

def get_response(raw_output):
  return raw_output[0]['generated_text'][-1]['content']


In [None]:
# !!!!!!!! huggingface login required to use llama model (ADD TOKEN AFTER `--token`) !!!!!!!!
!huggingface-cli login --token

In [None]:
# settings
checkpoint = 'meta-llama/Llama-3.2-1B-Instruct'

In [None]:
# make directories for data and results
os.makedirs('ocr', exist_ok = True)
os.makedirs('info', exist_ok = True)

In [None]:
# upload JSON schema from `schema/schema.json`
files.upload()
schema = load_json('schema.json')
schema_string = json.dumps(schema, indent=2)

In [None]:
# Upload all JSON files from `ocr_results/`
os.chdir('ocr')
ocr_paths = files.upload()
ocr_extractions = [{'key': Path(el).stem,
                    'ocr': load_json(el)} for el in ocr_paths.keys()]
os.chdir('..')

In [None]:
# Template for information extraction
@dataclass
class Template:
  key: str
  ocr_extraction: str

  @classmethod
  def from_dict(cls, dict_):
    return cls(dict_['key'], dict_['ocr'])

  def ocr_to_string(self):
    return '\n'.join(self.ocr_extraction)

  def prompt(self):
    string = self.ocr_to_string()

    system = f'''The following is a document containing one or more emails.

Your task is to read the following emails and extract a JSON object that will capture the common structure of emails.

There may be multiple emails in the document—extract them all if so.

It is of utmost importance that you extract EXACTLY according to the JSON schema.

Now extract the JSON object.
'''

    user = string

    messages = [
      {"role": "sytem", "content": system},
      {"role": "user", "content": user},
    ]

    return messages

  def hf_to_outlines(self, messages):
    return tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)

  def generate(self, generator):
    messages = self.prompt()
    outlines_input = self.hf_to_outlines(messages)
    return generator(outlines_input)

In [None]:
# instantiate models
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = models.transformers(checkpoint)
generator = generate.json(model, schema_string)

In [None]:
# generate records and save
for el in tqdm(ocr_extractions):
  template = Template.from_dict(el)
  info = template.generate(generator)
  save_json(info, f'info/{el["key"]}.json')