<h1 style="text-align: center; font-size: 50px;"> Register Model </h1>

# Notebook Overview

# Notebook Overview
- Configure the Environment
- Define Constants
- Define Module Dependancies
- Verify Assets and Load Data
- Log and Register Model
- Evaluate Model
- Log Execution Time

## Configure the Environment

In [None]:
import logging
import time
import yaml
import re
import os
import json
import importlib.util
import multiprocessing
from datetime import datetime
from pathlib import Path
from typing import Any, List, Dict, Tuple, Optional

import mlflow
import pandas as pd
import numpy as np
from mlflow.models import evaluate, ModelSignature
from mlflow.types import Schema, ColSpec
from mlflow.metrics import ari_grade_level, exact_match, rouge1, rougeL, make_metric
from mlflow.tracking import MlflowClient
from langchain.prompts import PromptTemplate
from llama_cpp import Llama
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# Logger Configuration
logger: logging.Logger = logging.getLogger("register_markdown_model_logger")
if not logger.handlers:
    logger.setLevel(logging.INFO)
    logger.propagate = False
    formatter = logging.Formatter(fmt="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

# Start Execution
start_time: float = time.time()
logger.info("Notebook execution started.")

## Define Constants

In [None]:
# Path and Model Configuration 
CONFIG_PATH = Path("../configs/configs.yaml")
SECRETS_PATH = Path("../configs/secrets.yaml")
REQUIREMENTS_PATH = Path("../requirements.txt")
LOCAL_MODEL_PATH = Path("/home/jovyan/datafabric/llama3.1-8b-instruct/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf")
EVAL_DATA_PATH = Path("./results.json")

# MLflow Configuration
EXPERIMENT_NAME = "markdown-correction-experiment"
RUN_NAME = f"registration-run-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
MODEL_NAME = "MarkdownCorrector"

# Model and Path 
DEFAULT_MODELS: Dict[str, str] = {
    "local": "/home/jovyan/datafabric/llama2-7b/ggml-model-f16-Q5_K_M.gguf",
}
LOCAL_LLAMA_JUDGE_PATH: str = "/home/jovyan/datafabric/llama2-7b/ggml-model-f16-Q5_K_M.gguf"

In [None]:
%%time
%pip install -r {REQUIREMENTS_PATH} --quiet

## Define Module Dependancies

In [None]:
tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)

# --- Helper Functions ---
def get_markdown_correction_prompt() -> PromptTemplate:
    """Creates and returns the LangChain PromptTemplate for markdown correction."""
    template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a markdown grammar correction assistant. Your job is to correct only grammatical errors in the user's Markdown content.
Strictly follow these rules:
- Do **not** modify any placeholders (e.g., <<PH1>>, <<PH93>>, [[BULLET3]], <<SEP>>). Leave them **exactly as they appear**.
- Do **not** alter Markdown formatting (e.g., headings, links, lists).
- Do **not** add or remove extra content.
- Only correct grammar **within natural language sentences**.
- **Always** maintain title case where present.
Respond only with the corrected Markdown content. Do not explain anything.<|eot_id|><|start_header_id|>user<|end_header_id|>
Original markdown:
{markdown}
Corrected markdown:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
    return PromptTemplate.from_template(template)

def load_config_and_secrets(config_path: Path, secrets_path: Path) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Loads configuration and secrets from specified YAML files."""
    if not secrets_path.exists(): raise FileNotFoundError(f"secrets.yaml not found: {secrets_path}")
    if not config_path.exists(): raise FileNotFoundError(f"config.yaml not found: {config_path}")
    with open(config_path) as file: config = yaml.safe_load(file)
    with open(secrets_path) as file: secrets = yaml.safe_load(file)
    return config, secrets

def initialize_llm(
    model_source: str, 
    secrets: Optional[Dict[str, Any]], 
    local_model_path: str
) -> Any:
    """Initializes and returns a LlamaCpp language model instance."""
    from langchain_community.llms import LlamaCpp
    if model_source == "local":
        model = LlamaCpp(
            model_path=local_model_path, n_gpu_layers=-1, n_batch=512, n_ctx=32000,
            max_tokens=1024, f16_kv=True, use_mmap=False, temperature=0.0,
            num_threads=multiprocessing.cpu_count(), verbose=False
        )
        return model
    raise ValueError(f"Unsupported model source: {model_source}")

# --- Custom MLflow Model Definition ---
class MarkdownCorrectionService(mlflow.pyfunc.PythonModel):
    """An MLflow PythonModel for correcting grammar in Markdown content."""
    def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None:
        """Loads the model artifacts and initializes the LLM chain."""
        config_path = Path(context.artifacts["config"])
        secrets_path = Path(context.artifacts["secrets"])
        model_path = context.artifacts["llm"]
        config, secrets = load_config_and_secrets(config_path, secrets_path)
        self.prompt = get_markdown_correction_prompt()
        self.llm = initialize_llm(config["model_source"], secrets, model_path)
        self.llm_chain = self.prompt | self.llm
        logger.info("MarkdownCorrectionService context loaded successfully.")

    def predict(self, context: Any, model_input: pd.DataFrame) -> pd.Series:
        """Applies the correction pipeline to the input DataFrame."""
        if "markdown" not in model_input.columns:
            raise KeyError("Input DataFrame is missing the required 'markdown' column.")
        corrected = [self.llm_chain.invoke({"markdown": row["markdown"]}) for _, row in model_input.iterrows()]
        return pd.Series(corrected, name="corrected")

# --- Custom Metric Definitions ---
class LocalJudgeLlamaClient:
    _client = None
    @classmethod
    def get_client(cls, model_path: Optional[str] = None) -> Llama:
        if cls._client is None:
            if model_path is None: raise ValueError("Must provide model_path to initialize local LLaMA judge.")
            cls._client = Llama(
                model_path=model_path, n_gpu_layers=-1, n_batch=512, n_ctx=32000, max_tokens=1024,
                f16_kv=True, use_mmap=False, temperature=0.0, seed=42, verbose=False,
                num_threads=multiprocessing.cpu_count()
            )
        return cls._client

LocalJudgeLlamaClient.get_client(model_path=LOCAL_LLAMA_JUDGE_PATH)

def simple_grammar_check(text: str) -> int:
    """Counts simple grammatical issues like repeated words and double spaces."""
    issues = 0
    text = str(text).strip()
    sentences = re.split(r'[.!?]+', text)
    for sentence in sentences:
        sentence = sentence.strip()
        if not sentence: continue
        if '  ' in sentence: issues += 1
        words = sentence.lower().split()
        for i, word in enumerate(words):
            if i > 0 and word == words[i-1]:
                issues += 1
    return issues

def semantic_similarity_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """Calculates the mean cosine similarity between predictions and targets."""
    all_texts = [str(text) if text else "" for text in list(targets) + list(predictions)]
    if len(set(all_texts)) < 2: return 1.0
    tfidf_matrix = tfidf_vectorizer.fit_transform(all_texts)
    n_targets = len(targets)
    target_vectors, pred_vectors = tfidf_matrix[:n_targets], tfidf_matrix[n_targets:]
    similarities = [cosine_similarity(target_vectors[i:i+1], pred_vectors[i:i+1])[0][0] for i in range(n_targets)]
    return np.mean(similarities)

def grammar_error_count_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """Calculates the average number of grammar errors in predictions."""
    return np.mean([simple_grammar_check(str(p)) for p in predictions])

def grammar_error_rate_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """Calculates grammar errors per word."""
    rates = [(simple_grammar_check(str(p)) / max(len(str(p).split()), 1)) for p in predictions]
    return np.mean(rates)

def grammar_improvement_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """Measures the reduction in grammar errors from target to prediction."""
    improvements = [(simple_grammar_check(str(t)) - simple_grammar_check(str(p))) for p, t in zip(predictions, targets)]
    return np.mean(improvements)

def grammar_score_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """Assigns a grammar score from 0-100."""
    scores = [max(100 - (simple_grammar_check(str(p)) * 10), 0) for p in predictions]
    return np.mean(scores)

def readability_improvement_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """Measures the change in a simple readability score."""
    def _readability(text: str) -> float:
        words = text.split()
        sentences = text.split('.')
        if not words or not sentences: return 0
        return max(20 - (len(words) / len(sentences)), 0)
    improvements = [_readability(str(p)) - _readability(str(t)) for p, t in zip(predictions, targets)]
    return np.mean(improvements)

def llm_judge_eval_fn_local(predictions: List[str]) -> float:
    """Uses a local LLM to rate the grammar of predictions on a scale of 1-10."""
    llama = LocalJudgeLlamaClient.get_client()
    scores = []
    for pred in predictions:
        prompt = f"Rate the grammar of the following text on a scale of 1 to 10. Respond with only a single digit.\nText: {pred}\nAnswer:"
        try:
            result = llama(prompt, max_tokens=3, stop=["\n"])
            score_text = result["choices"][0]["text"].strip()
            score = float(re.findall(r"\d+", score_text)[0])
            scores.append(score)
        except Exception as e:
            logger.warning(f"[LLaMA judge error]: {e}. Appending default score of 5.0.")
            scores.append(5.0)
    return np.mean(scores) if scores else 0.0

# --- Create MLflow Metric Objects ---
semantic_similarity_metric = make_metric(eval_fn=semantic_similarity_eval_fn, greater_is_better=True, name="semantic_similarity")
grammar_error_count_metric = make_metric(eval_fn=grammar_error_count_eval_fn, greater_is_better=False, name="grammar_error_count")
grammar_error_rate_metric = make_metric(eval_fn=grammar_error_rate_eval_fn, greater_is_better=False, name="grammar_error_rate")
grammar_improvement_metric = make_metric(eval_fn=grammar_improvement_eval_fn, greater_is_better=True, name="grammar_improvement")
grammar_score_metric = make_metric(eval_fn=grammar_score_eval_fn, greater_is_better=True, name="grammar_score")
readability_improvement_metric = make_metric(eval_fn=readability_improvement_eval_fn, greater_is_better=True, name="readability_improvement")
llm_judge_metric_local = make_metric(eval_fn=llm_judge_eval_fn_local, greater_is_better=True, name="llm_judge_local_score")

## Verify Assets and Load Data

In [None]:
def log_asset_status(asset_path: Path, asset_name: str) -> None:
    """Logs the existence status of a given file or directory."""
    if asset_path.exists():
        logger.info(f"✅ {asset_name} is properly configured at: {asset_path}")
    else:
        logger.warning(f"⚠️ {asset_name} not found at: {asset_path}.")

log_asset_status(CONFIG_PATH, "Config file")
log_asset_status(SECRETS_PATH, "Secrets file")
log_asset_status(LOCAL_MODEL_PATH, "Local LLaMA model")
log_asset_status(EVAL_DATA_PATH, "Evaluation data JSON")
log_asset_status(REQUIREMENTS_PATH, "Requirements file")

with open(EVAL_DATA_PATH, "r") as f:
    results = json.load(f)
original_texts = [item["original"] for item in results]
eval_df = pd.DataFrame(original_texts, columns=["markdown"])
logger.info(f"Loaded {len(eval_df)} records for evaluation.")
print(eval_df.head())

## Log and Register Model

In [None]:
%%time

mlflow.set_experiment(EXPERIMENT_NAME)

with mlflow.start_run(run_name=RUN_NAME) as run:
    run_id = run.info.run_id
    logger.info(f"MLflow Run Started. Run ID: {run_id}")

    artifacts = {
        "config": str(CONFIG_PATH),
        "secrets": str(SECRETS_PATH),
        "llm": str(LOCAL_MODEL_PATH),
    }
    
    signature = ModelSignature(
        inputs=Schema([ColSpec("string", "markdown")]),
        outputs=Schema([ColSpec("string", "corrected")]),
    )
    
    mlflow.pyfunc.log_model(
        artifact_path=MODEL_NAME,
        python_model=MarkdownCorrectionService(),
        artifacts=artifacts,
        signature=signature,
        registered_model_name=MODEL_NAME,
        pip_requirements=str(REQUIREMENTS_PATH),
    )
    
    model_uri = f"runs:/{run_id}/{MODEL_NAME}"
    logger.info(f"Model URI: {model_uri}")

# Validate Registered Model
client = MlflowClient()
try:
    latest_version_info = client.get_latest_versions(MODEL_NAME, stages=["None"])[0]
    latest_version = latest_version_info.version
    logger.info(f"Successfully registered model '{MODEL_NAME}' version {latest_version}.")
    
    model_uri_latest = f"models:/{MODEL_NAME}/{latest_version}"
    loaded_model = mlflow.pyfunc.load_model(model_uri_latest)
    
    sample_input = eval_df.head(1)
    logger.info(f"Performing sample prediction on: \n{sample_input['markdown'].iloc[0][:100]}...")
    
    prediction = loaded_model.predict(sample_input)
    logger.info(f"✅ Sample prediction successful. Output:\n{prediction.iloc[0][:100]}...")
except Exception as e:
    logger.error(f"Failed to validate registered model. Error: {e}")

## Evaluate Model

In [None]:
%%time

logger.info("Starting model evaluation with mlflow.evaluate...")

with mlflow.start_run(run_id=run_id):
    results = mlflow.evaluate(
        model=model_uri,
        data=eval_df,
        targets="markdown",
        feature_names=["markdown"],
        evaluators="default",
        extra_metrics=[
            ari_grade_level(),
            exact_match(),
            rouge1(),
            rougeL(),
            semantic_similarity_metric,
            grammar_error_count_metric,
            grammar_error_rate_metric,
            grammar_improvement_metric,
            grammar_score_metric,
            readability_improvement_metric,
            llm_judge_metric_local
        ],
    )

    logger.info("✅ Evaluation complete.")
    logger.info("Evaluation Metrics:")
    for key, value in results.metrics.items():
        logger.info(f"  - {key}: {value:.4f}")

## Log Execution Time

In [None]:
end_time: float = time.time()
elapsed_time: float = end_time - start_time
elapsed_minutes: int = int(elapsed_time // 60)
elapsed_seconds: float = elapsed_time % 60

logger.info(f"⏱️ Total execution time: {elapsed_minutes}m {elapsed_seconds:.2f}s")
logger.info("✅ Notebook execution completed successfully.")