In [None]:
import json
import pickle
import os
import re
import logging
import boto3
import zstandard as zstd
import numpy as np
from io import BytesIO
from sentence_transformers import SentenceTransformer

# Set up logging to CloudWatch
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Global variables for model and embeddings
sentence_model = None
embeddings_cache = {}
cluster_keywords_map = {}

def download_embeddings_from_s3(bucket, key):
    """Download and decompress embeddings from S3"""
    try:
        s3_client = boto3.client('s3')
        response = s3_client.get_object(Bucket=bucket, Key=key)
        compressed_data = response['Body'].read()
        
        # Decompress with zstandard
        decompressor = zstd.ZstdDecompressor()
        decompressed_data = decompressor.decompress(compressed_data)
        
        # Load numpy array
        buffer = BytesIO(decompressed_data)
        embeddings = np.load(buffer).astype(np.float32)
        
        return embeddings
    except Exception as e:
        logger.error(f"Error downloading embeddings from S3: {str(e)}")
        raise

def model_fn(model_dir):
    """
    Loads the KMeans model, embeddings from S3, and cluster-keyword mapping
    from the model directory.
    """
    global sentence_model, embeddings_cache, cluster_keywords_map
    
    logger.info("Starting model loading process")
    
    # Load the metadata to get S3 path for embeddings
    metadata_path = os.path.join(model_dir, 'metadata.json')
    if not os.path.exists(metadata_path):
        raise ValueError(f"Metadata file not found at {metadata_path}")
    
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
    # Load the trained KMeans model
    model_path = os.path.join(model_dir, 'model.pkl')
    if not os.path.exists(model_path):
        raise ValueError(f"Model file not found at {model_path}")
    
    with open(model_path, 'rb') as f:
        model = pickle.load(f)
    
    logger.info("KMeans Model loaded")
    
    # Load the cluster_keywords.json mapping
    keywords_json_path = os.path.join(model_dir, 'cluster_keywords.json')
    if not os.path.exists(keywords_json_path):
        logger.warning(f"{keywords_json_path} not found. 'proposedkey' will not be available via this mapping.")
        cluster_keywords_map = {}
    else:
        with open(keywords_json_path, 'r') as f:
            # Convert keys to integers as cluster IDs are typically integers
            raw_map = json.load(f)
            cluster_keywords_map = {int(k): v for k, v in raw_map.items()}
        logger.info(f"Cluster keywords mapping loaded from {keywords_json_path}")
    
    # Download embeddings from S3
    bucket = metadata.get("embedding_bucket")
    embeddings_key = metadata.get("embedding_path")
    
    if not bucket or not embeddings_key:
        raise ValueError("Embedding bucket or path not found in metadata")
    
    logger.info(f"Downloading embeddings from s3://{bucket}/{embeddings_key}")
    embeddings = download_embeddings_from_s3(bucket, embeddings_key)
    logger.info("Embeddings downloaded and decompressed")
    
    # Load the sentence transformer model (same as used in training)
    sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
    logger.info("SentenceTransformer model loaded")
    
    # Cache embeddings for faster prediction
    embeddings_cache[metadata["material_type"]] = embeddings
    
    # Return the loaded artifacts
    return model, metadata, embeddings

def input_fn(request_body, request_content_type):
    """
    Parses the incoming request body. Expects 'application/json' content type.
    """
    if request_content_type != 'application/json':
        raise ValueError(f"Unsupported content type: {request_content_type}. Only 'application/json' is supported.")
    
    return json.loads(request_body)

def clean_text(text):
    """
    Minimal text cleaning - preserve industrial terminology as much as possible
    Same as used during training
    """
    text = str(text).lower()
    # Keep alphanumeric, spaces, and common industrial symbols
    text = re.sub(r'[^a-z0-9\s\/\-\.]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def predict_fn(input_data, model_and_metadata_and_embeddings):
    """
    Makes predictions on the input data using the loaded model and embeddings.
    Applies the same preprocessing steps as during training.
    Retrieves the proposed keyword using the cluster ID and the loaded map.
    """
    global sentence_model, embeddings_cache, cluster_keywords_map
    
    model, metadata, embeddings = model_and_metadata_and_embeddings
    material_type = metadata.get("material_type", "unknown")
    
    responses = []
    
    for data in input_data:
        # Extract incoming data fields
        incoming_component = data.get('component', '')
        incoming_mattype = data.get('mattype', '')
        incoming_componentdesc = data.get('componentdesc', '')
        incoming_werks = data.get('werks', '')
        incoming_aods = data.get('aods', '')
        incoming_aoci = data.get('aoci', '')
        incoming_keyword = data.get('keyword', '')
        
        predicted_cluster = None
        output_proposedkey = None
        
        try:
            # Clean the text (same as during training)
            processed_desc = clean_text(incoming_componentdesc)
            
            if not processed_desc:
                raise ValueError("Processed component description is empty or invalid after cleaning.")
            
            # Generate embedding for the input description
            input_embedding = sentence_model.encode([processed_desc])
            
            # Find the closest cluster by comparing with all embeddings
            # Using cosine similarity (same as during training)
            from sklearn.metrics.pairwise import cosine_similarity
            similarities = cosine_similarity(input_embedding, embeddings)
            predicted_cluster = np.argmax(similarities[0])
            
            logger.info(f"Predicted cluster for '{processed_desc}': {predicted_cluster}")
            
            # Get the proposed key from the pre-loaded map
            if cluster_keywords_map:
                output_proposedkey = cluster_keywords_map.get(predicted_cluster)
                
                if output_proposedkey:
                    logger.info(f"Found proposedkey for cluster {predicted_cluster}: {output_proposedkey}")
                else:
                    logger.warning(f"No 'proposedkey' found in map for cluster {predicted_cluster}. " +
                                  f"Check cluster_keywords.json or if this cluster was in training data.")
            else:
                logger.warning("Cluster keywords map not available. 'proposedkey' will be None.")
            
            # Construct the response for the current item
            response = {
                'component': incoming_component,
                'mattype': incoming_mattype,
                'werks': incoming_werks,
                'aoci': incoming_aoci,
                'aods': incoming_aods,
                'componentdesc': incoming_componentdesc,  # Use original incoming for output
                'cluster': int(predicted_cluster),
                'keyword': incoming_keyword,
                'proposedkey': output_proposedkey
            }
            
            responses.append(response)
            
        except Exception as e:
            logger.error(f"Error processing input data: {e}", exc_info=True)
            
            # Append an error response for the failed item
            error_response = {
                'component': incoming_component,
                'mattype': incoming_mattype,
                'werks': incoming_werks,
                'aoci': incoming_aoci,
                'aods': incoming_aods,
                'componentdesc': incoming_componentdesc,
                'keyword': incoming_keyword,
                'cluster': None,
                'proposedkey': None,
                'error': str(e)
            }
            
            responses.append(error_response)
    
    return responses

def output_fn(prediction, content_type='application/json'):
    """
    Serializes the prediction result to the specified content type.
    """
    return json.dumps(prediction), content_type