In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
facility_df = pd.read_csv('/content/drive/MyDrive/Capstone Project/facility_merged_dataset.csv')
aviation_df = pd.read_csv('/content/drive/MyDrive/Capstone Project/aviation_merged_dataset.csv')
vehicle_df = pd.read_csv('/content/drive/MyDrive/Capstone Project/vehicle_merged_dataset.csv')

facility_df = facility_df.rename(columns={'Description_morpho': 'Description_grammar'})

facility_df['domain'] = 'facility'
aviation_df['domain'] = 'aviation'
vehicle_df['domain'] = 'vehicle'

facility_df['source_dataset'] = 'facility'
aviation_df['source_dataset'] = 'aviation'
vehicle_df['source_dataset'] = 'vehicle'

In [None]:
# Remove duplicates
facility_clean = facility_df.drop_duplicates(subset=['word', 'Part of Speech (POS)'], keep='first')
aviation_clean = aviation_df.drop_duplicates(subset=['word', 'Part of Speech (POS)'], keep='first')
vehicle_clean = vehicle_df.drop_duplicates(subset=['word', 'Part of Speech (POS)'], keep='first')

print(f"Facility: {len(facility_df)} → {len(facility_clean)}")
print(f"Aviation: {len(aviation_df)} → {len(aviation_clean)}")
print(f"Vehicle: {len(vehicle_df)} → {len(vehicle_clean)}")

Facility: 63 → 49
Aviation: 41 → 36
Vehicle: 46 → 36


In [None]:
# Ensure column alignment
assert list(facility_clean.columns) == list(aviation_clean.columns) == list(vehicle_clean.columns)

# Combine all datasets
combined_df = pd.concat([facility_clean, aviation_clean, vehicle_clean], ignore_index=True)

# Domain-stratified split (maintains domain distribution)
train_df, temp_df = train_test_split(
    combined_df,
    test_size=0.30,
    stratify=combined_df['domain'],  # Critical: stratify by domain
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.50,
    stratify=temp_df['domain'],
    random_state=42
)

print(f"Train: {len(train_df)} (70%)")
print(f"Val: {len(val_df)} (15%)")
print(f"Test: {len(test_df)} (15%)")

Train: 84 (70%)
Val: 18 (15%)
Test: 19 (15%)


In [None]:
# For domain generalization testing
aviation_test = aviation_clean.copy()

# Split aviation for unseen domain test
aviation_train_val, aviation_test_unseen = train_test_split(
    aviation_clean,
    test_size=0.30,  # 30% for unseen domain test
    random_state=42
)

print(f"\nDomain Generalization Test:")
print(f"Aviation unseen: {len(aviation_test_unseen)} samples")


Domain Generalization Test:
Aviation unseen: 11 samples


In [None]:
train_df.to_csv('/content/drive/MyDrive/Capstone Project/train_dataset.csv', index=False)
val_df.to_csv('/content/drive/MyDrive/Capstone Project/val_dataset.csv', index=False)
test_df.to_csv('/content/drive/MyDrive/Capstone Project/test_dataset.csv', index=False)
aviation_test_unseen.to_csv('aviation_unseen_test.csv', index=False)

print("✓ Data preparation complete!")
print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")
print(f"Aviation Unseen Test: {len(aviation_test_unseen)}")
print(f"\nDomain Distribution in Train:")
print(train_df['domain'].value_counts())

✓ Data preparation complete!
Train: 84 | Val: 18 | Test: 19
Aviation Unseen Test: 11

Domain Distribution in Train:
domain
facility    34
vehicle     25
aviation    25
Name: count, dtype: int64


In [None]:
"""
Multi-Task NLP Model for Ticket Triage & Fault Category Classification
========================================================================
Production-ready implementation with DistilBERT for maintenance ticket classification

Author: AI Engineer
Date: 2025
Environment: Google Colab with GPU (T4)
"""

# ============================================================================
# IMPORTS
# ============================================================================
import os
import json
import pickle
import random
import logging
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from collections import Counter
from datetime import datetime

# Sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score,
    precision_recall_fscore_support, hamming_loss
)

# NLP & Augmentation
import nltk
from nltk.corpus import wordnet

# PyTorch & Transformers
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

from transformers import (
    DistilBertTokenizer, DistilBertModel,
    get_linear_schedule_with_warmup
)

warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# ============================================================================
# CONFIGURATION & SETUP
# ============================================================================
class Config:
    """Configuration class for all hyperparameters and paths"""
    # Paths
    BASE_PATH = '/content/drive/MyDrive/Capstone Project/'
    MODELS_DIR = os.path.join(BASE_PATH, 'models')
    RESULTS_DIR = os.path.join(BASE_PATH, 'results')
    LOGS_DIR = os.path.join(BASE_PATH, 'logs')

    # Model parameters
    MODEL_NAME = 'distilbert-base-uncased'
    MAX_LENGTH = 128
    BATCH_SIZE = 16
    LEARNING_RATE = 2e-5
    WEIGHT_DECAY = 0.01
    EPOCHS = 12
    WARMUP_STEPS = 500
    PATIENCE = 3

    # Multi-task loss weights
    URGENCY_WEIGHT = 0.4
    FAULT_WEIGHT = 0.6

    # Random seed
    SEED = 42

    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def create_directories():
    """Create necessary directories for outputs"""
    for directory in [Config.MODELS_DIR, Config.RESULTS_DIR, Config.LOGS_DIR]:
        os.makedirs(directory, exist_ok=True)
        os.makedirs(os.path.join(directory, 'tokenizer'), exist_ok=True)
    logger.info("✓ Directories created successfully")

# ============================================================================
# DATA PREPARATION
# ============================================================================
def prepare_datasets():
    """Load and prepare all datasets with domain labels"""
    logger.info("="*80)
    logger.info("STEP 1: DATA PREPARATION")
    logger.info("="*80)

    # Load datasets
    facility_df = pd.read_csv(os.path.join(Config.BASE_PATH, 'facility_merged_dataset.csv'))
    aviation_df = pd.read_csv(os.path.join(Config.BASE_PATH, 'aviation_merged_dataset.csv'))
    vehicle_df = pd.read_csv(os.path.join(Config.BASE_PATH, 'vehicle_merged_dataset.csv'))

    # Standardize column names
    facility_df = facility_df.rename(columns={'Description_morpho': 'Description_grammar'})

    # Add domain identifiers
    facility_df['domain'] = 'facility'
    aviation_df['domain'] = 'aviation'
    vehicle_df['domain'] = 'vehicle'

    logger.info(f"Facility records: {len(facility_df)}")
    logger.info(f"Aviation records: {len(aviation_df)}")
    logger.info(f"Vehicle records: {len(vehicle_df)}")

    # Remove duplicates per domain
    facility_clean = facility_df.drop_duplicates(subset=['word', 'Part of Speech (POS)'], keep='first')
    aviation_clean = aviation_df.drop_duplicates(subset=['word', 'Part of Speech (POS)'], keep='first')
    vehicle_clean = vehicle_df.drop_duplicates(subset=['word', 'Part of Speech (POS)'], keep='first')

    logger.info(f"After deduplication - Facility: {len(facility_clean)}, Aviation: {len(aviation_clean)}, Vehicle: {len(vehicle_clean)}")

    # Combine all datasets
    all_data = pd.concat([facility_clean, aviation_clean, vehicle_clean], ignore_index=True)

    # Domain-stratified split
    train_data, temp_data = train_test_split(
        all_data, test_size=0.30, stratify=all_data['domain'], random_state=Config.SEED
    )

    val_data, test_data = train_test_split(
        temp_data, test_size=0.50, stratify=temp_data['domain'], random_state=Config.SEED
    )

    # Create unseen domain test (aviation only)
    aviation_unseen_test = aviation_clean.sample(frac=0.20, random_state=Config.SEED)

    logger.info(f"\n✓ Data Split Complete:")
    logger.info(f"  Train: {len(train_data)} (70%)")
    logger.info(f"  Val: {len(val_data)} (15%)")
    logger.info(f"  Test: {len(test_data)} (15%)")
    logger.info(f"  Aviation Unseen: {len(aviation_unseen_test)}")

    return train_data, val_data, test_data, aviation_unseen_test

# ============================================================================
# LABEL GENERATION
# ============================================================================
def generate_urgency_labels(df):
    """
    Generate urgency labels based on domain and keywords in examples
    Logic:
    - Critical: engine failure, fire, safety hazards
    - High: leaking, broken, inoperable, failed
    - Medium: repair, maintenance, check
    - Low: normal operation, routine
    """
    def classify_urgency(row):
        text = str(row['Example']).lower() + ' ' + str(row['Description_grammar']).lower()
        word = str(row['word']).lower()

        # Critical keywords
        critical_keywords = ['fire', 'explosion', 'hazard', 'emergency', 'quit', 'failed', 'suddenly']
        if any(kw in text for kw in critical_keywords):
            return 'critical'

        # High urgency keywords
        high_keywords = ['leak', 'broken', 'inoperable', 'crack', 'damage', 'stuck', 'alarm']
        if any(kw in text for kw in high_keywords) or any(kw in word for kw in high_keywords):
            return 'high'

        # Low urgency keywords
        low_keywords = ['normal', 'routine', 'check', 'inspect']
        if any(kw in text for kw in low_keywords):
            return 'low'

        # Default to medium
        return 'medium'

    df['urgency'] = df.apply(classify_urgency, axis=1)
    return df

def generate_fault_labels(df):
    """
    Generate fault category labels from Description_grammar
    Clean and standardize the categories
    """
    # Use Description_grammar as fault category
    df['fault_category'] = df['Description_grammar'].fillna('unknown').str.lower().str.strip()

    # Group similar categories
    category_mapping = {
        'repair': 'general_repair',
        'unknown': 'general_repair',
        'nan': 'general_repair'
    }

    df['fault_category'] = df['fault_category'].replace(category_mapping)

    return df

def prepare_labels(train_data, val_data, test_data, aviation_test):
    """Generate and encode labels for all datasets"""
    logger.info("\n" + "="*80)
    logger.info("STEP 2: LABEL GENERATION")
    logger.info("="*80)

    # Generate urgency labels
    train_data = generate_urgency_labels(train_data)
    val_data = generate_urgency_labels(val_data)
    test_data = generate_urgency_labels(test_data)
    aviation_test = generate_urgency_labels(aviation_test)

    # Generate fault labels
    train_data = generate_fault_labels(train_data)
    val_data = generate_fault_labels(val_data)
    test_data = generate_fault_labels(test_data)
    aviation_test = generate_fault_labels(aviation_test)

    logger.info(f"Urgency distribution in train:")
    logger.info(train_data['urgency'].value_counts())

    logger.info(f"\nFault category distribution in train (top 10):")
    logger.info(train_data['fault_category'].value_counts().head(10))

    # Label encoding
    urgency_encoder = LabelEncoder()
    fault_encoder = LabelEncoder()

    # Fit encoders on training data
    urgency_encoder.fit(train_data['urgency'])
    fault_encoder.fit(train_data['fault_category'])

    # Transform all datasets
    train_data['urgency_label'] = urgency_encoder.transform(train_data['urgency'])
    val_data['urgency_label'] = val_data['urgency'].map(
        lambda x: urgency_encoder.transform([x])[0] if x in urgency_encoder.classes_ else 0
    )
    test_data['urgency_label'] = test_data['urgency'].map(
        lambda x: urgency_encoder.transform([x])[0] if x in urgency_encoder.classes_ else 0
    )
    aviation_test['urgency_label'] = aviation_test['urgency'].map(
        lambda x: urgency_encoder.transform([x])[0] if x in urgency_encoder.classes_ else 0
    )

    train_data['fault_label'] = fault_encoder.transform(train_data['fault_category'])
    val_data['fault_label'] = val_data['fault_category'].map(
        lambda x: fault_encoder.transform([x])[0] if x in fault_encoder.classes_ else 0
    )
    test_data['fault_label'] = test_data['fault_category'].map(
        lambda x: fault_encoder.transform([x])[0] if x in fault_encoder.classes_ else 0
    )
    aviation_test['fault_label'] = aviation_test['fault_category'].map(
        lambda x: fault_encoder.transform([x])[0] if x in fault_encoder.classes_ else 0
    )

    logger.info(f"\n✓ Label Encoding Complete:")
    logger.info(f"  Urgency classes: {len(urgency_encoder.classes_)} - {urgency_encoder.classes_}")
    logger.info(f"  Fault classes: {len(fault_encoder.classes_)}")

    return train_data, val_data, test_data, aviation_test, urgency_encoder, fault_encoder

# ============================================================================
# TEXT PREPROCESSING & AUGMENTATION
# ============================================================================
def preprocess_text(text):
    """Clean and normalize text"""
    text = str(text).lower()
    # Remove special characters but keep spaces
    text = ''.join(char if char.isalnum() or char.isspace() else ' ' for char in text)
    # Remove extra spaces
    text = ' '.join(text.split())
    return text

def get_synonyms(word):
    """Get synonyms for a word using WordNet"""
    try:
        synonyms = set()
        for syn in wordnet.synsets(word):
            for lemma in syn.lemmas():
                synonym = lemma.name().replace('_', ' ')
                if synonym != word:
                    synonyms.add(synonym)
        return list(synonyms)
    except:
        return []

def augment_text_synonym(text, aug_prob=0.3):
    """Augment text by replacing words with synonyms"""
    words = text.split()
    augmented_words = []

    for word in words:
        if random.random() < aug_prob and len(word) > 3:
            synonyms = get_synonyms(word)
            if synonyms:
                augmented_words.append(random.choice(synonyms))
            else:
                augmented_words.append(word)
        else:
            augmented_words.append(word)

    return ' '.join(augmented_words)

def augment_text_noise(text, noise_prob=0.1):
    """Add noise by randomly swapping/deleting characters"""
    words = text.split()
    augmented_words = []

    for word in words:
        if random.random() < noise_prob and len(word) > 3:
            # Randomly swap two adjacent characters
            idx = random.randint(0, len(word)-2)
            word_list = list(word)
            word_list[idx], word_list[idx+1] = word_list[idx+1], word_list[idx]
            augmented_words.append(''.join(word_list))
        else:
            augmented_words.append(word)

    return ' '.join(augmented_words)

def create_text_features(df):
    """Create combined text features from multiple columns"""
    df['combined_text'] = (
        df['word'].fillna('') + ' ' +
        df['Example'].fillna('') + ' ' +
        df['Description_grammar'].fillna('') + ' ' +
        df['Lemma'].fillna('')
    )
    df['combined_text'] = df['combined_text'].apply(preprocess_text)
    return df

def augment_dataset(df, augmentation_factor=1):
    """Augment dataset with synonym replacement and noise injection"""
    logger.info(f"\nAugmenting dataset (factor={augmentation_factor})...")

    augmented_data = []

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Augmentation"):
        # Original sample
        augmented_data.append(row.to_dict())

        # Create augmented samples
        for _ in range(augmentation_factor):
            aug_row = row.copy()

            # Augment combined text
            if random.random() < 0.5:
                aug_row['combined_text'] = augment_text_synonym(row['combined_text'])
            else:
                aug_row['combined_text'] = augment_text_noise(row['combined_text'])

            augmented_data.append(aug_row.to_dict())

    augmented_df = pd.DataFrame(augmented_data)
    logger.info(f"✓ Dataset augmented: {len(df)} → {len(augmented_df)}")

    return augmented_df

# ============================================================================
# BASELINE MODEL (TF-IDF + SVM)
# ============================================================================
def train_baseline_model(train_data, val_data, test_data):
    """Train baseline TF-IDF + SVM model"""
    logger.info("\n" + "="*80)
    logger.info("STEP 3: BASELINE MODEL (TF-IDF + LinearSVC)")
    logger.info("="*80)

    # Create text features if not exists
    if 'combined_text' not in train_data.columns:
        train_data = create_text_features(train_data)
        val_data = create_text_features(val_data)
        test_data = create_text_features(test_data)

    # TF-IDF vectorization
    vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2))

    X_train = vectorizer.fit_transform(train_data['combined_text'])
    X_val = vectorizer.transform(val_data['combined_text'])
    X_test = vectorizer.transform(test_data['combined_text'])

    y_train_urgency = train_data['urgency_label'].values
    y_train_fault = train_data['fault_label'].values

    y_val_urgency = val_data['urgency_label'].values
    y_val_fault = val_data['fault_label'].values

    y_test_urgency = test_data['urgency_label'].values
    y_test_fault = test_data['fault_label'].values

    # Train urgency classifier
    logger.info("Training urgency classifier...")
    urgency_clf = LinearSVC(random_state=Config.SEED, max_iter=2000)
    urgency_clf.fit(X_train, y_train_urgency)

    # Train fault classifier
    logger.info("Training fault classifier...")
    fault_clf = LinearSVC(random_state=Config.SEED, max_iter=2000)
    fault_clf.fit(X_train, y_train_fault)

    # Predictions
    urgency_pred_val = urgency_clf.predict(X_val)
    fault_pred_val = fault_clf.predict(X_val)

    urgency_pred_test = urgency_clf.predict(X_test)
    fault_pred_test = fault_clf.predict(X_test)

    # Evaluation
    logger.info("\n--- Validation Set Performance ---")
    urgency_f1_val = f1_score(y_val_urgency, urgency_pred_val, average='macro')
    fault_f1_val = f1_score(y_val_fault, fault_pred_val, average='macro')
    logger.info(f"Urgency Macro-F1: {urgency_f1_val:.4f}")
    logger.info(f"Fault Macro-F1: {fault_f1_val:.4f}")

    logger.info("\n--- Test Set Performance ---")
    urgency_f1_test = f1_score(y_test_urgency, urgency_pred_test, average='macro')
    fault_f1_test = f1_score(y_test_fault, fault_pred_test, average='macro')
    logger.info(f"Urgency Macro-F1: {urgency_f1_test:.4f}")
    logger.info(f"Fault Macro-F1: {fault_f1_test:.4f}")

    # Save results
    baseline_results = {
        'urgency_f1_val': float(urgency_f1_val),
        'fault_f1_val': float(fault_f1_val),
        'urgency_f1_test': float(urgency_f1_test),
        'fault_f1_test': float(fault_f1_test),
        'avg_f1_test': float((urgency_f1_test + fault_f1_test) / 2)
    }

    with open(os.path.join(Config.RESULTS_DIR, 'baseline_results.json'), 'w') as f:
        json.dump(baseline_results, f, indent=4)

    # Save model
    baseline_model = {
        'vectorizer': vectorizer,
        'urgency_clf': urgency_clf,
        'fault_clf': fault_clf
    }

    with open(os.path.join(Config.MODELS_DIR, 'baseline_model.pkl'), 'wb') as f:
        pickle.dump(baseline_model, f)

    logger.info(f"\n✓ Baseline model saved. Average F1: {baseline_results['avg_f1_test']:.4f}")

    return baseline_results

# ============================================================================
# PYTORCH DATASET
# ============================================================================
class MaintenanceDataset(Dataset):
    """PyTorch Dataset for maintenance tickets"""

    def __init__(self, dataframe, tokenizer, max_length=128):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.loc[idx, 'combined_text']
        urgency_label = self.data.loc[idx, 'urgency_label']
        fault_label = self.data.loc[idx, 'fault_label']

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'urgency_label': torch.tensor(urgency_label, dtype=torch.long),
            'fault_label': torch.tensor(fault_label, dtype=torch.long)
        }

# ============================================================================
# DISTILBERT MULTI-TASK MODEL
# ============================================================================
class DistilBERTMultiTask(nn.Module):
    """
    DistilBERT-based multi-task model for urgency and fault classification

    Architecture:
    - Shared DistilBERT encoder
    - Two separate classification heads
    - Dropout for regularization
    """

    def __init__(self, n_urgency_classes, n_fault_classes, dropout=0.3):
        super(DistilBERTMultiTask, self).__init__()

        # Shared encoder
        self.bert = DistilBertModel.from_pretrained(Config.MODEL_NAME)

        # Freeze first few layers for efficiency
        for param in self.bert.embeddings.parameters():
            param.requires_grad = False

        hidden_size = self.bert.config.hidden_size

        # Urgency classification head
        self.urgency_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, n_urgency_classes)
        )

        # Fault classification head
        self.fault_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, n_fault_classes)
        )

    def forward(self, input_ids, attention_mask):
        # Shared encoding
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # Use [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0, :]

        # Two separate heads
        urgency_logits = self.urgency_classifier(pooled_output)
        fault_logits = self.fault_classifier(pooled_output)

        return urgency_logits, fault_logits

# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================
def train_epoch(model, dataloader, optimizer, scheduler, scaler, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    urgency_criterion = nn.CrossEntropyLoss()
    fault_criterion = nn.CrossEntropyLoss()

    progress_bar = tqdm(dataloader, desc="Training")

    for batch in progress_bar:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        urgency_labels = batch['urgency_label'].to(device)
        fault_labels = batch['fault_label'].to(device)

        # Mixed precision training
        with autocast():
            urgency_logits, fault_logits = model(input_ids, attention_mask)

            urgency_loss = urgency_criterion(urgency_logits, urgency_labels)
            fault_loss = fault_criterion(fault_logits, fault_labels)

            # Multi-task weighted loss
            loss = Config.URGENCY_WEIGHT * urgency_loss + Config.FAULT_WEIGHT * fault_loss

        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})

    return total_loss / len(dataloader)

def evaluate_model(model, dataloader, device):
    """Evaluate model on validation/test set"""
    model.eval()
    total_loss = 0

    urgency_preds = []
    urgency_labels_list = []
    fault_preds = []
    fault_labels_list = []

    urgency_criterion = nn.CrossEntropyLoss()
    fault_criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            urgency_labels = batch['urgency_label'].to(device)
            fault_labels = batch['fault_label'].to(device)

            urgency_logits, fault_logits = model(input_ids, attention_mask)

            urgency_loss = urgency_criterion(urgency_logits, urgency_labels)
            fault_loss = fault_criterion(fault_logits, fault_labels)
            loss = Config.URGENCY_WEIGHT * urgency_loss + Config.FAULT_WEIGHT * fault_loss

            total_loss += loss.item()

            urgency_preds.extend(torch.argmax(urgency_logits, dim=1).cpu().numpy())
            urgency_labels_list.extend(urgency_labels.cpu().numpy())

            fault_preds.extend(torch.argmax(fault_logits, dim=1).cpu().numpy())
            fault_labels_list.extend(fault_labels.cpu().numpy())

    avg_loss = total_loss / len(dataloader)

    # Calculate metrics
    urgency_f1 = f1_score(urgency_labels_list, urgency_preds, average='macro')
    fault_f1 = f1_score(fault_labels_list, fault_preds, average='macro')

    metrics = {
        'loss': avg_loss,
        'urgency_f1': urgency_f1,
        'fault_f1': fault_f1,
        'avg_f1': (urgency_f1 + fault_f1) / 2,
        'urgency_preds': urgency_preds,
        'urgency_labels': urgency_labels_list,
        'fault_preds': fault_preds,
        'fault_labels': fault_labels_list
    }

    return metrics

def plot_confusion_matrices(y_true_urgency, y_pred_urgency, y_true_fault, y_pred_fault,
                            urgency_encoder, fault_encoder, prefix='test'):
    """Plot confusion matrices for both tasks"""
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    # Urgency confusion matrix
    cm_urgency = confusion_matrix(y_true_urgency, y_pred_urgency)
    sns.heatmap(cm_urgency, annot=True, fmt='d', cmap='Blues', ax=axes[0],
                xticklabels=urgency_encoder.classes_,
                yticklabels=urgency_encoder.classes_)
    axes[0].set_title('Urgency Classification Confusion Matrix', fontweight='bold', fontsize=14)
    axes[0].set_ylabel('True Label', fontweight='bold')
    axes[0].set_xlabel('Predicted Label', fontweight='bold')

    # Fault confusion matrix (show top classes if too many)
    cm_fault = confusion_matrix(y_true_fault, y_pred_fault)
    if len(fault_encoder.classes_) > 15:
        # Show only if feasible
        axes[1].text(0.5, 0.5, f'Fault categories: {len(fault_encoder.classes_)}\nToo many to display',
                    ha='center', va='center', fontsize=12)
        axes[1].set_title('Fault Classification (Too many categories)', fontweight='bold', fontsize=14)
    else:
        sns.heatmap(cm_fault, annot=True, fmt='d', cmap='Greens', ax=axes[1],
                   xticklabels=fault_encoder.classes_,
                   yticklabels=fault_encoder.classes_)
        axes[1].set_title('Fault Classification Confusion Matrix', fontweight='bold', fontsize=14)
        axes[1].set_ylabel('True Label', fontweight='bold')
        axes[1].set_xlabel('Predicted Label', fontweight='bold')

    plt.tight_layout()
    plt.savefig(os.path.join(Config.RESULTS_DIR, f'confusion_matrices_{prefix}.png'), dpi=300, bbox_inches='tight')
    plt.close()
    logger.info(f"✓ Confusion matrices saved for {prefix}")

def train_distilbert_model(train_data, val_data, test_data, aviation_test,
                           urgency_encoder, fault_encoder):
    """Train DistilBERT multi-task model"""
    logger.info("\n" + "="*80)
    logger.info("STEP 4: DISTILBERT MULTI-TASK MODEL")
    logger.info("="*80)

    # Initialize tokenizer
    tokenizer = DistilBertTokenizer.from_pretrained(Config.MODEL_NAME)

    # Prepare text features
    train_data = create_text_features(train_data)
    val_data = create_text_features(val_data)
    test_data = create_text_features(test_data)
    aviation_test = create_text_features(aviation_test)

    # Augment training data
    train_data_aug = augment_dataset(train_data, augmentation_factor=1)

    # Create datasets
    train_dataset = MaintenanceDataset(train_data_aug, tokenizer, Config.MAX_LENGTH)
    val_dataset = MaintenanceDataset(val_data, tokenizer, Config.MAX_LENGTH)
    test_dataset = MaintenanceDataset(test_data, tokenizer, Config.MAX_LENGTH)
    aviation_dataset = MaintenanceDataset(aviation_test, tokenizer, Config.MAX_LENGTH)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE)
    aviation_loader = DataLoader(aviation_dataset, batch_size=Config.BATCH_SIZE)

    logger.info(f"✓ Datasets created:")
    logger.info(f"  Train batches: {len(train_loader)}")
    logger.info(f"  Val batches: {len(val_loader)}")
    logger.info(f"  Test batches: {len(test_loader)}")
    logger.info(f"  Aviation batches: {len(aviation_loader)}")

    # Initialize model
    n_urgency_classes = len(urgency_encoder.classes_)
    n_fault_classes = len(fault_encoder.classes_)

    model = DistilBERTMultiTask(n_urgency_classes, n_fault_classes)
    model = model.to(Config.DEVICE)

    logger.info(f"\n✓ Model initialized on {Config.DEVICE}")
    logger.info(f"  Urgency classes: {n_urgency_classes}")
    logger.info(f"  Fault classes: {n_fault_classes}")
    logger.info(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    logger.info(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # Optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)

    total_steps = len(train_loader) * Config.EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=Config.WARMUP_STEPS,
        num_training_steps=total_steps
    )

    scaler = GradScaler()

    # Training loop
    logger.info(f"\n{'='*80}")
    logger.info("STARTING TRAINING")
    logger.info(f"{'='*80}\n")

    best_val_f1 = 0
    patience_counter = 0
    train_losses = []
    val_losses = []
    val_f1_scores = []

    for epoch in range(Config.EPOCHS):
        logger.info(f"\n--- Epoch {epoch + 1}/{Config.EPOCHS} ---")

        # Train
        train_loss = train_epoch(model, train_loader, optimizer, scheduler, scaler, Config.DEVICE)
        train_losses.append(train_loss)
        logger.info(f"Train Loss: {train_loss:.4f}")

        # Validate
        val_metrics = evaluate_model(model, val_loader, Config.DEVICE)
        val_losses.append(val_metrics['loss'])
        val_f1_scores.append(val_metrics['avg_f1'])

        logger.info(f"Val Loss: {val_metrics['loss']:.4f}")
        logger.info(f"Val Urgency F1: {val_metrics['urgency_f1']:.4f}")
        logger.info(f"Val Fault F1: {val_metrics['fault_f1']:.4f}")
        logger.info(f"Val Avg F1: {val_metrics['avg_f1']:.4f}")

        # Save best model
        if val_metrics['avg_f1'] > best_val_f1:
            best_val_f1 = val_metrics['avg_f1']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': best_val_f1,
            }, os.path.join(Config.MODELS_DIR, 'distilbert_multitask_best.pt'))
            logger.info(f"✓ Best model saved! (F1: {best_val_f1:.4f})")
            patience_counter = 0
        else:
            patience_counter += 1
            logger.info(f"Patience: {patience_counter}/{Config.PATIENCE}")

        # Early stopping
        if patience_counter >= Config.PATIENCE:
            logger.info(f"\nEarly stopping triggered at epoch {epoch + 1}")
            break

        # Clear cache
        torch.cuda.empty_cache()

    # Plot training curves
    plt.figure(figsize=(14, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss', marker='o')
    plt.plot(val_losses, label='Val Loss', marker='s')
    plt.xlabel('Epoch', fontweight='bold')
    plt.ylabel('Loss', fontweight='bold')
    plt.title('Training and Validation Loss', fontweight='bold', fontsize=14)
    plt.legend()
    plt.grid(alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.plot(val_f1_scores, label='Val Avg F1', marker='o', color='green')
    plt.xlabel('Epoch', fontweight='bold')
    plt.ylabel('F1 Score', fontweight='bold')
    plt.title('Validation F1 Score', fontweight='bold', fontsize=14)
    plt.legend()
    plt.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(Config.RESULTS_DIR, 'training_curves.png'), dpi=300, bbox_inches='tight')
    plt.close()
    logger.info("\n✓ Training curves saved")

    # Load best model for evaluation
    checkpoint = torch.load(os.path.join(Config.MODELS_DIR, 'distilbert_multitask_best.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    logger.info(f"✓ Best model loaded (Epoch {checkpoint['epoch']}, Val F1: {checkpoint['val_f1']:.4f})")

    # Evaluate on test set
    logger.info("\n" + "="*80)
    logger.info("EVALUATING ON TEST SET")
    logger.info("="*80)

    test_metrics = evaluate_model(model, test_loader, Config.DEVICE)
    logger.info(f"\nTest Loss: {test_metrics['loss']:.4f}")
    logger.info(f"Test Urgency F1: {test_metrics['urgency_f1']:.4f}")
    logger.info(f"Test Fault F1: {test_metrics['fault_f1']:.4f}")
    logger.info(f"Test Avg F1: {test_metrics['avg_f1']:.4f}")

    # Classification reports
    logger.info("\n--- Urgency Classification Report ---")
    print(classification_report(test_metrics['urgency_labels'], test_metrics['urgency_preds'],
                               target_names=urgency_encoder.classes_))

    logger.info("\n--- Fault Classification Report (Summary) ---")
    fault_report = classification_report(test_metrics['fault_labels'], test_metrics['fault_preds'],
                                        output_dict=True)
    logger.info(f"Macro Avg F1: {fault_report['macro avg']['f1-score']:.4f}")
    logger.info(f"Weighted Avg F1: {fault_report['weighted avg']['f1-score']:.4f}")

    # Hamming loss
    hamming_urgency = hamming_loss(test_metrics['urgency_labels'], test_metrics['urgency_preds'])
    hamming_fault = hamming_loss(test_metrics['fault_labels'], test_metrics['fault_preds'])
    logger.info(f"\nHamming Loss - Urgency: {hamming_urgency:.4f}, Fault: {hamming_fault:.4f}")

    # Plot confusion matrices
    plot_confusion_matrices(
        test_metrics['urgency_labels'], test_metrics['urgency_preds'],
        test_metrics['fault_labels'], test_metrics['fault_preds'],
        urgency_encoder, fault_encoder, prefix='test'
    )

    # Evaluate on aviation unseen domain
    logger.info("\n" + "="*80)
    logger.info("EVALUATING ON AVIATION UNSEEN DOMAIN")
    logger.info("="*80)

    aviation_metrics = evaluate_model(model, aviation_loader, Config.DEVICE)
    logger.info(f"\nAviation Loss: {aviation_metrics['loss']:.4f}")
    logger.info(f"Aviation Urgency F1: {aviation_metrics['urgency_f1']:.4f}")
    logger.info(f"Aviation Fault F1: {aviation_metrics['fault_f1']:.4f}")
    logger.info(f"Aviation Avg F1: {aviation_metrics['avg_f1']:.4f}")

    # Domain generalization report
    domain_report = f"""
{'='*80}
DOMAIN GENERALIZATION REPORT - AVIATION UNSEEN TEST
{'='*80}

Test Set (Mixed Domains):
  - Urgency F1: {test_metrics['urgency_f1']:.4f}
  - Fault F1: {test_metrics['fault_f1']:.4f}
  - Average F1: {test_metrics['avg_f1']:.4f}

Aviation Unseen Domain:
  - Urgency F1: {aviation_metrics['urgency_f1']:.4f}
  - Fault F1: {aviation_metrics['fault_f1']:.4f}
  - Average F1: {aviation_metrics['avg_f1']:.4f}

Performance Drop:
  - Urgency: {(test_metrics['urgency_f1'] - aviation_metrics['urgency_f1']):.4f}
  - Fault: {(test_metrics['fault_f1'] - aviation_metrics['fault_f1']):.4f}
  - Average: {(test_metrics['avg_f1'] - aviation_metrics['avg_f1']):.4f}

Interpretation:
  {'Good generalization!' if abs(test_metrics['avg_f1'] - aviation_metrics['avg_f1']) < 0.05 else 'Moderate domain shift detected'}

{'='*80}
"""

    logger.info(domain_report)

    with open(os.path.join(Config.RESULTS_DIR, 'domain_generalization_report.txt'), 'w') as f:
        f.write(domain_report)

    # Save results
    distilbert_results = {
        'val_best_f1': float(best_val_f1),
        'test_urgency_f1': float(test_metrics['urgency_f1']),
        'test_fault_f1': float(test_metrics['fault_f1']),
        'test_avg_f1': float(test_metrics['avg_f1']),
        'test_hamming_urgency': float(hamming_urgency),
        'test_hamming_fault': float(hamming_fault),
        'aviation_urgency_f1': float(aviation_metrics['urgency_f1']),
        'aviation_fault_f1': float(aviation_metrics['fault_f1']),
        'aviation_avg_f1': float(aviation_metrics['avg_f1']),
        'domain_shift': float(test_metrics['avg_f1'] - aviation_metrics['avg_f1'])
    }

    with open(os.path.join(Config.RESULTS_DIR, 'distilbert_results.json'), 'w') as f:
        json.dump(distilbert_results, f, indent=4)

    # Save tokenizer and encoders
    tokenizer.save_pretrained(os.path.join(Config.MODELS_DIR, 'tokenizer'))

    encoders = {
        'urgency_encoder': urgency_encoder,
        'fault_encoder': fault_encoder
    }
    with open(os.path.join(Config.MODELS_DIR, 'label_encoders.pkl'), 'wb') as f:
        pickle.dump(encoders, f)

    logger.info("\n✓ All results and models saved successfully!")

    return distilbert_results, model, tokenizer

# ============================================================================
# INFERENCE FUNCTION
# ============================================================================
def load_trained_model():
    """Load trained model and encoders for inference"""
    # Load encoders
    with open(os.path.join(Config.MODELS_DIR, 'label_encoders.pkl'), 'rb') as f:
        encoders = pickle.load(f)

    urgency_encoder = encoders['urgency_encoder']
    fault_encoder = encoders['fault_encoder']

    # Load tokenizer
    tokenizer = DistilBertTokenizer.from_pretrained(os.path.join(Config.MODELS_DIR, 'tokenizer'))

    # Load model
    n_urgency_classes = len(urgency_encoder.classes_)
    n_fault_classes = len(fault_encoder.classes_)

    model = DistilBERTMultiTask(n_urgency_classes, n_fault_classes)
    checkpoint = torch.load(os.path.join(Config.MODELS_DIR, 'distilbert_multitask_best.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(Config.DEVICE)
    model.eval()

    return model, tokenizer, urgency_encoder, fault_encoder

def predict_ticket(text, model, tokenizer, urgency_encoder, fault_encoder):
    """
    Make predictions for a new maintenance ticket

    Args:
        text (str): Maintenance ticket description
        model: Trained model
        tokenizer: DistilBERT tokenizer
        urgency_encoder: Label encoder for urgency
        fault_encoder: Label encoder for fault category

    Returns:
        dict: Predictions with probabilities
    """
    model.eval()

    # Preprocess text
    text = preprocess_text(text)

    # Tokenize
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=Config.MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(Config.DEVICE)
    attention_mask = encoding['attention_mask'].to(Config.DEVICE)

    # Predict
    with torch.no_grad():
        urgency_logits, fault_logits = model(input_ids, attention_mask)

        urgency_probs = torch.softmax(urgency_logits, dim=1)
        fault_probs = torch.softmax(fault_logits, dim=1)

        urgency_pred = torch.argmax(urgency_probs, dim=1).item()
        fault_pred = torch.argmax(fault_probs, dim=1).item()

        urgency_confidence = urgency_probs[0, urgency_pred].item()
        fault_confidence = fault_probs[0, fault_pred].item()

    results = {
        'urgency': {
            'class': urgency_encoder.inverse_transform([urgency_pred])[0],
            'confidence': urgency_confidence,
            'all_probabilities': {
                urgency_encoder.classes_[i]: urgency_probs[0, i].item()
                for i in range(len(urgency_encoder.classes_))
            }
        },
        'fault': {
            'class': fault_encoder.inverse_transform([fault_pred])[0],
            'confidence': fault_confidence
        }
    }

    return results

# ============================================================================
# MAIN EXECUTION
# ============================================================================
def main():
    """Main execution pipeline"""
    start_time = datetime.now()

    logger.info("="*80)
    logger.info("MULTI-TASK NLP MODEL FOR TICKET TRIAGE & FAULT CLASSIFICATION")
    logger.info("="*80)
    logger.info(f"Start Time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info(f"Device: {Config.DEVICE}")
    logger.info("="*80)

    # Setup
    set_seed(Config.SEED)
    create_directories()

    # Download NLTK data for augmentation
    try:
        nltk.download('wordnet', quiet=True)
        nltk.download('omw-1.4', quiet=True)
    except:
        logger.warning("Could not download NLTK data. Augmentation may be limited.")

    # Step 1: Prepare datasets
    train_data, val_data, test_data, aviation_test = prepare_datasets()

    # Step 2: Generate and encode labels
    train_data, val_data, test_data, aviation_test, urgency_encoder, fault_encoder = prepare_labels(
        train_data, val_data, test_data, aviation_test
    )

    # Step 3: Train baseline model
    baseline_results = train_baseline_model(train_data, val_data, test_data)

    # Step 4: Train DistilBERT model
    distilbert_results, model, tokenizer = train_distilbert_model(
        train_data, val_data, test_data, aviation_test,
        urgency_encoder, fault_encoder
    )

    # Step 5: Final comparison
    logger.info("\n" + "="*80)
    logger.info("FINAL PERFORMANCE COMPARISON")
    logger.info("="*80)

    comparison = f"""
Model Performance Summary:
--------------------------

Baseline (TF-IDF + LinearSVC):
  - Test Avg F1: {baseline_results['avg_f1_test']:.4f}
  - Urgency F1: {baseline_results['urgency_f1_test']:.4f}
  - Fault F1: {baseline_results['fault_f1_test']:.4f}

DistilBERT Multi-Task:
  - Test Avg F1: {distilbert_results['test_avg_f1']:.4f}
  - Urgency F1: {distilbert_results['test_urgency_f1']:.4f}
  - Fault F1: {distilbert_results['test_fault_f1']:.4f}

Improvement:
  - Avg F1: {(distilbert_results['test_avg_f1'] - baseline_results['avg_f1_test']):.4f} ({(distilbert_results['test_avg_f1'] - baseline_results['avg_f1_test']) / baseline_results['avg_f1_test'] * 100:.1f}%)

Domain Generalization (Aviation Unseen):
  - Avg F1: {distilbert_results['aviation_avg_f1']:.4f}
  - Domain Shift: {distilbert_results['domain_shift']:.4f}

{'='*80}
"""

    logger.info(comparison)

    with open(os.path.join(Config.RESULTS_DIR, 'final_comparison.txt'), 'w') as f:
        f.write(comparison)

    # Step 6: Demo inference
    logger.info("\n" + "="*80)
    logger.info("INFERENCE DEMO")
    logger.info("="*80)

    demo_texts = [
        "engine leak detected high pressure cylinder gasket broken",
        "routine maintenance check normal operation",
        "emergency alarm fire hazard critical repair needed"
    ]

    for i, demo_text in enumerate(demo_texts, 1):
        logger.info(f"\nExample {i}: '{demo_text}'")
        prediction = predict_ticket(demo_text, model, tokenizer, urgency_encoder, fault_encoder)
        logger.info(f"  Urgency: {prediction['urgency']['class']} (confidence: {prediction['urgency']['confidence']:.3f})")
        logger.info(f"  Fault: {prediction['fault']['class']} (confidence: {prediction['fault']['confidence']:.3f})")

    # Completion
    end_time = datetime.now()
    duration = end_time - start_time

    logger.info("\n" + "="*80)
    logger.info("PIPELINE COMPLETED SUCCESSFULLY!")
    logger.info("="*80)
    logger.info(f"End Time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info(f"Total Duration: {duration}")
    logger.info(f"\nAll outputs saved to:")
    logger.info(f"  - Models: {Config.MODELS_DIR}")
    logger.info(f"  - Results: {Config.RESULTS_DIR}")
    logger.info(f"  - Logs: {Config.LOGS_DIR}")
    logger.info("="*80)

# ============================================================================
# RUN
# ============================================================================
if __name__ == "__main__":
    main()

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Augmentation:   0%|          | 0/84 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]

              precision    recall  f1-score   support

    critical       1.00      0.25      0.40         4
        high       0.53      0.89      0.67         9
      medium       0.00      0.00      0.00         6

    accuracy                           0.47        19
   macro avg       0.51      0.38      0.36        19
weighted avg       0.46      0.47      0.40        19



Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

##**Model Testing**

##Baseline Model

In [None]:
import pickle
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer

# Load the baseline model
with open('/content/drive/MyDrive/Capstone Project/Model 2/models/baseline_model.pkl', 'rb') as f:
    baseline_model = pickle.load(f)

# Extract components
vectorizer = baseline_model['vectorizer']
urgency_clf = baseline_model['urgency_clf']
fault_clf = baseline_model['fault_clf']

# Load label encoders
with open('/content/drive/MyDrive/Capstone Project/Model 2/models/label_encoders.pkl', 'rb') as f:
    encoders = pickle.load(f)

urgency_encoder = encoders['urgency_encoder']
fault_encoder = encoders['fault_encoder']

# Function to preprocess text (same as in training)
def preprocess_text(text):
    text = str(text).lower()
    text = ''.join(char if char.isalnum() or char.isspace() else ' ' for char in text)
    text = ' '.join(text.split())
    return text

def create_combined_text(word, example, description, lemma):
    """Combine all text fields"""
    combined = f"{word} {example} {description} {lemma}"
    return preprocess_text(combined)

# Test on new data
def predict_baseline(word, example, description_grammar, lemma):
    """
    Predict using baseline model

    Args:
        word: The maintenance word
        example: Example sentence
        description_grammar: Grammar description
        lemma: Lemma of the word

    Returns:
        dict with predictions
    """
    # Create combined text
    combined_text = create_combined_text(word, example, description_grammar, lemma)

    # Vectorize
    X = vectorizer.transform([combined_text])

    # Predict
    urgency_pred = urgency_clf.predict(X)[0]
    fault_pred = fault_clf.predict(X)[0]

    # Decode labels
    urgency_label = urgency_encoder.inverse_transform([urgency_pred])[0]
    fault_label = fault_encoder.inverse_transform([fault_pred])[0]

    return {
        'urgency': urgency_label,
        'fault_category': fault_label
    }

# Example usage
result = predict_baseline(
    word="engine",
    example="engine failed at first run",
    description_grammar="aircraft engine",
    lemma="engine"
)

print("Baseline Prediction:")
print(f"  Urgency: {result['urgency']}")
print(f"  Fault Category: {result['fault_category']}")

Baseline Prediction:
  Urgency: critical
  Fault Category: aircraft start failed


##Distilbert

In [None]:
import torch
import pickle
from transformers import DistilBertTokenizer

# Load label encoders
with open('/content/drive/MyDrive/Capstone Project/Model 2/models/label_encoders.pkl', 'rb') as f:
    encoders = pickle.load(f)

urgency_encoder = encoders['urgency_encoder']
fault_encoder = encoders['fault_encoder']

# Load tokenizer
tokenizer = DistilBertTokenizer.from_pretrained(
    '/content/drive/MyDrive/Capstone Project/Model 2/models/tokenizer'
)

# Load model architecture (must define the class first)
import torch.nn as nn
from transformers import DistilBertModel

class DistilBERTMultiTask(nn.Module):
    def __init__(self, n_urgency_classes, n_fault_classes, dropout=0.3):
        super(DistilBERTMultiTask, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

        for param in self.bert.embeddings.parameters():
            param.requires_grad = False

        hidden_size = self.bert.config.hidden_size

        self.urgency_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, n_urgency_classes)
        )

        self.fault_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, n_fault_classes)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]

        urgency_logits = self.urgency_classifier(pooled_output)
        fault_logits = self.fault_classifier(pooled_output)

        return urgency_logits, fault_logits

# Initialize model
n_urgency_classes = len(urgency_encoder.classes_)
n_fault_classes = len(fault_encoder.classes_)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DistilBERTMultiTask(n_urgency_classes, n_fault_classes)

# Load trained weights
checkpoint = torch.load(
    '/content/drive/MyDrive/Capstone Project/Model 2/models/distilbert_multitask_best.pt',
    map_location=device
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

print(f"✓ Model loaded successfully!")
print(f"  Urgency classes: {urgency_encoder.classes_}")
print(f"  Fault classes: {len(fault_encoder.classes_)}")

# Prediction function
def predict_distilbert(word, example, description_grammar, lemma, max_length=128):
    """
    Predict using DistilBERT model

    Returns:
        dict with predictions and confidence scores
    """
    # Preprocess text
    def preprocess_text(text):
        text = str(text).lower()
        text = ''.join(char if char.isalnum() or char.isspace() else ' ' for char in text)
        return ' '.join(text.split())

    # Combine text
    combined_text = f"{word} {example} {description_grammar} {lemma}"
    combined_text = preprocess_text(combined_text)

    # Tokenize
    encoding = tokenizer(
        combined_text,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    # Predict
    with torch.no_grad():
        urgency_logits, fault_logits = model(input_ids, attention_mask)

        urgency_probs = torch.softmax(urgency_logits, dim=1)
        fault_probs = torch.softmax(fault_logits, dim=1)

        urgency_pred = torch.argmax(urgency_probs, dim=1).item()
        fault_pred = torch.argmax(fault_probs, dim=1).item()

        urgency_confidence = urgency_probs[0, urgency_pred].item()
        fault_confidence = fault_probs[0, fault_pred].item()

    return {
        'urgency': {
            'class': urgency_encoder.inverse_transform([urgency_pred])[0],
            'confidence': urgency_confidence
        },
        'fault': {
            'class': fault_encoder.inverse_transform([fault_pred])[0],
            'confidence': fault_confidence
        }
    }

# Example usage
result = predict_distilbert(
    word="engine",
    example="engine failed suddenly critical alarm",
    description_grammar="aircraft engine failure",
    lemma="engine"
)

print("\nDistilBERT Prediction:")
print(f"  Urgency: {result['urgency']['class']} (confidence: {result['urgency']['confidence']:.3f})")
print(f"  Fault: {result['fault']['class']} (confidence: {result['fault']['confidence']:.3f})")

✓ Model loaded successfully!
  Urgency classes: ['critical' 'high' 'medium']
  Fault classes: 31

DistilBERT Prediction:
  Urgency: medium (confidence: 0.351)
  Fault: engine power control (confidence: 0.041)
