## Libraries

In [None]:
!pip install pymupdf PyMuPDF pdfplumber transformers torch nltk indic_transliteration pillow
# import nltk
# nltk.download('all')

## Main

In [None]:
import fitz  # PyMuPDF
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import re
import time

# Configuration
MODEL_NAME = "facebook/nllb-200-3.3B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LANGUAGES = {
    "Hindi": {"code": "hin_Deva", "iso": "hi"},
    "Tamil": {"code": "tam_Taml", "iso": "ta"},
    "Telugu": {"code": "tel_Telu", "iso": "te"}
}

# Initialize NLLB model
print("🔄 Initializing translation model...")
start = time.time()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, src_lang="eng_Latn")
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE).eval()
print(f"✅ Model loaded in {time.time()-start:.2f}s")

def parse_user_entities(user_input):
    entities = [e.strip() for e in user_input.split(',') if e.strip()]
    print(f"📌 Entities to preserve: {', '.join(entities) if entities else 'None'}")
    return sorted(set(entities), key=len, reverse=True)

def parse_user_languages(user_input):
    selected = [lang.strip().capitalize() for lang in user_input.split(',')]
    valid = [lang for lang in selected if lang in LANGUAGES]
    if not valid:
        print("⚠️ No valid languages selected. Using all available.")
        return list(LANGUAGES.keys())
    print(f"🌍 Selected languages: {', '.join(valid)}")
    return valid

def replace_with_placeholders(text, entities):
    placeholder_map = {}
    modified_text = text
    for idx, entity in enumerate(entities):
        pattern = re.compile(r'\b' + re.escape(entity) + r'\b', re.IGNORECASE)
        def replacer(match):
            original = match.group()
            placeholder = f"__ENT{len(placeholder_map):03d}__"
            placeholder_map[placeholder] = original
            return placeholder
        modified_text, count = pattern.subn(replacer, modified_text)
        if count > 0:
            print(f"🔧 Replaced '{entity}' {count} time(s) in text: '{text}'")
    print(f"🔍 Modified text with placeholders: '{modified_text}'")
    return modified_text, placeholder_map

def translate_text(text, target_lang_code):
    if not text.strip():
        return text
    inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=256).to(DEVICE)
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_lang_code),
            max_length=256,
            num_beams=5,
            early_stopping=True
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def translate_with_entities(text, entities, target_lang):
    target_lang_code = LANGUAGES[target_lang]["code"]
    modified_text, placeholder_map = replace_with_placeholders(text, entities)
    translated_text = translate_text(modified_text, target_lang_code)
    print(f"✅ Translated: '{modified_text}' -> '{translated_text}'")
    for placeholder, entity in placeholder_map.items():
        if placeholder in translated_text:
            translated_text = translated_text.replace(placeholder, entity)
            print(f"🔄 Restored '{placeholder}' to '{entity}' in text: '{translated_text}'")
        else:
            print(f"⚠️ Placeholder '{placeholder}' not found in: '{translated_text}'")
    return translated_text

def join_spans(spans):
    if not spans:
        return ""
    spans = sorted(spans, key=lambda s: s["bbox"][0])
    text_parts = [spans[0]["text"]]
    for i in range(1, len(spans)):
        span1 = spans[i - 1]
        span2 = spans[i]
        d = span2["bbox"][0] - span1["bbox"][2]
        if len(span1["text"]) > 0 and len(span2["text"]) > 0:
            width1 = span1["bbox"][2] - span1["bbox"][0]
            width2 = span2["bbox"][2] - span2["bbox"][0]
            avg_char_width1 = width1 / len(span1["text"])
            avg_char_width2 = width2 / len(span2["text"])
            min_avg_char_width = min(avg_char_width1, avg_char_width2)
            if d < 0.5 * min_avg_char_width or d < 0:
                text_parts.append(span2["text"])
            else:
                text_parts.append(" " + span2["text"])
        else:
            text_parts.append(span2["text"])
    return "".join(text_parts)

def extract_pdf_components(pdf_path):
    print(f"\n📄 Extracting components from {pdf_path}...")
    doc = fitz.open(pdf_path)
    components = []
    for page_num, page in enumerate(doc):
        print(f"\n📖 Processing page {page_num+1}")
        blocks = page.get_text("dict")["blocks"]
        text_blocks = []
        for b in blocks:
            if b["type"] == 0:  # Text block
                text_lines = []
                for line in b["lines"]:
                    if line["spans"]:
                        text = join_spans(line["spans"])
                        y_pos = line["spans"][0]["origin"][1]
                        font_size = line["spans"][0]["size"]
                        bbox = fitz.Rect(b["bbox"])
                        text_lines.append({
                            "text": text,
                            "y_pos": y_pos,
                            "font_size": font_size,
                            "line_bbox": line["bbox"]
                        })
                if text_lines:
                    text_blocks.append({
                        "bbox": b["bbox"],
                        "lines": text_lines
                    })
        components.append({
            "page_num": page_num,
            "text_blocks": text_blocks,
            "size": (page.rect.width, page.rect.height)
        })
    doc.close()
    return components

def split_block_into_subblocks(block):
    lines = block["lines"]
    if not lines:
        return []

    subblocks = []
    current_subblock = {"text": "", "lines": []}

    for i, line in enumerate(lines):
        text = line["text"].strip()
        current_subblock["text"] += " " + text if current_subblock["text"] else text
        current_subblock["lines"].append(line)

        # Split on significant vertical gaps or end of block
        if i < len(lines) - 1:
            next_line = lines[i + 1]
            gap = next_line["y_pos"] - line["y_pos"] - line["font_size"]
            if gap > line["font_size"] * 0.5:  # Adjust threshold as needed
                subblocks.append(current_subblock)
                current_subblock = {"text": "", "lines": []}
        else:
            subblocks.append(current_subblock)

    return subblocks

def redistribute_translated_text(translated_text, original_lines):
    if not original_lines or not translated_text.strip():
        return [""] * len(original_lines)

    translated_words = translated_text.split()
    translated_lines = []
    word_idx = 0

    default_font = fitz.Font("helv")

    for line in original_lines:
        max_width = line["line_bbox"][2] - line["line_bbox"][0]
        font_size = line["font_size"]
        current_line = []
        current_width = 0

        while word_idx < len(translated_words):
            word = translated_words[word_idx]
            word_width = default_font.text_length(word + " ", fontsize=font_size)
            if current_width + word_width <= max_width:
                current_line.append(word)
                current_width += word_width
                word_idx += 1
            else:
                break

        translated_lines.append(" ".join(current_line) if current_line else "")

    while len(translated_lines) < len(original_lines):
        translated_lines.append("")

    if word_idx < len(translated_words):
        remaining_text = " ".join(translated_words[word_idx:])
        if translated_lines[-1]:
            translated_lines[-1] += " " + remaining_text
        else:
            translated_lines[-1] = remaining_text

    return translated_lines

def translate_pdf_components(components, entities, target_lang):
    print(f"\n🚀 Starting {target_lang} translation")
    for page in components:
        for block in page["text_blocks"]:
            subblocks = split_block_into_subblocks(block)
            translated_subblocks = []

            # Translate each sub-block independently
            for subblock in subblocks:
                original_text = subblock["text"]
                translated_text = translate_with_entities(original_text, entities, target_lang)
                translated_subblocks.append(translated_text)

            # Combine translated sub-blocks with a space (for distribution)
            block["translated_text"] = " ".join(translated_subblocks)
            block["original_lines"] = block["lines"]
    return components

def rebuild_pdf(components, target_lang, output_path, original_pdf_path, use_white_background=True):
    print(f"\n🏗️ Rebuilding {target_lang} PDF...")
    doc = fitz.open(original_pdf_path)
    lang_iso = LANGUAGES[target_lang]["iso"]

    for page_data in components:
        page = doc[page_data["page_num"]]
        links = list(page.get_links())
        for block in page_data["text_blocks"]:
            original_bbox = fitz.Rect(block["bbox"])
            translated_text = block.get("translated_text", "")
            if not translated_text.strip():
                continue
            original_lines = block["original_lines"]
            translated_lines = redistribute_translated_text(translated_text, original_lines)

            if use_white_background:
                page.draw_rect(original_bbox, color=(1, 1, 1), fill=(1, 1, 1), fill_opacity=1.0)
            else:
                page.add_redact_annot(original_bbox)
                page.apply_redactions()

            for i, (original_line, translated_line) in enumerate(zip(original_lines, translated_lines)):
                line_rect = fitz.Rect(original_line["line_bbox"])
                font_size = original_line["font_size"]
                if translated_line.strip():
                    html = f"""
<div style="width: 100%; height: 100%; padding: 0; margin: 0;">
    <p lang="{lang_iso}" style="margin: 0; padding: 0;">{translated_line}</p>
</div>
"""
                    css = f"""
p {{
    font-size: {font_size}pt;
}}
"""
                    try:
                        page.insert_htmlbox(
                            line_rect,
                            html,
                            css=css,
                            scale_low=0,
                            rotate=0,
                            oc=0,
                            opacity=1,
                            overlay=True
                        )
                        print(f"✓ Inserted line {i+1} at {line_rect.top_left}: '{translated_line[:30]}...'")
                    except Exception as e:
                        print(f"⚠️ Error inserting line {i+1} at {line_rect.top_left}: {e}")

        for link in links:
            page.insert_link(link)
            print(f"🔗 Restored link to: {link.get('uri', 'unknown destination')}")

    print(f"💾 Saving to {output_path}")
    doc.save(output_path, garbage=4, deflate=True)

    print(f"\n🔍 Verifying text in {output_path}...")
    doc = fitz.open(output_path)
    for page_num in range(len(doc)):
        page = doc[page_num]
        text = page.get_text("text")
        print(f"Extracted text from page {page_num+1}:\n{text}\n")
    doc.close()

if __name__ == "__main__":
    pdf_path = "/content/PSY-225_FiveParagraphTheme_Examples.pdf"
    print("\n" + "="*40)
    print("📝 Enter entities to preserve (comma-separated, e.g., 'Name, Place etc'):")
    entities = parse_user_entities(input().strip())
    print("\n" + "="*40)
    print("🌐 Available languages:", ", ".join(LANGUAGES.keys()))
    print("📢 Enter target languages (comma-separated):")
    languages = parse_user_languages(input().strip())
    print("\n" + "="*40)
    print("🎨 Use white background for blocks? (yes/no):")
    use_white = input().strip().lower() in ('yes', 'y', 'true', 't', '1')

    components = extract_pdf_components(pdf_path)
    for lang in languages:
        start_time = time.time()
        print(f"\n🚀 Starting {lang} translation")
        translated = translate_pdf_components(components, entities, lang)
        output_path = f"/content/translated_{lang}.pdf"
        rebuild_pdf(translated, lang, output_path, pdf_path, use_white_background=use_white)
        print(f"\n✅ {lang} translation completed in {time.time()-start_time:.2f}s")