# FBT Classification Pipeline
## Fringe Benefit Tax: Yes or No?

**Goal:** Classify if an expense is subject to Fringe Benefit Tax

**Output:**
- `fbt_label`: Y (taxable) / N (not taxable)
- `category`: Detailed category for explanation
- `location`: Extracted location info

---

## 1. Configuration

In [None]:
CONFIG = {
    'raw_data_path': 'data/data _raw_2024-25.xlsx',
    'wp_files': [
        'data/WP_1_Apr_to_Dec_24_FBT_ent_acc.xlsx',
        'data/WP_2_Apr_to_Dec_24_FBT_ent_acc.xlsx',
        'data/WP_3_Jan_to_Mar_25_FBT_ent_acc.xlsx',
        'data/WP_4_Jan_to_Mar_25_FBT_ent_acc.xlsx'
    ],
    'reference_location': {'lat': -33.8688, 'lon': 151.2093, 'name': 'Sydney CBD'},
    'test_size': 0.2,
    'random_state': 42,
    'cv_folds': 5,
    'max_features': 5000,
    'ngram_range': (1, 3),
    'min_df': 2,
    'max_df': 0.95,
    'model_output': 'fbt_model.joblib',
    'predictions_output': 'fbt_predictions.csv'
}

## 2. Imports

In [None]:
import pandas as pd
import numpy as np
import re
import math
import warnings
from typing import List, Dict, Optional
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, VotingClassifier, GradientBoostingClassifier
from sklearn.naive_bayes import ComplementNB
from sklearn.svm import LinearSVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import classification_report, confusion_matrix, f1_score, ConfusionMatrixDisplay, accuracy_score
from scipy import sparse
import joblib

print("Imports successful")

## 3. FBT Category Mapping

**Fringe Benefit = Y (Taxable):**
- Meal Entertainment (ME)
- Celebrations (Xmas, farewell, birthday)
- Alcohol
- Travel FOR entertainment
- Lodging FOR entertainment

**Fringe Benefit = N (Not Taxable):**
- Business travel meals
- Business expenses
- Training/conference
- 58P Recreation (non-deductible but no FBT)
- Office consumables
- Pure business travel/lodging

In [None]:
# =============================================================================
# FBT CLASSIFICATION RULES
# =============================================================================

# Categories that ARE subject to FBT
FBT_YES_CATEGORIES = [
    'MEAL_ENTERTAINMENT',      # ME - client dinners, staff lunches
    'CELEBRATION',             # Xmas party, farewell, birthday
    'ALCOHOL',                 # Drinks, wine, beer
    'TRAVEL_FOR_ENTERTAINMENT', # Travel connected to entertainment
    'LODGING_FOR_ENTERTAINMENT' # Accommodation for entertainment
]

# Categories NOT subject to FBT
FBT_NO_CATEGORIES = [
    'BUSINESS_TRAVEL_MEAL',    # Meal whilst travelling for business
    'BUSINESS_EXPENSE',        # General business expense
    'TRAINING',                # Conference, seminar, workshop
    'RECREATION_58P',          # 58P recreational (no FBT but non-deductible)
    'CLIENT_EVENT',            # Client-focused events (field days)
    'SUSTENANCE',              # Morning tea, working lunch on premises
    'BUSINESS_TRAVEL',         # Pure business travel
    'BUSINESS_LODGING'         # Pure business accommodation
]

# Mapping rules: Tax Description keywords -> Category
CATEGORY_RULES = {
    # FBT = YES
    'MEAL_ENTERTAINMENT': [
        'me', 'meal entertainment', 'dinner', 'lunch', 'breakfast', 
        'restaurant', 'catering', 'me - no workplace'
    ],
    'CELEBRATION': [
        'xmas', 'christmas', 'party', 'farewell', 'celebration', 
        'birthday', 'bd cake', 'welcome', 'anniversary', 'end of year'
    ],
    'ALCOHOL': [
        'alcohol', 'drinks', 'wine', 'beer', 'liquor'
    ],
    'TRAVEL_FOR_ENTERTAINMENT': [
        'travel connected to entertainment', 'travel associated with entertainment',
        'travel related to rabo me', 'travel associated with me',
        'travel for entertainment', 'travel associated with dinner',
        'travel to entertainment'
    ],
    'LODGING_FOR_ENTERTAINMENT': [
        'accommodation for entertainment', 'lodging for entertainment',
        'hotel for entertainment', 'stay for entertainment'
    ],
    
    # FBT = NO
    'BUSINESS_TRAVEL_MEAL': [
        'accept travel meal', 'travel meal', 'meal whilst travel',
        'business travel food', 'accept business travel food',
        'ok - accept dinner on travel'
    ],
    'BUSINESS_EXPENSE': [
        'business expense', 'ok', 'accept', 'no fbt', 'tax deductible',
        'office consumable', 'deductible', 'ok - business',
        'accept predominately business', 'business related'
    ],
    'TRAINING': [
        'training', 'conference', 'seminar', 'workshop', 'course',
        'accept seminar', 'accept seminar food'
    ],
    'RECREATION_58P': [
        '58p', 'recreational', 'team building', 'bowling', 'golf', 
        'sailing', 'fun day', 'escape room'
    ],
    'CLIENT_EVENT': [
        'client event', 'client focused', 'field day', 'client visit',
        'travel - predominant purpose client visit'
    ],
    'SUSTENANCE': [
        'sustenance', 'morning tea', 'on premises', 'light lunch', 'working lunch'
    ],
    'BUSINESS_TRAVEL': [
        'accept business travel', 'accept work travel', 'business travel - accept'
    ],
    'BUSINESS_LODGING': [
        'accept business lodging', 'business accommodation'
    ]
}


def map_to_category(tax_desc: str, me_label: str = None) -> str:
    """
    Map Tax Description to category.
    Also considers ME? label if available.
    """
    if pd.isna(tax_desc) or not str(tax_desc).strip():
        # No tax description - check ME label
        if me_label and str(me_label).upper() == 'Y':
            return 'MEAL_ENTERTAINMENT'
        return 'UNLABELED'
    
    tax_lower = str(tax_desc).lower().strip()
    
    # Exact match for 'me'
    if tax_lower == 'me':
        return 'MEAL_ENTERTAINMENT'
    
    # Check each category
    for category, keywords in CATEGORY_RULES.items():
        for kw in keywords:
            if kw in tax_lower:
                return category
    
    # Check ME label as fallback
    if me_label and str(me_label).upper() == 'Y':
        return 'MEAL_ENTERTAINMENT'
    
    return 'OTHER'


def get_fbt_label(category: str) -> str:
    """
    Get FBT Y/N from category.
    """
    if category in FBT_YES_CATEGORIES:
        return 'Y'
    elif category in FBT_NO_CATEGORIES:
        return 'N'
    return 'UNKNOWN'


print("FBT mapping defined")
print(f"\nFBT = Y categories: {FBT_YES_CATEGORIES}")
print(f"\nFBT = N categories: {FBT_NO_CATEGORIES}")

## 4. Location Database

In [None]:
LOCATIONS = {
    # Major Cities
    'sydney': {'lat': -33.8688, 'lon': 151.2093, 'state': 'NSW', 'type': 'city'},
    'melbourne': {'lat': -37.8136, 'lon': 144.9631, 'state': 'VIC', 'type': 'city'},
    'brisbane': {'lat': -27.4698, 'lon': 153.0251, 'state': 'QLD', 'type': 'city'},
    'perth': {'lat': -31.9505, 'lon': 115.8605, 'state': 'WA', 'type': 'city'},
    'adelaide': {'lat': -34.9285, 'lon': 138.6007, 'state': 'SA', 'type': 'city'},
    'hobart': {'lat': -42.8821, 'lon': 147.3272, 'state': 'TAS', 'type': 'city'},
    'darwin': {'lat': -12.4634, 'lon': 130.8456, 'state': 'NT', 'type': 'city'},
    'canberra': {'lat': -35.2809, 'lon': 149.1300, 'state': 'ACT', 'type': 'city'},
    # Regional NSW
    'dubbo': {'lat': -32.2569, 'lon': 148.6011, 'state': 'NSW', 'type': 'regional'},
    'wagga wagga': {'lat': -35.1082, 'lon': 147.3598, 'state': 'NSW', 'type': 'regional'},
    'wagga': {'lat': -35.1082, 'lon': 147.3598, 'state': 'NSW', 'type': 'regional'},
    'tamworth': {'lat': -31.0830, 'lon': 150.9170, 'state': 'NSW', 'type': 'regional'},
    'orange': {'lat': -33.2840, 'lon': 149.1004, 'state': 'NSW', 'type': 'regional'},
    'bathurst': {'lat': -33.4190, 'lon': 149.5778, 'state': 'NSW', 'type': 'regional'},
    'albury': {'lat': -36.0737, 'lon': 146.9135, 'state': 'NSW', 'type': 'regional'},
    'moree': {'lat': -29.4640, 'lon': 149.8470, 'state': 'NSW', 'type': 'regional'},
    'griffith': {'lat': -34.2890, 'lon': 146.0400, 'state': 'NSW', 'type': 'regional'},
    'parkes': {'lat': -33.1370, 'lon': 148.1750, 'state': 'NSW', 'type': 'regional'},
    'narrabri': {'lat': -30.3250, 'lon': 149.7830, 'state': 'NSW', 'type': 'regional'},
    'newcastle': {'lat': -32.9283, 'lon': 151.7817, 'state': 'NSW', 'type': 'regional'},
    'wollongong': {'lat': -34.4278, 'lon': 150.8931, 'state': 'NSW', 'type': 'regional'},
    # Regional QLD
    'roma': {'lat': -26.5700, 'lon': 148.7850, 'state': 'QLD', 'type': 'regional'},
    'toowoomba': {'lat': -27.5598, 'lon': 151.9507, 'state': 'QLD', 'type': 'regional'},
    'rockhampton': {'lat': -23.3791, 'lon': 150.5100, 'state': 'QLD', 'type': 'regional'},
    'mackay': {'lat': -21.1411, 'lon': 149.1861, 'state': 'QLD', 'type': 'regional'},
    'townsville': {'lat': -19.2590, 'lon': 146.8169, 'state': 'QLD', 'type': 'regional'},
    'cairns': {'lat': -16.9186, 'lon': 145.7781, 'state': 'QLD', 'type': 'regional'},
    'longreach': {'lat': -23.4420, 'lon': 144.2500, 'state': 'QLD', 'type': 'regional'},
    'mount isa': {'lat': -20.7256, 'lon': 139.4927, 'state': 'QLD', 'type': 'regional'},
    'emerald': {'lat': -23.5270, 'lon': 148.1640, 'state': 'QLD', 'type': 'regional'},
    'dalby': {'lat': -27.1810, 'lon': 151.2650, 'state': 'QLD', 'type': 'regional'},
    'goondiwindi': {'lat': -28.5470, 'lon': 150.3100, 'state': 'QLD', 'type': 'regional'},
    'charleville': {'lat': -26.4030, 'lon': 146.2430, 'state': 'QLD', 'type': 'regional'},
    'cloncurry': {'lat': -20.7050, 'lon': 140.5060, 'state': 'QLD', 'type': 'regional'},
    'gold coast': {'lat': -28.0167, 'lon': 153.4000, 'state': 'QLD', 'type': 'regional'},
    # Regional VIC
    'geelong': {'lat': -38.1499, 'lon': 144.3617, 'state': 'VIC', 'type': 'regional'},
    'ballarat': {'lat': -37.5622, 'lon': 143.8503, 'state': 'VIC', 'type': 'regional'},
    'bendigo': {'lat': -36.7570, 'lon': 144.2794, 'state': 'VIC', 'type': 'regional'},
    'shepparton': {'lat': -36.3833, 'lon': 145.4000, 'state': 'VIC', 'type': 'regional'},
    'mildura': {'lat': -34.2087, 'lon': 142.1311, 'state': 'VIC', 'type': 'regional'},
    'horsham': {'lat': -36.7117, 'lon': 142.2000, 'state': 'VIC', 'type': 'regional'},
    # Regional SA/WA
    'port lincoln': {'lat': -34.7333, 'lon': 135.8500, 'state': 'SA', 'type': 'regional'},
    'port augusta': {'lat': -32.4936, 'lon': 137.7825, 'state': 'SA', 'type': 'regional'},
    'geraldton': {'lat': -28.7775, 'lon': 114.6147, 'state': 'WA', 'type': 'regional'},
    'kalgoorlie': {'lat': -30.7489, 'lon': 121.4658, 'state': 'WA', 'type': 'regional'},
    'broome': {'lat': -17.9614, 'lon': 122.2359, 'state': 'WA', 'type': 'regional'},
    'karratha': {'lat': -20.7361, 'lon': 116.8467, 'state': 'WA', 'type': 'regional'},
    # International
    'utrecht': {'lat': 52.0907, 'lon': 5.1214, 'state': 'NL', 'type': 'international'},
    'amsterdam': {'lat': 52.3676, 'lon': 4.9041, 'state': 'NL', 'type': 'international'},
    'singapore': {'lat': 1.3521, 'lon': 103.8198, 'state': 'SG', 'type': 'international'},
    'hong kong': {'lat': 22.3193, 'lon': 114.1694, 'state': 'HK', 'type': 'international'},
    'london': {'lat': 51.5074, 'lon': -0.1278, 'state': 'UK', 'type': 'international'},
    'new zealand': {'lat': -40.9006, 'lon': 174.8860, 'state': 'NZ', 'type': 'international'},
    'auckland': {'lat': -36.8509, 'lon': 174.7645, 'state': 'NZ', 'type': 'international'},
    'tokyo': {'lat': 35.6762, 'lon': 139.6503, 'state': 'JP', 'type': 'international'},
}

class LocationExtractor:
    def __init__(self, ref: Dict):
        self.ref = ref
        names = sorted(LOCATIONS.keys(), key=len, reverse=True)
        self.pattern = re.compile(r'\b(' + '|'.join(re.escape(n) for n in names) + r')\b', re.I)
    
    def haversine(self, lat1, lon1, lat2, lon2):
        R = 6371
        lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])
        dlat, dlon = lat2 - lat1, lon2 - lon1
        a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2
        return R * 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
    
    def extract(self, text: str) -> Dict:
        result = {
            'location_name': '', 'location_state': '',
            'has_location': 0, 'distance_km': 0, 'travel_hours': 0,
            'is_international': 0, 'is_regional': 0, 'is_remote': 0
        }
        if pd.isna(text) or not text:
            return result
        
        matches = self.pattern.findall(str(text).lower())
        if not matches:
            return result
        
        loc = matches[0].lower()
        if loc not in LOCATIONS:
            return result
        
        coords = LOCATIONS[loc]
        distance = self.haversine(self.ref['lat'], self.ref['lon'], coords['lat'], coords['lon'])
        travel = (distance / 800) + 2 if distance > 500 else distance / 80
        
        return {
            'location_name': loc.title(),
            'location_state': coords['state'],
            'has_location': 1,
            'distance_km': round(distance, 0),
            'travel_hours': round(travel, 1),
            'is_international': 1 if coords['type'] == 'international' else 0,
            'is_regional': 1 if coords['type'] == 'regional' else 0,
            'is_remote': 1 if distance > 500 else 0
        }

loc_extractor = LocationExtractor(CONFIG['reference_location'])
print(f"Loaded {len(LOCATIONS)} locations")

## 5. Data Loading

In [None]:
def load_raw_data(filepath: str) -> pd.DataFrame:
    print(f"Loading: {filepath}")
    xl = pd.ExcelFile(filepath)
    dfs = []
    for sheet in xl.sheet_names:
        df = pd.read_excel(filepath, sheet_name=sheet)
        df['_sheet'] = sheet
        dfs.append(df)
        print(f"  {sheet}: {len(df)} rows")
    combined = pd.concat(dfs, ignore_index=True)
    print(f"Total: {len(combined)} rows")
    return combined


def parse_workpaper(filepath: str, sheet: str) -> Optional[pd.DataFrame]:
    try:
        df = pd.read_excel(filepath, sheet_name=sheet, header=None)
        header_idx = None
        for i, row in df.iterrows():
            if 'BUSINESS_UNIT_CODE' in ' '.join([str(v).upper() for v in row.values if pd.notna(v)]):
                header_idx = i
                break
        if header_idx is None:
            return None
        
        headers = [str(h).strip() if pd.notna(h) else f'_col_{i}' for i, h in enumerate(df.iloc[header_idx])]
        data = df.iloc[header_idx + 1:].copy()
        data.columns = headers
        
        # Rename label columns
        for col in headers:
            cl = str(col).lower()
            if 'tax' in cl and 'desc' in cl:
                data = data.rename(columns={col: 'TAX_DESCRIPTION'})
            elif 'me?' in cl:
                data = data.rename(columns={col: 'ME_LABEL'})
        
        data['_sheet'] = sheet
        data['_file'] = filepath.split('/')[-1]
        return data
    except:
        return None


def load_labeled_data(wp_files: List[str]) -> pd.DataFrame:
    all_dfs = []
    for fp in wp_files:
        print(f"\nProcessing: {fp}")
        try:
            xl = pd.ExcelFile(fp)
            for sheet in xl.sheet_names:
                if 'summary' in sheet.lower() or 'trial balance' in sheet.lower():
                    continue
                df = parse_workpaper(fp, sheet)
                if df is not None and len(df) > 0:
                    all_dfs.append(df)
                    print(f"  {sheet}: {len(df)} rows")
        except Exception as e:
            print(f"  Error: {e}")
    
    combined = pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame()
    print(f"\nTotal: {len(combined)} rows")
    return combined

In [None]:
raw_data = load_raw_data(CONFIG['raw_data_path'])
labeled_data = load_labeled_data(CONFIG['wp_files'])

## 6. Data Cleaning

In [None]:
def clean_data(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    # Remove junk columns
    df = df.loc[:, ~df.columns.str.contains('^Unnamed|^_col_', case=False, na=False)]
    
    # Standardize column names
    new_cols = []
    for c in df.columns:
        c_clean = re.sub(r'[^a-zA-Z0-9_]', '', re.sub(r'\s+', '_', str(c).strip())).upper()
        new_cols.append(c_clean)
    df.columns = new_cols
    
    # Clean text columns
    text_cols = ['PURPOSE', 'CHARGE_DESCRIPTION', 'LINE_DESCR', 'DESCRIPTION', 
                 'INVOICE_DESCR', 'VENDOR_NAME', 'TAX_DESCRIPTION']
    for col in text_cols:
        if col in df.columns:
            df[col] = df[col].fillna('').astype(str).str.strip()
            df[col] = df[col].replace(['nan', 'NaN', 'None', ''], np.nan)
    
    # Clean ME label
    if 'ME_LABEL' in df.columns:
        df['ME_LABEL'] = df['ME_LABEL'].fillna('').astype(str).str.strip().str.upper()
        df['ME_LABEL'] = df['ME_LABEL'].replace(['NAN', ''], np.nan)
    
    # Clean numeric/date
    if 'BASE_AMOUNT' in df.columns:
        df['BASE_AMOUNT'] = pd.to_numeric(df['BASE_AMOUNT'], errors='coerce')
    if 'JOURNAL_DATE' in df.columns:
        df['JOURNAL_DATE'] = pd.to_datetime(df['JOURNAL_DATE'], errors='coerce')
    
    # Dedupe
    dedup = [c for c in ['BUSINESS_UNIT_CODE', 'ACCOUNT_CODE', 'BASE_AMOUNT', 'JOURNAL_DATE'] if c in df.columns]
    if dedup:
        df = df.drop_duplicates(subset=dedup, keep='first')
    
    return df

raw_clean = clean_data(raw_data)
labeled_clean = clean_data(labeled_data)
print(f"Raw: {len(raw_clean)}, Labeled: {len(labeled_clean)}")

## 7. Apply FBT Mapping

In [None]:
# Map to category and FBT label
labeled_clean['CATEGORY'] = labeled_clean.apply(
    lambda r: map_to_category(r.get('TAX_DESCRIPTION'), r.get('ME_LABEL')), axis=1
)
labeled_clean['FBT_LABEL'] = labeled_clean['CATEGORY'].apply(get_fbt_label)

print("Category Distribution:")
print(labeled_clean['CATEGORY'].value_counts())

print("\nFBT Label Distribution:")
print(labeled_clean['FBT_LABEL'].value_counts())

In [None]:
# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Category distribution
cat_counts = labeled_clean['CATEGORY'].value_counts().head(15)
colors = ['#e74c3c' if c in FBT_YES_CATEGORIES else '#2ecc71' if c in FBT_NO_CATEGORIES else '#95a5a6' 
          for c in cat_counts.index]
axes[0].barh(range(len(cat_counts)), cat_counts.values, color=colors)
axes[0].set_yticks(range(len(cat_counts)))
axes[0].set_yticklabels(cat_counts.index)
axes[0].set_xlabel('Count')
axes[0].set_title('Category Distribution (Red=FBT, Green=No FBT)')
axes[0].invert_yaxis()

# FBT distribution
fbt_counts = labeled_clean['FBT_LABEL'].value_counts()
colors = ['#e74c3c' if x == 'Y' else '#2ecc71' if x == 'N' else '#95a5a6' for x in fbt_counts.index]
axes[1].bar(fbt_counts.index.astype(str), fbt_counts.values, color=colors)
axes[1].set_xlabel('FBT Label')
axes[1].set_ylabel('Count')
axes[1].set_title('FBT Distribution')

plt.tight_layout()
plt.show()

## 8. Feature Engineering

In [None]:
def create_combined_text(df: pd.DataFrame) -> pd.Series:
    cols = ['PURPOSE', 'CHARGE_DESCRIPTION', 'LINE_DESCR', 'DESCRIPTION', 'INVOICE_DESCR', 'VENDOR_NAME']
    available = [c for c in cols if c in df.columns]
    
    def combine(row):
        parts = [str(row.get(c, '')).strip() for c in available 
                 if pd.notna(row.get(c)) and str(row.get(c)).strip()]
        text = ' '.join(parts).lower()
        return ' '.join(re.sub(r'[^a-z0-9\s]', ' ', text).split())
    
    return df.apply(combine, axis=1)


def create_features(df: pd.DataFrame, text_series: pd.Series) -> pd.DataFrame:
    feat = pd.DataFrame(index=df.index)
    
    # Amount
    if 'BASE_AMOUNT' in df.columns:
        amt = df['BASE_AMOUNT'].fillna(0)
        feat['amount_log'] = np.log1p(np.abs(amt))
        feat['amount_is_negative'] = (amt < 0).astype(int)
    
    # Date
    if 'JOURNAL_DATE' in df.columns:
        dates = pd.to_datetime(df['JOURNAL_DATE'], errors='coerce')
        feat['month'] = dates.dt.month.fillna(0).astype(int)
        feat['day_of_week'] = dates.dt.dayofweek.fillna(0).astype(int)
        feat['is_friday'] = (dates.dt.dayofweek == 4).astype(int)
        feat['is_december'] = (dates.dt.month == 12).astype(int)
    
    # Sheet/Account type
    if '_SHEET' in df.columns:
        sheet_lower = df['_SHEET'].fillna('').str.lower()
        feat['is_travel_sheet'] = sheet_lower.str.contains('travel|lodging').astype(int)
        feat['is_entertainment_sheet'] = sheet_lower.str.contains('entertainment|meal').astype(int)
        feat['is_training_sheet'] = sheet_lower.str.contains('training').astype(int)
    
    # Location features
    loc_data = [loc_extractor.extract(t) for t in text_series]
    feat['has_location'] = [d['has_location'] for d in loc_data]
    feat['distance_km'] = [d['distance_km'] for d in loc_data]
    feat['travel_hours'] = [d['travel_hours'] for d in loc_data]
    feat['is_international'] = [d['is_international'] for d in loc_data]
    feat['is_regional'] = [d['is_regional'] for d in loc_data]
    feat['is_remote'] = [d['is_remote'] for d in loc_data]
    
    # Location info for output
    feat['_location_name'] = [d['location_name'] for d in loc_data]
    feat['_location_state'] = [d['location_state'] for d in loc_data]
    
    return feat

In [None]:
labeled_clean['combined_text'] = create_combined_text(labeled_clean)
labeled_features = create_features(labeled_clean, labeled_clean['combined_text'])

print(f"Features: {[c for c in labeled_features.columns if not c.startswith('_')]}")

## 9. Prepare Training Data

In [None]:
# Filter to labeled data (FBT = Y or N, exclude UNKNOWN)
mask = labeled_clean['FBT_LABEL'].isin(['Y', 'N'])

df_train = labeled_clean[mask].copy()
feat_train = labeled_features[mask].copy()

X_text = df_train['combined_text'].values
y = df_train['FBT_LABEL'].values

# Encode
fbt_encoder = LabelEncoder()
y_encoded = fbt_encoder.fit_transform(y)

print(f"Training samples: {len(df_train)}")
print(f"Classes: {fbt_encoder.classes_}")
print(f"\nFBT distribution:")
print(pd.Series(y).value_counts())

In [None]:
# Split
X_text_train, X_text_test, X_feat_train, X_feat_test, y_train, y_test = train_test_split(
    X_text, feat_train, y_encoded,
    test_size=CONFIG['test_size'], random_state=CONFIG['random_state'], stratify=y_encoded
)
print(f"Train: {len(X_text_train)}, Test: {len(X_text_test)}")

## 10. Vectorization

In [None]:
# TF-IDF
tfidf = TfidfVectorizer(
    max_features=CONFIG['max_features'], ngram_range=CONFIG['ngram_range'],
    min_df=CONFIG['min_df'], max_df=CONFIG['max_df'], sublinear_tf=True
)
X_tfidf_train = tfidf.fit_transform(X_text_train)
X_tfidf_test = tfidf.transform(X_text_test)

# Numerical features
num_cols = [c for c in X_feat_train.columns if not c.startswith('_')]
scaler = StandardScaler()
X_num_train = scaler.fit_transform(X_feat_train[num_cols].fillna(0))
X_num_test = scaler.transform(X_feat_test[num_cols].fillna(0))

# Combine
X_train = sparse.hstack([X_tfidf_train, sparse.csr_matrix(X_num_train)])
X_test = sparse.hstack([X_tfidf_test, sparse.csr_matrix(X_num_test)])

print(f"Features: {X_train.shape[1]} (TF-IDF: {X_tfidf_train.shape[1]}, Numerical: {len(num_cols)})")

## 11. Model Training

In [None]:
models = {
    'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced', n_jobs=-1),
    'Complement NB': ComplementNB(alpha=0.1),
    'Linear SVC': CalibratedClassifierCV(LinearSVC(max_iter=2000, random_state=42, class_weight='balanced')),
    'Random Forest': RandomForestClassifier(n_estimators=200, max_depth=20, random_state=42, class_weight='balanced', n_jobs=-1),
    'Gradient Boosting': GradientBoostingClassifier(n_estimators=100, max_depth=5, random_state=42),
    'Extra Trees': ExtraTreesClassifier(n_estimators=200, max_depth=20, random_state=42, class_weight='balanced', n_jobs=-1)
}

results = []
for name, model in models.items():
    print(f"Training {name}...")
    cv = StratifiedKFold(n_splits=CONFIG['cv_folds'], shuffle=True, random_state=42)
    cv_f1 = cross_val_score(model, X_train, y_train, cv=cv, scoring='f1').mean()
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    test_f1 = f1_score(y_test, y_pred)
    test_acc = accuracy_score(y_test, y_pred)
    results.append({'Model': name, 'CV F1': round(cv_f1, 4), 'Test F1': round(test_f1, 4), 'Test Acc': round(test_acc, 4)})
    print(f"  CV F1: {cv_f1:.4f}, Test F1: {test_f1:.4f}, Acc: {test_acc:.4f}")

results_df = pd.DataFrame(results).sort_values('Test F1', ascending=False)
print("\nResults:")
print(results_df.to_string(index=False))

In [None]:
# Ensemble
ensemble = VotingClassifier(
    estimators=[
        ('lr', LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced')),
        ('nb', ComplementNB(alpha=0.1)),
        ('svc', CalibratedClassifierCV(LinearSVC(max_iter=2000, random_state=42, class_weight='balanced'))),
        ('rf', RandomForestClassifier(n_estimators=100, max_depth=15, random_state=42, class_weight='balanced', n_jobs=-1))
    ],
    voting='soft'
)
ensemble.fit(X_train, y_train)
y_pred = ensemble.predict(X_test)

print(f"\nEnsemble Results:")
print(f"  Test F1: {f1_score(y_test, y_pred):.4f}")
print(f"  Test Accuracy: {accuracy_score(y_test, y_pred):.4f}")

## 12. Evaluation

In [None]:
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=fbt_encoder.classes_))

In [None]:
# Confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
ConfusionMatrixDisplay.from_predictions(y_test, y_pred, display_labels=fbt_encoder.classes_,
                                        ax=ax, cmap='Blues')
ax.set_title('FBT Classification: Fringe Benefit Y/N')
plt.tight_layout()
plt.show()

## 13. Save Model

In [None]:
model_pkg = {
    'model': ensemble,
    'tfidf': tfidf,
    'scaler': scaler,
    'fbt_encoder': fbt_encoder,
    'num_cols': num_cols,
    'fbt_yes_cats': FBT_YES_CATEGORIES,
    'fbt_no_cats': FBT_NO_CATEGORIES,
    'category_rules': CATEGORY_RULES,
    'config': CONFIG
}
joblib.dump(model_pkg, CONFIG['model_output'])
print(f"Saved: {CONFIG['model_output']}")

## 14. Inference Function

In [None]:
def predict_fbt(texts: List[str], model_path: str = CONFIG['model_output']) -> pd.DataFrame:
    """
    Predict Fringe Benefit Tax classification.
    
    Returns: DataFrame with fbt_label, confidence, location, distance_km, travel_hours
    """
    pkg = joblib.load(model_path)
    
    # Extract location
    loc_info = [loc_extractor.extract(t) for t in texts]
    
    # Create features
    X_tfidf = pkg['tfidf'].transform(texts)
    feat_df = pd.DataFrame({
        'amount_log': 0, 'amount_is_negative': 0,
        'month': 0, 'day_of_week': 0, 'is_friday': 0, 'is_december': 0,
        'is_travel_sheet': 0, 'is_entertainment_sheet': 0, 'is_training_sheet': 0,
        'has_location': [d['has_location'] for d in loc_info],
        'distance_km': [d['distance_km'] for d in loc_info],
        'travel_hours': [d['travel_hours'] for d in loc_info],
        'is_international': [d['is_international'] for d in loc_info],
        'is_regional': [d['is_regional'] for d in loc_info],
        'is_remote': [d['is_remote'] for d in loc_info]
    })
    
    # Ensure all columns exist
    for col in pkg['num_cols']:
        if col not in feat_df.columns:
            feat_df[col] = 0
    
    X_num = pkg['scaler'].transform(feat_df[pkg['num_cols']].fillna(0))
    X = sparse.hstack([X_tfidf, sparse.csr_matrix(X_num)])
    
    # Predict
    y_pred = pkg['model'].predict(X)
    fbt_labels = pkg['fbt_encoder'].inverse_transform(y_pred)
    
    # Confidence
    probs = pkg['model'].predict_proba(X)
    confidence = probs.max(axis=1)
    
    return pd.DataFrame({
        'text': texts,
        'fbt_label': fbt_labels,
        'confidence': confidence.round(3),
        'location': [d['location_name'] for d in loc_info],
        'state': [d['location_state'] for d in loc_info],
        'distance_km': [d['distance_km'] for d in loc_info],
        'travel_hours': [d['travel_hours'] for d in loc_info]
    })

In [None]:
# Test
test_texts = [
    "Client dinner at restaurant Melbourne with 5 partners",
    "Taxi to airport business trip Brisbane",
    "Team building activity bowling Sydney",
    "Training seminar registration fee",
    "Christmas party catering staff 50 people",
    "Flight to Roma client visit farm inspection",
    "Travel connected to entertainment dinner Utrecht",
    "Morning tea for team meeting",
    "Accommodation for client dinner event Melbourne",
    "Business travel meal whilst in Dubbo"
]

predictions = predict_fbt(test_texts)
print("\nPredictions:")
print(predictions.to_string(index=False))

## 15. Summary

In [None]:
print("="*60)
print("FBT CLASSIFICATION PIPELINE")
print("="*60)

print("\nGoal: Predict if expense is subject to Fringe Benefit Tax")

print("\nFBT = Y (Taxable):")
for cat in FBT_YES_CATEGORIES:
    print(f"  - {cat}")

print("\nFBT = N (Not Taxable):")
for cat in FBT_NO_CATEGORIES:
    print(f"  - {cat}")

print(f"\nTraining samples: {len(df_train)}")
print(f"Features: {X_train.shape[1]}")

print(f"\nOutput columns:")
print(f"  fbt_label, confidence, location, state, distance_km, travel_hours")

print(f"\nModel saved: {CONFIG['model_output']}")