# LingoTax GNN Deduction Predictor — Training Notebook

This notebook trains the GNN model on synthetic tax-profile data.
Designed to run in **Google Colab** with GPU acceleration.

**Author:** LingoTax Team (HackAI 2026)

## Cell 1: GPU Check

In [None]:
# Cell 1: Check GPU availability and print CUDA info
import torch
print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available:  {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU device:      {torch.cuda.get_device_name(0)}')
    print(f'CUDA version:    {torch.version.cuda}')
else:
    print('WARNING: No GPU detected. Training will run on CPU (slower).')

## Cell 2: Install Dependencies

In [None]:
# Cell 2: Install dependencies (robust PyG install for Colab)
import subprocess, sys, torch

# Core deps
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q',
    'pandas', 'numpy', 'scikit-learn', 'joblib', 'matplotlib'])

# PyTorch Geometric — install matching your torch+CUDA version
TORCH_VERSION = torch.__version__.split('+')[0]  # e.g. '2.1.0'
CUDA_VERSION = torch.version.cuda or 'cpu'       # e.g. '12.1' or None
CUDA_TAG = f'cu{CUDA_VERSION.replace(".", "")}' if CUDA_VERSION != 'cpu' else 'cpu'

print(f'Installing PyG for torch={TORCH_VERSION}, cuda={CUDA_TAG}')
try:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q',
        'torch-geometric',
        '-f', f'https://data.pyg.org/whl/torch-{TORCH_VERSION}+{CUDA_TAG}.html'])
    print('PyTorch Geometric installed successfully!')
except Exception as e:
    print(f'PyG install failed ({e}). Falling back to MLP mode (no graph convolution).')
    print('The model will still work — just without GraphSAGE layers.')

## Cell 3: Clone Repository

In [None]:
# Cell 3: Clone the LingoTax repo
# Replace <your-repo> with your actual GitHub URL.
# If the repo is private, use a Personal Access Token (PAT):
#   !git clone https://<PAT>@github.com/<org>/lingotax.git

!git clone https://github.com/<your-org>/lingotax.git 2>/dev/null || echo 'Repo already cloned'
%cd lingotax/model
!ls -la

## Cell 4: Generate Synthetic Data

In [None]:
# Cell 4: Generate 20,000 synthetic tax profiles
!python data/gen_synthetic.py --n 20000 --distribution-aware --noise 0.05 --seed 42

import pandas as pd
df = pd.read_csv('data/users.csv')
print(f'Generated {len(df)} profiles')
print(df.head())
print('\nDeduction label sums:')
for col in ['foreign_tax_credit', 'student_loan_interest', 'standard_deduction',
            'earned_income_credit', 'child_tax_credit', 'educator_expense',
            'ira_deduction', 'home_ownership_credit']:
    print(f'  {col}: {df[col].sum()}')

## Cell 5: Train the GNN Model

In [None]:
# Cell 5: Train the GNN (use --gpu if CUDA is available)
import torch
gpu_flag = '--gpu' if torch.cuda.is_available() else ''
!python train/train_gnn.py --epochs 50 --hidden-dim 64 --lr 0.001 --seed 42 {gpu_flag}

## Cell 6: Visualize Results & Save to Google Drive

In [None]:
# Cell 6: View training metadata and save model to Google Drive
import json
import matplotlib.pyplot as plt

with open('models/metadata/gnn_v1.json') as f:
    meta = json.load(f)

print(f"Macro AUC: {meta['metrics']['macro_auc']}")
print(f"Training time: {meta['elapsed_seconds']}s")
print(f"Train size: {meta['train_size']}")

# Per-deduction AUC bar chart
deductions = [k for k in meta['metrics'] if k != 'macro_auc']
aucs = [meta['metrics'][d]['auc'] for d in deductions]

plt.figure(figsize=(10, 5))
plt.barh(deductions, aucs, color='#4F46E5')
plt.xlabel('ROC AUC')
plt.title('GNN Deduction Predictor — Per-Deduction AUC')
plt.xlim(0, 1)
plt.tight_layout()
plt.savefig('models/metadata/auc_chart.png', dpi=150)
plt.show()

# Save to Google Drive
try:
    from google.colab import drive
    drive.mount('/content/drive')
    import shutil
    dst = '/content/drive/MyDrive/lingotax_models/'
    !mkdir -p {dst}
    shutil.copy('models/checkpoints/gnn_v1.pt', dst)
    shutil.copy('models/checkpoints/gnn_v1.meta.json', dst)
    shutil.copy('models/metadata/gnn_v1.json', dst)
    print(f'Model saved to Google Drive: {dst}')
except Exception as e:
    print(f'Google Drive save skipped: {e}')
    print('Model is available at: models/checkpoints/gnn_v1.pt')

## Cell 7 (Optional): Test Inference via FastAPI

In [None]:
# Cell 7 (Optional): Launch FastAPI in background and test with curl
# Uncomment to run:
# import subprocess
# proc = subprocess.Popen(['uvicorn', 'api.fastapi_infer:app', '--host', '0.0.0.0', '--port', '8001'])
# import time; time.sleep(3)
# !curl -X POST http://localhost:8001/predict_deductions \
#   -H 'Content-Type: application/json' \
#   -d '{"visa_type": "H-1B", "filing_status": "single", "total_income": 75000, "foreign_income": 10000, "foreign_tax_paid": 1, "state": "OH", "num_dependents": 0, "years_in_us": 4}'
# proc.terminate()
print('Uncomment the lines above to test the FastAPI inference endpoint.')