# üßæ TaxAlly API Server (Colab ‚Üí Webapp)

This notebook runs the TaxAlly HuggingFace server on Colab and exposes it via ngrok.

Your webapp can then connect to it!

## Setup
1. Enable GPU: Runtime ‚Üí Change runtime type ‚Üí T4 GPU
2. Run all cells
3. Copy the ngrok URL to your webapp's `.env`

In [None]:
#@title 1Ô∏è‚É£ Install Dependencies
!pip install -q transformers accelerate bitsandbytes torch
!pip install -q fastapi uvicorn pyngrok nest-asyncio
!pip install -q pydantic

print("‚úÖ Dependencies installed!")

In [None]:
#@title 2Ô∏è‚É£ Setup ngrok (Get free token at ngrok.com)
# Get your free auth token from: https://dashboard.ngrok.com/get-started/your-authtoken

NGROK_AUTH_TOKEN = ""  #@param {type:"string"}

if NGROK_AUTH_TOKEN:
    from pyngrok import ngrok
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)
    print("‚úÖ ngrok configured!")
else:
    print("‚ö†Ô∏è No ngrok token - will use Colab's default tunneling")
    print("Get free token at: https://dashboard.ngrok.com/get-started/your-authtoken")

In [None]:
#@title 3Ô∏è‚É£ Load HuggingFace Model (Qwen2.5-7B-Instruct)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"

print(f"üîÑ Loading {MODEL_NAME}...")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

# 4-bit quantization for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

print(f"‚úÖ Model loaded! Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

In [None]:
#@title 4Ô∏è‚É£ Define TaxAlly Tools
from datetime import datetime, timezone
from typing import Dict, Any, List, Optional

# Tax slabs (AY 2025-26)
OLD_REGIME_SLABS = [
    {"min": 0, "max": 250000, "rate": 0},
    {"min": 250000, "max": 500000, "rate": 5},
    {"min": 500000, "max": 1000000, "rate": 20},
    {"min": 1000000, "max": float('inf'), "rate": 30}
]

NEW_REGIME_SLABS = [
    {"min": 0, "max": 300000, "rate": 0},
    {"min": 300000, "max": 700000, "rate": 5},
    {"min": 700000, "max": 1000000, "rate": 10},
    {"min": 1000000, "max": 1200000, "rate": 15},
    {"min": 1200000, "max": 1500000, "rate": 20},
    {"min": 1500000, "max": float('inf'), "rate": 30}
]

SPECIAL_STATES = ['Arunachal Pradesh', 'Assam', 'Manipur', 'Meghalaya', 'Mizoram', 'Nagaland', 'Sikkim', 'Tripura']

def calculate_slab_tax(income: float, slabs: list) -> tuple:
    remaining = income
    total_tax = 0
    breakdown = []
    
    for slab in slabs:
        if remaining <= 0:
            break
        taxable = min(remaining, slab["max"] - slab["min"])
        tax = taxable * slab["rate"] / 100
        if tax > 0:
            breakdown.append({"slab": f"{slab['rate']}%", "tax": tax})
        total_tax += tax
        remaining -= taxable
    
    return total_tax, breakdown

class TaxAllyTools:
    @staticmethod
    def calculate_income_tax(income: float, deductions_80c: float = 0, deductions_80d: float = 0) -> dict:
        """Calculate income tax with old vs new regime comparison."""
        std_deduction = 75000
        
        # Old regime
        old_deductions = min(deductions_80c, 150000) + min(deductions_80d, 50000) + std_deduction
        old_taxable = max(0, income - old_deductions)
        old_tax, old_breakdown = calculate_slab_tax(old_taxable, OLD_REGIME_SLABS)
        if old_taxable <= 500000:
            old_tax = max(0, old_tax - 12500)
        old_cess = old_tax * 0.04
        old_total = old_tax + old_cess
        
        # New regime
        new_taxable = max(0, income - std_deduction)
        new_tax, new_breakdown = calculate_slab_tax(new_taxable, NEW_REGIME_SLABS)
        if new_taxable <= 700000:
            new_tax = max(0, new_tax - 25000)
        new_cess = new_tax * 0.04
        new_total = new_tax + new_cess
        
        savings = abs(old_total - new_total)
        recommendation = f"{'Old' if old_total < new_total else 'New'} Regime saves ‚Çπ{savings:,.0f}"
        
        return {
            "old_regime": {"totalTax": old_total, "effectiveRate": (old_total/income*100) if income > 0 else 0},
            "new_regime": {"totalTax": new_total, "effectiveRate": (new_total/income*100) if income > 0 else 0},
            "recommendation": recommendation
        }
    
    @staticmethod
    def check_gst_compliance(turnover: float, is_service: bool, state: str) -> dict:
        """Check GST registration requirement."""
        is_special = state in SPECIAL_STATES
        
        if is_service:
            threshold = 1000000 if is_special else 2000000
        else:
            threshold = 2000000 if is_special else 4000000
        
        required = turnover > threshold
        
        return {
            "registrationRequired": required,
            "threshold": threshold,
            "limitDescription": f"‚Çπ{threshold/100000:.0f} Lakhs ({'Special' if is_special else 'Regular'} State)",
            "recommendedAction": f"{'Register for GST immediately' if required else f'No registration needed. Headroom: ‚Çπ{(threshold-turnover)/100000:.1f}L'}"
        }
    
    @staticmethod
    def check_presumptive(gross_receipts: float, business_type: str) -> dict:
        """Check 44AD/44ADA eligibility."""
        is_professional = business_type == "professional"
        section = "44ADA" if is_professional else "44AD"
        limit = 7500000 if is_professional else 30000000
        
        eligible = gross_receipts <= limit
        
        if is_professional:
            deemed_income = gross_receipts * 0.50
            explanation = f"50% of ‚Çπ{gross_receipts:,.0f} = ‚Çπ{deemed_income:,.0f}"
        else:
            deemed_income = gross_receipts * 0.06  # Assuming mostly digital
            explanation = f"6% of ‚Çπ{gross_receipts:,.0f} = ‚Çπ{deemed_income:,.0f}"
        
        return {
            "section": section,
            "eligible": eligible,
            "deemedIncome": deemed_income,
            "explanation": explanation
        }
    
    @staticmethod
    def get_deadlines(has_gst: bool = False) -> dict:
        """Get upcoming tax deadlines."""
        today = datetime.now()
        fy_year = today.year if today.month >= 4 else today.year - 1
        
        deadlines = [
            {"name": "Advance Tax Q4", "date": f"{fy_year+1}-03-15", "category": "Income Tax"},
            {"name": "ITR Filing", "date": f"{fy_year+1}-07-31", "category": "Income Tax"},
        ]
        
        if has_gst:
            next_month = (today.month % 12) + 1
            next_year = today.year + (1 if today.month == 12 else 0)
            deadlines.append({"name": "GSTR-3B", "date": f"{next_year}-{next_month:02d}-20", "category": "GST"})
            deadlines.append({"name": "GSTR-1", "date": f"{next_year}-{next_month:02d}-11", "category": "GST"})
        
        # Calculate days until
        for d in deadlines:
            deadline_date = datetime.strptime(d["date"], "%Y-%m-%d")
            d["daysUntil"] = (deadline_date - today).days
            d["urgency"] = "CRITICAL" if d["daysUntil"] <= 7 else "WARNING" if d["daysUntil"] <= 15 else "NORMAL"
        
        return {"upcoming_deadlines": sorted(deadlines, key=lambda x: x["daysUntil"]), "total_count": len(deadlines)}

tools = TaxAllyTools()
print("‚úÖ TaxAlly Tools initialized!")

In [None]:
#@title 5Ô∏è‚É£ Define LLM Chat Function
import json
import re

SYSTEM_PROMPT = """You are TaxAlly, an expert AI tax compliance assistant for Indian individuals and micro-businesses.

You have access to these tools:
- calculate_income_tax(income, deductions_80c, deductions_80d): Compare old vs new tax regime
- check_gst_compliance(turnover, is_service, state): Check GST registration requirement
- check_presumptive(gross_receipts, business_type): Check 44AD/44ADA eligibility
- get_deadlines(has_gst): Get upcoming tax deadlines

When you need to use a tool, output it in this format:
```tool
{"tool": "tool_name", "params": {"param1": value1}}
```

Guidelines:
- Be precise and cite specific sections/rules
- Always explain your reasoning
- Flag compliance risks clearly
- Recommend professional consultation for complex cases
"""

def generate_response(user_message: str, profile: dict = None, history: list = None) -> dict:
    """Generate response using the LLM with tool calling."""
    
    # Build messages
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    
    if profile:
        profile_str = f"\nUser Profile: {json.dumps(profile)}"
        messages[0]["content"] += profile_str
    
    if history:
        for h in history[-5:]:  # Last 5 messages
            messages.append({"role": "user", "content": h.get("user", "")})
            messages.append({"role": "assistant", "content": h.get("assistant", "")})
    
    messages.append({"role": "user", "content": user_message})
    
    # Format for Qwen
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # Generate
    outputs = model.generate(
        **inputs,
        max_new_tokens=1024,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    
    # Check for tool calls
    tool_calls = []
    tool_pattern = r"```tool\n(.*?)\n```"
    matches = re.findall(tool_pattern, response_text, re.DOTALL)
    
    for match in matches:
        try:
            tool_data = json.loads(match)
            tool_name = tool_data.get("tool")
            params = tool_data.get("params", {})
            
            # Execute tool
            if tool_name == "calculate_income_tax":
                result = tools.calculate_income_tax(**params)
            elif tool_name == "check_gst_compliance":
                result = tools.check_gst_compliance(**params)
            elif tool_name == "check_presumptive":
                result = tools.check_presumptive(**params)
            elif tool_name == "get_deadlines":
                result = tools.get_deadlines(**params)
            else:
                result = {"error": f"Unknown tool: {tool_name}"}
            
            tool_calls.append({"tool": tool_name, "params": params, "result": result})
            
            # Inject result back and regenerate if needed
            response_text = response_text.replace(f"```tool\n{match}\n```", f"\n**Tool Result ({tool_name}):** {json.dumps(result)}\n")
            
        except json.JSONDecodeError:
            continue
    
    return {
        "response": response_text,
        "tool_calls": tool_calls
    }

# Test
test_response = generate_response("What is GST threshold for services in Maharashtra?")
print("Test response:", test_response["response"][:200])

In [None]:
#@title 6Ô∏è‚É£ Create FastAPI Server
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import uuid

app = FastAPI(title="TaxAlly API", version="1.0.0")

# Enable CORS for webapp
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Request/Response models
class ChatRequest(BaseModel):
    message: str
    session_id: Optional[str] = None
    user_id: Optional[str] = None
    profile: Optional[Dict[str, Any]] = None

class ChatResponse(BaseModel):
    response: str
    session_id: str
    tool_calls: List[dict] = []
    suggestions: List[str] = []

class ToolRequest(BaseModel):
    tool: str
    params: Dict[str, Any] = {}

# In-memory session storage
sessions = {}

@app.get("/health")
async def health():
    return {
        "status": "healthy",
        "model": "Qwen/Qwen2.5-7B-Instruct",
        "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
    }

@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    session_id = request.session_id or str(uuid.uuid4())
    
    # Get/create session history
    history = sessions.get(session_id, [])
    
    # Generate response
    result = generate_response(request.message, request.profile, history)
    
    # Store in history
    history.append({"user": request.message, "assistant": result["response"]})
    sessions[session_id] = history[-10:]  # Keep last 10
    
    return ChatResponse(
        response=result["response"],
        session_id=session_id,
        tool_calls=result.get("tool_calls", []),
        suggestions=["Ask about GST", "Calculate tax", "Check deadlines"]
    )

@app.get("/tools")
async def list_tools():
    return {
        "tools": [
            {"name": "calculate_income_tax", "description": "Compare old vs new tax regime"},
            {"name": "check_gst_compliance", "description": "Check GST registration requirement"},
            {"name": "check_presumptive", "description": "Check 44AD/44ADA eligibility"},
            {"name": "get_deadlines", "description": "Get upcoming tax deadlines"}
        ]
    }

@app.post("/tools/execute")
async def execute_tool(request: ToolRequest):
    try:
        if request.tool == "calculate_income_tax":
            result = tools.calculate_income_tax(**request.params)
        elif request.tool == "check_gst_compliance":
            result = tools.check_gst_compliance(**request.params)
        elif request.tool == "check_presumptive":
            result = tools.check_presumptive(**request.params)
        elif request.tool == "get_deadlines":
            result = tools.get_deadlines(**request.params)
        else:
            raise HTTPException(status_code=400, detail=f"Unknown tool: {request.tool}")
        
        return {"success": True, "tool": request.tool, "result": result}
    except Exception as e:
        return {"success": False, "tool": request.tool, "error": str(e)}

@app.get("/deadlines")
async def get_deadlines(has_gst: bool = False):
    return tools.get_deadlines(has_gst)

print("‚úÖ FastAPI app created!")

In [None]:
#@title 7Ô∏è‚É£ üöÄ Start Server with ngrok (RUN THIS!)
import nest_asyncio
from pyngrok import ngrok
import uvicorn

nest_asyncio.apply()

# Start ngrok tunnel
public_url = ngrok.connect(8000)
print("=" * 60)
print("üéâ TaxAlly Server is LIVE!")
print("=" * 60)
print(f"\nüåê PUBLIC URL: {public_url}")
print(f"\nüìã Copy this to your webapp's .env file:")
print(f"   TAXALLY_API_URL={public_url}")
print("\n" + "=" * 60)
print("\nüì° Endpoints:")
print(f"   GET  {public_url}/health")
print(f"   POST {public_url}/chat")
print(f"   GET  {public_url}/tools")
print(f"   POST {public_url}/tools/execute")
print(f"   GET  {public_url}/deadlines")
print("\n" + "=" * 60)
print("\n‚è≥ Server running... (Keep this cell running!)")
print("   Press the STOP button to shut down.\n")

# Run server
uvicorn.run(app, host="0.0.0.0", port=8000)

## üîß Webapp Configuration

After running the server, update your webapp:

### Server `.env` file:
```bash
# /Users/aditya/developer/hackathons/GFGVB/server/.env
TAXALLY_API_URL=https://xxxx-xx-xxx-xxx-xxx.ngrok-free.app
USE_TAXALLY_SERVER=true
```

### Test the connection:
```bash
curl https://your-ngrok-url/health
```

### Expected response:
```json
{"status": "healthy", "model": "Qwen/Qwen2.5-7B-Instruct", "gpu": "Tesla T4"}
```