# PDF Training Demo

### Dependencies
- Python 3.11
    - https://www.python.org/downloads/
- Poppler for pdf2image
    - https://pypi.org/project/pdf2image/

### Before Running
After installing above dependencies run
```
python.exe -m pip install --upgrade pip
pip install requirements.txt
```

Then copy `.env.example` and replace placeholders with your correct Google API Key, Google Application Credentials, Google Project Name, and Huggingface Token, and save as `.env`.

In [None]:
# Takes all pdfs in the directory tree, converts them to images using pdf2image
# then uses Google Vertex AI to parse the pdf images for text
# builds the text into a structured dataset for LoRA training
# then trains a LoRA for unsloth/Mistral-Small-Instruct-2409-bnb-4bit

# pip install pdf2image
# Poppler: https://github.com/oschwartz10612/poppler-windows/releases/
# F:\repos\HobieLLM\PDF Library\Beekeeping\10_things_to_help_bees_UK.pdf

# import module
from pdf2image import convert_from_path
import os

input_dir_str = input()

def get_pdf_paths(input_dir):
    for dirpath, _, filenames in os.walk(input_dir):
        for filename in filenames:
            if filename.lower().endswith(".pdf"):
                yield os.path.join(dirpath, filename)


pdf_paths = list(get_pdf_paths(input_dir_str))

# Store Pdf with convert_from_path function
def convert_and_save_pdf_images(pdf_path): 
    images = convert_from_path(pdf_path)

    pdf_name = os.path.splitext(os.path.split(pdf_path)[1])[0]
    
    for i in range(len(images)):
        # Save pages as images in the pdf
        images[i].save('pdf_images/' + pdf_name + '-page'+ str(i) +'.jpg', 'JPEG')

for pdf_path in pdf_paths:
    print(pdf_path)
    convert_and_save_pdf_images(pdf_path)

In [None]:
# Query Google Vertex AI to parse image for text
# Make text into a structured json dataset for LoRA training

from dotenv import load_dotenv
import json
import os
import io
from PIL import Image
import time
import base64
import vertexai
from vertexai.generative_models import HarmBlockThreshold, HarmCategory, GenerativeModel, Part, FinishReason, SafetySetting
import vertexai.preview.generative_models as generative_models
from google.oauth2 import service_account

load_dotenv()
vertexai.init(
    project=os.environ["GOOGLE_PROJECT"], 
    credentials=service_account.Credentials.from_service_account_file(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
)
model = GenerativeModel(
    "gemini-1.5-flash-001",
)
safety_config = [
    SafetySetting(
        category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
        threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH,
    ),
    SafetySetting(
        category=HarmCategory.HARM_CATEGORY_HARASSMENT,
        threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH,
    ),
    SafetySetting(
        category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
        threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH,
    ),
    SafetySetting(
        category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
        threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH,
    ),
]
generation_config = {
    "max_output_tokens": 8192,
    "temperature": 1,
    "top_p": 0.95,
}

prompt_text = ("Return all the text in this image. "
               "Remove any linebreaks that occur in the middle of a sentence. "
               "Additionally, return any text or data presented in tables or charts in the form of "
               "comma-separated values starting with the text 'Table {unique_table_title}:'. "
               "Additionally, describe any non-text pictures or diagrams starting with the text "
               "'Diagram Description:'.")

pdf_text_data = []

def get_image_text(image_path, count=0):
    pil_image = Image.open(image_path)
    
    img_byte_arr = io.BytesIO()
    pil_image.save(img_byte_arr, format='jpeg')
    
    blob = Part.from_data(
        mime_type='image/jpeg',
        data=img_byte_arr.getvalue()
    )

    responses = model.generate_content(
        [prompt_text, blob],
        generation_config=generation_config,
        safety_settings=safety_config,
        stream=True,
    )

    response_text = []

    try: 
        for response in responses:
            # print(response.text, end="")
            response_text.append(response.text)
            
        pdf_text_data.append({
            "image_path": image_path,
            "text": "".join(response_text)
        })
    except:
        print("Couldn't get text from response for " + image_path)
        time.sleep(5)
        if count < 1:
            get_image_text(image_path, count+1)

def get_image_paths(input_dir):
    for dirpath, _, filenames in os.walk(input_dir):
        for filename in filenames:
            if filename.lower().endswith(".jpg"):
                yield os.path.join(dirpath, filename)

for image_path in get_image_paths("./pdf_images"):
    time.sleep(2) # wait 10s to not go over api quota
    print(image_path)
    get_image_text(image_path)
    with open('./data/data.json', 'w') as f:
        json.dump(pdf_text_data, f)

In [None]:
# Train LoRA for use with unsloth/Mistral-Small-Instruct-2409-bnb-4bit on the pdf text dataset
# Using unsloth FastLanguageModel based on the transformers package

from unsloth import FastLanguageModel 
from unsloth import is_bfloat16_supported
import torch
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset

from huggingface_hub import login

token = os.environ["HUGGINGFACE_TOKEN"]

login(
  token=token,
)

max_seq_length = 2409

data_file_name = input()

dataset = load_dataset("json", data_files=data_file_name, split="train")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Mistral-Small-Instruct-2409-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = None,
    load_in_4bit = True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 256,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 512,
    lora_dropout = 0, # 0 is optimized
    bias = "none",    # "none" is optimized
    use_gradient_checkpointing = "unsloth", # "unsloth" for very long context
    random_state = 2222,
    max_seq_length = max_seq_length,
    use_rslora = False,
    loftq_config = None,
)

trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    dataset_text_field = "output",
    max_seq_length = max_seq_length,
    tokenizer = tokenizer,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 16,
        gradient_accumulation_steps = 4,
        warmup_steps = 100,
        num_train_epochs=3,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 100,
        save_steps = 1000,
        learning_rate=2e-4,
        output_dir = "outputs",
        optim = "adamw_8bit",
        seed = 2222,
        report_to="tensorboard",
    ),
)
trainer.train(resume_from_checkpoint=False)

 
# save model
trainer.save_model()