# üßæ TaxAlly - Real-Time Tax & Compliance Copilot

**Using HuggingFace + Qwen2.5-7B-Instruct**

This notebook runs entirely on Colab GPU (T4/A100) with no API keys required.

---

## 1Ô∏è‚É£ Setup & Installation

In [None]:
# Install required packages (takes ~2-3 minutes)
!pip install -q transformers accelerate bitsandbytes torch gradio
print("‚úÖ Packages installed")

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2Ô∏è‚É£ Load Qwen2.5-7B Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Model selection - uncomment your preferred model
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"  # Best reasoning (recommended)
# MODEL_NAME = "microsoft/Phi-3.5-mini-instruct"  # Faster, smaller
# MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"  # Lightweight option

print(f"Loading {MODEL_NAME}...")
print("This takes 2-4 minutes on first run (downloading ~14GB)")

# 4-bit quantization for memory efficiency
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True
)

print(f"\n‚úÖ Model loaded! Memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [None]:
# LLM Generation Function

def generate_response(system_prompt: str, user_message: str, max_tokens: int = 1024) -> str:
    """Generate response using the loaded model."""
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_message}
    ]
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    # Decode only new tokens
    response = tokenizer.decode(
        outputs[0][inputs['input_ids'].shape[1]:],
        skip_special_tokens=True
    )
    
    return response.strip()

# Quick test
test = generate_response("You are a helpful assistant.", "Say hello in one sentence.")
print(f"Test response: {test}")
print("\n‚úÖ LLM ready!")

## 3Ô∏è‚É£ Agent Core & Tools

In [None]:
# === AGENT CORE ===

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional
from datetime import datetime, timedelta, timezone
import json
import re


class AgentMode(Enum):
    INDIVIDUAL = "individual"
    MICRO_BUSINESS = "micro_business"


@dataclass
class AgentContext:
    user_id: str
    session_id: str
    mode: AgentMode
    entity_id: Optional[str] = None
    timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))


class BaseTool(ABC):
    @property
    @abstractmethod
    def name(self) -> str:
        pass

    @property
    @abstractmethod
    def description(self) -> str:
        pass

    @abstractmethod
    def execute(self, params: dict, context: AgentContext) -> Any:
        pass


print("‚úÖ Agent core loaded")

In [None]:
# === COMPLIANCE TOOLS ===

class ComplianceRuleEngine(BaseTool):
    """Check compliance against Indian tax rules."""
    
    # FY 2024-25 thresholds
    THRESHOLDS = {
        "gst_registration": 2000000,  # 20 lakh
        "gst_goods": 4000000,  # 40 lakh for goods
        "advance_tax": 10000,
        "presumptive_44ad": 20000000,  # 2 crore
        "presumptive_44ada": 5000000,  # 50 lakh
    }
    
    @property
    def name(self) -> str:
        return "compliance_check"
    
    @property
    def description(self) -> str:
        return "Check GST, Income Tax, TDS compliance and identify risks"
    
    def execute(self, params: dict, context: AgentContext) -> dict:
        profile = params.get("profile", {})
        turnover = params.get("turnover", 0)
        income = params.get("income", 0)
        
        return {
            "gst": self._check_gst(profile, turnover),
            "income_tax": self._check_income_tax(income),
            "advance_tax": self._check_advance_tax(income),
            "presumptive": self._check_presumptive(profile, turnover)
        }
    
    def _check_gst(self, profile: dict, turnover: float) -> dict:
        threshold = self.THRESHOLDS["gst_registration"]
        is_registered = profile.get("gst_registered", False)
        
        if turnover > threshold and not is_registered:
            return {
                "status": "üî¥ NON-COMPLIANT",
                "risk": f"GST registration REQUIRED. Turnover ‚Çπ{turnover:,.0f} exceeds ‚Çπ{threshold:,.0f} threshold.",
                "action": "Apply for GST registration on gst.gov.in immediately.",
                "penalty_risk": "Late registration may attract penalty up to ‚Çπ25,000"
            }
        elif turnover > threshold * 0.8 and not is_registered:
            return {
                "status": "üü° WARNING",
                "message": f"Approaching GST threshold. Current: ‚Çπ{turnover:,.0f}, Limit: ‚Çπ{threshold:,.0f}",
                "action": "Consider voluntary registration to claim ITC"
            }
        return {"status": "üü¢ COMPLIANT", "message": "GST status OK"}
    
    def _check_income_tax(self, income: float) -> dict:
        if income > 300000:
            tax_old = self._calculate_tax_old(income)
            tax_new = self._calculate_tax_new(income)
            
            return {
                "filing_required": True,
                "tax_comparison": {
                    "old_regime": f"‚Çπ{tax_old:,.0f}",
                    "new_regime": f"‚Çπ{tax_new:,.0f}",
                    "recommended": "New Regime" if tax_new < tax_old else "Old Regime",
                    "savings": f"‚Çπ{abs(tax_new - tax_old):,.0f}"
                },
                "due_date": "July 31, 2025 (non-audit)" if income < 10000000 else "October 31, 2025 (audit)"
            }
        return {"filing_required": False, "message": "Income below taxable threshold"}
    
    def _check_advance_tax(self, income: float) -> dict:
        tax = self._calculate_tax_new(income)
        if tax > self.THRESHOLDS["advance_tax"]:
            return {
                "required": True,
                "estimated_tax": f"‚Çπ{tax:,.0f}",
                "schedule": [
                    {"date": "June 15", "cumulative": "15%", "amount": f"‚Çπ{tax * 0.15:,.0f}"},
                    {"date": "Sept 15", "cumulative": "45%", "amount": f"‚Çπ{tax * 0.45:,.0f}"},
                    {"date": "Dec 15", "cumulative": "75%", "amount": f"‚Çπ{tax * 0.75:,.0f}"},
                    {"date": "Mar 15", "cumulative": "100%", "amount": f"‚Çπ{tax:,.0f}"}
                ],
                "next_deadline": self._get_next_advance_tax_deadline()
            }
        return {"required": False}
    
    def _check_presumptive(self, profile: dict, turnover: float) -> dict:
        income_source = profile.get("income_source", "business")
        
        if income_source == "profession":
            if turnover <= self.THRESHOLDS["presumptive_44ada"]:
                return {
                    "eligible": True,
                    "scheme": "44ADA (Professionals)",
                    "presumed_income": f"50% = ‚Çπ{turnover * 0.5:,.0f}",
                    "benefit": "No books of accounts required, file ITR-4"
                }
        else:
            if turnover <= self.THRESHOLDS["presumptive_44ad"]:
                return {
                    "eligible": True,
                    "scheme": "44AD (Business)",
                    "presumed_income": f"8% = ‚Çπ{turnover * 0.08:,.0f} (cash), 6% = ‚Çπ{turnover * 0.06:,.0f} (digital)",
                    "benefit": "No books of accounts required, file ITR-4"
                }
        return {"eligible": False, "message": "Turnover exceeds presumptive limits"}
    
    def _calculate_tax_new(self, income: float) -> float:
        """New regime FY 2024-25."""
        if income <= 300000: return 0
        elif income <= 700000: return (income - 300000) * 0.05
        elif income <= 1000000: return 20000 + (income - 700000) * 0.10
        elif income <= 1200000: return 50000 + (income - 1000000) * 0.15
        elif income <= 1500000: return 80000 + (income - 1200000) * 0.20
        else: return 140000 + (income - 1500000) * 0.30
    
    def _calculate_tax_old(self, income: float, deductions: float = 150000) -> float:
        """Old regime with standard deductions."""
        taxable = max(0, income - deductions)
        if taxable <= 250000: return 0
        elif taxable <= 500000: return (taxable - 250000) * 0.05
        elif taxable <= 1000000: return 12500 + (taxable - 500000) * 0.20
        else: return 112500 + (taxable - 1000000) * 0.30
    
    def _get_next_advance_tax_deadline(self) -> str:
        today = datetime.now()
        deadlines = [(6, 15, "Q1"), (9, 15, "Q2"), (12, 15, "Q3"), (3, 15, "Q4")]
        for m, d, q in deadlines:
            dt = today.replace(month=m, day=d)
            if dt > today:
                return f"{dt.strftime('%B %d')} ({q})"
        return "March 15 next year (Q4)"


class CalendarTracker(BaseTool):
    """Track tax compliance deadlines."""
    
    DEADLINES = {
        "gstr1": {"day": 11, "desc": "GSTR-1 (Outward supplies)", "for": "gst"},
        "gstr3b": {"day": 20, "desc": "GSTR-3B (Summary return + payment)", "for": "gst"},
        "tds_payment": {"day": 7, "desc": "TDS payment for previous month", "for": "tds"},
    }
    
    @property
    def name(self) -> str:
        return "calendar"
    
    @property
    def description(self) -> str:
        return "Get upcoming tax compliance deadlines"
    
    def execute(self, params: dict, context: AgentContext) -> list:
        days_ahead = params.get("days_ahead", 30)
        gst_registered = params.get("gst_registered", False)
        
        today = datetime.now()
        upcoming = []
        
        for key, dl in self.DEADLINES.items():
            if dl["for"] == "gst" and not gst_registered:
                continue
            
            day = dl["day"]
            next_date = today.replace(day=min(day, 28))
            if next_date <= today:
                if today.month == 12:
                    next_date = today.replace(year=today.year+1, month=1, day=day)
                else:
                    next_date = today.replace(month=today.month+1, day=day)
            
            days_until = (next_date - today).days
            if days_until <= days_ahead:
                upcoming.append({
                    "deadline": key.upper(),
                    "description": dl["desc"],
                    "date": next_date.strftime("%Y-%m-%d"),
                    "days_until": days_until,
                    "urgency": "üî¥ URGENT" if days_until <= 3 else "üü° SOON" if days_until <= 7 else "üü¢ OK"
                })
        
        return sorted(upcoming, key=lambda x: x["days_until"])


class TransactionCategorizer(BaseTool):
    """Categorize transactions for tax."""
    
    CATEGORIES = {
        "salary": {"patterns": ["salary", "payroll", "wages"], "head": "Salaries", "section": "Sec 17"},
        "professional": {"patterns": ["consulting", "freelance", "professional fee"], "head": "PGBP", "section": "Sec 44ADA"},
        "rent_received": {"patterns": ["rent received", "rental income"], "head": "House Property", "section": "Sec 24"},
        "interest": {"patterns": ["interest", "fd interest", "savings"], "head": "Other Sources", "section": "Sec 56"},
        "dividend": {"patterns": ["dividend"], "head": "Other Sources", "section": "Sec 56"},
        "capital_gain": {"patterns": ["stock sale", "mf redemption", "property sale"], "head": "Capital Gains", "section": "Sec 45"},
    }
    
    @property
    def name(self) -> str:
        return "categorize"
    
    @property
    def description(self) -> str:
        return "Categorize income/transaction for tax purposes"
    
    def execute(self, params: dict, context: AgentContext) -> dict:
        desc = params.get("description", "").lower()
        amount = params.get("amount", 0)
        
        for cat, info in self.CATEGORIES.items():
            if any(p in desc for p in info["patterns"]):
                result = {
                    "category": cat,
                    "income_head": info["head"],
                    "relevant_section": info["section"],
                    "confidence": 0.9
                }
                # Add specific guidance
                if cat == "professional":
                    result["tds"] = "TDS @10% u/s 194J if payment > ‚Çπ30,000"
                    result["gst"] = "18% GST applicable if turnover > ‚Çπ20L"
                elif cat == "interest":
                    result["tds"] = "TDS @10% u/s 194A if interest > ‚Çπ40,000 (‚Çπ50,000 for seniors)"
                return result
        
        return {"category": "unknown", "confidence": 0, "action": "Please provide more details"}


# Initialize tools
tools = {
    "compliance_check": ComplianceRuleEngine(),
    "calendar": CalendarTracker(),
    "categorize": TransactionCategorizer()
}

print("‚úÖ Tools loaded: compliance_check, calendar, categorize")

## 4Ô∏è‚É£ TaxAlly Agent

In [None]:
class TaxAllyAgent:
    """AI Tax Compliance Copilot."""
    
    def __init__(self):
        self.tools = tools
        self.user_profile = {}
        self.conversation = []
    
    def _build_system_prompt(self) -> str:
        profile_str = json.dumps(self.user_profile, indent=2) if self.user_profile else "Not collected yet"
        
        return f"""You are TaxAlly, an expert AI tax compliance assistant for India.

CURRENT USER PROFILE:
{profile_str}

AVAILABLE TOOLS (use when you need accurate data):
1. compliance_check - Check GST/Income Tax/TDS status. Params: turnover (number), income (number)
2. calendar - Get upcoming deadlines. Params: days_ahead (number), gst_registered (boolean)  
3. categorize - Classify a transaction. Params: description (string), amount (number)

TO USE A TOOL, include this exact format in your response:
```tool
{{"tool": "tool_name", "params": {{"turnover": 1500000, "income": 1500000}}}}
```

IMPORTANT: params must contain numbers/booleans, NOT strings for numeric values.

GUIDELINES:
1. Extract profile info from user messages (income, turnover, GST status)
2. Use tools to get accurate compliance data - don't guess
3. Explain tax concepts simply
4. Always highlight risks clearly with severity
5. Provide specific, actionable next steps
6. Mention relevant sections/forms where applicable

DISCLAIMERS:
- You provide guidance, not legal/tax advice
- Complex cases need a Chartered Accountant
- Tax laws change - verify with official sources

Be concise but thorough. Use bullet points for clarity."""
    
    def _parse_tool_calls(self, response: str) -> list:
        """Extract tool calls from response."""
        tool_calls = []
        pattern = r"```tool\n(.*?)\n```"
        matches = re.findall(pattern, response, re.DOTALL)
        
        for match in matches:
            try:
                data = json.loads(match)
                tool_calls.append(data)
            except:
                continue
        return tool_calls
    
    def _execute_tools(self, tool_calls: list) -> str:
        """Execute tools and format results."""
        results = []
        ctx = AgentContext(user_id="demo", session_id="demo", mode=AgentMode.INDIVIDUAL)
        
        for call in tool_calls:
            tool_name = call.get("tool")
            params = call.get("params", {})
            
            # Ensure params is a dict
            if not isinstance(params, dict):
                params = {}
            
            # Ensure profile is always a dict, never a string
            if "profile" in params and isinstance(params["profile"], str):
                # LLM passed profile as string, replace with actual profile
                params["profile"] = self.user_profile
            elif "profile" not in params:
                params["profile"] = self.user_profile
            
            # Inject values from user_profile if not in params
            if "turnover" not in params and "turnover" in self.user_profile:
                params["turnover"] = self.user_profile["turnover"]
            if "income" not in params and "income" in self.user_profile:
                params["income"] = self.user_profile["income"]
            if "gst_registered" not in params:
                params["gst_registered"] = self.user_profile.get("gst_registered", False)
            
            # Ensure numeric values are actually numbers
            for key in ["turnover", "income", "amount", "days_ahead"]:
                if key in params and isinstance(params[key], str):
                    try:
                        params[key] = float(params[key].replace(",", ""))
                    except:
                        params[key] = 0
            
            if tool_name in self.tools:
                try:
                    result = self.tools[tool_name].execute(params, ctx)
                    results.append(f"[{tool_name}]:\n{json.dumps(result, indent=2)}")
                except Exception as e:
                    results.append(f"[{tool_name}]: Error - {str(e)}")
        
        return "\n\n".join(results)
    
    def _extract_profile(self, message: str):
        """Extract profile info from message."""
        msg = message.lower()
        
        # Entity type
        if any(x in msg for x in ["freelancer", "freelance", "consultant"]):
            self.user_profile["entity_type"] = "individual"
            self.user_profile["income_source"] = "profession"
        elif "business" in msg:
            self.user_profile["entity_type"] = "proprietorship"
            self.user_profile["income_source"] = "business"
        
        # Income/turnover extraction
        patterns = [
            (r"(\d+(?:\.\d+)?)\s*(?:lakh|lac|l)\s*(?:per|a|/)\s*(?:year|annum|yr)", 100000),
            (r"(\d+(?:\.\d+)?)\s*(?:crore|cr)\s*(?:per|a|/)\s*(?:year|annum|yr)", 10000000),
            (r"(?:income|earning|turnover|revenue)[^0-9]*(\d+(?:,\d+)*)", 1),
            (r"(\d+(?:\.\d+)?)\s*(?:lakh|lac)", 100000),
        ]
        
        for pattern, multiplier in patterns:
            match = re.search(pattern, msg)
            if match:
                amount = float(match.group(1).replace(",", "")) * multiplier
                self.user_profile["income"] = amount
                self.user_profile["turnover"] = amount
                break
        
        # GST status
        if "gst registered" in msg or "have gst" in msg:
            self.user_profile["gst_registered"] = True
        elif "not gst" in msg or "no gst" in msg or "don't have gst" in msg:
            self.user_profile["gst_registered"] = False
    
    def chat(self, user_message: str) -> str:
        """Main chat interface."""
        # Extract profile updates
        self._extract_profile(user_message)
        
        # Generate response
        system_prompt = self._build_system_prompt()
        response = generate_response(system_prompt, user_message)
        
        # Check for tool calls
        tool_calls = self._parse_tool_calls(response)
        
        if tool_calls:
            tool_results = self._execute_tools(tool_calls)
            
            # Get final response with tool data
            followup = f"""Original question: {user_message}

TOOL RESULTS (use this accurate data in your response):
{tool_results}

Now provide a clear, helpful response using the tool results above. Be specific with numbers and dates."""
            
            response = generate_response(system_prompt, followup)
        
        self.conversation.append({"user": user_message, "assistant": response})
        return response


# Initialize agent
agent = TaxAllyAgent()
print("\n‚úÖ TaxAlly Agent ready!")

## 5Ô∏è‚É£ Demo Conversations

In [None]:
# Test the agent
print("=" * 70)
print("üßæ TaxAlly Demo")
print("=" * 70)

test_queries = [
    "I'm a freelance software developer earning 15 lakh per year. What are my tax obligations?",
]

for query in test_queries:
    print(f"\nüë§ User: {query}")
    print("-" * 50)
    response = agent.chat(query)
    print(f"\nü§ñ TaxAlly:\n{response}")
    print("=" * 70)

In [None]:
# Additional queries (run one at a time)

# Query 2: Deadlines
response = agent.chat("What are my upcoming tax deadlines?")
print(f"ü§ñ TaxAlly:\n{response}")

In [None]:
# Query 3: GST
response = agent.chat("Do I need to register for GST?")
print(f"ü§ñ TaxAlly:\n{response}")

In [None]:
# Query 4: Transaction categorization
response = agent.chat("I received 2 lakh as consulting fees from a US client. How should I handle this?")
print(f"ü§ñ TaxAlly:\n{response}")

## 6Ô∏è‚É£ Gradio Web Interface

In [None]:
import gradio as gr

def chat_fn(message, history):
    response = agent.chat(message)
    return response

# Check Gradio version for compatibility
print(f"Gradio version: {gr.__version__}")

demo = gr.ChatInterface(
    fn=chat_fn,
    title="üßæ TaxAlly - AI Tax Compliance Copilot",
    description=f"Powered by {MODEL_NAME} | For Indian individuals & micro-businesses",
    examples=[
        "I'm a freelancer earning 12 lakh per year. What are my tax obligations?",
        "My business turnover is 25 lakh. Do I need GST?",
        "What are my upcoming tax deadlines?",
        "Should I opt for old or new tax regime?",
        "I received consulting income of 50,000. What's the tax treatment?"
    ],
    theme="soft"
)

demo.launch(share=True, debug=True)

---

## Notes

### Model Options
| Model | VRAM | Speed | Quality |
|-------|------|-------|--------|
| Qwen2.5-7B-Instruct | ~5GB (4-bit) | Medium | Best |
| Phi-3.5-mini-instruct | ~3GB (4-bit) | Fast | Good |
| Qwen2.5-3B-Instruct | ~2GB (4-bit) | Fastest | OK |

### Architecture
- **LLM**: Reasoning & response generation
- **Tools**: Deterministic tax calculations
- **Profile**: In-memory, extracted from conversation

### Limitations
- No persistent storage (resets on restart)
- Document parsing not implemented
- Single user session