In [None]:
# STEP 1: Run the sample data generator first
# (Creates all test documents)

# STEP 2: Run the main application
# (Auto-installs all packages, launches Gradio UI)

# STEP 3: Upload the generated documents
# - PDFs from sample_tax_docs/
# - CSV/Excel files
# - Images

# STEP 4: Start chatting!
# Ask questions like:
# - "What is my total taxable income?"
# - "Show me my PF contribution"
# - "List all 80C deductions"

In [None]:
# ==========================================
# SAMPLE TAX DATA GENERATOR
# Creates realistic sample documents for testing
# ==========================================

# ⚠️ RUN THIS CELL FIRST - Install reportlab
# Then run the rest of the code

print("📦 Installing reportlab for PDF generation...")
try:
    import reportlab
    print("✅ reportlab already installed!")
except ImportError:
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "reportlab"])
    print("✅ reportlab installed successfully!")

# NOW import all libraries
import pandas as pd
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from PIL import Image, ImageDraw
import os

print("\n🎨 Creating sample tax documents...")

# Create output directory
os.makedirs('sample_tax_docs', exist_ok=True)

# ==========================================
# 1. CREATE SAMPLE SALARY SLIP PDF
# ==========================================

def create_salary_slip():
    pdf_path = 'sample_tax_docs/salary_slip_apr2024.pdf'
    c = canvas.Canvas(pdf_path, pagesize=letter)
    width, height = letter

    # Header
    c.setFont("Helvetica-Bold", 16)
    c.drawString(50, height - 50, "SALARY SLIP")

    c.setFont("Helvetica", 10)
    c.drawString(50, height - 80, "Employee: John Doe")
    c.drawString(50, height - 95, "Employee ID: EMP001")
    c.drawString(50, height - 110, "Month: April 2024")
    c.drawString(50, height - 125, "Department: Engineering")

    # Earnings
    c.setFont("Helvetica-Bold", 12)
    c.drawString(50, height - 160, "EARNINGS")

    c.setFont("Helvetica", 10)
    earnings = [
        ("Basic Salary", "Rs.50,000"),
        ("HRA", "Rs.20,000"),
        ("Special Allowance", "Rs.10,000"),
        ("Transport Allowance", "Rs.2,000"),
    ]

    y = height - 180
    for item, amount in earnings:
        c.drawString(70, y, item)
        c.drawString(300, y, amount)
        y -= 20

    c.setFont("Helvetica-Bold", 10)
    c.drawString(70, y - 10, "Gross Salary")
    c.drawString(300, y - 10, "Rs.82,000")

    # Deductions
    y -= 50
    c.setFont("Helvetica-Bold", 12)
    c.drawString(50, y, "DEDUCTIONS")

    c.setFont("Helvetica", 10)
    y -= 20
    deductions = [
        ("Provident Fund (PF)", "Rs.6,000"),
        ("Professional Tax", "Rs.200"),
        ("TDS", "Rs.8,500"),
    ]

    for item, amount in deductions:
        c.drawString(70, y, item)
        c.drawString(300, y, amount)
        y -= 20

    c.setFont("Helvetica-Bold", 10)
    c.drawString(70, y - 10, "Total Deductions")
    c.drawString(300, y - 10, "Rs.14,700")

    # Net Pay
    y -= 40
    c.setFont("Helvetica-Bold", 12)
    c.drawString(70, y, "NET PAY")
    c.drawString(300, y, "Rs.67,300")

    c.save()
    print(f"✅ Created: {pdf_path}")
    return pdf_path

# ==========================================
# 2. CREATE FORM 16 PDF
# ==========================================

def create_form16():
    pdf_path = 'sample_tax_docs/form16_fy2023-24.pdf'
    c = canvas.Canvas(pdf_path, pagesize=letter)
    width, height = letter

    c.setFont("Helvetica-Bold", 16)
    c.drawString(50, height - 50, "FORM 16 - Part A & B")

    c.setFont("Helvetica", 10)
    c.drawString(50, height - 80, "Assessment Year: 2024-25")
    c.drawString(50, height - 95, "Financial Year: 2023-24")
    c.drawString(50, height - 110, "Employee: John Doe")
    c.drawString(50, height - 125, "PAN: ABCDE1234F")

    # Income Details
    c.setFont("Helvetica-Bold", 12)
    c.drawString(50, height - 160, "INCOME DETAILS")

    c.setFont("Helvetica", 10)
    y = height - 180
    income_items = [
        ("Gross Salary (Annual)", "Rs.9,84,000"),
        ("Less: Standard Deduction", "Rs.50,000"),
        ("Less: HRA Exemption", "Rs.1,20,000"),
        ("Taxable Salary Income", "Rs.8,14,000"),
    ]

    for item, amount in income_items:
        c.drawString(70, y, item)
        c.drawString(350, y, amount)
        y -= 20

    # Deductions
    y -= 30
    c.setFont("Helvetica-Bold", 12)
    c.drawString(50, y, "DEDUCTIONS UNDER CHAPTER VI-A")

    c.setFont("Helvetica", 10)
    y -= 20
    deductions = [
        ("80C - EPF, ELSS", "Rs.1,50,000"),
        ("80D - Health Insurance", "Rs.25,000"),
        ("80TTA - Savings Interest", "Rs.10,000"),
    ]

    for item, amount in deductions:
        c.drawString(70, y, item)
        c.drawString(350, y, amount)
        y -= 20

    c.setFont("Helvetica-Bold", 10)
    c.drawString(70, y - 10, "Total Deductions")
    c.drawString(350, y - 10, "Rs.1,85,000")

    # Taxable Income
    y -= 40
    c.drawString(70, y, "Net Taxable Income")
    c.drawString(350, y, "Rs.6,29,000")

    y -= 30
    c.drawString(70, y, "Total TDS Deducted")
    c.drawString(350, y, "Rs.52,500")

    c.save()
    print(f"✅ Created: {pdf_path}")
    return pdf_path

# ==========================================
# 3. CREATE INVESTMENT DETAILS CSV
# ==========================================

def create_investment_csv():
    csv_path = 'sample_tax_docs/investments_fy2023-24.csv'

    data = {
        'Investment Type': [
            'Employee Provident Fund (EPF)',
            'Public Provident Fund (PPF)',
            'Equity Linked Savings Scheme (ELSS)',
            'Life Insurance Premium',
            'Health Insurance Premium',
            'National Savings Certificate (NSC)',
            'Home Loan Principal Repayment'
        ],
        'Section': ['80C', '80C', '80C', '80C', '80D', '80C', '80C'],
        'Amount (Rs)': [72000, 30000, 50000, 24000, 25000, 15000, 40000],
        'Date': ['2023-04-15', '2023-05-20', '2023-06-10', '2023-07-05',
                 '2023-08-12', '2023-09-08', '2023-10-15'],
        'Document Available': ['Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes']
    }

    df = pd.DataFrame(data)
    df.to_csv(csv_path, index=False)
    print(f"✅ Created: {csv_path}")
    print(f"   Total 80C: Rs.{df[df['Section'] == '80C']['Amount (Rs)'].sum():,}")
    print(f"   Total 80D: Rs.{df[df['Section'] == '80D']['Amount (Rs)'].sum():,}")
    return csv_path

# ==========================================
# 4. CREATE EXPENSE REPORT EXCEL
# ==========================================

def create_expense_excel():
    excel_path = 'sample_tax_docs/monthly_expenses_2023.xlsx'

    data = {
        'Month': ['Apr-23', 'May-23', 'Jun-23', 'Jul-23', 'Aug-23', 'Sep-23',
                  'Oct-23', 'Nov-23', 'Dec-23', 'Jan-24', 'Feb-24', 'Mar-24'],
        'Rent': [25000] * 12,
        'Medical': [2000, 1500, 3000, 1800, 2500, 2000, 2200, 1900, 2300, 2100, 1700, 2400],
        'Education': [0, 0, 15000, 0, 0, 0, 0, 0, 0, 0, 0, 10000],
        'Home Loan EMI': [18000] * 12,
        'Other': [10000, 9500, 11000, 10500, 9800, 10200, 10800, 10300, 11500, 10700, 9900, 10400]
    }

    df = pd.DataFrame(data)
    df['Total'] = df[['Rent', 'Medical', 'Education', 'Home Loan EMI', 'Other']].sum(axis=1)

    df.to_excel(excel_path, index=False, sheet_name='Monthly_Expenses')
    print(f"✅ Created: {excel_path}")
    print(f"   Annual Total Expenses: Rs.{df['Total'].sum():,}")
    return excel_path

# ==========================================
# 5. CREATE SAMPLE INVESTMENT PROOF IMAGES
# ==========================================

def create_sample_images():
    images_created = []

    # PF Receipt
    img1 = Image.new('RGB', (600, 400), color='white')
    draw1 = ImageDraw.Draw(img1)

    draw1.rectangle([10, 10, 590, 390], outline='black', width=2)
    draw1.text((150, 50), "PROVIDENT FUND RECEIPT", fill='black')
    draw1.text((50, 120), "Employee: John Doe", fill='black')
    draw1.text((50, 150), "PF Account: PF/123/456789", fill='black')
    draw1.text((50, 180), "Contribution (Apr 2023): Rs.6,000", fill='black')
    draw1.text((50, 210), "Employer Contribution: Rs.6,000", fill='black')
    draw1.text((50, 240), "Total: Rs.12,000", fill='black')
    draw1.text((50, 300), "Date: 05-May-2023", fill='black')

    img1_path = 'sample_tax_docs/pf_receipt.jpg'
    img1.save(img1_path)
    print(f"✅ Created: {img1_path}")
    images_created.append(img1_path)

    # ELSS Investment Certificate
    img2 = Image.new('RGB', (600, 400), color='lightblue')
    draw2 = ImageDraw.Draw(img2)

    draw2.rectangle([10, 10, 590, 390], outline='darkblue', width=3)
    draw2.text((100, 50), "ELSS INVESTMENT CERTIFICATE", fill='darkblue')
    draw2.text((50, 120), "Fund Name: Growth Equity Fund", fill='black')
    draw2.text((50, 150), "Investor: John Doe", fill='black')
    draw2.text((50, 180), "Investment Amount: Rs.50,000", fill='black')
    draw2.text((50, 210), "Units Allotted: 1,234.56", fill='black')
    draw2.text((50, 240), "NAV: Rs.40.50", fill='black')
    draw2.text((50, 300), "Date: 10-Jun-2023", fill='black')
    draw2.text((50, 330), "Lock-in Period: 3 Years", fill='red')

    img2_path = 'sample_tax_docs/elss_certificate.jpg'
    img2.save(img2_path)
    print(f"✅ Created: {img2_path}")
    images_created.append(img2_path)

    # LIC Premium Receipt
    img3 = Image.new('RGB', (600, 400), color='lightyellow')
    draw3 = ImageDraw.Draw(img3)

    draw3.rectangle([10, 10, 590, 390], outline='orange', width=2)
    draw3.text((150, 50), "LIC PREMIUM RECEIPT", fill='darkorange')
    draw3.text((50, 120), "Policy Holder: John Doe", fill='black')
    draw3.text((50, 150), "Policy No: 123456789", fill='black')
    draw3.text((50, 180), "Premium Amount: Rs.24,000", fill='black')
    draw3.text((50, 210), "Premium Type: Annual", fill='black')
    draw3.text((50, 240), "Coverage: Rs.10,00,000", fill='black')
    draw3.text((50, 300), "Payment Date: 05-Jul-2023", fill='black')

    img3_path = 'sample_tax_docs/lic_premium.jpg'
    img3.save(img3_path)
    print(f"✅ Created: {img3_path}")
    images_created.append(img3_path)

    return images_created

# ==========================================
# GENERATE ALL SAMPLE DATA
# ==========================================

def generate_all_documents():
    """Main function to generate all sample documents"""
    print("\n" + "="*50)
    print("🎨 GENERATING SAMPLE TAX DOCUMENTS")
    print("="*50 + "\n")

    files_created = {
        'pdfs': [],
        'tables': [],
        'images': []
    }

    # Create documents
    files_created['pdfs'].append(create_salary_slip())
    files_created['pdfs'].append(create_form16())
    files_created['tables'].append(create_investment_csv())
    files_created['tables'].append(create_expense_excel())
    files_created['images'].extend(create_sample_images())

    print("\n" + "="*50)
    print("✅ ALL SAMPLE DOCUMENTS CREATED!")
    print("="*50)
    print(f"\n📁 Files created in 'sample_tax_docs/' folder:")
    print(f"\n📄 PDFs ({len(files_created['pdfs'])}):")
    for f in files_created['pdfs']:
        print(f"   - {os.path.basename(f)}")

    print(f"\n📊 Tables ({len(files_created['tables'])}):")
    for f in files_created['tables']:
        print(f"   - {os.path.basename(f)}")

    print(f"\n🖼️ Images ({len(files_created['images'])}):")
    for f in files_created['images']:
        print(f"   - {os.path.basename(f)}")

    print("\n💡 Use these files to test the Multimodal RAG system!")
    print("📌 Image captions for upload: 'PF Receipt, ELSS Investment, LIC Premium'\n")

    return files_created

# Run the generator
if __name__ == "__main__":
    generate_all_documents()

# Call the function to generate everything
generate_all_documents()

📦 Installing reportlab for PDF generation...
✅ reportlab installed successfully!

🎨 Creating sample tax documents...

🎨 GENERATING SAMPLE TAX DOCUMENTS

✅ Created: sample_tax_docs/salary_slip_apr2024.pdf
✅ Created: sample_tax_docs/form16_fy2023-24.pdf
✅ Created: sample_tax_docs/investments_fy2023-24.csv
   Total 80C: Rs.231,000
   Total 80D: Rs.25,000
✅ Created: sample_tax_docs/monthly_expenses_2023.xlsx
   Annual Total Expenses: Rs.691,000
✅ Created: sample_tax_docs/pf_receipt.jpg
✅ Created: sample_tax_docs/elss_certificate.jpg
✅ Created: sample_tax_docs/lic_premium.jpg

✅ ALL SAMPLE DOCUMENTS CREATED!

📁 Files created in 'sample_tax_docs/' folder:

📄 PDFs (2):
   - salary_slip_apr2024.pdf
   - form16_fy2023-24.pdf

📊 Tables (2):
   - investments_fy2023-24.csv
   - monthly_expenses_2023.xlsx

🖼️ Images (3):
   - pf_receipt.jpg
   - elss_certificate.jpg
   - lic_premium.jpg

💡 Use these files to test the Multimodal RAG system!
📌 Image captions for upload: 'PF Receipt, ELSS Investment, 

{'pdfs': ['sample_tax_docs/salary_slip_apr2024.pdf',
  'sample_tax_docs/form16_fy2023-24.pdf'],
 'tables': ['sample_tax_docs/investments_fy2023-24.csv',
  'sample_tax_docs/monthly_expenses_2023.xlsx'],
 'images': ['sample_tax_docs/pf_receipt.jpg',
  'sample_tax_docs/elss_certificate.jpg',
  'sample_tax_docs/lic_premium.jpg']}

In [None]:
# ==========================================
# MULTIMODAL RAG TAX FILING ASSISTANT
# Colab-Ready Implementation
# ==========================================

# STEP 1: Install Required Packages
print("📦 Installing required packages...")
!pip install -q sentence-transformers faiss-cpu pypdf2 pandas pillow gradio openai python-dotenv transformers torch torchvision openpyxl tabulate

# STEP 2: Import Libraries
import os
import json
import base64
from io import BytesIO
import numpy as np
import pandas as pd
from PIL import Image
import faiss
from sentence_transformers import SentenceTransformer
import PyPDF2
import gradio as gr
from typing import List, Dict, Tuple, Any
import warnings
warnings.filterwarnings('ignore')

print("✅ All packages installed and imported successfully!")

# ==========================================
# STEP 3: Initialize Models and Vector Store
# ==========================================

class MultimodalRAGSystem:
    """
    Multimodal RAG system supporting text (PDFs), tables (CSV/Excel), and images.
    Uses FAISS for vector storage and retrieval.
    """

    def __init__(self):
        print("🚀 Initializing Multimodal RAG System...")

        # Text embeddings model
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
        print("✅ Text embedding model loaded")

        # Image embeddings model (using CLIP)
        self.image_model = SentenceTransformer('clip-ViT-B-32')
        print("✅ Image embedding model loaded")

        # Storage for different modalities
        self.text_chunks = []
        self.table_chunks = []
        self.image_chunks = []

        # Vector stores (FAISS indices)
        self.text_index = None
        self.table_index = None
        self.image_index = None

        # Document metadata
        self.text_metadata = []
        self.table_metadata = []
        self.image_metadata = []

        print("✅ RAG System initialized successfully!")

    # ==========================================
    # TEXT PROCESSING (PDFs)
    # ==========================================

    def process_pdf(self, pdf_path: str) -> List[Dict]:
        """Extract text from PDF and chunk it"""
        print(f"📄 Processing PDF: {pdf_path}")
        chunks = []

        try:
            with open(pdf_path, 'rb') as file:
                pdf_reader = PyPDF2.PdfReader(file)
                text = ""
                for page_num, page in enumerate(pdf_reader.pages):
                    text += page.extract_text() + "\n"

                # Chunk text (simple sentence-based chunking)
                sentences = text.split('.')
                chunk_size = 3  # sentences per chunk

                for i in range(0, len(sentences), chunk_size):
                    chunk_text = '. '.join(sentences[i:i+chunk_size]).strip()
                    if len(chunk_text) > 20:  # minimum chunk size
                        chunks.append({
                            'text': chunk_text,
                            'source': os.path.basename(pdf_path),
                            'type': 'pdf',
                            'page': i // chunk_size
                        })

            print(f"✅ Extracted {len(chunks)} chunks from PDF")
            return chunks

        except Exception as e:
            print(f"❌ Error processing PDF: {e}")
            return []

    # ==========================================
    # TABLE PROCESSING (CSV/Excel)
    # ==========================================

    def process_table(self, table_path: str) -> List[Dict]:
        """Process CSV/Excel files and convert to text chunks"""
        print(f"📊 Processing table: {table_path}")
        chunks = []

        try:
            # Read file based on extension
            if table_path.endswith('.csv'):
                df = pd.read_csv(table_path)
            else:
                df = pd.read_excel(table_path)

            # Store overall summary
            summary = f"Table from {os.path.basename(table_path)} with {len(df)} rows and columns: {', '.join(df.columns.tolist())}"
            chunks.append({
                'text': summary,
                'source': os.path.basename(table_path),
                'type': 'table_summary',
                'data': df.to_dict()
            })

            # Convert each row to text
            for idx, row in df.iterrows():
                row_text = f"Row {idx+1}: " + ", ".join([f"{col}={val}" for col, val in row.items()])
                chunks.append({
                    'text': row_text,
                    'source': os.path.basename(table_path),
                    'type': 'table_row',
                    'row_data': row.to_dict()
                })

            print(f"✅ Extracted {len(chunks)} chunks from table")
            return chunks

        except Exception as e:
            print(f"❌ Error processing table: {e}")
            return []

    # ==========================================
    # IMAGE PROCESSING
    # ==========================================

    def process_image(self, image_path: str, caption: str = "") -> Dict:
        """Process image and create embedding"""
        print(f"🖼️ Processing image: {image_path}")

        try:
            image = Image.open(image_path).convert('RGB')

            # Create text description for the image
            image_text = f"Image: {os.path.basename(image_path)}"
            if caption:
                image_text += f" - {caption}"

            chunk = {
                'text': image_text,
                'source': os.path.basename(image_path),
                'type': 'image',
                'image_path': image_path,
                'image': image,
                'caption': caption
            }

            print(f"✅ Processed image successfully")
            return chunk

        except Exception as e:
            print(f"❌ Error processing image: {e}")
            return None

    # ==========================================
    # EMBEDDING AND INDEXING
    # ==========================================

    def build_vector_store(self):
        """Create FAISS indices for all modalities"""
        print("\n🔨 Building vector stores...")

        # Text embeddings
        if self.text_chunks:
            text_contents = [chunk['text'] for chunk in self.text_chunks]
            text_embeddings = self.text_model.encode(text_contents, show_progress_bar=True)

            dimension = text_embeddings.shape[1]
            self.text_index = faiss.IndexFlatL2(dimension)
            self.text_index.add(text_embeddings.astype('float32'))
            print(f"✅ Text index built: {len(self.text_chunks)} chunks")

        # Table embeddings
        if self.table_chunks:
            table_contents = [chunk['text'] for chunk in self.table_chunks]
            table_embeddings = self.text_model.encode(table_contents, show_progress_bar=True)

            dimension = table_embeddings.shape[1]
            self.table_index = faiss.IndexFlatL2(dimension)
            self.table_index.add(table_embeddings.astype('float32'))
            print(f"✅ Table index built: {len(self.table_chunks)} chunks")

        # Image embeddings
        if self.image_chunks:
            image_contents = [chunk['text'] for chunk in self.image_chunks]
            image_embeddings = self.image_model.encode(image_contents, show_progress_bar=True)

            dimension = image_embeddings.shape[1]
            self.image_index = faiss.IndexFlatL2(dimension)
            self.image_index.add(image_embeddings.astype('float32'))
            print(f"✅ Image index built: {len(self.image_chunks)} chunks")

        print("✅ All vector stores built successfully!")

    # ==========================================
    # RETRIEVAL
    # ==========================================

    def retrieve(self, query: str, top_k: int = 3) -> Dict[str, List]:
        """Retrieve relevant chunks from all modalities"""
        results = {
            'text': [],
            'table': [],
            'image': []
        }

        # Search text
        if self.text_index:
            query_embedding = self.text_model.encode([query]).astype('float32')
            distances, indices = self.text_index.search(query_embedding, top_k)

            for idx, dist in zip(indices[0], distances[0]):
                if idx < len(self.text_chunks):
                    chunk = self.text_chunks[idx].copy()
                    chunk['score'] = float(dist)
                    results['text'].append(chunk)

        # Search tables
        if self.table_index:
            query_embedding = self.text_model.encode([query]).astype('float32')
            distances, indices = self.table_index.search(query_embedding, top_k)

            for idx, dist in zip(indices[0], distances[0]):
                if idx < len(self.table_chunks):
                    chunk = self.table_chunks[idx].copy()
                    chunk['score'] = float(dist)
                    results['table'].append(chunk)

        # Search images
        if self.image_index:
            query_embedding = self.image_model.encode([query]).astype('float32')
            distances, indices = self.image_index.search(query_embedding, top_k)

            for idx, dist in zip(indices[0], distances[0]):
                if idx < len(self.image_chunks):
                    chunk = self.image_chunks[idx].copy()
                    chunk['score'] = float(dist)
                    results['image'].append(chunk)

        return results

    # ==========================================
    # ANSWER GENERATION
    # ==========================================

    def generate_answer(self, query: str, retrieved_context: Dict) -> Tuple[str, List]:
        """Generate answer using retrieved context (rule-based for demo)"""

        # Collect all text context
        context_parts = []

        # Add text chunks
        for chunk in retrieved_context['text'][:2]:
            context_parts.append(f"[PDF - {chunk['source']}]: {chunk['text']}")

        # Add table chunks
        for chunk in retrieved_context['table'][:2]:
            context_parts.append(f"[Table - {chunk['source']}]: {chunk['text']}")

        # Add image descriptions
        images_to_show = []
        for chunk in retrieved_context['image'][:2]:
            context_parts.append(f"[Image - {chunk['source']}]: {chunk['caption'] or 'Related image'}")
            images_to_show.append(chunk['image'])

        # Build context
        context = "\n\n".join(context_parts)

        # Simple rule-based answer generation
        answer = self._create_answer(query, context, retrieved_context)

        return answer, images_to_show

    def _create_answer(self, query: str, context: str, retrieved: Dict) -> str:
        """Create answer based on query patterns"""
        query_lower = query.lower()

        # Tax-specific queries
        if "taxable income" in query_lower or "total income" in query_lower:
            # Look for income amounts in tables
            answer = "Based on the retrieved documents:\n\n"
            answer += f"Context:\n{context}\n\n"
            answer += "Your taxable income information has been extracted from the uploaded salary slips and Form 16. "
            answer += "Please review the retrieved context above for specific amounts."
            return answer

        elif "deduction" in query_lower or "80c" in query_lower:
            answer = "Tax Deductions Summary:\n\n"
            answer += f"Retrieved Information:\n{context}\n\n"
            answer += "Common Section 80C deductions include: EPF/PPF, ELSS, Life Insurance, Home Loan Principal, etc. "
            answer += "Check the retrieved documents for your specific investments."
            return answer

        elif "tds" in query_lower:
            answer = "TDS (Tax Deducted at Source) Information:\n\n"
            answer += f"{context}\n\n"
            answer += "TDS details are typically found in Form 26AS or your salary slips. "
            answer += "Compare your tax liability with TDS deducted to determine any balance payment needed."
            return answer

        elif "pf" in query_lower or "provident fund" in query_lower:
            answer = "Provident Fund Information:\n\n"
            answer += f"{context}\n\n"
            if retrieved['image']:
                answer += "📎 Related document images are displayed below."
            return answer

        else:
            # Generic answer
            answer = f"Query: {query}\n\n"
            answer += f"Retrieved Context:\n{context}\n\n"
            answer += "Based on the documents you've uploaded, the above information is most relevant to your query. "
            answer += "If you need more specific information, please try rephrasing your question."
            return answer

    # ==========================================
    # ADD DOCUMENTS
    # ==========================================

    def add_documents(self, pdf_files, table_files, image_files, image_captions):
        """Add documents to the system"""
        print("\n📥 Adding documents to RAG system...")

        # Process PDFs
        for pdf_file in pdf_files or []:
            chunks = self.process_pdf(pdf_file.name)
            self.text_chunks.extend(chunks)

        # Process tables
        for table_file in table_files or []:
            chunks = self.process_table(table_file.name)
            self.table_chunks.extend(chunks)

        # Process images
        if image_files:
            captions_list = image_captions.split(',') if image_captions else []
            for idx, image_file in enumerate(image_files):
                caption = captions_list[idx].strip() if idx < len(captions_list) else ""
                chunk = self.process_image(image_file.name, caption)
                if chunk:
                    self.image_chunks.append(chunk)

        # Build indices
        self.build_vector_store()

        summary = f"""
        📊 Documents Added Successfully!

        - PDFs: {len(pdf_files or [])} files, {len(self.text_chunks)} text chunks
        - Tables: {len(table_files or [])} files, {len(self.table_chunks)} table chunks
        - Images: {len(image_files or [])} files, {len(self.image_chunks)} image chunks

        Total: {len(self.text_chunks) + len(self.table_chunks) + len(self.image_chunks)} chunks indexed

        You can now ask questions!
        """

        return summary

# ==========================================
# STEP 4: Create Gradio Interface
# ==========================================

# Initialize RAG system
rag_system = MultimodalRAGSystem()

def upload_documents(pdf_files, table_files, image_files, image_captions):
    """Handle document uploads"""
    try:
        result = rag_system.add_documents(pdf_files, table_files, image_files, image_captions)
        return result
    except Exception as e:
        return f"❌ Error uploading documents: {str(e)}"

def chat(message, history):
    """Handle chat interactions"""
    if not rag_system.text_chunks and not rag_system.table_chunks and not rag_system.image_chunks:
        return "⚠️ Please upload documents first using the 'Upload Documents' tab!"

    try:
        # Retrieve relevant context
        retrieved = rag_system.retrieve(message, top_k=3)

        # Generate answer
        answer, images = rag_system.generate_answer(message, retrieved)

        return answer

    except Exception as e:
        return f"❌ Error processing query: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Multimodal RAG Tax Assistant", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 📊 Multimodal RAG Tax Filing Assistant

    Upload your tax documents (PDFs, CSVs, Images) and ask questions about your tax information!

    ### Sample Questions:
    - "What is my total taxable income this year?"
    - "List all Section 80C deductions I can claim"
    - "Did I submit enough investment proofs to avoid TDS?"
    - "Show me where my PF contribution is mentioned"
    """)

    with gr.Tabs():
        with gr.Tab("📤 Upload Documents"):
            gr.Markdown("### Upload your tax-related documents")

            with gr.Row():
                pdf_input = gr.File(
                    label="📄 Upload PDFs (Salary Slips, Form 16, etc.)",
                    file_count="multiple",
                    file_types=[".pdf"]
                )

            with gr.Row():
                table_input = gr.File(
                    label="📊 Upload Tables (CSV/Excel)",
                    file_count="multiple",
                    file_types=[".csv", ".xlsx", ".xls"]
                )

            with gr.Row():
                image_input = gr.File(
                    label="🖼️ Upload Images (Investment Proofs, Receipts)",
                    file_count="multiple",
                    file_types=[".jpg", ".jpeg", ".png"]
                )
                image_captions = gr.Textbox(
                    label="Image Captions (comma-separated, optional)",
                    placeholder="PF Receipt, LIC Premium, ELSS Investment"
                )

            upload_btn = gr.Button("📥 Upload and Process Documents", variant="primary")
            upload_output = gr.Textbox(label="Upload Status", lines=8)

            upload_btn.click(
                fn=upload_documents,
                inputs=[pdf_input, table_input, image_input, image_captions],
                outputs=upload_output
            )

        with gr.Tab("💬 Chat with Assistant"):
            gr.Markdown("### Ask questions about your tax documents")

            chatbot = gr.ChatInterface(
                fn=chat,
                examples=[
                    "What is my total taxable income this year?",
                    "List all Section 80C deductions I can claim",
                    "Did I submit enough investment proofs?",
                    "Show me my PF contribution details",
                    "What TDS was deducted from my salary?"
                ],
                title="Tax Filing Assistant Chat"
            )

# ==========================================
# STEP 5: Launch the Application
# ==========================================

print("\n" + "="*50)
print("🚀 LAUNCHING MULTIMODAL RAG TAX ASSISTANT")
print("="*50)

# Launch with public link for Colab
demo.launch(share=True, debug=True)

📦 Installing required packages...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25h✅ All packages installed and imported successfully!
🚀 Initializing Multimodal RAG System...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

✅ Text embedding model loaded


modules.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

0_CLIPModel/model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

0_CLIPModel/pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/604 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


✅ Image embedding model loaded
✅ RAG System initialized successfully!

🚀 LAUNCHING MULTIMODAL RAG TAX ASSISTANT
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://36435d5717aedf1c44.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)



📥 Adding documents to RAG system...
📄 Processing PDF: /tmp/gradio/66799a5914dbe7204cc644be31d7c7013572809bccd81e311282aa0cb34ea0bf/salary_slip_apr2024.pdf
✅ Extracted 4 chunks from PDF
📄 Processing PDF: /tmp/gradio/90af8476ffed6ac2feb87039a2cd6cf2f5991fef282836cd8a1d63ae38974c07/form16_fy2023-24.pdf
✅ Extracted 4 chunks from PDF
📄 Processing PDF: /tmp/gradio/78119d132cb20fe2f9c6f3f221736e4810fd778a28ea9663cb92f27870519fe6/interview react.pdf
✅ Extracted 16 chunks from PDF
📊 Processing table: /tmp/gradio/98a2cf8225fb733eec4a83eac4526268a78c301e40964e62b0efc61d3c4765a9/monthly_expenses_2023.xlsx
✅ Extracted 13 chunks from table
📊 Processing table: /tmp/gradio/62cb01f59dd22a91e9949ac01791afe0ffa60f252c560a46eded949a1ac3a99f/investments_fy2023-24.csv
✅ Extracted 8 chunks from table
🖼️ Processing image: /tmp/gradio/8c978c3385796d035be5d62ca82db84f50c5d22afc63e0ef97fc6da135ca9de7/pf_receipt.jpg
✅ Processed image successfully
🖼️ Processing image: /tmp/gradio/35dd6f004eeb332628e2fb51cf7bcf67d

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

✅ Text index built: 24 chunks


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

✅ Table index built: 21 chunks


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

✅ Image index built: 3 chunks
✅ All vector stores built successfully!
