In [None]:
# installation
!pip install transformers accelerate

In [None]:
# imports
from google.colab import files
from pathlib import Path
import json
from transformers import pipeline
from dataclasses import dataclass

In [None]:
# !!!!!! huggingface login required to use llama model (ADD TOKEN AFTER `--token`) !!!!!!
!huggingface-cli login --token

In [None]:
# settings
checkpoint = 'meta-llama/Llama-3.2-1B-Instruct'

In [None]:
# setup text generation pipeline
pipe = pipeline("text-generation", model = checkpoint)

In [None]:
# utilities
def save_json(data, file_name):
    with open(file_name, 'w') as f:
        json.dump(data, f)

def load_json(file_name):
    with open(file_name, 'r') as f:
        return json.load(f)

def get_response(raw_output):
  return raw_output[0]['generated_text'][-1]['content']

In [None]:
# !!!! LOAD ALL FILES IN `ocr_results/` directory` !!!!!
ocr_extractions = files.upload()
ocr_extractions = [load_json(el) for el in ocr_extractions.keys()]

In [None]:
# Template for schema inference
@dataclass
class Template:
  ocr_extractions: list

  def ocr_to_doc(self, ocr_extraction):
    return '\n'.join(ocr_extraction)

  def label_document(self, doc, num):
    return f'Document {num}:\n\n{doc}'

  def concat_docs(self, docs):
    docs = [self.label_document(doc, i+1) for i, doc in enumerate(docs)]
    return '\n---\n'.join(docs)

  def ocr_to_string(self):
    docs = [self.ocr_to_doc(el) for el in self.ocr_extractions]
    string = self.concat_docs(docs)
    return string

  def prompt(self):
    string = self.ocr_to_string()

    system = '''The following is a document containing one or more emails.

Your task is to read the following emails and infer a JSON schema that will capture the structure that is common to emails in general.

There may be more than one email in the text, so make sure the schema can handle an arbitrary number of emails. It is very important to refrain from repeating keys.

Remember, your task is to infer a schema, not to fill it out with details from an email.

Now extract the schema and output only the JSON object below with no additional filler text.

JSON:'''

    #system = "The following is a document containing one or more emails.\n\nWe want to extract structured information from emails in the form of a JSON object. Your task is read the following emails and infer a schema that will capture the structure that is common to emails in general.\n\nThere may be more than one email in the text, so make sure the schema can handle an arbitrary number of emails. The schema should heavily prioritize conciseness, avoiding repeat keys, but allowing for multiple entries.\n\nPlease format the schema in a JSON object."
    user = string

    messages = [
      {"role": "system", "content": system},
      {"role": "user", "content": user},
    ]

    return messages


  def generate(self, pipe):
    messages = self.prompt()
    return pipe(messages,
                max_new_tokens=3000,
                do_sample = False,
                repetition_penalty = 1.05)




In [None]:
# instantiates prompter with a few sample emails
template = Template(ocr_extractions)

In [None]:
# generates template
gen = template.generate(pipe)

In [None]:
# discards extraneous output
schema = get_response(gen)

In [None]:
# saves schema
save_json(schema, 'schema.json')