In [None]:
#installing required packages
!pip install transformers datasets scikit-learn torch pandas matplotlib gradio -q


# Import libraries
import os
import re
import json
import shutil
import threading
import time
import traceback
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# Import transformer libraries
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
    pipeline,
    EarlyStoppingCallback
)
from datasets import Dataset, DatasetDict
import gradio as gr


# ## Configuration


# --- Configuration ---
DATASET_FILENAME = 'dataset-tickets-multi-lang-4-20k.csv'
TEXT_COLUMNS = ['subject', 'body']
TARGET_COLUMN = 'queue'
TEXT_FEATURE_COLUMN = 'text'  # Name for combined/cleaned text
LABEL_COLUMN = 'label'        # Standard name for labels in datasets
MODEL_OUTPUT_DIR = "./ticket-classifier-model"  # Directory to save fine-tuned model
LABEL_MAPPING_DIR = MODEL_OUTPUT_DIR  # Save mappings with the model
TEST_SET_SIZE = 0.2
RANDOM_STATE = 42
MAX_TOKEN_LENGTH = 128  # Max length for tokenizer
TRAIN_BATCH_SIZE = 16  # Adjust based on GPU memory
EVAL_BATCH_SIZE = 32
NUM_EPOCHS = 1  # Use 1 for fast results, increase for better accuracy
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01

# Enable development mode for faster training and less resource usage
DEV_MODE = True  # Set to False for full training

# Initialize global variables
label2id = {}
id2label = {}
num_labels = 0





# Text cleaning function
def clean_text(text):
    """Basic, language-agnostic text cleaning."""
    if not isinstance(text, str):
        return ""
    text = text.lower()
    text = re.sub(r'[^a-z0-9\säöüß]', '', text, flags=re.IGNORECASE)  # Keep German chars
    text = re.sub(r'\d+', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# Safe training function to prevent KeyboardInterrupt
def safe_train_with_timeout(trainer, timeout_minutes=15):
    """
    Run the training process with a timeout to prevent KeyboardInterrupt and hanging
    """
    print(f"Starting training with {timeout_minutes} minute timeout safety...")

    result = [None]
    exception = [None]
    completed = [False]

    def training_thread():
        try:
            result[0] = trainer.train()
            completed[0] = True
        except Exception as e:
            exception[0] = e
            print(f"Training error: {e}")

    # Start the training in a separate thread
    thread = threading.Thread(target=training_thread)
    thread.daemon = True
    thread.start()

    # Monitor the training thread
    timeout_seconds = timeout_minutes * 60
    start_time = time.time()

    # Loop until training completes or times out
    while thread.is_alive():
        elapsed = time.time() - start_time

        # Check if we've exceeded the timeout
        if elapsed > timeout_seconds:
            print(f"Training timeout after {timeout_minutes} minutes.")
            return None

        # Print progress updates
        if int(elapsed) % 60 == 0 and int(elapsed) > 0:  # Every minute
            minutes_elapsed = elapsed / 60
            print(f"Training in progress... {minutes_elapsed:.1f} minutes elapsed")

        time.sleep(1)

    # Check for errors
    if exception[0]:
        print(f"Training encountered an error: {exception[0]}")
        traceback.print_exc()
        return None

    if completed[0]:
        print("Training completed successfully!")

    return result[0]

# Function to get a smaller model for development
def get_development_model_checkpoint():
    """Returns a smaller, faster model for development/testing"""
    # Options in order of increasing size/quality
    options = [
        "prajjwal1/bert-tiny",         # 4.4M parameters - fastest
        "prajjwal1/bert-mini",         # 11.3M parameters - small
        "distilbert-base-multilingual-cased",  # ~66M parameters - good balance
        "bert-base-multilingual-cased"  # 110M parameter model - original, slowest
    ]

    return options[0] if DEV_MODE else options[2]


# ## Create Dummy Dataset
#
# This function creates a synthetic dataset for testing if you don't have real data:

# %%
def create_dummy_dataset(filename, num_samples=1000):
    """Creates a dummy dataset for testing if the actual dataset is not available."""
    print(f"Creating dummy dataset at {filename} for testing purposes...")

    # Define some sample queues
    queues = ["technical_support", "billing", "account_management", "sales", "general_inquiry"]

    # Sample subjects and bodies for each queue
    queue_templates = {
        "technical_support": {
            "subjects": [
                "Can't login", "App crashing", "Website error", "Connection issue",
                "Password reset not working", "Software bug", "Installation problem"
            ],
            "bodies": [
                "I can't log into my account", "The app keeps crashing when I try to use it",
                "I'm getting an error when I try to access the website",
                "My connection keeps dropping", "Password reset link doesn't work",
                "There's a bug in the latest version", "Having trouble installing the software"
            ]
        },
        "billing": {
            "subjects": [
                "Invoice question", "Billing error", "Payment issue", "Refund request",
                "Subscription problem", "Pricing question", "Payment method update"
            ],
            "bodies": [
                "I have a question about my recent invoice", "I think there's an error on my bill",
                "My payment was declined", "I would like a refund for my purchase",
                "Having issues with my subscription", "Can you explain the pricing?",
                "Need to update my payment information"
            ]
        },
        "account_management": {
            "subjects": [
                "Change email", "Update account details", "Close account", "Change password",
                "Account verification", "Account access", "Profile update"
            ],
            "bodies": [
                "I need to change my email address", "Please update my account information",
                "I want to close my account", "Need to change my password",
                "How do I verify my account?", "Cannot access my account",
                "Need to update my profile"
            ]
        },
        "sales": {
            "subjects": [
                "Product inquiry", "Purchase question", "Discount inquiry", "Enterprise plan",
                "Upgrade options", "Product comparison", "Bulk order"
            ],
            "bodies": [
                "I want to know more about your product", "Question about making a purchase",
                "Are there any discounts available?", "Information about enterprise plans",
                "Looking to upgrade my account", "How does your product compare to competitors?",
                "Interested in placing a bulk order"
            ]
        },
        "general_inquiry": {
            "subjects": [
                "General question", "Information request", "Help needed", "Contact info",
                "Hours of operation", "Shipping info", "Return policy"
            ],
            "bodies": [
                "I have a general question", "Can you provide more information?",
                "I need help with something", "What is your contact information?",
                "What are your hours of operation?", "Information about shipping",
                "What is your return policy?"
            ]
        }
    }

    # Generate random data
    import random
    data = []

    for _ in range(num_samples):
        queue = random.choice(queues)
        templates = queue_templates[queue]

        subject = random.choice(templates["subjects"])
        body = random.choice(templates["bodies"])

        # Add some randomness to the text
        if random.random() > 0.5:
            subject += f" #{random.randint(1000, 9999)}"
        if random.random() > 0.7:
            body += f". Reference ID: REF-{random.randint(10000, 99999)}"

        data.append({
            "subject": subject,
            "body": body,
            "queue": queue
        })

    # Create DataFrame and save to CSV
    df = pd.DataFrame(data)
    df.to_csv(filename, index=False)
    print(f"Created dummy dataset with {len(df)} samples across {len(queues)} queues.")
    return filename

# Create dummy dataset if we don't have real data
if not os.path.exists(DATASET_FILENAME):
    print(f"Dataset file '{DATASET_FILENAME}' not found. Creating dummy dataset...")
    create_dummy_dataset(DATASET_FILENAME, num_samples=2000)


# ## Data Processing
#
# This class handles loading and preparing the ticket data:

# %%
class TicketDataProcessor:
    """Handles loading, cleaning, and preparing ticket data for transformers."""

    def __init__(self, csv_path, text_cols, target_col):
        self.csv_path = csv_path
        self.text_cols = text_cols
        self.target_col = target_col

    def load_and_prepare(self):
        """Loads CSV, cleans text, creates labels, returns HF Dataset."""
        global label2id, id2label, num_labels

        print(f"Loading dataset from {self.csv_path}...")
        try:
            df = pd.read_csv(self.csv_path)
            print(f"Loaded {len(df)} rows.")
        except Exception as e:
            print(f"Error loading CSV: {e}")
            return None

        # Basic checks and preprocessing
        if self.target_col not in df.columns:
            print(f"Error: Target column '{self.target_col}' not found.")
            return None

        if not all(col in df.columns for col in self.text_cols):
            print(f"Error: One or more text columns not found.")
            return None

        # Handle missing values and combine text
        df.dropna(subset=[self.target_col], inplace=True)
        df[TEXT_FEATURE_COLUMN] = df[self.text_cols].fillna('').agg(' '.join, axis=1)
        df[TEXT_FEATURE_COLUMN] = df[TEXT_FEATURE_COLUMN].apply(clean_text)

        # Create label mappings
        unique_labels = sorted(df[self.target_col].unique().tolist())
        num_labels = len(unique_labels)
        label2id = {name: i for i, name in enumerate(unique_labels)}
        id2label = {i: name for i, name in enumerate(unique_labels)}
        print(f"Found {num_labels} unique labels: {unique_labels}")

        # Convert the target column to numeric IDs
        df[LABEL_COLUMN] = df[self.target_col].map(lambda x: label2id.get(x, 0))

        # Keep only necessary columns
        df = df[[TEXT_FEATURE_COLUMN, LABEL_COLUMN]]

        # Convert to Hugging Face Dataset
        raw_dataset = Dataset.from_pandas(df)

        # Limit dataset size in development mode
        if DEV_MODE:
            max_samples = 2000
            if len(raw_dataset) > max_samples:
                indices = np.random.choice(len(raw_dataset), max_samples, replace=False)
                raw_dataset = raw_dataset.select(indices)
                print(f"DEV MODE: Limited dataset to {max_samples} samples")

        # Split dataset
        train_test_split = raw_dataset.train_test_split(
            test_size=TEST_SET_SIZE,
            seed=RANDOM_STATE
        )
        train_val_split = train_test_split['train'].train_test_split(
            test_size=0.1,
            seed=RANDOM_STATE
        )

        dataset_dict = DatasetDict({
            'train': train_val_split['train'],
            'validation': train_val_split['test'],
            'test': train_test_split['test']
        })

        print("Dataset prepared and split:")
        for split, ds in dataset_dict.items():
            print(f"  {split}: {len(ds)} examples")

        return dataset_dict


# ## Model Training
#
# This class handles the training process:

# %%
# Tokenizer function
def tokenize_function(examples, tokenizer):
    """Tokenizes text data for sequence classification."""
    try:
        if TEXT_FEATURE_COLUMN not in examples:
            raise ValueError(f"Text column '{TEXT_FEATURE_COLUMN}' not found.")

        texts = examples[TEXT_FEATURE_COLUMN]
        if not texts or all(not text for text in texts):
            raise ValueError("Input texts are empty or None")

        return tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=MAX_TOKEN_LENGTH,
            return_tensors=None
        )
    except Exception as e:
        print(f"Error in tokenization: {e}")
        batch_size = len(examples[list(examples.keys())[0]])
        return {
            "input_ids": [[0] * 5] * batch_size,
            "attention_mask": [[0] * 5] * batch_size
        }

# Metrics computation
def compute_metrics(eval_pred):
    """Computes metrics for sequence classification."""
    predictions, labels = eval_pred
    preds = np.argmax(predictions, axis=1)

    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, preds, average='macro', zero_division=0
    )
    acc = accuracy_score(labels, preds)

    return {
        'accuracy': acc,
        'f1_macro': f1_macro,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro
    }

# Trainer class
class TicketTrainer:
    """Handles the training process using transformers.Trainer."""
    def __init__(self, model_checkpoint, dataset_dict, output_dir):
        self.model_checkpoint = model_checkpoint
        self.dataset_dict = dataset_dict
        self.output_dir = output_dir
        self.tokenizer = None
        self.model = None
        self.trainer = None
        self.tokenized_datasets = None

    def setup_trainer(self):
        """Loads tokenizer, model, tokenizes data, and configures Trainer."""
        global label2id, id2label, num_labels

        if num_labels == 0 or not label2id or not id2label:
            print("Error: Label mappings not initialized. Run data processing first.")
            return False

        # Load tokenizer
        print(f"Loading tokenizer for '{self.model_checkpoint}'...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint)
        except Exception as e:
            print(f"Error loading tokenizer: {e}")
            try:
                self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
            except Exception:
                return False

        # Tokenize datasets
        try:
            def process_split(dataset, split_name):
                print(f"Tokenizing {split_name} split...")
                tokenized = dataset.map(
                    lambda examples: tokenize_function(examples, self.tokenizer),
                    batched=True,
                    batch_size=32
                )
                if TEXT_FEATURE_COLUMN in tokenized.column_names:
                    tokenized = tokenized.remove_columns([TEXT_FEATURE_COLUMN])
                tokenized.set_format("torch")
                return tokenized

            self.tokenized_datasets = DatasetDict({
                split: process_split(dataset, split)
                for split, dataset in self.dataset_dict.items()
            })

        except Exception as e:
            print(f"Error during dataset tokenization: {e}")
            return False

        # Load model
        print(f"Loading model '{self.model_checkpoint}'...")
        try:
            self.model = AutoModelForSequenceClassification.from_pretrained(
                self.model_checkpoint,
                num_labels=num_labels,
                id2label=id2label,
                label2id=label2id,
                ignore_mismatched_sizes=True
            )
        except Exception as e:
            print(f"Error loading model: {e}")
            try:
                fallback_model = "prajjwal1/bert-tiny" if DEV_MODE else "bert-base-uncased"
                self.model = AutoModelForSequenceClassification.from_pretrained(
                    fallback_model,
                    num_labels=num_labels,
                    id2label=id2label,
                    label2id=label2id,
                    ignore_mismatched_sizes=True
                )
            except Exception:
                return False

        # Create training arguments
        training_args = TrainingArguments(
            output_dir=self.output_dir,
            learning_rate=5e-5 if DEV_MODE else LEARNING_RATE,
            per_device_train_batch_size=TRAIN_BATCH_SIZE,
            per_device_eval_batch_size=EVAL_BATCH_SIZE,
            num_train_epochs=1 if DEV_MODE else NUM_EPOCHS,
            weight_decay=WEIGHT_DECAY,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            metric_for_best_model="f1_macro",
            push_to_hub=False,
            fp16=torch.cuda.is_available(),
            logging_dir=f'{self.output_dir}/logs',
            logging_steps=100,
            report_to="none",
            save_total_limit=1,
            dataloader_num_workers=2 if DEV_MODE else 0,
            disable_tqdm=False
        )

        # Initialize Trainer
        callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]

        try:
            self.trainer = Trainer(
                model=self.model,
                args=training_args,
                train_dataset=self.tokenized_datasets["train"],
                eval_dataset=self.tokenized_datasets["validation"],
                tokenizer=self.tokenizer,
                data_collator=DataCollatorWithPadding(tokenizer=self.tokenizer),
                compute_metrics=compute_metrics,
                callbacks=callbacks
            )
            print("Trainer setup complete.")
            return True
        except Exception as e:
            print(f"Error setting up trainer: {e}")
            return False

    def evaluate_on_test(self):
        """Evaluates the final model on the test set."""
        if not self.trainer or "test" not in self.tokenized_datasets:
            return None

        try:
            print("Evaluating model on the test set...")
            test_results = self.trainer.evaluate(self.tokenized_datasets["test"])
            return test_results
        except Exception as e:
            print(f"Error during test evaluation: {e}")
            return None

    def save_model_and_tokenizer(self):
        """Saves the fine-tuned model and tokenizer."""
        if not self.trainer:
            return
        try:
            print(f"Saving model and tokenizer to {self.output_dir}...")
            self.trainer.save_model(self.output_dir)
            self.tokenizer.save_pretrained(self.output_dir)
            print("Model and tokenizer saved.")
        except Exception as e:
            print(f"Error saving model: {e}")


# ## Utilities for Label Mappings

# %%
def save_label_mappings(output_dir):
    """Saves label2id and id2label mappings."""
    global label2id, id2label
    if not label2id or not id2label:
        print("Error: Label mappings not available to save.")
        return False
    os.makedirs(output_dir, exist_ok=True)
    try:
        with open(os.path.join(output_dir, "label2id.json"), 'w') as f:
            json.dump(label2id, f, indent=2)
        with open(os.path.join(output_dir, "id2label.json"), 'w') as f:
            id2label_str = {str(k): v for k, v in id2label.items()}
            json.dump(id2label_str, f, indent=2)
        return True
    except Exception as e:
        print(f"Error saving label mappings: {e}")
        return False

def load_label_mappings(input_dir):
    """Loads label2id and id2label mappings."""
    global label2id, id2label, num_labels
    try:
        path_l2i = os.path.join(input_dir, "label2id.json")
        path_i2l = os.path.join(input_dir, "id2label.json")
        if not os.path.exists(path_l2i) or not os.path.exists(path_i2l):
            return False

        with open(path_l2i, 'r') as f:
            label2id = json.load(f)
        with open(path_i2l, 'r') as f:
            loaded_id2label = json.load(f)
            id2label = {int(k): v for k, v in loaded_id2label.items()}

        num_labels = len(label2id)
        print(f"Loaded {num_labels} label mappings.")
        return True
    except Exception as e:
        print(f"Error loading label mappings: {e}")
        return False

def check_model_exists(model_dir):
    """Check if a trained transformer model exists and is valid."""
    if not os.path.isdir(model_dir):
        return False, False
    required_files = ["config.json", "pytorch_model.bin", "tokenizer_config.json"]
    has_files = all(os.path.exists(os.path.join(model_dir, f)) for f in required_files)
    has_mappings = os.path.exists(os.path.join(model_dir, "label2id.json"))
    return True, (has_files and has_mappings)


# ## Prediction Class

# %%
class TicketPredictor:
    """Handles prediction using a fine-tuned transformer model."""
    def __init__(self, model_path):
        self.model_path = model_path
        self.pipeline = None
        self.id2label = None
        self.tokenizer = None
        self.model = None

        # First load the label mappings
        if not load_label_mappings(self.model_path):
            print("Warning: Failed to load label mappings.")

        try:
            # Try loading as pipeline first
            device = 0 if torch.cuda.is_available() else -1
            self.pipeline = pipeline(
                "text-classification",
                model=self.model_path,
                tokenizer=self.model_path,
                device=device
            )
            self.id2label = self.pipeline.model.config.id2label

        except Exception as e:
            print(f"Error initializing pipeline: {e}")

            # Fallback to loading model and tokenizer separately
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
                self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
                self.id2label = self.model.config.id2label
            except Exception as e2:
                raise RuntimeError("Failed to initialize prediction") from e2

    def predict(self, text):
        """Predicts the queue for a given text."""
        if not text or not isinstance(text, str) or not text.strip():
            return "N/A", 0.0, {}

        try:
            # Clean the input text
            cleaned_text = clean_text(text)
            if not cleaned_text:
                return "N/A", 0.0, {}

            # Get predictions
            if self.pipeline:
                try:
                    results = self.pipeline(cleaned_text, return_all_scores=True)
                except Exception:
                    return "Error", 0.0, {}
            else:
                # Fallback to manual prediction
                try:
                    inputs = self.tokenizer(cleaned_text, return_tensors="pt",
                                           truncation=True, padding=True)
                    with torch.no_grad():
                        outputs = self.model(**inputs)

                    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
                    probs = probs[0].tolist()

                    results = [{
                        'label': str(i),
                        'score': prob
                    } for i, prob in enumerate(probs)]
                except Exception:
                    return "Error", 0.0, {}

            # Process results
            if isinstance(results, list) and len(results) > 0:
                scores_list = results[0] if isinstance(results[0], list) else results
            else:
                return "Error", 0.0, {}

            # Extract probabilities and map labels
            probabilities = {}
            top_pred_label = "Unknown"
            top_pred_score = 0.0

            for item in scores_list:
                if not isinstance(item, dict) or 'label' not in item or 'score' not in item:
                    continue

                label_str = item.get('label', '')
                score = item.get('score', 0.0)

                # Handle different label formats
                queue_name = label_str

                if label_str.startswith("LABEL_"):
                    try:
                        label_id = int(label_str.split("_")[1])
                        if self.id2label and label_id in self.id2label:
                            queue_name = self.id2label[label_id]
                    except:
                        pass
                elif label_str.isdigit() and self.id2label and int(label_str) in self.id2label:
                    queue_name = self.id2label[int(label_str)]

                probabilities[queue_name] = score
                if score > top_pred_score:
                    top_pred_score = score
                    top_pred_label = queue_name

            return top_pred_label, top_pred_score, probabilities

        except Exception:
            return "Error", 0.0, {}


# ## Gradio Interface

# Fix for the Gradio interface configuration
# Replace the create_gradio_interface function with this version:

def create_gradio_interface(predictor):
    """Creates the Gradio web interface."""
    if not predictor:
        return None

    def predict_for_gradio(text_input):
        """Process prediction and create visualization."""
        if not text_input or not text_input.strip():
            return "Please enter some ticket text.", None

        try:
            pred_label, pred_score, probabilities = predictor.predict(text_input)

            if pred_label in ("Error", "N/A"):
                return "An error occurred or input was empty.", None

            output_md = f"""
            ## Ticket Classification Results

            **Predicted Queue:** `{pred_label}`

            **Confidence:** {pred_score:.2%}
            """

            # Generate visualization
            try:
                if probabilities and len(probabilities) > 0:
                    # Sort and limit to top categories
                    sorted_items = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
                    queues = [item[0] for item in sorted_items[:10]]  # Top 10 only
                    scores = [item[1] for item in sorted_items[:10]]

                    # Create plot
                    fig, ax = plt.subplots(figsize=(8, min(8, max(4, len(queues) * 0.5))))
                    bars = ax.barh(queues, scores, color='skyblue')
                    ax.set_xlabel('Probability')
                    ax.set_title('Prediction Probabilities')
                    ax.set_xlim(0, 1)

                    # Add labels
                    for i, bar in enumerate(bars):
                        width = bar.get_width()
                        ax.text(min(width + 0.01, 0.95), bar.get_y() + bar.get_height()/2,
                               f'{width:.1%}', va='center')

                    plt.tight_layout()
                    prob_plot = fig
                    plt.close(fig)
                else:
                    prob_plot = None
            except Exception:
                prob_plot = None

            return output_md, prob_plot

        except Exception:
            return "An error occurred during processing.", None

    # Create the interface - fixed flagging_mode parameter
    interface = gr.Interface(
        fn=predict_for_gradio,
        inputs=gr.Textbox(
            lines=8,
            label="Ticket Text Input",
            placeholder="Enter support ticket subject and/or body here..."
        ),
        outputs=[
            gr.Markdown(label="Classification Results"),
            gr.Plot(label="Probability Distribution")
        ],
        title="Support Ticket Queue Classifier",
        description="Predicts the appropriate queue for a support ticket.",
        examples=[
            ["Subject: Cannot login Body: My password reset link expired, please help me access my account."],
            ["Order #12345 not received yet, tracking shows stuck?"],
            ["Subject: Billing inquiry Body: I believe I was overcharged on my last invoice."],
            ["How do I update my company address?"],
            ["[fr] Sujet : Problème de connexion Corps : Impossible de me connecter à mon compte."]
        ],
        flagging_mode="never",  # This is the corrected parameter (was allow_flagging=False)
        analytics_enabled=False,
        cache_examples=False,
    )

    interface.queue(max_size=1)
    return interface



# ## Training model and Prediction Pipeline

def train_and_launch():
    """Run the complete training and prediction pipeline."""
    # Choose model
    MODEL_CHECKPOINT = get_development_model_checkpoint()
    print(f"Using model: {MODEL_CHECKPOINT}")

    predictor = None
    force_retrain = False  # Set to True to force retraining

    # Check for existing model
    exists, is_valid = check_model_exists(MODEL_OUTPUT_DIR)

    if exists and is_valid and not force_retrain:
        print(f"✓ Valid model found. Loading for prediction.")
        try:
            load_label_mappings(MODEL_OUTPUT_DIR)
            predictor = TicketPredictor(MODEL_OUTPUT_DIR)
        except Exception as e:
            print(f"Error loading model: {e}")
            exists, is_valid = False  # Mark for retraining
            shutil.rmtree(MODEL_OUTPUT_DIR, ignore_errors=True)

    if predictor is None:  # Need to train
        if exists and not is_valid:
            print(f"Found incomplete model. Cleaning up and retraining...")
            shutil.rmtree(MODEL_OUTPUT_DIR, ignore_errors=True)

        os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

        # Process data
        data_processor = TicketDataProcessor(DATASET_FILENAME, TEXT_COLUMNS, TARGET_COLUMN)
        dataset_dict = data_processor.load_and_prepare()

        if dataset_dict is None:
            print("Error: Failed to process data.")
            return None

        # Save label mappings
        save_label_mappings(MODEL_OUTPUT_DIR)

        # Train model
        print("\nSetting up trainer...")
        ticket_trainer = TicketTrainer(MODEL_CHECKPOINT, dataset_dict, MODEL_OUTPUT_DIR)
        if not ticket_trainer.setup_trainer():
            print("Error: Failed to set up trainer.")
            return None

        print("\nStarting training (with timeout protection)...")
        train_result = safe_train_with_timeout(ticket_trainer.trainer, timeout_minutes=15)

        # Save model regardless of training completion
        print("\nSaving model...")
        ticket_trainer.save_model_and_tokenizer()
        save_label_mappings(MODEL_OUTPUT_DIR)

        # Evaluate if possible
        try:
            test_results = ticket_trainer.evaluate_on_test()
            if test_results:
                print("\nTest metrics:", test_results)
        except Exception as e:
            print(f"Evaluation error: {e}")

        # Initialize predictor with trained model
        try:
            load_label_mappings(MODEL_OUTPUT_DIR)
            predictor = TicketPredictor(MODEL_OUTPUT_DIR)
        except Exception as e:
            print(f"Error initializing predictor: {e}")
            return None

    # Launch Gradio interface
    if predictor:
        print("\nCreating Gradio interface...")
        interface = create_gradio_interface(predictor)
        if interface:
            print("\nLaunching interface. Share=True will generate a public link.")
            return interface.launch(share=True, debug=False)
        else:
            print("Failed to create interface.")
    else:
        print("No predictor available.")

    return None



train_and_launch()



def text_interface():
    """Run a simple text-based prediction interface."""
    MODEL_CHECKPOINT = get_development_model_checkpoint()


    if os.path.exists(MODEL_OUTPUT_DIR):
        try:
            load_label_mappings(MODEL_OUTPUT_DIR)
            predictor = TicketPredictor(MODEL_OUTPUT_DIR)

            print("\n=== Text-Based Prediction Interface ===")
            print("Type 'q' or 'quit' to exit")

            while True:
                user_input = input("\nEnter ticket text: ")

                if user_input.lower() in ('q', 'quit', 'exit'):
                    break

                if not user_input.strip():
                    print("Please enter some text.")
                    continue

                print("Processing...")
                pred_label, pred_score, probabilities = predictor.predict(user_input)

                if pred_label not in ("Error", "N/A"):
                    print(f"\nPredicted Queue: {pred_label}")
                    print(f"Confidence: {pred_score:.2%}")


                    if probabilities and len(probabilities) > 1:
                        print("\nTop Predictions:")
                        sorted_preds = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
                        for i, (queue, score) in enumerate(sorted_preds[:3]):
                            print(f"{i+1}. {queue}: {score:.2%}")
                else:
                    print(f"Prediction error: {pred_label}")

            print("Interface closed.")
            return True
        except Exception as e:
            print(f"Error loading model: {e}")
            return False
    else:
        print("No model found. Please run training first.")
        return False


Using model: prajjwal1/bert-tiny
Found incomplete model. Cleaning up and retraining...
Loading dataset from dataset-tickets-multi-lang-4-20k.csv...
Loaded 20000 rows.
Found 10 unique labels: ['Billing and Payments', 'Customer Service', 'General Inquiry', 'Human Resources', 'IT Support', 'Product Support', 'Returns and Exchanges', 'Sales and Pre-Sales', 'Service Outages and Maintenance', 'Technical Support']
DEV MODE: Limited dataset to 2000 samples
Dataset prepared and split:
  train: 1440 examples
  validation: 160 examples
  test: 400 examples

Setting up trainer...
Loading tokenizer for 'prajjwal1/bert-tiny'...
Tokenizing train split...


Map:   0%|          | 0/1440 [00:00<?, ? examples/s]

Tokenizing validation split...


Map:   0%|          | 0/160 [00:00<?, ? examples/s]

Tokenizing test split...


Map:   0%|          | 0/400 [00:00<?, ? examples/s]

Loading model 'prajjwal1/bert-tiny'...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  self.trainer = Trainer(


Trainer setup complete.

Starting training (with timeout protection)...
Starting training with 15 minute timeout safety...


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Precision Macro,Recall Macro
1,No log,2.066205,0.26875,0.047072,0.029861,0.111111


Training completed successfully!

Saving model...
Saving model and tokenizer to ./ticket-classifier-model...
Model and tokenizer saved.
Evaluating model on the test set...


Device set to use cuda:0



Test metrics: {'eval_loss': 2.0580272674560547, 'eval_accuracy': 0.305, 'eval_f1_macro': 0.0519369944657301, 'eval_precision_macro': 0.033888888888888885, 'eval_recall_macro': 0.1111111111111111, 'eval_runtime': 0.3531, 'eval_samples_per_second': 1132.756, 'eval_steps_per_second': 36.815, 'epoch': 1.0}
Loaded 10 label mappings.
Loaded 10 label mappings.

Creating Gradio interface...

Launching interface. Share=True will generate a public link.
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://00fa84d3fba26877f0.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
