<a href="https://colab.research.google.com/github/ShovalBenjer/Bigdata_Pyspark_Spark_Hadoop_Apache/blob/main/CFPB_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install kafka-python transformers torch autoviz mlflow pyspark==3.5.5 findspark pyarrow pandas pyyaml ipython

Collecting kafka-python
  Downloading kafka_python-2.1.3-py2.py3-none-any.whl.metadata (9.1 kB)
Collecting autoviz
  Downloading autoviz-0.1.905-py3-none-any.whl.metadata (14 kB)
Collecting mlflow
  Downloading mlflow-2.21.2-py3-none-any.whl.metadata (30 kB)
Collecting findspark
  Downloading findspark-2.0.1-py2.py3-none-any.whl.metadata (352 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.

In [None]:
#!/usr/bin/env python
# Comprehensive Big Data Pipeline with Spark Structured Streaming, Kafka
# Designed for Google Colab Pro with External Kafka Setup
# VERSION incorporating fixes for Indentation, Task 2 Filter, Task 4 GPU/Except, Task 5 TypeError, AutoViz Inline

print("--- Initializing Pipeline Script ---")
print("Ensure Kafka/Zookeeper are running externally and topics are created.")
print("Ensure Consumer_Complaints.csv is available at /content/Consumer_Complaints.csv")
print("Ensure required packages are installed (run pip install cell).")

import os
import sys
import torch
import numpy as np
import json
import time
import pandas as pd
import mlflow
import findspark
import contextlib
from typing import Iterator

# --- Spark Configuration ---
# Set Spark environment variables
os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
findspark.init() # Finds Spark installation

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import (
    StringIndexer, OneHotEncoder, VectorAssembler, SQLTransformer # Import SQLTransformer
)
from pyspark.ml.base import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable, MLWriter, MLReader, DefaultParamsWriter, DefaultParamsReader
from pyspark.ml.torch.distributor import TorchDistributor
from pyspark.sql.streaming import StreamingQueryListener
from kafka import KafkaProducer, errors as kafka_errors

# --- Transformers & PyTorch ---
from transformers import DistilBertTokenizer, DistilBertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, RandomSampler, DistributedSampler
import torch.nn as nn
import torch.distributed as dist
from collections import OrderedDict

# --- Visualization ---
try:
    from autoviz import AutoViz_Class
    import matplotlib.pyplot as plt
    AUTOVIZ_AVAILABLE = True
except ImportError:
    AUTOVIZ_AVAILABLE = False
    print("WARN: autoviz or matplotlib not found. Visualization will be skipped.")

# --- Configuration Variables ---
BASE_DIR = "/content/consumer_complaints"
CSV_FILE_PATH = "/content/Consumer_Complaints.csv"
TEST_DATA_PERSISTENCE_PATH = f"{BASE_DIR}/data/test_data_source.parquet"
TRAINING_PIPELINE_SAVE_PATH = f"{BASE_DIR}/models/training_pipeline"
EMBEDDING_MODEL_SAVE_PATH = f"{BASE_DIR}/models/embedding_model"
STREAMING_CHECKPOINT_LOCATION = f"{BASE_DIR}/checkpoints"
MLFLOW_TRACKING_URI = f"file://{BASE_DIR}/mlflow"
VISUALIZATION_DIR = f"{BASE_DIR}/visualizations"
TRAIN_PARQUET_PATH = f"{BASE_DIR}/data/train_data.parquet"
VAL_PARQUET_PATH = f"{BASE_DIR}/data/val_data.parquet"

# Kafka configuration (assuming external setup on localhost)
KAFKA_BROKERS = "localhost:9092"
KAFKA_TOPIC_RAW = "complaints-raw"
KAFKA_TOPIC_TRAINING = "complaints-training-data"
KAFKA_TOPIC_TESTING_STREAM = "complaints-testing-stream"
KAFKA_TOPIC_PREDICTIONS = "complaint-predictions"
KAFKA_TOPIC_METRICS = "streaming-metrics"

# Simulation parameters
MESSAGES_PER_MINUTE = 100
BATCH_SIZE = 20

# Training parameters
TRAIN_SAMPLE_LIMIT = 20000 # Adjust based on Colab RAM
VAL_SAMPLE_LIMIT = 2000   # Adjust based on Colab RAM
BERT_MAX_LENGTH = 128
BERT_BATCH_SIZE = 16 # Per GPU/Process
NUM_EPOCHS = 3       # Reduced for faster demo

# --- Create Directories ---
for path in [
    f"{BASE_DIR}/data", f"{BASE_DIR}/models", f"{BASE_DIR}/checkpoints",
    MLFLOW_TRACKING_URI.replace("file://", ""), VISUALIZATION_DIR
]:
    os.makedirs(path, exist_ok=True)

# --- Initialize Spark Session ---
print("Initializing Spark Session...")
# --- Initialize Spark Session ---
spark = SparkSession.builder \
    .appName("Consumer Complaints ML Pipeline") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.5") \
    .config("spark.sql.streaming.checkpointLocation", STREAMING_CHECKPOINT_LOCATION) \
    .config("spark.executor.memory", "6g") \
    .config("spark.driver.memory", "6g") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()
print("\n--- Spark Configuration ---")
print(f"Spark Version: {spark.version}")
print(f"Application ID: {spark.sparkContext.applicationId}")
print(f"Using PySpark: {spark.sparkContext.pythonVer}")

# --- MLflow Setup ---
def setup_mlflow_tracking():
    mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
    mlflow.set_experiment("complaint-classification")
    print(f"MLflow tracking configured. URI: {MLFLOW_TRACKING_URI}")
    return mlflow

mlflow = setup_mlflow_tracking()

In [None]:
# --- Schema Definitions ---
def get_full_schema():
    return StructType([
        StructField("Date received", StringType(), True), StructField("Product", StringType(), True),
        StructField("Sub-product", StringType(), True), StructField("Issue", StringType(), True),
        StructField("Sub-issue", StringType(), True), StructField("Consumer complaint narrative", StringType(), True),
        StructField("Company public response", StringType(), True), StructField("Company", StringType(), True),
        StructField("State", StringType(), True), StructField("ZIP code", StringType(), True),
        StructField("Tags", StringType(), True), StructField("Consumer consent provided?", StringType(), True),
        StructField("Submitted via", StringType(), True), StructField("Date sent to company", StringType(), True),
        StructField("Company response to consumer", StringType(), True), StructField("Timely response?", StringType(), True),
        StructField("Consumer disputed?", StringType(), True), StructField("Complaint ID", StringType(), True)
    ])

def get_streaming_schema():
    return StructType([
        StructField("Date received", StringType(), True), StructField("Complaint ID", StringType(), True),
        StructField("Company", StringType(), True), StructField("State", StringType(), True),
        StructField("ZIP code", StringType(), True), StructField("Submitted via", StringType(), True),
        StructField("Consumer complaint narrative", StringType(), True)
    ])

# --- Custom BERT Embedding Transformer ---
class BERTEmbeddingTransformer(Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable):
    def __init__(self, inputCol=None, outputCol=None, modelPath=None):
        super().__init__()
        self.inputCol = Param(self, "inputCol", "")
        self.outputCol = Param(self, "outputCol", "")
        self.modelPath = Param(self, "modelPath", "")
        self._setDefault(inputCol=inputCol, outputCol=outputCol, modelPath=modelPath)
        self.setModelPath(modelPath) # Set model path on init

    def setInputCol(self, value): return self._set(inputCol=value)
    def getInputCol(self): return self.getOrDefault(self.inputCol)
    def setOutputCol(self, value): return self._set(outputCol=value)
    def getOutputCol(self): return self.getOrDefault(self.outputCol)
    def setModelPath(self, value): return self._set(modelPath=value)
    def getModelPath(self): return self.getOrDefault(self.modelPath)

    def _transform(self, dataset):
        schema = dataset.schema # Get schema once
        input_col_name = self.getInputCol()
        output_col_name = self.getOutputCol()
        _model_path = self.getModelPath() # Get path for closure

        # Ensure input column exists
        if input_col_name not in schema.fieldNames():
            raise ValueError(f"Input column '{input_col_name}' does not exist in DataFrame.")

        # UDF definition remains the same, ensure _model_path is used
        @F.pandas_udf(ArrayType(FloatType()))
        def bert_embed_batch(texts_series: pd.Series) -> pd.Series:
            # Load model once per executor process
            if not hasattr(bert_embed_batch, 'model') or not hasattr(bert_embed_batch, 'tokenizer'):
                model_dir = _model_path # Use variable from closure
                try:
                    if not model_dir or not os.path.exists(model_dir):
                         raise ValueError(f"Model path '{model_dir}' not found or not specified.")
                    bert_embed_batch.tokenizer = DistilBertTokenizer.from_pretrained(model_dir)
                    bert_embed_batch.model = DistilBertModel.from_pretrained(model_dir)
                    bert_embed_batch.model.to("cpu").eval() # Use CPU in UDF
                    bert_embed_batch.embedding_dim = bert_embed_batch.model.config.dim
                    print(f"Worker loaded BERT embedder from {model_dir}.")
                except Exception as e:
                    print(f"Worker ERROR loading BERT embedder from '{model_dir}': {e}. Using zero embeddings.")
                    bert_embed_batch.tokenizer = None
                    bert_embed_batch.model = None
                    bert_embed_batch.embedding_dim = 768 # Default DistilBERT

            results = []
            if bert_embed_batch.model is None or bert_embed_batch.tokenizer is None:
                # Return zeros if model failed to load
                return pd.Series([[0.0] * bert_embed_batch.embedding_dim] * len(texts_series))

            # Process texts
            for text in texts_series:
                try:
                    clean_text = str(text) if text is not None else ""
                    if len(clean_text.strip()) == 0:
                        results.append([0.0] * bert_embed_batch.embedding_dim)
                        continue

                    inputs = bert_embed_batch.tokenizer(
                        clean_text, return_tensors="pt", truncation=True,
                        max_length=BERT_MAX_LENGTH, padding="max_length"
                    )
                    with torch.no_grad():
                        outputs = bert_embed_batch.model(**inputs)
                    # Use CLS token embedding [:, 0, :]
                    results.append(outputs.last_hidden_state[:, 0, :].squeeze().tolist())
                except Exception as e:
                    # Log less frequently or sample errors if too noisy
                    # print(f"Worker ERROR embedding text snippet '{clean_text[:50]}...': {e}")
                    results.append([0.0] * bert_embed_batch.embedding_dim)
            return pd.Series(results)

        # Apply UDF
        return dataset.withColumn(output_col_name, bert_embed_batch(F.col(input_col_name)))

    # write/read methods remain the same as previous correct version
    def write(self):
        writer = DefaultParamsWriter(self)
        original_save = writer.save
        def custom_save(path):
            original_save(path)
            extra_metadata = {"modelPath": self.getModelPath()}
            extra_metadata_path = os.path.join(path, "bert_model_metadata.json")
            with open(extra_metadata_path, "w") as f: json.dump(extra_metadata, f)
        writer.save = custom_save
        return writer
    @classmethod
    def read(cls):
        reader = DefaultParamsReader(cls)
        original_load = reader.load
        def custom_load(path):
            instance = original_load(path)
            extra_metadata_path = os.path.join(path, "bert_model_metadata.json")
            if os.path.exists(extra_metadata_path):
                with open(extra_metadata_path, "r") as f:
                    extra_metadata = json.load(f)
                    instance.setModelPath(extra_metadata.get("modelPath"))
            return instance
        reader.load = custom_load
        return reader

# --- PyTorch Classes ---
class ComplaintDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self): return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx] if self.texts[idx] else "" # Handle None
        encoding = self.tokenizer.encode_plus(
            text, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', truncation=True, return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

class EnhancedDistilBERTClassifier(torch.nn.Module):
    def __init__(self, bert_model, dropout_rate=0.3):
        super().__init__()
        self.bert = bert_model
        self.dropout1 = torch.nn.Dropout(dropout_rate)
        hidden_size = self.bert.config.dim
        self.dense1 = torch.nn.Linear(hidden_size, 256)
        self.batch_norm1 = torch.nn.BatchNorm1d(256)
        self.relu1 = torch.nn.ReLU()
        self.dropout2 = torch.nn.Dropout(dropout_rate)
        self.dense2 = torch.nn.Linear(256, 64)
        self.batch_norm2 = torch.nn.BatchNorm1d(64)
        self.relu2 = torch.nn.ReLU()
        self.dropout3 = torch.nn.Dropout(dropout_rate)
        self.classifier = torch.nn.Linear(64, 2) # Binary classification

    def forward(self, input_ids=None, attention_mask=None, embeddings=None):
        if embeddings is None:
            if input_ids is None: raise ValueError("Must provide input_ids or embeddings")
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            sequence_output = outputs.last_hidden_state[:, 0, :] # CLS token
        else:
            sequence_output = embeddings

        x = self.dropout1(sequence_output)
        x = self.dense1(x)
        if x.shape[0] > 1 or not self.training: # Apply BN if batch>1 OR if evaluating (use running stats)
             try:
                 x = self.batch_norm1(x)
             except ValueError as e:
                 print(f"WARN: BatchNorm1d error (input shape {x.shape}): {e}. Skipping BN.")
        x = self.relu1(x)
        x = self.dropout2(x)
        x = self.dense2(x)
        if x.shape[0] > 1 or not self.training:
             try:
                 x = self.batch_norm2(x)
             except ValueError as e:
                 print(f"WARN: BatchNorm1d error (input shape {x.shape}): {e}. Skipping BN.")
        x = self.relu2(x)
        x = self.dropout3(x)
        logits = self.classifier(x)
        return logits

# --- Kafka Metrics Listener ---
class MetricsListener(StreamingQueryListener):
    def __init__(self, kafka_brokers, topic):
        self.topic = topic
        self.producer = None # Initialize as None
        try:
            self.producer = KafkaProducer(
                bootstrap_servers=kafka_brokers.split(','),
                value_serializer=lambda v: json.dumps(v).encode('utf-8'),
                retries=3, linger_ms=5, request_timeout_ms=10000 # Shorter timeout
            )
            print("MetricsListener Kafka Producer initialized.")
        except Exception as e: # Catch broader exceptions during init
            print(f"ERROR: MetricsListener failed to initialize Kafka Producer: {e}")

    def send_metric(self, metrics):
        if self.producer:
            try:
                future = self.producer.send(self.topic, value=metrics)
                # Optional: Add timeout for blocking flush if needed, but generally avoid blocking here
                # future.get(timeout=1) # Example: wait 1 sec max
            except Exception as e:
                print(f"ERROR sending metric to Kafka topic '{self.topic}': {e}")
        # else: print("WARN: Metrics Kafka producer not available. Metric not sent.") # Can be noisy

    def onQueryStarted(self, event): self.send_metric({
        "queryName": event.name, "id": str(event.id), "runId": str(event.runId),
        "timestamp": event.timestamp, "event": "started" })
    def onQueryProgress(self, event): self.send_metric({
        "queryName": event.progress.name, "id": str(event.progress.id), "runId": str(event.progress.runId),
        "timestamp": event.progress.timestamp, "event": "progress", "numInputRows": event.progress.numInputRows,
        "inputRowsPerSecond": event.progress.inputRowsPerSecond, "processedRowsPerSecond": event.progress.processedRowsPerSecond,
        "batchId": event.progress.batchId })
    def onQueryTerminated(self, event): self.send_metric({
        "queryName": getattr(event, 'name', None), "id": str(event.id), "runId": str(event.runId),
        "timestamp": time.time() * 1000, "event": "terminated", "exception": str(event.exception) if event.exception else None })

    def close_producer(self): # Add explicit close method
        if self.producer:
            try:
                print("Flushing and closing MetricsListener Kafka Producer...")
                self.producer.flush(timeout=5) # 5 sec timeout
                self.producer.close(timeout=5)
                self.producer = None
                print("MetricsListener Kafka Producer closed.")
            except Exception as e:
                print(f"ERROR closing MetricsListener Kafka Producer: {e}")

    def __del__(self): # Keep __del__ as fallback
         self.close_producer()

# --- Task 1: Load Data to Kafka ---
# Use the version that filters narrative BEFORE sampling
def load_data_to_kafka():
    print(f"\n--- Task 1: Loading Data to Kafka ---")
    print(f"Reading CSV: {CSV_FILE_PATH}")
    if not os.path.exists(CSV_FILE_PATH):
        print(f"ERROR: CSV file not found at {CSV_FILE_PATH}. Download it first.")
        return None

    raw_df_unfiltered = spark.read.format("csv") \
                   .option("header", "true") \
                   .schema(get_full_schema()) \
                   .option("escape", "\"") \
                   .option("multiLine", "true") \
                   .load(CSV_FILE_PATH)

    initial_count = raw_df_unfiltered.count()
    print(f"Total records loaded initially from CSV: {initial_count}")
    if initial_count == 0: return None

    print("Filtering for non-empty 'Consumer complaint narrative'...")
    raw_df = raw_df_unfiltered.filter(
        (F.col("Consumer complaint narrative").isNotNull()) &
        (F.length(F.trim(F.col("Consumer complaint narrative"))) > 0)
    )
    filtered_count = raw_df.cache().count()
    print(f"Records after filtering for non-empty narrative: {filtered_count}")
    raw_df_unfiltered.unpersist()

    if filtered_count == 0:
        print("ERROR: No records found with non-empty narratives.")
        raw_df.unpersist()
        return None

    MAX_KAFKA_LOAD_RECORDS = 50000
    df_to_write = raw_df
    if filtered_count > MAX_KAFKA_LOAD_RECORDS:
        sample_fraction = MAX_KAFKA_LOAD_RECORDS / filtered_count
        print(f"Sampling {sample_fraction:.2%} of filtered records ({MAX_KAFKA_LOAD_RECORDS}) for Kafka loading...")
        df_to_write = raw_df.sample(False, sample_fraction, seed=42)
        print(f"Sample size for Kafka: {df_to_write.cache().count()} records") # Cache sample

    try:
        print(f"Writing data to Kafka topic: {KAFKA_TOPIC_RAW} at {KAFKA_BROKERS}")
        kafka_df = df_to_write.selectExpr("`Complaint ID` AS key", "to_json(struct(*)) AS value")
        kafka_df.write \
            .format("kafka") \
            .option("kafka.bootstrap.servers", KAFKA_BROKERS) \
            .option("topic", KAFKA_TOPIC_RAW) \
            .option("kafka.request.timeout.ms", "120000") \
            .option("kafka.delivery.timeout.ms", "180000") \
            .save()
        print(f"Data successfully written to Kafka topic: {KAFKA_TOPIC_RAW}")
    except Exception as e:
        print(f"ERROR writing to Kafka: {e}")
        print("Consider checking Kafka broker status and topic existence.")
        return None # Indicate failure if write fails
    finally:
         if df_to_write.is_cached: df_to_write.unpersist() # Unpersist sample if created
         if raw_df.is_cached: raw_df.unpersist() # Unpersist original filtered

    return df_to_write # Return the dataframe that was written

# --- Task 2: Preprocess, Filter, Visualize ---
# Use the version with inline AutoViz and no date filter
def visualize_with_autoviz(df, max_rows=5000, save_dir=VISUALIZATION_DIR):
    if not AUTOVIZ_AVAILABLE:
        print("AutoViz not available. Skipping visualization.")
        return
    try:
        print("\nStarting AutoViz data visualization...")
        os.makedirs(save_dir, exist_ok=True)
        df = df.cache() # Cache before counting
        df_count = df.count()
        if df_count == 0:
            print("No data to visualize.")
            df.unpersist()
            return

        if df_count > max_rows:
            print(f"Sampling {max_rows} rows for AutoViz.")
            fraction = max_rows / df_count
            sample_df = df.sample(False, fraction, seed=42)
        else:
            sample_df = df

        pandas_df = sample_df.limit(max_rows).toPandas() # Limit again just in case + convert
        df.unpersist() # Unpersist original after sampling/conversion

        if not pandas_df.empty:
            AV = AutoViz_Class()
            print("Setting backend for inline display...")
            try:
                 # Attempt to run the magic command
                 from IPython import get_ipython
                 ipython = get_ipython()
                 if ipython: ipython.run_line_magic('matplotlib', 'inline')
            except Exception as magic_e:
                 print(f"WARN: Could not set matplotlib inline: {magic_e}")

            print("Generating AutoViz charts (verbose=1)...")
            viz_df = AV.AutoViz(
                "", dfte=pandas_df, depVar="", header=0, verbose=1,
                lowess=False, chart_format="html",
                max_rows_analyzed=max_rows,
                save_plot_dir=save_dir
            )
            print(f"\nAutoViz visualizations saved to: {save_dir}")
        else:
             print("Sampled DataFrame is empty, skipping AutoViz.")

    except Exception as e:
        print(f"Error during AutoViz: {e}")
        import traceback
        traceback.print_exc()
    finally:
        if 'df' in locals() and df.is_cached: df.unpersist() # Ensure unpersist

def preprocess_filter_and_visualize():
    print(f"\n--- Task 2: Preprocessing, Filtering & Visualization ---")
    print(f"Reading data from Kafka topic: {KAFKA_TOPIC_RAW}")
    full_schema = get_full_schema()
    print("Waiting 5s before reading from Kafka...")
    time.sleep(5)
    try:
        kafka_raw_df = spark.read \
            .format("kafka") \
            .option("kafka.bootstrap.servers", KAFKA_BROKERS) \
            .option("subscribe", KAFKA_TOPIC_RAW) \
            .option("startingOffsets", "earliest") \
            .option("failOnDataLoss", "false") \
            .load()
        kafka_read_count = kafka_raw_df.count()
        print(f"Read {kafka_read_count} raw messages from Kafka.")
        if kafka_read_count == 0: return None

        parsed_df = kafka_raw_df.select(
            F.from_json(F.col("value").cast("string"), full_schema).alias("data")
        ).select("data.*").na.drop(subset=["Complaint ID"])
        dedup_df = parsed_df.dropDuplicates(["Complaint ID"]).cache() # Cache after deduplication
        print(f"Records after JSON parsing and deduplication: {dedup_df.count()}")

    except Exception as e:
        print(f"ERROR reading or parsing from Kafka topic {KAFKA_TOPIC_RAW}: {e}")
        import traceback; traceback.print_exc()
        return None

    # Apply filters (narrative already filtered in Task 1, keep response filter)
    # No date filter needed here now
    filtered_df = dedup_df.filter(F.col("Company response to consumer") != "In progress")
    # We don't strictly need the narrative filter again, but it ensures consistency
    filtered_df = filtered_df.filter(
         (F.col("Consumer complaint narrative").isNotNull()) &
         (F.length(F.trim(F.col("Consumer complaint narrative"))) > 0)
        ).cache()

    filtered_count = filtered_df.count()
    print(f"Records after filtering (response not 'In progress'): {filtered_count}")
    dedup_df.unpersist() # Unpersist previous step

    if filtered_count == 0:
        print("WARNING: No records left after filtering.")
        filtered_df.unpersist()
        return None

    visualize_with_autoviz(filtered_df) # Will unpersist df inside
    return filtered_df # Return the cached df

# --- Task 3: Split, Label, Prepare Data ---
# Keep this function as is
def split_label_and_prepare_data(filtered_df, seed_value=42):
    print("\n--- Task 3: Data Splitting & Labeling ---")
    if filtered_df is None: return None, None
    df_for_split = filtered_df # Use the potentially cached DF

    training_base_df, test_base_df = df_for_split.randomSplit([0.8, 0.2], seed=seed_value)
    training_labeled_df = training_base_df.withColumn(
        "is_target_complaint",
        F.when(
            (F.col("Consumer disputed?") == 'No') &
            (F.col("Timely response?") == 'Yes') &
            (F.col("Company response to consumer").isin(
                'Closed with explanation', 'Closed',
                'Closed with monetary relief', 'Closed with non-monetary relief'
            )), 1
        ).otherwise(0)
    ).cache() # Cache labeled training data

    print("Target distribution in training data:")
    training_labeled_df.groupBy("is_target_complaint").count().show()

    try:
        print(f"Writing training data to Kafka topic: {KAFKA_TOPIC_TRAINING}")
        training_kafka_df = training_labeled_df.selectExpr("`Complaint ID` AS key", "to_json(struct(*)) AS value")
        training_kafka_df.write \
            .format("kafka") \
            .option("kafka.bootstrap.servers", KAFKA_BROKERS) \
            .option("topic", KAFKA_TOPIC_TRAINING) \
            .save()
        print(f"Training data sent to Kafka topic: {KAFKA_TOPIC_TRAINING}")
    except Exception as e: print(f"ERROR writing training data to Kafka: {e}")

    try:
        print(f"Saving test data to Parquet: {TEST_DATA_PERSISTENCE_PATH}")
        # Select only essential columns needed for inference stream
        test_cols = ["Date received", "Complaint ID", "Company", "State", "ZIP code", "Submitted via", "Consumer complaint narrative"]
        test_base_df.select(*test_cols).write.format("parquet").mode("overwrite").save(TEST_DATA_PERSISTENCE_PATH)
        print(f"Test data saved to: {TEST_DATA_PERSISTENCE_PATH}")
    except Exception as e: print(f"ERROR saving test data to Parquet: {e}")

    # Unpersist the input DF if it was cached
    if df_for_split.is_cached: df_for_split.unpersist()
    # Keep training_labeled_df cached for Task 5 fitting
    return training_labeled_df, test_base_df


# --- Task 4: Distributed BERT Training ---
# Use the version with the outer try/except and TorchDistributor fallback
def train_bert_model_distributed():
    print("\n--- Task 4: Distributed BERT Model Training ---")
    def train_function():
        # Imports needed on worker
        import pandas as pd
        import mlflow
        from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, DistributedSampler
        from transformers import DistilBertTokenizer, DistilBertModel, get_linear_schedule_with_warmup
        from torch.optim import AdamW
        import torch.distributed as dist
        import torch.nn as nn
        from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
        import os, torch
        from collections import OrderedDict
        import numpy as np # Needed for metrics

        use_gpu = torch.cuda.is_available()
        is_distributed = dist.is_available() and dist.is_initialized()
        rank = dist.get_rank() if is_distributed else 0
        world_size = dist.get_world_size() if is_distributed else 1
        local_rank = rank % torch.cuda.device_count() if use_gpu else 0
        device = torch.device(f"cuda:{local_rank}" if use_gpu else "cpu")
        if use_gpu: torch.cuda.set_device(device)
        is_rank_0 = (rank == 0)
        if is_rank_0: print(f"Worker {rank}/{world_size}: Starting train_function. Device: {device}")

        if is_rank_0: print(f"Worker {rank}: Reading data from Parquet...")
        try:
            train_pd = pd.read_parquet(TRAIN_PARQUET_PATH)
            val_pd = pd.read_parquet(VAL_PARQUET_PATH)
            if is_rank_0: print(f"Worker {rank}: Loaded {len(train_pd)} train, {len(val_pd)} val records.")
        except Exception as e:
            print(f"Worker {rank} ERROR reading Parquet: {e}")
            if is_distributed: dist.barrier()
            return None

        train_texts = train_pd["Consumer complaint narrative"].fillna("").astype(str).tolist()
        train_labels = train_pd["is_target_complaint"].tolist()
        val_texts = val_pd["Consumer complaint narrative"].fillna("").astype(str).tolist()
        val_labels = val_pd["is_target_complaint"].tolist()

        if is_rank_0: print(f"Worker {rank}: Initializing model and tokenizer...")
        try:
            if is_distributed and not is_rank_0: dist.barrier()
            tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
            base_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
            if is_distributed and is_rank_0: dist.barrier()
            classifier = EnhancedDistilBERTClassifier(base_model, dropout_rate=0.3)
            classifier.to(device)
        except Exception as e:
            print(f"Worker {rank}: ERROR initializing model/tokenizer: {e}")
            if is_distributed: dist.barrier()
            return None

        train_dataset = ComplaintDataset(train_texts, train_labels, tokenizer, BERT_MAX_LENGTH)
        val_dataset = ComplaintDataset(val_texts, val_labels, tokenizer, BERT_MAX_LENGTH)
        train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) if is_distributed else RandomSampler(train_dataset)
        val_sampler = SequentialSampler(val_dataset)
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=BERT_BATCH_SIZE, num_workers=2, pin_memory=True if use_gpu else False)
        val_dataloader = DataLoader(val_dataset, sampler=val_sampler, batch_size=BERT_BATCH_SIZE * 2, num_workers=2, pin_memory=True if use_gpu else False)

        if is_distributed:
            classifier = torch.nn.parallel.DistributedDataParallel(classifier, device_ids=[local_rank] if use_gpu else None, find_unused_parameters=False)

        optimizer_params = classifier.parameters()
        optimizer = AdamW(optimizer_params, lr=3e-5)
        num_training_steps = len(train_dataloader) * NUM_EPOCHS
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * num_training_steps), num_training_steps=num_training_steps)

        class_weights_tensor = torch.tensor([1.0, 1.0], dtype=torch.float)
        if is_rank_0:
            train_labels_tensor = torch.tensor(train_labels)
            total = len(train_labels_tensor); pos_count = torch.sum(train_labels_tensor == 1).item(); neg_count = total - pos_count
            if pos_count > 0 and neg_count > 0:
                weight_for_0 = total / (2.0 * neg_count); weight_for_1 = total / (2.0 * pos_count)
                class_weights_tensor = torch.tensor([weight_for_0, weight_for_1], dtype=torch.float)
            print(f"Rank 0 calculated class weights: {class_weights_tensor}")

        if is_distributed:
            class_weights_tensor = class_weights_tensor.to(device); dist.broadcast(class_weights_tensor, src=0)
        class_weights = class_weights_tensor.to(device)
        if is_rank_0 or not is_distributed: print(f"Effective class weights on device {device}: {class_weights}")
        criterion = nn.CrossEntropyLoss(weight=class_weights)

        best_val_f1 = 0.0; mlflow_run_id = None
        with mlflow.start_run(run_name="bert_classifier_dist", nested=True) if is_rank_0 else contextlib.nullcontext() as run:
            if is_rank_0 and run:
                mlflow_run_id = run.info.run_id
                mlflow.log_params({ # Log multiple params
                    "learning_rate": 3e-5, "batch_size_per_worker": BERT_BATCH_SIZE, "world_size": world_size,
                    "total_batch_size": BERT_BATCH_SIZE * world_size, "epochs": NUM_EPOCHS, "bert_model": "distilbert-base-uncased",
                    "data_loading": "Parquet_Worker_Read", "num_train_samples_in_parquet": len(train_texts),
                    "num_val_samples_in_parquet": len(val_texts), "train_sample_limit": TRAIN_SAMPLE_LIMIT,
                    "val_sample_limit": VAL_SAMPLE_LIMIT
                })
            for epoch in range(NUM_EPOCHS):
                if is_rank_0: print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
                if isinstance(train_sampler, DistributedSampler): train_sampler.set_epoch(epoch)
                classifier.train(); total_train_loss = 0.0; train_steps = 0
                for step, batch in enumerate(train_dataloader):
                    input_ids = batch['input_ids'].to(device); attention_mask = batch['attention_mask'].to(device); labels = batch['labels'].to(device)
                    optimizer.zero_grad(); logits = classifier(input_ids=input_ids, attention_mask=attention_mask); loss = criterion(logits, labels)
                    if torch.isnan(loss): print(f"Worker {rank}: NaN loss at step {step}! Skipping."); continue
                    loss.backward(); optimizer.step(); scheduler.step()
                    total_train_loss += loss.item(); train_steps += 1
                    if is_rank_0 and step % 50 == 0 and step > 0: print(f"  Epoch {epoch+1}, Step {step}/{len(train_dataloader)}, Batch Loss: {loss.item():.4f}")

                avg_train_loss = total_train_loss / train_steps if train_steps > 0 else 0.0
                if is_distributed:
                    loss_tensor = torch.tensor(avg_train_loss, device=device); dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG); avg_train_loss = loss_tensor.item()
                if is_rank_0:
                    print(f"Epoch {epoch+1} Avg Training Loss: {avg_train_loss:.4f}")
                    if mlflow_run_id: mlflow.log_metric("train_loss", avg_train_loss, step=epoch)

                classifier.eval(); all_val_preds = []; all_val_labels = []; total_val_loss = 0.0; val_steps = 0
                with torch.no_grad():
                    for batch in val_dataloader:
                        input_ids = batch['input_ids'].to(device); attention_mask = batch['attention_mask'].to(device); labels = batch['labels'].to(device)
                        logits = classifier(input_ids=input_ids, attention_mask=attention_mask); loss = criterion(logits, labels)
                        if not torch.isnan(loss): total_val_loss += loss.item(); val_steps += 1
                        else: print(f"Worker {rank}: NaN validation loss! Skipping.")
                        preds = torch.argmax(logits, dim=1)
                        all_val_preds.extend(preds.cpu().numpy()); all_val_labels.extend(labels.cpu().numpy())

                avg_val_loss = 0.0
                if is_distributed:
                    val_loss_sum_tensor = torch.tensor(total_val_loss, device=device); val_steps_tensor = torch.tensor(val_steps, device=device)
                    dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM); dist.all_reduce(val_steps_tensor, op=dist.ReduceOp.SUM)
                    total_val_loss_agg = val_loss_sum_tensor.item(); total_val_steps_agg = val_steps_tensor.item()
                    avg_val_loss = total_val_loss_agg / total_val_steps_agg if total_val_steps_agg > 0 else 0.0
                else: avg_val_loss = total_val_loss / val_steps if val_steps > 0 else 0.0

                if is_rank_0:
                    np_val_labels = np.array(all_val_labels); np_val_preds = np.array(all_val_preds)
                    val_accuracy = accuracy_score(np_val_labels, np_val_preds); val_precision = precision_score(np_val_labels, np_val_preds, average='binary', zero_division=0)
                    val_recall = recall_score(np_val_labels, np_val_preds, average='binary', zero_division=0); val_f1 = f1_score(np_val_labels, np_val_preds, average='binary', zero_division=0)
                    print(f"Epoch {epoch+1} Avg Validation Loss: {avg_val_loss:.4f}")
                    print(f"Validation Metrics (Rank 0): Acc: {val_accuracy:.4f}, P: {val_precision:.4f}, R: {val_recall:.4f}, F1: {val_f1:.4f}")
                    if mlflow_run_id: mlflow.log_metrics({"val_loss": avg_val_loss, "val_accuracy": val_accuracy, "val_precision": val_precision, "val_recall": val_recall, "val_f1": val_f1}, step=epoch)

                    if val_f1 > best_val_f1:
                        best_val_f1 = val_f1
                        print(f"  >>> New best F1: {val_f1:.4f}. Saving model to {EMBEDDING_MODEL_SAVE_PATH}...")
                        os.makedirs(EMBEDDING_MODEL_SAVE_PATH, exist_ok=True)
                        model_to_save = classifier.module if is_distributed else classifier
                        try:
                             torch.save(model_to_save.state_dict(), f"{EMBEDDING_MODEL_SAVE_PATH}/classifier.pt")
                             model_to_save.bert.save_pretrained(EMBEDDING_MODEL_SAVE_PATH)
                             tokenizer.save_pretrained(EMBEDDING_MODEL_SAVE_PATH)
                             print(f"  >>> Model components saved successfully.")
                             if mlflow_run_id: mlflow.log_metric("best_val_f1", best_val_f1, step=epoch)
                        except Exception as save_e: print(f"  >>> ERROR saving model: {save_e}")
                if is_distributed: dist.barrier()
            if is_rank_0:
                 print(f"\nTraining complete. Best Validation F1 across epochs: {best_val_f1:.4f}")
                 if mlflow_run_id: mlflow.log_metric("final_best_f1", best_val_f1)
        return EMBEDDING_MODEL_SAVE_PATH

    # --- Driver-side setup ---
    try:
        print("Reading training data from Kafka for preprocessing...")
        full_schema_with_target = get_full_schema().add(StructField("is_target_complaint", IntegerType(), True))
        kafka_df = spark.read.format("kafka").option("kafka.bootstrap.servers", KAFKA_BROKERS).option("subscribe", KAFKA_TOPIC_TRAINING).option("startingOffsets", "earliest").option("failOnDataLoss", "false").load()
        kafka_read_count = kafka_df.count(); print(f"Read {kafka_read_count} raw training messages from Kafka.")
        if kafka_read_count == 0: return None, None
        training_df = kafka_df.select(F.from_json(F.col("value").cast("string"), full_schema_with_target).alias("data")).select("data.*").na.drop(subset=["Complaint ID"]).dropDuplicates(["Complaint ID"])
        print(f"Parsed {training_df.count()} unique training records.")

        pos_df = training_df.filter(F.col("is_target_complaint") == 1); neg_df = training_df.filter(F.col("is_target_complaint") == 0)
        pos_count = pos_df.count(); neg_count = neg_df.count(); print(f"Training data counts: Positive={pos_count}, Negative={neg_count}")
        ratio = pos_count / (pos_count + neg_count) if (pos_count + neg_count) > 0 else 0.5; print(f"Positive class ratio: {ratio:.2f}")
        if ratio > 0 and ratio < 0.1 and neg_count > pos_count:
            print("Balancing dataset by undersampling negative class...")
            target_neg_fraction = min(1.0, (pos_count * 7.0) / neg_count) if neg_count > 0 else 1.0; print(f"Target negative fraction: {target_neg_fraction:.2f}")
            neg_sample_df = neg_df.sample(False, target_neg_fraction, seed=42); balanced_df = pos_df.unionByName(neg_sample_df)
            print(f"Balanced dataset size: {balanced_df.count()} (+ve={pos_df.count()}, -ve={neg_sample_df.count()})")
        elif ratio > 0.9 and pos_count > neg_count:
             print("Balancing dataset by undersampling positive class..."); target_pos_fraction = min(1.0, (neg_count * 7.0) / pos_count) if pos_count > 0 else 1.0; print(f"Target positive fraction: {target_pos_fraction:.2f}")
             pos_sample_df = pos_df.sample(False, target_pos_fraction, seed=42); balanced_df = neg_df.unionByName(pos_sample_df)
             print(f"Balanced dataset size: {balanced_df.count()} (+ve={pos_sample_df.count()}, -ve={neg_df.count()})")
        else: print("Dataset already reasonably balanced or cannot balance further."); balanced_df = training_df
        balanced_df = balanced_df.cache(); balanced_count = balanced_df.count()
        if balanced_count == 0: print("ERROR: Balanced DataFrame empty."); return None, None

        train_spark_df, val_spark_df = balanced_df.randomSplit([0.9, 0.1], seed=42); print(f"Split sizes: Train={train_spark_df.count()}, Val={val_spark_df.count()}")
        final_train_df = train_spark_df.limit(TRAIN_SAMPLE_LIMIT).cache(); final_val_df = val_spark_df.limit(VAL_SAMPLE_LIMIT).cache()
        final_train_count = final_train_df.count(); final_val_count = final_val_df.count(); print(f"Final sampled sizes for Parquet: Train={final_train_count}, Val={final_val_count}")
        if final_train_count == 0 or final_val_count == 0: print("ERROR: Not enough data after sampling."); return None, None

        print(f"Saving sampled training data to Parquet: {TRAIN_PARQUET_PATH}"); final_train_df.select("Consumer complaint narrative", "is_target_complaint").write.mode("overwrite").parquet(TRAIN_PARQUET_PATH)
        print(f"Saving sampled validation data to Parquet: {VAL_PARQUET_PATH}"); final_val_df.select("Consumer complaint narrative", "is_target_complaint").write.mode("overwrite").parquet(VAL_PARQUET_PATH)
        balanced_df.unpersist(); final_train_df.unpersist(); final_val_df.unpersist()

        print("Starting TorchDistributor...")
        if torch.cuda.is_available(): num_processes = torch.cuda.device_count(); use_gpu_dist_flag = True; print(f"Torch reports GPU available. num_processes={num_processes}, use_gpu=True.")
        else: num_processes = 1; use_gpu_dist_flag = False; print("Torch reports NO GPU. num_processes=1, use_gpu=False.")
        print("NOTE: Each worker process loads full Parquet subset into memory.")
        distributor = None
        try:
            distributor = TorchDistributor(num_processes=num_processes, local_mode=True, use_gpu=use_gpu_dist_flag, _ssl_conf=None); print("TorchDistributor initialized.")
        except RuntimeError as e:
            if "GPUs were unable to be found on the driver" in str(e):
                print("WARN: TorchDistributor driver GPU check failed. Forcing CPU mode."); num_processes = 1; use_gpu_dist_flag = False
                distributor = TorchDistributor(num_processes=num_processes, local_mode=True, use_gpu=use_gpu_dist_flag, _ssl_conf=None); print("TorchDistributor initialized in CPU fallback.")
            else: print(f"ERROR initializing TorchDistributor: {e}"); raise e
        except Exception as e_init: print(f"UNEXPECTED ERROR initializing TorchDistributor: {e_init}"); raise e_init

        saved_model_path = distributor.run(train_function)
        if saved_model_path is None or not os.path.exists(os.path.join(saved_model_path, "config.json")): raise RuntimeError("Model training execution failed or model not saved.")
        print(f"Training finished. Model components expected in: {saved_model_path}")
        tokenizer = DistilBertTokenizer.from_pretrained(saved_model_path)
        return saved_model_path, tokenizer
    except Exception as e:
        print(f"ERROR during training setup or distribution: {e}")
        import traceback; traceback.print_exc(); return None, None

# --- Task 5: Unified Feature Pipeline ---
def create_unified_pipeline():
    print("\n--- Task 5: Creating Unified Feature Pipeline ---")
    categorical_columns = ["Company", "State", "Submitted via"]
    zip_col = "ZIP code"
    stages = []
    imputed_cols_map = {} # To map original names to imputed names

    # Impute and Index Categorical/Zip columns first using SQLTransformer
    impute_select_exprs = ["*"] # Start with all existing columns
    for col_name in categorical_columns + [zip_col]:
        imputed_col_name = f"{col_name}_imputed"
        impute_select_exprs.append(f"COALESCE(CAST({col_name} AS STRING), 'Unknown') AS {imputed_col_name}")
        imputed_cols_map[col_name] = imputed_col_name

    impute_sql = f"SELECT {', '.join(impute_select_exprs)} FROM __THIS__"
    stages.append(SQLTransformer(statement=impute_sql))

    # BERT embedding transformer
    stages.append(BERTEmbeddingTransformer(
        inputCol="Consumer complaint narrative", outputCol="narrative_features", modelPath=EMBEDDING_MODEL_SAVE_PATH
    ))

    # Index and Encode the imputed columns
    encoded_columns = []
    for col_name in categorical_columns + [zip_col]:
        imputed_col_name = imputed_cols_map[col_name] # Get the imputed column name
        indexer_output = f"{col_name}_indexed"
        encoder_output = f"{col_name}_encoded"
        stages.append(StringIndexer(inputCol=imputed_col_name, outputCol=indexer_output, handleInvalid="keep"))
        stages.append(OneHotEncoder(inputCol=indexer_output, outputCol=encoder_output, dropLast=False))
        encoded_columns.append(encoder_output)

    # Feature assembly (Ensure narrative_features exists and is first)
    feature_columns = ["narrative_features"] + encoded_columns
    stages.append(VectorAssembler(inputCols=feature_columns, outputCol="features", handleInvalid="keep"))

    pipeline = Pipeline(stages=stages)
    print(f"Pipeline created with stages: {[type(s).__name__ for s in stages]}")
    return pipeline


# --- Task 6: Simulation Script ---
def simulate_test_data_to_kafka():
    print("\n--- Task 6: Test Data Simulation ---")
    if not os.path.exists(TEST_DATA_PERSISTENCE_PATH): return 0
    print(f"Loading test data from: {TEST_DATA_PERSISTENCE_PATH}")
    try: test_pd = pd.read_parquet(TEST_DATA_PERSISTENCE_PATH); print(f"Loaded {len(test_pd)} test records.")
    except Exception as e: print(f"ERROR loading test data: {e}"); return 0
    messages = test_pd.to_dict('records'); total_messages = len(messages)
    if total_messages == 0: print("No test messages."); return 0
    producer = None
    try:
        producer = KafkaProducer(bootstrap_servers=KAFKA_BROKERS.split(','), value_serializer=lambda v: json.dumps(v).encode('utf-8'), key_serializer=lambda k: str(k).encode('utf-8'), batch_size=16384, linger_ms=10, retries=3)
        print(f"Kafka Producer connected for simulation.")
    except kafka_errors.NoBrokersAvailable: print(f"ERROR: Simulation Kafka Producer failed connect."); return 0
    delay = (60.0 / MESSAGES_PER_MINUTE) * BATCH_SIZE if MESSAGES_PER_MINUTE > 0 else 0; print(f"Starting simulation: {total_messages} msgs @ ~{MESSAGES_PER_MINUTE}/min (Batch: {BATCH_SIZE}, Delay: {delay:.2f}s)")
    start_time = time.time(); messages_sent = 0
    try:
        for i in range(0, total_messages, BATCH_SIZE):
            batch_start = time.time(); batch_end = min(i + BATCH_SIZE, total_messages); batch = messages[i:batch_end]
            for msg in batch: producer.send(KAFKA_TOPIC_TESTING_STREAM, key=msg.get("Complaint ID", str(messages_sent)), value=msg); messages_sent += 1
            producer.flush(); batch_elapsed = time.time() - batch_start; sleep_time = max(0, delay - batch_elapsed)
            if sleep_time > 0: time.sleep(sleep_time)
            if messages_sent % (BATCH_SIZE * 10) == 0 or messages_sent == total_messages: # Log less often
                elapsed = time.time() - start_time; rate = (messages_sent / elapsed * 60) if elapsed > 0 else 0
                print(f"  Sim Progress: {messages_sent}/{total_messages} ({messages_sent/total_messages*100:.1f}%) @ {rate:.1f} msgs/min")
    except KeyboardInterrupt: print("\nSim interrupted.")
    except Exception as e: print(f"\nERROR during sim: {e}")
    finally:
        if producer: producer.close()
        elapsed = time.time() - start_time; rate = (messages_sent / elapsed * 60) if elapsed > 0 else 0
        print("\nSim Summary:"); print(f"- Sent: {messages_sent}/{total_messages}"); print(f"- Time: {elapsed:.2f}s"); print(f"- Rate: {rate:.1f} msg/min")
    return messages_sent

# --- Task 7: Streaming Inference Job ---
@F.pandas_udf(DoubleType())
def predict_udf(features_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    import pandas as pd; import torch; import numpy as np
    from transformers import DistilBertModel; import os; from collections import OrderedDict
    # Need EnhancedDistilBERTClassifier definition available globally or redefine here
    # Assuming global definition is accessible

    model_path = EMBEDDING_MODEL_SAVE_PATH; classifier_state_path = os.path.join(model_path, "classifier.pt")
    device = torch.device("cpu"); classifier_head = None; embedding_dim = 768
    try:
        if os.path.exists(classifier_state_path) and os.path.exists(model_path):
            bert_config = DistilBertModel.from_pretrained(model_path).config; embedding_dim = bert_config.dim
            temp_bert = DistilBertModel(bert_config); classifier_head = EnhancedDistilBERTClassifier(temp_bert, dropout_rate=0.3)
            state_dict = torch.load(classifier_state_path, map_location=device)
            new_state_dict = OrderedDict(); is_ddp = any(k.startswith('module.') for k in state_dict.keys())
            for k, v in state_dict.items(): name = k[7:] if is_ddp else k; new_state_dict[name] = v
            classifier_head.load_state_dict(new_state_dict); classifier_head.to(device); classifier_head.eval()
            # print(f"Worker loaded classifier head. Embedding dim: {embedding_dim}") # Reduce noise
        else: print(f"Worker WARN: Model/State not found ('{model_path}', '{classifier_state_path}'). Default preds.")
    except Exception as e: print(f"Worker ERROR loading model: {e}. Default preds."); classifier_head = None

    for features_series in features_iterator:
        if features_series.empty: yield pd.Series([], dtype=float); continue
        results = []
        if classifier_head is not None:
            try:
                feature_list = []; expected_len = -1
                valid_indices = [] # Track indices that are valid
                for i, f in enumerate(features_series.tolist()):
                    if f is not None:
                        vec = np.array(f)
                        current_len = len(vec)
                        if expected_len == -1: expected_len = current_len
                        if current_len == expected_len: feature_list.append(vec); valid_indices.append(i)
                        else: print(f"WARN: Skip vector len {current_len} != exp {expected_len}")
                    # else: print("WARN: Skip None vector") # Can be noisy
                if not feature_list: # If no valid vectors in batch
                    yield pd.Series([0.0] * len(features_series), dtype=float); continue

                feature_vectors = np.stack(feature_list); batch_tensor = torch.tensor(feature_vectors, dtype=torch.float32).to(device)
                bert_embeddings = batch_tensor[:, :embedding_dim] # Assumes correct order/dim
                with torch.no_grad():
                    logits = classifier_head(embeddings=bert_embeddings)
                    predictions = torch.argmax(logits, dim=1).cpu().numpy().astype(float)

                # Map predictions back to original series positions
                results_array = np.full(len(features_series), 0.0) # Default 0.0
                for i, pred_idx in enumerate(valid_indices):
                     results_array[pred_idx] = predictions[i]
                results = results_array.tolist()

            except IndexError as slice_e: print(f"Worker ERROR slicing embed (dim {embedding_dim}): {slice_e}"); results = [0.0] * len(features_series)
            except ValueError as stack_e: print(f"Worker ERROR stacking vectors: {stack_e}"); results = [0.0] * len(features_series)
            except Exception as e: print(f"Worker ERROR during pred batch: {e}"); results = [0.0] * len(features_series)
        else: results = np.random.binomial(1, 0.3, len(features_series)).astype(float).tolist() # Fallback
        yield pd.Series(results, dtype=float)

def run_streaming_inference_job(pipeline_model_path):
    print("\n--- Task 7: Streaming Inference Job ---")
    try:
        print(f"Loading fitted pipeline model from: {pipeline_model_path}")
        pipeline_model = PipelineModel.load(pipeline_model_path); print("Pipeline model loaded.")
    except Exception as e: print(f"ERROR loading pipeline model: {e}"); import traceback; traceback.print_exc(); return None, None

    metrics_listener = MetricsListener(KAFKA_BROKERS, KAFKA_TOPIC_METRICS)
    spark.streams.addListener(metrics_listener)
    stream_schema = get_streaming_schema()
    print(f"Setting up streaming source from Kafka topic: {KAFKA_TOPIC_TESTING_STREAM}")
    kafka_stream = spark.readStream.format("kafka").option("kafka.bootstrap.servers", KAFKA_BROKERS).option("subscribe", KAFKA_TOPIC_TESTING_STREAM).option("startingOffsets", "latest").option("failOnDataLoss", "false").load()
    parsed_stream = kafka_stream.select(F.from_json(F.col("value").cast("string"), stream_schema).alias("data")).select("data.*").na.drop(subset=["Complaint ID"])
    processed_stream = pipeline_model.transform(parsed_stream)
    prediction_stream = processed_stream.filter(F.col("features").isNotNull()).withColumn("prediction", predict_udf(F.col("features")))
    final_stream = prediction_stream.select(
        F.col("Complaint ID").alias("complaint_id"), F.col("prediction"), F.col("State").alias("state"),
        F.col("ZIP code").alias("zip_code"), F.col("Submitted via").alias("submitted_via"), F.current_timestamp().alias("inference_time")
    ).withColumn("inference_time_str", F.date_format("inference_time", "yyyy-MM-dd HH:mm:ss"))

    kafka_output = final_stream.selectExpr("complaint_id AS key", "to_json(struct(*)) AS value")
    print(f"Starting streaming query to write predictions to Kafka topic: {KAFKA_TOPIC_PREDICTIONS}")
    query = kafka_output.writeStream.format("kafka").option("kafka.bootstrap.servers", KAFKA_BROKERS).option("topic", KAFKA_TOPIC_PREDICTIONS).option("checkpointLocation", f"{STREAMING_CHECKPOINT_LOCATION}/predictions").outputMode("append").trigger(processingTime="10 seconds").start()
    console_query = final_stream.writeStream.format("console").option("truncate", "false").option("numRows", 5).trigger(processingTime="15 seconds").outputMode("append").start()
    print("Streaming queries started.")
    return query, console_query


# --- Main Execution ---
def main():
    print("\n--- Starting Main Pipeline Execution ---")
    start_pipeline_time = time.time()
    metrics_listener_main = None # To close explicitly later
    pipeline_failed = False
    active_queries = [] # Track active queries

    try:
        print("Checking Kafka connection...")
        try:
             temp_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKERS.split(','), request_timeout_ms=5000)
             temp_producer.close(); print("Kafka connection test successful.")
        except kafka_errors.NoBrokersAvailable:
             print(f"FATAL ERROR: Cannot connect to Kafka at {KAFKA_BROKERS}. Check ZK/Kafka server status."); return

        # Task 1: Load data
        load_data_to_kafka() # Just need data in Kafka

        # Task 2: Preprocess (reads from Kafka)
        filtered_df = preprocess_filter_and_visualize()
        if filtered_df is None: raise ValueError("Preprocessing failed or yielded no data.")

        # Task 3: Split & Label (writes train to Kafka, test to Parquet)
        training_df_cached, _ = split_label_and_prepare_data(filtered_df) # Keep training df cached for fitting pipeline
        if training_df_cached is None: raise ValueError("Data splitting/labeling failed.")

        # Task 4: Train Model
        model_path, tokenizer = train_bert_model_distributed()
        if model_path is None or tokenizer is None: raise RuntimeError("Model training failed.")
        print(f"Model training completed. Model saved in: {model_path}")

        # Task 5: Create and Save Feature Pipeline
        print("Fitting the unified feature pipeline...")
        pipeline = create_unified_pipeline()
        fit_sample_df = training_df_cached.limit(100) # Use cached labeled data for fitting
        pipeline_model = pipeline.fit(fit_sample_df)
        pipeline_model.write().overwrite().save(TRAINING_PIPELINE_SAVE_PATH)
        print(f"Fitted pipeline saved to: {TRAINING_PIPELINE_SAVE_PATH}")
        if training_df_cached.is_cached: training_df_cached.unpersist() # Unpersist after fitting

        # Task 6: Simulate Test Data Stream
        messages_simulated = simulate_test_data_to_kafka()
        if messages_simulated == 0: print("WARNING: No test data simulated.")

        # Task 7: Start Streaming Inference Job
        # Pass the correct path where the fitted pipeline was saved
        query, console_query = run_streaming_inference_job(TRAINING_PIPELINE_SAVE_PATH)
        if query is None: raise RuntimeError("Failed to start streaming inference job.")
        active_queries.extend([query, console_query]) # Track queries

        print("\n--- Pipeline Running ---"); print("Streaming inference started."); print(">>> Press Ctrl+C in Colab cell to stop. <<<")
        # Keep main thread alive while streams run (alternative to awaitTermination)
        while any(q.isActive for q in active_queries): time.sleep(5)
        print("Streaming queries seem to have stopped.")

    except KeyboardInterrupt:
        print("\nKeyboardInterrupt received. Stopping pipeline...")
        pipeline_failed = True
    except Exception as e:
        print(f"\nFATAL ERROR in pipeline execution: {e}")
        import traceback; traceback.print_exc(); pipeline_failed = True
    finally:
        print("\n--- Cleaning up ---")
        stopped_count = 0
        for q in active_queries:
            if q and q.isActive:
                print(f"Stopping query '{q.name}' (id: {q.id})...")
                try: q.stop(); q.awaitTermination(timeout=10); stopped_count += 1; print("Stopped.")
                except Exception as stop_e: print(f"Error stopping query '{q.name}': {stop_e}")
        print(f"Stopped {stopped_count} active streaming queries.")

        # Explicitly close metrics listener producer
        # Find listener instance (assuming only one was added)
        for listener in spark.streams.listListeners():
             if isinstance(listener, MetricsListener):
                 listener.close_producer()
                 break # Assume only one

        # Optional: Stop Spark session
        # print("Stopping Spark session...")
        # spark.stop()

        end_pipeline_time = time.time()
        print(f"\nTotal Pipeline Execution Time: {end_pipeline_time - start_pipeline_time:.2f} seconds")
        status = "finished with errors or was interrupted" if pipeline_failed else "finished successfully (streaming stopped)"
        print(f"Pipeline {status}.")

# --- Run Main ---
if __name__ == "__main__":
    main()