### Packages Installation

In [1]:
!uv pip install -q addict \
    transformers==4.46.3 \
    tokenizers==0.20.3 \
    PyMuPDF \
    img2pdf \
    einops \
    easydict \
    Pillow \
    numpy \
    groq \
    streamlit \
    rank-bm25 \
    nltk \
    flash-attn==2.7.3 --no-build-isolation

print(f"ALL REQUIRED PACKAGES INSTALLED!!!")

ALL REQUIRED PACKAGES INSTALLED!!!


### Run the assistant in notebook

In [None]:
import os
import json
import warnings
from datetime import datetime
from io import BytesIO
import tempfile  # For PDF processing

# Basic environment / warning control
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore", message=".*GetPrototype.*")

# Third-party libs
import torch
import gc
try:
    import flash_attn  # optional
except Exception:
    flash_attn = None

# Transformers / sentence-transformers
from transformers import AutoModel, AutoTokenizer, logging as hf_logging
hf_logging.set_verbosity_error()
from sentence_transformers import SentenceTransformer

# Aux libs used in CLI
import requests
from PIL import Image
import pandas as pd
import numpy as np
from tabulate import tabulate

# Optional Colab secrets and Groq client (only used if available)
try:
    from google.colab import userdata
    colab_secrets_available = True
except Exception:
    colab_secrets_available = False

try:
    from groq import Groq
except Exception:
    Groq = None

# PDF processing dependencies
try:
    import fitz  # PyMuPDF
    PDF_SUPPORT = True
except ImportError:
    PDF_SUPPORT = False
    print("‚ö†Ô∏è PyMuPDF not installed. PDF processing disabled. Install with: pip install PyMuPDF")

# Hybrid Search Imports
try:
    from rank_bm25 import BM25Okapi
    import string
    import nltk
    from nltk.corpus import stopwords
    nltk.download('stopwords', quiet=True)
    HYBRID_SEARCH_AVAILABLE = True
except ImportError:
    HYBRID_SEARCH_AVAILABLE = False
    print("‚ö†Ô∏è Hybrid search dependencies missing. Install with: pip install rank-bm25 nltk")

class BillAssistant:
    """
    Class-based bill assistant with semantic Q&A using sentence-transformer embeddings.

    Usage:
        assistant = BillAssistant(model_name='deepseek-ai/DeepSeek-OCR')
        assistant.run_cli()
    """

    def __init__(self, model_name: str = "deepseek-ai/DeepSeek-OCR", use_colab_secrets: bool = False):
        self.model_name = model_name
        self.device_info = self._gather_device_info()
        self.model = None
        self.tokenizer = None
        self.sentence_model = None
        self.client = None  # Groq client (optional)
        self.current_bill = None
        self.bill_text = None
        self.pdf_support = PDF_SUPPORT

        # Semantic structures
        self.chunks = []               # list[str]
        self.chunk_embeddings = None   # numpy.ndarray shape (n_chunks, emb_dim)
        self.bill_embeddings = None    # embedding of whole bill (optional)

        # Hybrid search components
        self.bm25 = None
        self.bm25_corpus = None
        self.hybrid_search_available = HYBRID_SEARCH_AVAILABLE

        # Optionally load Colab secrets and Groq client
        if use_colab_secrets and colab_secrets_available and Groq is not None:
            try:
                GROQ_API_KEY = userdata.get("GROQ_API_KEY")
                if GROQ_API_KEY:
                    self.client = Groq(api_key=GROQ_API_KEY)
            except Exception as e:
                print(f"‚ö†Ô∏è Failed to init Colab secrets / Groq client: {e}")

        # Load models eagerly (you may want to lazy-load in heavy environments)
        self.load_models()

    # ------ Utility / device info ------
    def _gather_device_info(self):
        try:
            cuda_available = torch.cuda.is_available()
            if cuda_available:
                try:
                    device_name = torch.cuda.get_device_name(0)
                except Exception:
                    device_name = "Unknown CUDA device"
                try:
                    compute_cap = torch.cuda.get_device_capability(0)
                except Exception:
                    compute_cap = ("N/A",)
            else:
                device_name = "CPU"
                compute_cap = ("N/A",)
            return {
                "torch_version": torch.__version__,
                "cuda_available": cuda_available,
                "device_name": device_name,
                "compute_capability": compute_cap,
                "flash_attn": getattr(flash_attn, "__version__", None) if flash_attn else None
            }
        except Exception as e:
            return {"error": str(e)}

    def print_device_info(self):
        info = self.device_info
        print(f"PyTorch version: {info.get('torch_version')}")
        print(f"CUDA available: {info.get('cuda_available')}")
        print(f"GPU: {info.get('device_name')}")
        print(f"Compute capability: {info.get('compute_capability')}")
        if info.get("flash_attn"):
            print(f"‚úì Flash Attention version: {info.get('flash_attn')}")
        else:
            print("‚úó Flash Attention not installed or not available")

    # ------ Model loading ------
    def load_models(self):
        "Load tokenizer, model and sentence-transformer used for embeddings."
        print("Loading Assistant Eye...", end='\t')
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
            self.model = AutoModel.from_pretrained(
                self.model_name,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                use_safetensors=True
            )
            self.model = self.model.eval()
            print("‚úÖ ASSISTANT EYE LOADED SUCCESSFULLY!!!")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to load model/tokenizer: {e}")
            self.model = None
            self.tokenizer = None

        try:
            print("Loading Assistant Brain...", end='\t')
            self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
            print("‚úÖ ASSISTANT BRAIN LOADED SUCCESSFULLY!!!")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to load sentence-transformer: {e}")
            self.sentence_model = None

    # ------ Core OCR / inference run ------
    def model_run(self, prompt: str, image_file: str):
        """
        Run the OCR/inference model.
        - If the real model is available, call model.infer(...) as per original script.
        - If not available (or for debugging), returns a hardcoded sample result.
        """
        output_path = f"/content/outputs/{os.path.splitext(os.path.basename(image_file))[0]}"
        os.makedirs(output_path, exist_ok=True)

        if self.model is None or self.tokenizer is None:
            # fallback -- return debug sample text (same as your test)
            print("‚ö†Ô∏è Model/tokenizer unavailable.")
            return ""

        # If real model exists, call its inference method (kept as in original code)
        try:
            torch.cuda.empty_cache()
            res = self.model.infer(
                self.tokenizer,
                prompt=prompt,
                image_file=image_file,
                output_path=output_path,
                base_size=1536,
                image_size=1024,
                crop_mode=False,
                save_results=True,
                test_compress=False,
                eval_mode=True,  # return instead of printing
            )
            print("Extractin complete.")
            return res
        except Exception as e:
            print(f"‚ö†Ô∏è Model inference failed: {e}")
            return ""

    def chunk_text(self, text: str, chunk_size: int = 400, overlap: int = 50):
        """
        Split text into overlapping chunks (approx. chunk_size tokens/characters).
        This uses naive character-based splitting for simplicity. For production,
        use token-based splitting (e.g., tiktoken) to respect model token counts.
        """
        if not text:
            return []
        text = text.strip()
        chunks = []
        start = 0
        length = len(text)
        while start < length:
            end = start + chunk_size
            chunk = text[start:end].strip()
            if chunk:
                chunks.append(chunk)
            start = end - overlap  # overlap
        return chunks

    def compute_chunk_embeddings(self):
        """
        Compute embeddings for each chunk and also store embedding for whole bill.
        """
        if self.sentence_model is None:
            print("‚ö†Ô∏è Sentence model not available; cannot compute embeddings.")
            self.chunk_embeddings = None
            self.bill_embeddings = None
            self.bm25 = None
            self.bm25_corpus = None
            return

        if not self.chunks:
            self.chunk_embeddings = None
            self.bill_embeddings = None
            self.bm25 = None
            self.bm25_corpus = None
            return

        print("üß† Computing semantic embeddings...")
        emb = self.sentence_model.encode(self.chunks, convert_to_numpy=True)
        # Normalize embeddings (helps cosine similarity)
        norms = np.linalg.norm(emb, axis=1, keepdims=True)
        norms[norms == 0] = 1e-10
        emb_norm = emb / norms
        self.chunk_embeddings = emb_norm  # shape (n_chunks, d)

        # whole-bill embedding
        whole_emb = self.sentence_model.encode([self.bill_text], convert_to_numpy=True)
        whole_emb /= (np.linalg.norm(whole_emb, axis=1, keepdims=True) + 1e-10)
        self.bill_embeddings = whole_emb[0]

       # Initialize BM25 for hybrid search if available
        if self.hybrid_search_available:
            print("üìö Building BM25 index for hybrid search...")
            try:
                self.bm25_corpus = [self._preprocess_text(chunk) for chunk in self.chunks]
                self.bm25 = BM25Okapi(self.bm25_corpus)
                print(f"‚úÖ BM25 index built with {len(self.chunks)} chunks!")
            except Exception as e:
                print(f"‚ö†Ô∏è Failed to build BM25 index: {e}")
                self.bm25 = None
                self.bm25_corpus = None
        else:
            self.bm25 = None
            self.bm25_corpus = None

    # ------ High-level processing ------
    def process_bill(self, image_path: str, prompt: str = "<image>\nStrict OCR. Extract all the text in the image as Markdown."):
        """
        Process a bill image or PDF: run OCR, parse to structured JSON, and compute embeddings.
        Returns a status message (string).
        """
        if not image_path:
            return "‚ùå No image path provided."

        # If URL, download
        if image_path.startswith("http"):
            try:
                response = requests.get(image_path)
                img = Image.open(BytesIO(response.content)).convert("RGB")
                tmp_path = "/content/tmp/bill_download"
                _, ext = os.path.splitext(image_path)
                tmp_path += ext.lower() if ext else ".jpg"
                img.save(tmp_path)
                image_path = tmp_path
            except Exception as e:
                return f"‚ùå Failed to download image: {e}"

        if not os.path.exists(image_path):
            return "‚ùå File not found!"

        # Handle PDF files
        if image_path.lower().endswith('.pdf'):
            if not self.pdf_support:
                return "‚ùå PDF processing not available. Install PyMuPDF with: pip install PyMuPDF"

            print("üìÑ Processing PDF file (converting pages to images)...")
            temp_dir = tempfile.mkdtemp()
            image_paths = []
            try:
                # Convert PDF to images
                doc = fitz.open(image_path)
                for page_num in range(doc.page_count):
                    page = doc[page_num]
                    pix = page.get_pixmap(dpi=150)  # 150 DPI for good quality
                    img_path = os.path.join(temp_dir, f"page_{page_num+1}.png")
                    pix.save(img_path)
                    image_paths.append(img_path)
                doc.close()

                # Process each page
                bill_texts = []
                for i, img_path in enumerate(image_paths):
                    print(f"üìù Extracting text from page {i+1}/{len(image_paths)}...")
                    page_text = self.model_run(prompt, img_path)
                    if not page_text.strip():
                        page_text = "[No text extracted from this page]"
                    bill_texts.append(f"--- Page {i+1} ---\n{page_text}")

                bill_text = "\n\n".join(bill_texts)
            except Exception as e:
                return f"‚ùå PDF processing failed: {e}"
            finally:
                # Clean up temporary images
                for img_path in image_paths:
                    try:
                        os.remove(img_path)
                    except:
                        pass
                try:
                    os.rmdir(temp_dir)
                except:
                    pass
        else:
            # Process single image
            print("üìù Extracting text...")
            bill_text = self.model_run(prompt, image_path)

        if not bill_text or not bill_text.strip():
            return "‚ùå No text extracted!"

        print("üß† Parsing with AI...")
        parsed_data = self.parse_bill_with_llm(bill_text)
        if not parsed_data:
            return "‚ùå Parsing failed!"

        self.current_bill = parsed_data
        self.bill_text = bill_text

        # chunk & compute embeddings (semantic Q&A)
        print("üîç Creating text chunks...")
        self.chunks = self.chunk_text(bill_text, chunk_size=400, overlap=50)
        if not self.chunks:
            self.chunks = [bill_text]

        print(f"üî¢ {len(self.chunks)} chunks created. Computing embeddings...")
        self.compute_chunk_embeddings()

        return "‚úÖ Bill processed successfully!"

    def _preprocess_text(self, text: str):
        """Preprocess text for BM25: lowercase, remove punctuation, remove stopwords."""
        text = text.lower()
        text = text.translate(str.maketrans('', '', string.punctuation))
        tokens = text.split()
        stop_words = set(stopwords.words('english'))
        tokens = [token for token in tokens if token not in stop_words and len(token) > 2]
        return tokens

    def _cosine_sim(self, a: np.ndarray, b: np.ndarray):
        "Compute cosine similarity between 1D a and 2D b (b is list of vectors)."
        if a.ndim == 1:
            a = a.reshape(1, -1)
        a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-10)
        b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-10)
        return np.dot(a_norm, b_norm.T).squeeze(0)  # shape (n_b,)

    def semantic_search(self, query: str, top_k: int = 3, alpha: float = 0.7):
        """
        Hybrid search combining BM25 (keyword) and semantic (embedding) scores.
        alpha: weight for semantic similarity (0.0 = BM25 only, 1.0 = semantic only)
        Returns list of tuples: (chunk_text, combined_score, chunk_index)
        """
        if not self.chunks:
            return []

        # Always compute semantic scores if available
        semantic_scores = np.zeros(len(self.chunks))
        if self.chunk_embeddings is not None and self.sentence_model is not None:
            try:
                q_emb = self.sentence_model.encode([query], convert_to_numpy=True)[0]
                q_emb = q_emb / (np.linalg.norm(q_emb) + 1e-10)
                semantic_scores = self._cosine_sim(q_emb, self.chunk_embeddings)
            except Exception as e:
                print(f"‚ö†Ô∏è Semantic search failed: {e}")

        # Compute BM25 scores if available
        bm25_scores = np.zeros(len(self.chunks))
        if self.bm25 is not None:
            try:
                query_tokens = self._preprocess_text(query)
                bm25_scores = self.bm25.get_scores(query_tokens)
            except Exception as e:
                print(f"‚ö†Ô∏è BM25 search failed: {e}")

        # Normalize scores to [0, 1] range
        def normalize(scores):
            min_score = np.min(scores)
            max_score = np.max(scores)
            if max_score - min_score < 1e-10:  # Avoid division by zero
                return scores
            return (scores - min_score) / (max_score - min_score + 1e-10)

        norm_semantic = normalize(semantic_scores)
        norm_bm25 = normalize(bm25_scores)

        # Combine scores with weight alpha
        combined_scores = alpha * norm_semantic + (1 - alpha) * norm_bm25

        # Get top indices (descending order)
        top_idx = np.argsort(-combined_scores)[:top_k]
        results = [(self.chunks[i], float(combined_scores[i]), int(i)) for i in top_idx]
        return results


    def _clean_json(self, text: str) -> str:
      if text is None:
          return ""
      text = text.replace('\\r\\n', '\n').replace('\\n', '\n').strip()
      text = re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.IGNORECASE)
      text = re.sub(r"\n?```$", "", text, flags=re.IGNORECASE)
      return text.strip()

    def _basic_repair(self, text: str) -> str:
      t = text
      # common fixes (same idea as earlier)
      t = re.sub(r'"\s+"', r'"', t)
      t = re.sub(r'"\s+""', r'""', t)
      t = re.sub(r':\s*"\s*"\s*([^"\n\r]+)"', r': "\1"', t)
      t = re.sub(r':\s*"\s*([^"\n\r]+)"', r': "\1"', t)  # keep trying to remove stray quotes
      t = re.sub(r'":\s*"\s+([^"\n\r]+)"', r'": "\1"', t)
      t = re.sub(r"\'([^\']*)\'", r'"\1"', t)
      t = re.sub(r',\s*(\}|\])', r'\1', t)
      t = re.sub(r'"\s*\}\s*\s*\{', r'"\n},\n{', t)
      t = ''.join(ch for ch in t if ch == '\n' or (31 < ord(ch) < 127))
      t = re.sub(r'(?m)^(\s*)([A-Za-z0-9_\-]+)\s*:', r'\1"\2":', t)
      t = re.sub(r'"\s+([^"]+?)\s+"', lambda m: f'"{m.group(1).strip()}"', t)
      # don't auto-append braces/brackets here ‚Äî leave to the more aggressive routine
      return t.strip()

    def _balance_closers(self, s: str) -> str:
      # Add minimal closers to match opens
      open_braces = s.count('{')
      close_braces = s.count('}')
      open_brackets = s.count('[')
      close_brackets = s.count(']')
      if open_braces > close_braces:
          s = s + ('}' * (open_braces - close_braces))
      if open_brackets > close_brackets:
          s = s + (']' * (open_brackets - close_brackets))
      return s

    def safe_load_json_recover(self, raw_text: str, debug: bool = False, max_trim_chars: int = 3000):
      """
      Attempt to clean/repair LLM-produced JSON and recover a Python object.
      Strategy:
        1) Clean and do basic regex repairs.
        2) Try json.loads.
        3) If it fails, attempt balancing quotes/braces/brackets and try again.
        4) If still fails, progressively trim from the end (up to `max_trim_chars`) and try balancing + loads.
      Returns: Python object (dict/list) or raises ValueError with repaired snippet for inspection.
      """
      cleaned = self._clean_json(raw_text)
      repaired = self._basic_repair(cleaned)

      if debug:
          print("=== Initial repaired ===")
          print(repaired)
          print("=== Trying json.loads ===")

      # 1) Try direct load after basic repair
      try:
          return json.loads(repaired)
      except json.JSONDecodeError as e:
          pass

      # 2) If odd number of double-quotes, try closing the last quote
      if repaired.count('"') % 2 == 1:
          cand = repaired + '"'
          try:
              return json.loads(self._balance_closers(cand))
          except Exception:
              # continue to trimming attempts
              pass

      # 3) Try adding minimal closers and reloading
      cand = self._balance_closers(repaired)
      try:
          return json.loads(cand)
      except Exception:
          pass

      # 4) Progressive trimming: remove trailing characters (one by one or in small chunks)
      #    and try to parse the prefix + balanced closers.
      L = len(repaired)
      # we'll try trimming up to max_trim_chars, in steps (bigger steps at first)
      step = 1
      trimmed = None
      for trim in range(0, min(max_trim_chars, L), step):
          if trim == 0:
              candidate = repaired
          else:
              candidate = repaired[:L - trim]

          # If candidate ends in a partial token like ' "qty": 3, "amount":', remove a trailing incomplete token:
          # remove trailing sequences after the last '}' or ']' if they exist
          last_close = max(candidate.rfind('}'), candidate.rfind(']'))
          if last_close != -1 and last_close > len(candidate) - 200:
              # keep up to last_close (close object/array)
              candidate = candidate[:last_close+1]

          # Close any open quotes (best-effort)
          if candidate.count('"') % 2 == 1:
              candidate = candidate + '"'

          # Balance braces/brackets
          candidate = self._balance_closers(candidate)

          try:
              parsed = json.loads(candidate)
              if debug:
                  print(f"Recovered by trimming {trim} chars.")
              return parsed
          except Exception:
              # increase step size after initial few tries to speed up
              if trim < 50:
                  step = 1
              elif trim < 200:
                  step = 5
              else:
                  step = 20
              continue

      # If all attempts fail, raise with helpful diagnostic including the best-effort repaired text
      # Provide truncated snippet to avoid enormous message
      best_snippet = repaired[:4000] + ("... (truncated)" if len(repaired) > 4000 else "")
      msg = (
          "Failed to parse JSON after attempted repairs and progressive trimming.\n\n"
          "A best-effort repaired snippet is shown below (inspect to decide next action):\n\n"
          f"{best_snippet}\n\n"
          "Common fixes: ask the LLM to re-output only the JSON inside a single ```json``` codeblock, "
          "increase the model `max_tokens` for longer outputs, or detect why the output was truncated."
      )
      raise ValueError(msg)

    # ------ Parsing (LLM or hardcoded) ------
    def parse_bill_with_llm(self, text: str):
        "Parse bill text into structured JSON."
        # If you want to call Groq LLM, you might do something like:
        if self.client is not None:
            prompt = f"""
                Extract structured data from this bill text as JSON:
                {text}

                Required fields:
                - invoice_number, invoice_date, due_date, po_number
                - bill_to (name, address, phone)
                - ship_to (name, address, phone)
                - items (qty, description, unit_price, amount)
                - subtotal, tax, total
                - company_name, company_address

                Return ONLY valid JSON with these fields. Use empty strings for missing data.
                """
            print(f"\n\nPROMPT:\n{prompt}\n\n")
            response = self.client.chat.completions.create(
                        messages=[{"role": "user", "content": prompt}],
                        model="llama-3.1-8b-instant",
                        temperature=0.1,
                        max_tokens=1024
                    )
            response_content = response.choices[0].message.content

            try:
                parsed = self.safe_load_json_recover(response_content, debug=True)
                print("Parsed OK:", parsed)
                return parsed
            except ValueError as e:
                print("Could not parse. Inspect repaired output:")
                print(str(e))
                return {}

    # ---------------- answering queries ----------------
    def answer_query(self, query: str, top_k: int = 3):
        if not self.current_bill:
            return "‚ùå Process a bill first!"

        query_lower = query.lower()
        if "total" in query_lower:
            return f"üí∞ Total: ${self.current_bill.get('total', 'N/A')}"
        if "invoice number" in query_lower or "invoice #" in query_lower or "invoice" == query_lower.strip():
            return f"üìã Invoice: {self.current_bill.get('invoice_number', 'N/A')}"
        if "date" in query_lower:
            return f"üìÖ Date: {self.current_bill.get('invoice_date', 'N/A')}"
        if "items" in query_lower:
            items = self.current_bill.get("items", [])
            if items:
                df = pd.DataFrame(items)
                return f"üõí Items:\n{tabulate(df, headers='keys', tablefmt='grid')}"
            return "‚ÑπÔ∏è No items found"

        # fallback to semantic retrieval
        retrieved = self.semantic_search(query, top_k=top_k)
        if not retrieved:
            return "‚ÑπÔ∏è No relevant content found in the bill."
        # print(f"\nRETRIEVED DATA\n{retrieved}")

        if self.client is not None:
            context_text = "\n\n---\n\n".join([f"Chunk {idx} (score {score:.4f}):\n{text}" for text, score, idx in retrieved])
            prompt = (
                f"Use ONLY the context below (do NOT hallucinate). "
                f"Extract an answer to the question. If information is not present, say 'Not found'.\n\n"
                f"CONTEXT:\n{context_text}\n\n"
                f"QUESTION: {query}\n\n"
                f"Answer concisely based ONLY on the context above:"
            )
            try:
                response = self.client.chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    model="llama-3.1-8b-instant",
                    temperature=0.0,
                    max_tokens=256
                )
                return response.choices[0].message.content
            except Exception as e:
                # fallback to returning chunks
                return f"‚ö†Ô∏è LLM query failed: {e}\n\nTop relevant text:\n\n" + "\n\n---\n\n".join([t for t, s, i in retrieved])

        # If no LLM, return the top chunks concatenated with scores
        best_text = "\n\n---\n\n".join([f"(score {s:.4f})\n{t}" for t, s, i in retrieved])
        return f"üîé Top relevant bill text (best {len(retrieved)} chunks):\n\n{best_text}"

    # ------ Export ------
    def export_to_json(self, filename: str = None):
        "Export the current bill to a JSON file."
        if not self.current_bill:
            return "‚ùå No bill data available!"

        if not filename:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"bill_{timestamp}.json"

        try:
            with open(filename, "w") as f:
                json.dump(self.current_bill, f, indent=2)
        except Exception as e:
            return f"‚ùå Failed to export: {e}"
        return f"‚úÖ Exported to {filename}"

    # ------ Chat ------
    def chat(self, message: str):
        "Simple chat handler. Uses Groq if available for richer replies."
        msg_lower = message.lower().strip()

        if msg_lower in ("hi", "hello", "hey"):
            return "üëã Hello! I'm your bill assistant. I can help you extract data from bills, answer questions, and export information."

        if "help" in msg_lower:
            help_text = "üí° I can:\n1. Process bill images and PDFs\n2. Answer questions about bills\n3. Export data to JSON\n4. General chat"
            if self.current_bill:
                help_text += "\n\n‚úÖ Current bill loaded! Ask about totals, dates, items, or custom questions."
            else:
                help_text += "\n\n‚ö†Ô∏è No bill processed yet. Please process a bill first to ask questions about it."
            return help_text

        # Handle queries when bill is loaded
        if self.current_bill is not None:
            return self.answer_query(message, top_k=3)

        # No bill loaded - provide helpful guidance without LLM calls
        if any(keyword in msg_lower for keyword in ["bill", "invoice", "receipt", "document", "pdf", "image", "process", "upload", "load", "scan"]):
            return "üìé Please process a bill first using option 1 in the main menu. I can handle images and PDFs!"

        if any(keyword in msg_lower for keyword in ["thank", "bye", "exit", "goodbye"]):
            return "üëã You can exit chat mode anytime by typing 'exit' or 'quit'."

        # General fallback when no bill is loaded
        return "ü§ñ I'm ready to help with bills! Please process a bill first (option 1), then ask questions like 'What's the total?' or 'Show items'. Type 'help' for options."

    def chat_loop(self):
        "Continuous chat loop until user exits."
        print("\nüí¨ Chat mode activated (type 'exit' or 'quit' to return to main menu)")
        print("ü§ñ Assistant: Hello! I'm your bill assistant. How can I help you today?")

        while True:
            message = input("You: ").strip()

            if message.lower() in ('exit', 'quit', 'bye', 'goodbye'):
                print("üëã Exiting chat mode...")
                break

            response = self.chat(message)
            print(f"ü§ñ Assistant: {response}")

    def run_cli(self):
        """Interactive command-line menu."""
        self.print_device_info()
        print("\n" + "=" * 50)
        print("BILL ASSISTANT")
        print("=" * 50)

        print("\n1. üì∏ Process bill image/PDF")
        print("2. üíæ Export to JSON")
        print("3. üí¨ Chat")
        print("4. üö™ Exit")

        while True:
            choice = input("\nSelect option (1-4): ").strip()

            if choice == "1":
                torch.cuda.empty_cache()
                gc.collect()
                img_path = input("Enter image/PDF path (or URL): ").strip()
                result = self.process_bill(img_path)
                print(result)
                if self.current_bill:
                    print("\nüìä Bill Summary:")
                    print(f"Invoice: {self.current_bill.get('invoice_number', 'N/A')}")
                    print(f"Date: {self.current_bill.get('invoice_date', 'N/A')}")
                    print(f"Total: ${self.current_bill.get('total', 'N/A')}")
                print("\n" + "=" * 50)
                print("BILL ASSISTANT")
                print("=" * 50)
                print("\n1. üì∏ Process bill image/PDF")
                print("2. üíæ Export to JSON")
                print("3. üí¨ Chat")
                print("4. üö™ Exit")

            elif choice == "2":
                filename = input("Enter filename (or press Enter for default): ").strip()
                result = self.export_to_json(filename if filename else None)
                print(result)

            elif choice == "3":
                self.chat_loop()  # Call the new continuous chat loop
                # Show menu again after exiting chat
                print("\n" + "=" * 50)
                print("BILL ASSISTANT")
                print("=" * 50)
                print("\n1. üì∏ Process bill image/PDF")
                print("2. üíæ Export to JSON")
                print("3. üí¨ Chat")
                print("4. üö™ Exit")

            elif choice == "4":
                print("üëã Goodbye!")
                break

            else:
                print("‚ùå Invalid choice!")

# If run as script, start the CLI3
if __name__ == "__main__":
    assistant = BillAssistant(use_colab_secrets=True)
    assistant.run_cli()

Loading Assistant Eye...	‚úÖ ASSISTANT EYE LOADED SUCCESSFULLY!!!
Loading Assistant Brain...	‚úÖ ASSISTANT BRAIN LOADED SUCCESSFULLY!!!
PyTorch version: 2.8.0+cu126
CUDA available: True
GPU: Tesla T4
Compute capability: (7, 5)
‚úì Flash Attention version: 2.7.3

BILL ASSISTANT

1. üì∏ Process bill image/PDF
2. üíæ Export to JSON
3. üí¨ Chat
4. üö™ Exit

Select option (1-4): 1
Enter image/PDF path (or URL): /content/img.jpg
üìù Extracting text...




BASE:  torch.Size([1, 256, 1280])
NO PATCHES
Extractin complete.
üß† Parsing with AI...


PROMPT:

                Extract structured data from this bill text as JSON:
                East Repair Inc.  
1912 Harvest Lane  
New York, NY 12210  

Bill To  
John Smith  
2 Court Square  
New York, NY 12210  

Ship To  
John Smith  
3787 Pineview Drive  
Cambridge, MA 12210  

Invoice #  
US-001  

Invoice Date  
11/02/2019  

P.O.#  
2312/2019  

Due Date  
26/02/2019  

| QTY | DESCRIPTION | UNIT PRICE | AMOUNT |
|-----|-------------|------------|--------|
| 1   | Front and rear brake cables | 100.00 | 100.00 |
| 2   | New set of pedal arms | 15.00 | 30.00 |
| 3   | Labor 3hrs | 5.00 | 15.00 |

Subtotal  
145.00  

Sales Tax 6.25%  
9.06  

TOTAL  
$154.06  

Terms & Conditions  
Payment is due within 15 days  

Please make checks payable to: East Repair Inc.

                Required fields:
                - invoice_number, invoice_date, due_date, po_number
                - bill_to (n

KeyboardInterrupt: Interrupted by user

### Deploy the application

In [2]:
%%writefile assistant.py
import os
import json
import warnings
from datetime import datetime
from io import BytesIO
import tempfile  # For PDF processing

# Third-party libs
import torch
import gc

# Transformers / sentence-transformers
from transformers import AutoModel, AutoTokenizer, logging as hf_logging
from sentence_transformers import SentenceTransformer

# Aux libs used in CLI
import requests
from PIL import Image
import pandas as pd
import numpy as np
import re
from tabulate import tabulate

import flash_attn  # optional

# Basic environment / warning control
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore", message=".*GetPrototype.*")
hf_logging.set_verbosity_error()

from google.colab import userdata
colab_secrets_available = True
from groq import Groq
import fitz  # PyMuPDF
PDF_SUPPORT = True

from rank_bm25 import BM25Okapi
import string
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords', quiet=True)
HYBRID_SEARCH_AVAILABLE = True

class BillAssistant:
    """
    Class-based bill assistant with semantic Q&A using sentence-transformer embeddings.

    Usage:
        assistant = BillAssistant(model_name='deepseek-ai/DeepSeek-OCR')
        assistant.run_cli()
    """

    def __init__(self, model_name: str = "deepseek-ai/DeepSeek-OCR", use_colab_secrets: bool = False):
        self.model_name = model_name
        self.device_info = self._gather_device_info()
        self.model = None
        self.tokenizer = None
        self.sentence_model = None
        self.client = None  # Groq client (optional)
        self.current_bill = None
        self.bill_text = None
        self.pdf_support = PDF_SUPPORT

        # Semantic structures
        self.chunks = []               # list[str]
        self.chunk_embeddings = None   # numpy.ndarray shape (n_chunks, emb_dim)
        self.bill_embeddings = None    # embedding of whole bill (optional)

        # Hybrid search components
        self.bm25 = None
        self.bm25_corpus = None
        self.hybrid_search_available = HYBRID_SEARCH_AVAILABLE

        # Optionally load Colab secrets and Groq client
        if use_colab_secrets and colab_secrets_available and Groq is not None:
            try:
                GROQ_API_KEY = userdata.get("GROQ_API_KEY")
                if GROQ_API_KEY:
                    self.client = Groq(api_key=GROQ_API_KEY)
            except Exception as e:
                print(f"‚ö†Ô∏è Failed to init Colab secrets / Groq client: {e}")

        # Load models eagerly (you may want to lazy-load in heavy environments)
        self.load_models()
        self.print_device_info()

    # ------ Utility / device info ------
    def _gather_device_info(self):
        try:
            cuda_available = torch.cuda.is_available()
            if cuda_available:
                try:
                    device_name = torch.cuda.get_device_name(0)
                except Exception:
                    device_name = "Unknown CUDA device"
                try:
                    compute_cap = torch.cuda.get_device_capability(0)
                except Exception:
                    compute_cap = ("N/A",)
            else:
                device_name = "CPU"
                compute_cap = ("N/A",)
            return {
                "torch_version": torch.__version__,
                "cuda_available": cuda_available,
                "device_name": device_name,
                "compute_capability": compute_cap,
                "flash_attn": getattr(flash_attn, "__version__", None) if flash_attn else None
            }
        except Exception as e:
            return {"error": str(e)}

    def print_device_info(self):
        info = self.device_info
        print(f"PyTorch version: {info.get('torch_version')}")
        print(f"CUDA available: {info.get('cuda_available')}")
        print(f"GPU: {info.get('device_name')}")
        print(f"Compute capability: {info.get('compute_capability')}")
        if info.get("flash_attn"):
            print(f"‚úì Flash Attention version: {info.get('flash_attn')}")
        else:
            print("‚úó Flash Attention not installed or not available")

    # ------ Model loading ------
    def load_models(self):
        "Load tokenizer, model and sentence-transformer used for embeddings."
        print("Loading Assistant Eye...", end='\t')
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
            self.model = AutoModel.from_pretrained(
                self.model_name,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                use_safetensors=True
            )
            self.model = self.model.eval()
            print("‚úÖ ASSISTANT EYE LOADED SUCCESSFULLY!!!")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to load model/tokenizer: {e}")
            self.model = None
            self.tokenizer = None

        try:
            print("Loading Assistant Brain...", end='\t')
            self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
            print("‚úÖ ASSISTANT BRAIN LOADED SUCCESSFULLY!!!")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to load sentence-transformer: {e}")
            self.sentence_model = None

    # ------ Core OCR / inference run ------
    def model_run(self, prompt: str, image_file: str):
        """
        Run the OCR/inference model.
        - If the real model is available, call model.infer(...) as per original script.
        - If not available (or for debugging), returns a hardcoded sample result.
        """
        output_path = f"/content/outputs/{os.path.splitext(os.path.basename(image_file))[0]}"
        os.makedirs(output_path, exist_ok=True)

        if self.model is None or self.tokenizer is None:
            # fallback -- return debug sample text (same as your test)
            print("‚ö†Ô∏è Model/tokenizer unavailable.")
            return ""

        # If real model exists, call its inference method (kept as in original code)
        try:
            print("OCR...")
            torch.cuda.empty_cache()
            res = self.model.infer(
                self.tokenizer,
                prompt=prompt,
                image_file=image_file,
                output_path=output_path,
                base_size=1536,
                image_size=1024,
                crop_mode=False,
                save_results=True,
                test_compress=False,
                eval_mode=True,  # return instead of printing
            )
            print(f"Extraction complete.\n{res}")
            return res
        except Exception as e:
            print(f"‚ö†Ô∏è Model inference failed: {e}")
            return ""

    def chunk_text(self, text: str, chunk_size: int = 400, overlap: int = 50):
        """
        Split text into overlapping chunks (approx. chunk_size tokens/characters).
        This uses naive character-based splitting for simplicity. For production,
        use token-based splitting (e.g., tiktoken) to respect model token counts.
        """
        if not text:
            return []
        text = text.strip()
        chunks = []
        start = 0
        length = len(text)
        while start < length:
            end = start + chunk_size
            chunk = text[start:end].strip()
            if chunk:
                chunks.append(chunk)
            start = end - overlap  # overlap
        return chunks

    def compute_chunk_embeddings(self):
        """
        Compute embeddings for each chunk and also store embedding for whole bill.
        """
        if self.sentence_model is None:
            print("‚ö†Ô∏è Sentence model not available; cannot compute embeddings.")
            self.chunk_embeddings = None
            self.bill_embeddings = None
            self.bm25 = None
            self.bm25_corpus = None
            return

        if not self.chunks:
            self.chunk_embeddings = None
            self.bill_embeddings = None
            self.bm25 = None
            self.bm25_corpus = None
            return

        print("üß† Computing semantic embeddings...")
        emb = self.sentence_model.encode(self.chunks, convert_to_numpy=True)
        # Normalize embeddings (helps cosine similarity)
        norms = np.linalg.norm(emb, axis=1, keepdims=True)
        norms[norms == 0] = 1e-10
        emb_norm = emb / norms
        self.chunk_embeddings = emb_norm  # shape (n_chunks, d)

        # whole-bill embedding
        whole_emb = self.sentence_model.encode([self.bill_text], convert_to_numpy=True)
        whole_emb /= (np.linalg.norm(whole_emb, axis=1, keepdims=True) + 1e-10)
        self.bill_embeddings = whole_emb[0]

       # Initialize BM25 for hybrid search if available
        if self.hybrid_search_available:
            print("üìö Building BM25 index for hybrid search...")
            try:
                self.bm25_corpus = [self._preprocess_text(chunk) for chunk in self.chunks]
                self.bm25 = BM25Okapi(self.bm25_corpus)
                print(f"‚úÖ BM25 index built with {len(self.chunks)} chunks!")
            except Exception as e:
                print(f"‚ö†Ô∏è Failed to build BM25 index: {e}")
                self.bm25 = None
                self.bm25_corpus = None
        else:
            self.bm25 = None
            self.bm25_corpus = None

    # ------ High-level processing ------
    def process_bill(self, image_path: str, prompt: str = "<image>\nStrict OCR. Extract all the text in the image as Markdown."):
        """
        Process a bill image or PDF: run OCR, parse to structured JSON, and compute embeddings.
        Returns a status message (string).
        """
        if not image_path:
            return "‚ùå No image path provided."

        # If URL, download
        if image_path.startswith("http"):
            try:
                response = requests.get(image_path)
                img = Image.open(BytesIO(response.content)).convert("RGB")
                tmp_path = "/content/tmp/bill_download"
                _, ext = os.path.splitext(image_path)
                tmp_path += ext.lower() if ext else ".jpg"
                img.save(tmp_path)
                image_path = tmp_path
            except Exception as e:
                return f"‚ùå Failed to download image: {e}"

        if not os.path.exists(image_path):
            return "‚ùå File not found!"

        # Handle PDF files
        if image_path.lower().endswith('.pdf'):
            if not self.pdf_support:
                return "‚ùå PDF processing not available. Install PyMuPDF with: pip install PyMuPDF"

            print("üìÑ Processing PDF file (converting pages to images)...")
            temp_dir = tempfile.mkdtemp()
            image_paths = []
            try:
                # Convert PDF to images
                doc = fitz.open(image_path)
                for page_num in range(doc.page_count):
                    page = doc[page_num]
                    pix = page.get_pixmap(dpi=150)  # 150 DPI for good quality
                    img_path = os.path.join(temp_dir, f"page_{page_num+1}.png")
                    pix.save(img_path)
                    image_paths.append(img_path)
                doc.close()

                # Process each page
                bill_texts = []
                for i, img_path in enumerate(image_paths):
                    print(f"üìù Extracting text from page {i+1}/{len(image_paths)}...")
                    page_text = self.model_run(prompt, img_path)
                    if not page_text.strip():
                        page_text = "[No text extracted from this page]"
                    bill_texts.append(f"--- Page {i+1} ---\n{page_text}")

                bill_text = "\n\n".join(bill_texts)
            except Exception as e:
                return f"‚ùå PDF processing failed: {e}"
            finally:
                # Clean up temporary images
                for img_path in image_paths:
                    try:
                        os.remove(img_path)
                    except:
                        pass
                try:
                    os.rmdir(temp_dir)
                except:
                    pass
        else:
            # Process single image
            print("üìù Extracting text...")
            bill_text = self.model_run(prompt, image_path)

        if not bill_text or not bill_text.strip():
            return "‚ùå No text extracted!"

        print("üß† Parsing with AI...")
        parsed_data = self.parse_bill_with_llm(bill_text)
        print(f"ParsedData:\n{parsed_data}")
        if not parsed_data:
            return "‚ùå Parsing failed!"

        self.current_bill = parsed_data
        self.bill_text = bill_text

        # chunk & compute embeddings (semantic Q&A)
        print("üîç Creating text chunks...")
        self.chunks = self.chunk_text(bill_text, chunk_size=400, overlap=50)
        if not self.chunks:
            self.chunks = [bill_text]

        print(f"üî¢ {len(self.chunks)} chunks created. Computing embeddings...")
        self.compute_chunk_embeddings()

        return "‚úÖ Bill processed successfully!"

    def _preprocess_text(self, text: str):
        """Preprocess text for BM25: lowercase, remove punctuation, remove stopwords."""
        text = text.lower()
        text = text.translate(str.maketrans('', '', string.punctuation))
        tokens = text.split()
        stop_words = set(stopwords.words('english'))
        tokens = [token for token in tokens if token not in stop_words and len(token) > 2]
        return tokens

    def _cosine_sim(self, a: np.ndarray, b: np.ndarray):
        "Compute cosine similarity between 1D a and 2D b (b is list of vectors)."
        if a.ndim == 1:
            a = a.reshape(1, -1)
        a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-10)
        b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-10)
        return np.dot(a_norm, b_norm.T).squeeze(0)  # shape (n_b,)

    def semantic_search(self, query: str, top_k: int = 3, alpha: float = 0.7):
        """
        Hybrid search combining BM25 (keyword) and semantic (embedding) scores.
        alpha: weight for semantic similarity (0.0 = BM25 only, 1.0 = semantic only)
        Returns list of tuples: (chunk_text, combined_score, chunk_index)
        """
        if not self.chunks:
            return []

        # Always compute semantic scores if available
        semantic_scores = np.zeros(len(self.chunks))
        if self.chunk_embeddings is not None and self.sentence_model is not None:
            try:
                q_emb = self.sentence_model.encode([query], convert_to_numpy=True)[0]
                q_emb = q_emb / (np.linalg.norm(q_emb) + 1e-10)
                semantic_scores = self._cosine_sim(q_emb, self.chunk_embeddings)
            except Exception as e:
                print(f"‚ö†Ô∏è Semantic search failed: {e}")

        # Compute BM25 scores if available
        bm25_scores = np.zeros(len(self.chunks))
        if self.bm25 is not None:
            try:
                query_tokens = self._preprocess_text(query)
                bm25_scores = self.bm25.get_scores(query_tokens)
            except Exception as e:
                print(f"‚ö†Ô∏è BM25 search failed: {e}")

        # Normalize scores to [0, 1] range
        def normalize(scores):
            min_score = np.min(scores)
            max_score = np.max(scores)
            if max_score - min_score < 1e-10:  # Avoid division by zero
                return scores
            return (scores - min_score) / (max_score - min_score + 1e-10)

        norm_semantic = normalize(semantic_scores)
        norm_bm25 = normalize(bm25_scores)

        # Combine scores with weight alpha
        combined_scores = alpha * norm_semantic + (1 - alpha) * norm_bm25

        # Get top indices (descending order)
        top_idx = np.argsort(-combined_scores)[:top_k]
        results = [(self.chunks[i], float(combined_scores[i]), int(i)) for i in top_idx]
        return results


    def _clean_json(self, text: str) -> str:
      if text is None:
          return ""
      text = text.replace('\\r\\n', '\n').replace('\\n', '\n').strip()
      text = re.sub(r"^```(?:json)?\s*\n?", "", text.strip(), flags=re.IGNORECASE)
      text = re.sub(r"\n?```$", "", text, flags=re.IGNORECASE)
      return text.strip()

    def _basic_repair(self, text: str) -> str:
      t = text
      # common fixes (same idea as earlier)
      t = re.sub(r'"\s+"', r'"', t)
      t = re.sub(r'"\s+""', r'""', t)
      t = re.sub(r':\s*"\s*"\s*([^"\n\r]+)"', r': "\1"', t)
      t = re.sub(r':\s*"\s*([^"\n\r]+)"', r': "\1"', t)  # keep trying to remove stray quotes
      t = re.sub(r'":\s*"\s+([^"\n\r]+)"', r'": "\1"', t)
      t = re.sub(r"\'([^\']*)\'", r'"\1"', t)
      t = re.sub(r',\s*(\}|\])', r'\1', t)
      t = re.sub(r'"\s*\}\s*\s*\{', r'"\n},\n{', t)
      t = ''.join(ch for ch in t if ch == '\n' or (31 < ord(ch) < 127))
      t = re.sub(r'(?m)^(\s*)([A-Za-z0-9_\-]+)\s*:', r'\1"\2":', t)
      t = re.sub(r'"\s+([^"]+?)\s+"', lambda m: f'"{m.group(1).strip()}"', t)
      # don't auto-append braces/brackets here ‚Äî leave to the more aggressive routine
      return t.strip()

    def _balance_closers(self, s: str) -> str:
      # Add minimal closers to match opens
      open_braces = s.count('{')
      close_braces = s.count('}')
      open_brackets = s.count('[')
      close_brackets = s.count(']')
      if open_braces > close_braces:
          s = s + ('}' * (open_braces - close_braces))
      if open_brackets > close_brackets:
          s = s + (']' * (open_brackets - close_brackets))
      return s

    def safe_load_json_recover(self, raw_text: str, debug: bool = False, max_trim_chars: int = 3000):
      """
      Attempt to clean/repair LLM-produced JSON and recover a Python object.
      Strategy:
        1) Clean and do basic regex repairs.
        2) Try json.loads.
        3) If it fails, attempt balancing quotes/braces/brackets and try again.
        4) If still fails, progressively trim from the end (up to `max_trim_chars`) and try balancing + loads.
      Returns: Python object (dict/list) or raises ValueError with repaired snippet for inspection.
      """
      cleaned = self._clean_json(raw_text)
      repaired = self._basic_repair(cleaned)

      if debug:
          print("=== Initial repaired ===")
          print(repaired)
          print("=== Trying json.loads ===")

      # 1) Try direct load after basic repair
      try:
          return json.loads(repaired)
      except json.JSONDecodeError as e:
          pass

      # 2) If odd number of double-quotes, try closing the last quote
      if repaired.count('"') % 2 == 1:
          cand = repaired + '"'
          try:
              return json.loads(self._balance_closers(cand))
          except Exception:
              # continue to trimming attempts
              pass

      # 3) Try adding minimal closers and reloading
      cand = self._balance_closers(repaired)
      try:
          return json.loads(cand)
      except Exception:
          pass

      # 4) Progressive trimming: remove trailing characters (one by one or in small chunks)
      #    and try to parse the prefix + balanced closers.
      L = len(repaired)
      # we'll try trimming up to max_trim_chars, in steps (bigger steps at first)
      step = 1
      trimmed = None
      for trim in range(0, min(max_trim_chars, L), step):
          if trim == 0:
              candidate = repaired
          else:
              candidate = repaired[:L - trim]

          # If candidate ends in a partial token like ' "qty": 3, "amount":', remove a trailing incomplete token:
          # remove trailing sequences after the last '}' or ']' if they exist
          last_close = max(candidate.rfind('}'), candidate.rfind(']'))
          if last_close != -1 and last_close > len(candidate) - 200:
              # keep up to last_close (close object/array)
              candidate = candidate[:last_close+1]

          # Close any open quotes (best-effort)
          if candidate.count('"') % 2 == 1:
              candidate = candidate + '"'

          # Balance braces/brackets
          candidate = self._balance_closers(candidate)

          try:
              parsed = json.loads(candidate)
              if debug:
                  print(f"Recovered by trimming {trim} chars.")
              return parsed
          except Exception:
              # increase step size after initial few tries to speed up
              if trim < 50:
                  step = 1
              elif trim < 200:
                  step = 5
              else:
                  step = 20
              continue

      # If all attempts fail, raise with helpful diagnostic including the best-effort repaired text
      # Provide truncated snippet to avoid enormous message
      best_snippet = repaired[:4000] + ("... (truncated)" if len(repaired) > 4000 else "")
      msg = (
          "Failed to parse JSON after attempted repairs and progressive trimming.\n\n"
          "A best-effort repaired snippet is shown below (inspect to decide next action):\n\n"
          f"{best_snippet}\n\n"
          "Common fixes: ask the LLM to re-output only the JSON inside a single ```json``` codeblock, "
          "increase the model `max_tokens` for longer outputs, or detect why the output was truncated."
      )
      raise ValueError(msg)

    # ------ Parsing (LLM or hardcoded) ------
    def parse_bill_with_llm(self, text: str):
        "Parse bill text into structured JSON."
        print(f"Parsing with LLM...")
        # If you want to call Groq LLM, you might do something like:
        if self.client is not None:
            print(f"Cient found.")
            prompt = f"""
                Extract structured data from this bill text as JSON:
                {text}

                Required fields:
                - invoice_number, invoice_date, due_date, po_number
                - bill_to (name, address, phone)
                - ship_to (name, address, phone)
                - items (qty, description, unit_price, amount)
                - subtotal, tax, total
                - company_name, company_address

                Return ONLY valid JSON with these fields. Use empty strings for missing data.
                """
            print(f"\n\nPROMPT:\n{prompt}\n\n")
            response = self.client.chat.completions.create(
                        messages=[{"role": "user", "content": prompt}],
                        model="llama-3.1-8b-instant",
                        temperature=0.1,
                        max_tokens=1024
                    )
            response_content = response.choices[0].message.content

            try:
                parsed = self.safe_load_json_recover(response_content, debug=True)
                print("Parsed OK:", parsed)
                return parsed
            except ValueError as e:
                print("Could not parse. Inspect repaired output:")
                print(str(e))
                return {}

    # ---------------- answering queries ----------------
    def answer_query(self, query: str, top_k: int = 3):
        if not self.current_bill:
            return "‚ùå Process a bill first!"

        query_lower = query.lower()
        if "total" in query_lower:
            return f"üí∞ Total: ${self.current_bill.get('total', 'N/A')}"
        if "invoice number" in query_lower or "invoice #" in query_lower or "invoice" == query_lower.strip():
            return f"üìã Invoice: {self.current_bill.get('invoice_number', 'N/A')}"
        if "date" in query_lower:
            return f"üìÖ Date: {self.current_bill.get('invoice_date', 'N/A')}"
        if "items" in query_lower:
            items = self.current_bill.get("items", [])
            if items:
                df = pd.DataFrame(items)
                return f"üõí Items:\n{tabulate(df, headers='keys', tablefmt='grid')}"
            return "‚ÑπÔ∏è No items found"

        # fallback to semantic retrieval
        retrieved = self.semantic_search(query, top_k=top_k)
        if not retrieved:
            return "‚ÑπÔ∏è No relevant content found in the bill."
        # print(f"\nRETRIEVED DATA\n{retrieved}")

        if self.client is not None:
            context_text = "\n\n---\n\n".join([f"Chunk {idx} (score {score:.4f}):\n{text}" for text, score, idx in retrieved])
            prompt = (
                f"Use ONLY the context below (do NOT hallucinate). "
                f"Extract an answer to the question. If information is not present, say 'Not found'.\n\n"
                f"CONTEXT:\n{context_text}\n\n"
                f"QUESTION: {query}\n\n"
                f"Answer concisely based ONLY on the context above:"
            )
            try:
                response = self.client.chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    model="llama-3.1-8b-instant",
                    temperature=0.0,
                    max_tokens=256
                )
                return response.choices[0].message.content
            except Exception as e:
                # fallback to returning chunks
                return f"‚ö†Ô∏è LLM query failed: {e}\n\nTop relevant text:\n\n" + "\n\n---\n\n".join([t for t, s, i in retrieved])

        # If no LLM, return the top chunks concatenated with scores
        best_text = "\n\n---\n\n".join([f"(score {s:.4f})\n{t}" for t, s, i in retrieved])
        return f"üîé Top relevant bill text (best {len(retrieved)} chunks):\n\n{best_text}"

    # ------ Export ------
    def export_to_json(self, filename: str = None):
        "Export the current bill to a JSON file."
        if not self.current_bill:
            return "‚ùå No bill data available!"

        if not filename:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"bill_{timestamp}.json"

        try:
            with open(filename, "w") as f:
                json.dump(self.current_bill, f, indent=2)
        except Exception as e:
            return f"‚ùå Failed to export: {e}"
        return f"‚úÖ Exported to {filename}"

    # ------ Chat ------
    def chat(self, message: str):
        "Simple chat handler. Uses Groq if available for richer replies."
        msg_lower = message.lower().strip()

        if msg_lower in ("hi", "hello", "hey"):
            return "üëã Hello! I'm your bill assistant. I can help you extract data from bills, answer questions, and export information."

        if "help" in msg_lower:
            help_text = "üí° I can:\n1. Process bill images and PDFs\n2. Answer questions about bills\n3. Export data to JSON\n4. General chat"
            if self.current_bill:
                help_text += "\n\n‚úÖ Current bill loaded! Ask about totals, dates, items, or custom questions."
            else:
                help_text += "\n\n‚ö†Ô∏è No bill processed yet. Please process a bill first to ask questions about it."
            return help_text

        # Handle queries when bill is loaded
        if self.current_bill is not None:
            return self.answer_query(message, top_k=3)

        # No bill loaded - provide helpful guidance without LLM calls
        if any(keyword in msg_lower for keyword in ["bill", "invoice", "receipt", "document", "pdf", "image", "process", "upload", "load", "scan"]):
            return "üìé Please process a bill first using option 1 in the main menu. I can handle images and PDFs!"

        if any(keyword in msg_lower for keyword in ["thank", "bye", "exit", "goodbye"]):
            return "üëã You can exit chat mode anytime by typing 'exit' or 'quit'."

        # General fallback when no bill is loaded
        return "ü§ñ I'm ready to help with bills! Please process a bill first (option 1), then ask questions like 'What's the total?' or 'Show items'. Type 'help' for options."

    def chat_loop(self):
        "Continuous chat loop until user exits."
        print("\nüí¨ Chat mode activated (type 'exit' or 'quit' to return to main menu)")
        print("ü§ñ Assistant: Hello! I'm your bill assistant. How can I help you today?")

        while True:
            message = input("You: ").strip()

            if message.lower() in ('exit', 'quit', 'bye', 'goodbye'):
                print("üëã Exiting chat mode...")
                break

            response = self.chat(message)
            print(f"ü§ñ Assistant: {response}")

    def run_cli(self):
        """Interactive command-line menu."""
        self.print_device_info()
        print("\n" + "=" * 50)
        print("BILL ASSISTANT")
        print("=" * 50)

        print("\n1. üì∏ Process bill image/PDF")
        print("2. üíæ Export to JSON")
        print("3. üí¨ Chat")
        print("4. üö™ Exit")

        while True:
            choice = input("\nSelect option (1-4): ").strip()

            if choice == "1":
                torch.cuda.empty_cache()
                gc.collect()
                img_path = input("Enter image/PDF path (or URL): ").strip()
                result = self.process_bill(img_path)
                print(result)
                if self.current_bill:
                    print("\nüìä Bill Summary:")
                    print(f"Invoice: {self.current_bill.get('invoice_number', 'N/A')}")
                    print(f"Date: {self.current_bill.get('invoice_date', 'N/A')}")
                    print(f"Total: ${self.current_bill.get('total', 'N/A')}")
                print("\n" + "=" * 50)
                print("BILL ASSISTANT")
                print("=" * 50)
                print("\n1. üì∏ Process bill image/PDF")
                print("2. üíæ Export to JSON")
                print("3. üí¨ Chat")
                print("4. üö™ Exit")

            elif choice == "2":
                filename = input("Enter filename (or press Enter for default): ").strip()
                result = self.export_to_json(filename if filename else None)
                print(result)

            elif choice == "3":
                self.chat_loop()  # Call the new continuous chat loop
                # Show menu again after exiting chat
                print("\n" + "=" * 50)
                print("BILL ASSISTANT")
                print("=" * 50)
                print("\n1. üì∏ Process bill image/PDF")
                print("2. üíæ Export to JSON")
                print("3. üí¨ Chat")
                print("4. üö™ Exit")

            elif choice == "4":
                print("üëã Goodbye!")
                break

            else:
                print("‚ùå Invalid choice!")

# # If run as script, start the CLI3
# if __name__ == "__main__":
#     assistant = BillAssistant(use_colab_secrets=True)
#     assistant.run_cli()

Writing assistant.py


In [1]:
%%writefile main.py
import streamlit as st
import pandas as pd
import json
import os
import tempfile
from io import BytesIO
from PIL import Image
from assistant import BillAssistant  # assume this is the provided class
import hashlib

# Cache the assistant so models load only once
@st.cache_resource
def load_assistant():
    return BillAssistant(use_colab_secrets=True)

assistant = load_assistant()

st.set_page_config(page_title="Bill Assistant", layout="centered")
st.title("Bill Assistant")
st.write("Upload a bill image (PNG/JPEG) or PDF below and then click **Process bill**. Processing will run once per uploaded file unless you clear it.")

# Initialize session state keys we use
if "uploaded_file_id" not in st.session_state:
    st.session_state.uploaded_file_id = None
if "temp_path" not in st.session_state:
    st.session_state.temp_path = None
if "process_result" not in st.session_state:
    st.session_state.process_result = None
if "processed_ok" not in st.session_state:
    st.session_state.processed_ok = False
if "messages" not in st.session_state:
    st.session_state.messages = []

def make_file_id(uploaded_file) -> str:
    """Create a simple id for uploaded file to avoid reprocessing same content."""
    # Use filename + size + sha1 of bytes for more certainty
    try:
        content = uploaded_file.getbuffer()
        h = hashlib.sha1(content).hexdigest()
        return f"{uploaded_file.name}-{uploaded_file.size}-{h}"
    except Exception:
        return f"{uploaded_file.name}-{uploaded_file.size}"

def save_uploaded_to_temp(uploaded_file) -> str:
    """Save Streamlit uploaded file to a unique temp path and return the path."""
    suffix = ""
    if uploaded_file.type == "application/pdf" or uploaded_file.name.lower().endswith(".pdf"):
        suffix = ".pdf"
    else:
        # try to preserve extension
        ext = os.path.splitext(uploaded_file.name)[1] or ".png"
        suffix = ext if ext.startswith(".") else f".{ext}"
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix="bill_")
    with open(tmp.name, "wb") as f:
        f.write(uploaded_file.getbuffer())
    return tmp.name

# File uploader for image or PDF bills
uploaded_file = st.file_uploader("Choose an image (PNG/JPEG) or PDF", type=["png", "jpg", "jpeg", "pdf"])

if uploaded_file is not None:
    file_id = make_file_id(uploaded_file)
    if uploaded_file.type != "application/pdf":
        try:
            preview_image = Image.open(uploaded_file).convert("RGB")
            print(f"Previewing image...")
            st.image(preview_image, caption="Preview", width="content")
        except Exception:
            st.write("Preview not available for this image.")
    else:
        st.write("PDF uploaded. Preview not shown.")

    # If it's the same file already processed, show status and results without reprocessing
    if st.session_state.uploaded_file_id == file_id and st.session_state.processed_ok:
        st.success("This file was already processed in this session.")
    else:
        # Show action buttons
        cols = st.columns([1, 1, 1])
        with cols[0]:
            if st.button("Process bill"):
                # save to temp and process
                tmp_path = save_uploaded_to_temp(uploaded_file)
                st.session_state.temp_path = tmp_path
                st.session_state.uploaded_file_id = file_id

                # run processing with a spinner
                with st.spinner("Processing bill (OCR + parsing)... this may take a while for large PDFs"):
                    # option: free GPU memory (harmless if not available)
                    try:
                        import torch, gc
                        torch.cuda.empty_cache()
                        gc.collect()
                    except Exception:
                        pass

                    result = assistant.process_bill(tmp_path)
                    st.session_state.process_result = result
                    st.session_state.processed_ok = result.startswith("‚úÖ")

                if st.session_state.processed_ok:
                    st.success("Bill processed successfully!")
                else:
                    st.error("Processing failed: " + (st.session_state.process_result or "Unknown error"))

        with cols[1]:
            if st.button("Clear / Reset"):
                # remove temp file if any
                try:
                    if st.session_state.temp_path and os.path.exists(st.session_state.temp_path):
                        os.remove(st.session_state.temp_path)
                except Exception:
                    pass
                # reset session-state fields
                st.session_state.uploaded_file_id = None
                st.session_state.temp_path = None
                st.session_state.process_result = None
                st.session_state.processed_ok = False
                st.session_state.messages = []
                st.experimental_rerun()

        with cols[2]:
            if st.button("Show raw extracted text (if processed)"):
                if st.session_state.processed_ok and assistant.bill_text:
                    st.text_area("Extracted text", assistant.bill_text, height=300)
                else:
                    st.info("No extracted text available. Process the file first.")

# If processing was successful, show summary & download
if st.session_state.processed_ok and assistant.current_bill:
    bill = assistant.current_bill
    st.write("---")
    st.header("Extracted bill summary")
    st.markdown(f"**Invoice:** {bill.get('invoice_number', 'N/A')}")
    st.markdown(f"**Date:** {bill.get('invoice_date', 'N/A')}")
    st.markdown(f"**Total:** ${bill.get('total', 'N/A')}")

    items = bill.get("items", [])
    if items:
        df_items = pd.DataFrame(items)
        st.write("**Line Items:**")
        st.table(df_items)

    bill_json = json.dumps(bill, indent=2)
    st.download_button(
        "Download Bill as JSON",
        data=bill_json,
        file_name="extracted_bill.json",
        mime="application/json"
    )

# Chat interface for queries about the bill
st.write("---")
st.write("### Ask questions about the bill or chit-chat")

# Display previous messages
for msg in st.session_state.messages:
    role = msg.get("role", "user")
    with st.chat_message(role):
        st.markdown(msg["content"])

# Chat input
user_input = st.chat_input("Your message...")
if user_input:
    # Add user message
    st.session_state.messages.append({"role": "user", "content": user_input})
    # Get assistant response
    response = assistant.chat(user_input)
    st.session_state.messages.append({"role": "assistant", "content": response})
    # Display assistant's new message
    with st.chat_message("assistant"):
        st.markdown(response)

Writing main.py


In [None]:
!streamlit run main.py & npx localtunnel --port 8501 --y