<h1 style="text-align: center; font-size: 50px;"> Register Model </h1>

# Notebook Overview

- Configure the Environment
- Define Constants and Paths
- Define Necessary Modules
- Define MLflow Model Class
- Log and Register Model
- Prepare Eval Data
- Evaluate Model
- Log Execution Time

## Configure the Environment

In [None]:
import logging
import time

# Configure logger
logger: logging.Logger = logging.getLogger("register_model_logger")
logger.setLevel(logging.INFO)
logger.propagate = False  # Prevent duplicate logs

# Set formatter
formatter: logging.Formatter = logging.Formatter(
    fmt="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

# Configure and attach stream handler
stream_handler: logging.StreamHandler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

start_time = time.time()
logger.info("✅ Notebook execution started.")

In [None]:
%%time

%pip install -r ../requirements.txt --quiet

In [None]:
# Standard & Third-Party Libraries
import os
import sys
import json
import re
from pathlib import Path
from collections import defaultdict
import pandas as pd
import mlflow
from mlflow.models import ModelSignature, evaluate
from mlflow.types.schema import Schema, ColSpec
from typing import List

# Add src directory to system path
sys.path.append(str(Path("..").resolve() / "src"))

# --- Imports for Evaluation ---
# Import standard MLflow metrics
from mlflow.metrics import exact_match, rouge1, rougeL

logger.info("✅ All libraries and metrics imported successfully.")

## Define Constants and Paths

In [None]:
# Configuration Paths
CONFIG_PATH = Path("../configs/configs.yaml")
SECRETS_PATH = Path("../configs/secrets.yaml")
REQUIREMENTS_PATH = Path("../requirements.txt")
LOCAL_MODEL_PATH = Path("/home/jovyan/datafabric/llama3.1-8b-instruct/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf")

# MLflow Configuration
MLFLOW_EXPERIMENT_NAME = "Markdown_Correction_Service"
MLFLOW_MODEL_NAME = "markdown-corrector"

logger.info("Constants and paths defined.")

## Define Necessary Modules

In [None]:
############################################
###--------GITHUB EXTRACTOR START--------###
############################################

import requests
from urllib.parse import urlparse
import sys
import base64 
import re 
from markdown_it import MarkdownIt
from markdown_it.token import Token 
import os
from bs4 import BeautifulSoup
from typing import Optional, Tuple, Dict, Union

class GitHubMarkdownProcessor:
    """
    Processor for extracting and parsing Markdown files from GitHub repositories.

    This class fetches `.md` files from a public or private GitHub repository,
    replaces structural and inline components with labeled placeholders, and 
    saves the resulting structure locally.

    Attributes:
        repo_url (str): GitHub repository URL.
        repo_owner (str): Owner of the GitHub repository.
        repo_name (str): Name of the GitHub repository.
        access_token (Optional[str]): GitHub Personal Access Token (if needed).
        save_dir (str): Directory to save the processed files.
        api_base_url (str): Base GitHub API URL for the repository.
    """

    def __init__(self, repo_url: str, access_token: Optional[str] = None, save_dir: str = './parsed_repo'):
        """
        Initializes the Markdown processor with a GitHub repo URL.

        Args:
            repo_url (str): Full URL to the GitHub repository.
            access_token (Optional[str]): GitHub token for private repo access.
            save_dir (str): Output directory to store processed Markdown files.
        """
        self.repo_url = repo_url
        self.access_token = access_token
        self.save_dir = save_dir

        owner, repo, error = self.parse_url()
        if error:
            raise ValueError(error)
        
        self.repo_owner = owner
        self.repo_name = repo
        self.api_base_url = f"https://api.github.com/repos/{self.repo_owner}/{self.repo_name}"

    def parse_url(self) -> Tuple[Optional[str], Optional[str], Optional[str]] :
        """
        Parses a GitHub URL and extracts the repository owner and name.
    
        Args:
            github_url (str): The full GitHub URL to parse.
    
        Returns:
            Tuple[Optional[str], Optional[str], Optional[str]]:
                - The repository owner (e.g., 'openai')
                - The repository name (e.g., 'whisper')
                - An error message string if parsing fails; otherwise, None
        """
        parsed_url = urlparse(self.repo_url)
        path_parts = parsed_url.path.strip("/").split("/")
    
        # Validate URL format
        if len(path_parts) < 2:
            return None, None, "Invalid GitHub URL format."
        
        # Return owner and repo
        return path_parts[0], path_parts[1], None
    
    def check_repo(self) -> str:
        """
        Determines the visibility (public or private) of a GitHub repository.
    
        Parses the provided GitHub URL, queries the GitHub API to fetch repository details,
        and returns a string indicating its visibility or an error message if access fails.
    
        Args:
            github_url (str): The URL of the GitHub repository to check.
            access_token (Optional[str], optional): GitHub personal access token for authenticated requests. Defaults to None.
    
        Returns:
            str: "private" or "public" if the repository is accessible;
                 otherwise, a descriptive error message.
        """
        # Parse url into components
        owner, name, error = self.parse_url()
        if error:
            return error
        
        # Build GitHub URL
        url = f"https://api.github.com/repos/{owner}/{name}"
    
        # Build authentication header
        headers = {}
        if self.access_token:
            headers["Authorization"] = f"Bearer {self.access_token}"
    
        response = requests.get(url, headers=headers)
    
        # Determine privacy of repo
        if response.status_code == 200:
            repo_data = response.json()
            return "private" if repo_data.get("private") else "public"
        elif response.status_code == 404:
            return "Repository is inaccessible. Please authenticate."
        else:
            return f"Error: {response.status_code}, {response.text}"
        
    def extract_md_files(self) -> Tuple[Optional[Dict], Optional[str]]:
        """
        Traverses a GitHub repository to extract all Markdown (.md) files and organize them in a nested directory structure.
    
        Connects to the GitHub API, retrieves the file tree of the default branch, downloads any Markdown files,
        and reconstructs their paths locally in a dictionary format. Supports optional authentication with a personal access token.
    
        Args:
            github_url (str): The GitHub repository URL (e.g., "https://github.com/user/repo").
            access_token (Optional[str], optional): GitHub personal access token for authenticated requests. Defaults to None.
    
        Returns:
            Tuple[Optional[Dict], Optional[str]]:
                - A nested dictionary representing the directory structure and Markdown file contents.
                - An error message string if any step fails; otherwise, None.
        """
        
        # Parse url into components
        owner, name, error = self.parse_url()
        if error:
            return None, error
        
        # Build GitHub URL
        url = f"https://api.github.com/repos/{owner}/{name}"
    
        # Build authentication header
        headers = {}
        if self.access_token:
            headers["Authorization"] = f"Bearer {self.access_token}"
    
        response = requests.get(url, headers=headers)
        if not response.ok:
            return None, f"Error: {response.status_code}, {response.text}"
    
        # Get default branch
        default_branch = response.json().get("default_branch", "main")
    
        # Build GitHub URL for file tree
        tree_url = f"https://api.github.com/repos/{owner}/{name}/git/trees/{default_branch}?recursive=1"
        tree_response = requests.get(tree_url, headers=headers)
        if not tree_response.ok:
            return None, f"Error: {tree_response.status_code}, {tree_response.text}"
        
        # Dictionary to hold directory structure
        dir_structure = {}
    
        # Iterate through repo tree structure
        for item in tree_response.json().get("tree", []):
            path=item["path"]
        
            # Skip all non md files
            if item["type"] != "blob" or not path.endswith(".md"):
                continue
    
            # Fetch md file content
            content_url = f"https://api.github.com/repos/{owner}/{name}/contents/{path}"
            content_response = requests.get(content_url, headers=headers)
            if not content_response.ok:
                continue
    
            # Decode content response
            file_data = content_response.json()
            try:
                content = base64.b64decode(file_data["content"]).decode("utf-8")
            except Exception as e:
                content = f"Error decoding content: {e}"
    
            # Build directory structure
            parts = path.split("/")
            current = dir_structure 
            for part in parts[:-1]:
                current = current.setdefault(part, {})
            current[parts[-1]] = content 
    
        return dir_structure, None
    
    def run(self) -> Dict[str, str]:
        """
        High-level method to:
        1. Check repository access
        2. Extract markdown files
        3. Return raw markdown content by file path
    
        Returns:
            Dict[str, str]: Mapping from file paths to raw markdown content.
        """
        visibility = self.check_repo()
        print(f"Repository visibility: {visibility}")
    
        if visibility == "Repository is inaccessible. Please authenticate.":
            raise PermissionError("Cannot access repository. Check your access token.")
    
        structure, error = self.extract_md_files()
        if error:
            raise RuntimeError(f"Markdown extraction failed: {error}")
    
        raw_data = {}
    
        def process_structure(structure: Dict[str, Union[str, dict]], path: str = "") -> None:
            """
            Recursively flattens a nested directory structure of markdown files.
        
            Args:
                structure (Dict[str, Union[str, dict]]): Nested dictionary representing directories and markdown file contents.
                path (str, optional): Current path used to build the full file path during traversal. Defaults to "".
        
            Returns:
                None: Updates the outer `raw_data` dictionary in-place with path-to-content mappings.
            """
            for name, content in structure.items():
                current_path = os.path.join(path, name)
                if isinstance(content, dict):
                    process_structure(content, current_path)
                else:
                    raw_data[current_path] = content
    
        process_structure(structure)
        print("Raw markdown extraction complete.")
        return raw_data

############################################
###---------GITHUB EXTRACTOR END---------###
############################################

############################################
###-------------PARSER START-------------###
############################################

import re
from typing import Tuple, Dict
from markdown_it import MarkdownIt

def parse_md_for_grammar_correction(md_content: str) -> Tuple[Dict[str, str], str]:
    """
    Parses Markdown content by replacing non-prose elements with placeholders.
    This version uses a simplified, robust method for HTML blocks while keeping
    all other original logic intact, including newline placeholders.

    Args:
        md_content (str): Raw markdown content to process.

    Returns:
        Tuple[Dict[str, str], str]: 
            - A dictionary mapping placeholder keys to original markdown blocks.
            - The transformed markdown with placeholders in place of structure.
    """

    md = MarkdownIt()
        
    placeholder_map = {}
    counter = 0

    def get_next_placeholder(value: str, prefix="PH") -> str:
        """
        Generates a unique placeholder token for the given value and stores the mapping.
    
        If the input contains any previously assigned placeholders, they are unwrapped and reassigned
        as a new single placeholder. This helps deduplicate and simplify complex inline structures.
    
        Args:
            value (str): The text content to be replaced by a placeholder.
            prefix (str): The placeholder prefix (e.g., "PH", "BULLET").
    
        Returns:
            str: The generated placeholder token (e.g., "<<PH3>>" or "[[BULLET2]]").
        """
        nonlocal counter
        value = re.sub(
            r'<<(PH|BULLET|SEP)\d*>>|\[\[BULLET\d+\]\]',
            lambda m: placeholder_map.get(m.group(0).strip('<>[]'), m.group(0)),
            value
        )
        counter += 1
        key = f"{prefix}{counter}"
        placeholder_map[key] = value
        
        if prefix == "BULLET":
            return f"[[{key}]]"
        else:
            return f"<<{key}>>"
    
    def protect_tables(content: str) -> str:
        """
        Detects Markdown tables and replaces them with placeholders.

        Args:
            content (str): Input markdown content.

        Returns:
            str: Markdown with tables replaced by placeholders.
        """
        lines = content.splitlines()
        protected_lines = []
        in_table = False
        table_buffer = []
        
        for line in lines:
            if '|' in line and line.strip().startswith('|') and line.strip().endswith('|'):
                if not in_table:
                    in_table = True
                    table_buffer = [line]
                else:
                    table_buffer.append(line)
            elif in_table and re.match(r'^\s*\|[\s\-\|:]+\|\s*$', line):
                table_buffer.append(line)
            elif in_table:
                if len(table_buffer) >= 2:
                    table_content = '\n'.join(table_buffer)
                    placeholder = get_next_placeholder(table_content)
                    protected_lines.append(placeholder)
                else:
                    protected_lines.extend(table_buffer)
                
                table_buffer = []
                in_table = False
                protected_lines.append(line)
            else:
                protected_lines.append(line)

        # Handle final table
        if in_table and len(table_buffer) >= 2:
            table_content = '\n'.join(table_buffer)
            placeholder = get_next_placeholder(table_content)
            protected_lines.append(placeholder)
        elif in_table:
            protected_lines.extend(table_buffer)
        
        return '\n'.join(protected_lines)

    def is_low_prose_line(line: str, threshold: float = 0.1) -> bool:
        """
        Determines if a line contains mostly structure and little natural language prose.

        Args:
            line (str): A single markdown line.
            threshold (float): Minimum proportion of prose content.

        Returns:
            bool: True if it's mostly non-prose.
        """
        # Treat empty lines as low prose
        if not line.strip():
            return True 

        # Create a "clean" version by removing all known non-prose elements
        no_placeholders = re.sub(r'<<(PH|BULLET|SEP)\d*>>|\[\[BULLET\d+\]\]', '', line)
        
        # Remove common markdown characters
        no_markdown = re.sub(r'[*_`[\]()#|-]', '', no_placeholders)
        
        # What's left is considered "prose"
        prose_content = no_markdown.strip()
        
        # If the ratio of prose to the total line length is below the threshold, protect it
        if len(line) > 0 and (len(prose_content) / len(line)) < threshold:
            return True
            
        return False
    
    md_content = protect_tables(md_content)
    
    tokens = md.parse(md_content)
    lines = md_content.splitlines()
    block_replacements = []

    i = 0
    while i < len(tokens):
        token = tokens[i]

        # First, check for the special case of an HTML block followed by a heading.
        if (token.type == "html_block" 
            and i + 3 < len(tokens) 
            and tokens[i + 1].type == "heading_open"):
            
            html_start, _ = token.map
            _, heading_end = tokens[i + 1].map 

            raw_block = '\n'.join(lines[html_start:heading_end])
            placeholder = get_next_placeholder(raw_block)
            block_replacements.append((html_start, heading_end, placeholder))
            
            i += 4 
            continue

        elif token.type == "fence" or token.type == "html_block":
            start, end = token.map
            raw_block = '\n'.join(lines[start:end])
            placeholder = get_next_placeholder(raw_block)
            block_replacements.append((start, end, placeholder))
            i += 1
            continue
            
        elif token.type == "heading_open":
            level = int(token.tag[1])
            inline_token = tokens[i + 1]
            header_text = inline_token.content.strip()
            placeholder = get_next_placeholder(header_text)
            start, end = token.map
            markdown_prefix = "#" * level
            block_replacements.append((start, end, f"{markdown_prefix} {placeholder}"))
            i += 3
            continue

        elif token.type == "blockquote_open":
            start = token.map[0] if token.map else i
            j = i + 1
            blockquote_content = []
            while j < len(tokens) and tokens[j].type != "blockquote_close":
                if tokens[j].type == "paragraph_open" and j + 1 < len(tokens) and tokens[j+1].type == "inline":
                    blockquote_content.append(tokens[j+1].content)
                j += 1
            
            if j < len(tokens):
                end = tokens[j].map[1] if tokens[j].map else start + 1
                if blockquote_content:
                    text_content = " ".join(blockquote_content)
                    placeholder = get_next_placeholder(text_content)
                    block_replacements.append((start, end, f"> {placeholder}"))
                i = j + 1
            else:
                i += 1
            continue
            
        i += 1

    for start, end, replacement in sorted(block_replacements, reverse=True):
        lines[start:end] = [replacement]

    # Process inline links, code, and URLs
    def replace_md_links(match):
        """
        Replaces standard Markdown links with placeholders for the URL target.
        """
        text, url = match.group(1), match.group(2)
        if re.match(r'^<<PH\d+>>$', url):
            return match.group(0)
        return f"[{text}]({get_next_placeholder(url)})"

    def replace_internal_links(match):
        """
        Replaces internal anchor links with placeholders for the anchor target.
        """
        text, anchor = match.group(1), f"#{match.group(2)}"
        if re.match(r'^<<PH\d+>>$', anchor):
            return match.group(0)
        return f"[{text}]({get_next_placeholder(anchor)})"

    processed_lines = []
    for line in lines:
        line = re.sub(r'`([^`]+)`', lambda m: f"`{get_next_placeholder(m.group(1))}`", line)
        line = re.sub(r'https?://[^\s)\]}]+', lambda m: get_next_placeholder(m.group(0)), line)
        line = re.sub(r'\[([^\]]+)]\(([^)]+)\)', replace_md_links, line)
        line = re.sub(r'\[([^\]]+)]\(#([^)]+)\)', replace_internal_links, line)
        processed_lines.append(line)

    def is_title_line(content: str) -> bool:
        """Heuristic check to determine if a line is a title-style phrase."""
        clean = re.sub(r'<<[^>]+>>', '', content)
        clean = re.sub(r'[*_`[\]\(\)]', '', clean).strip()
        words = re.findall(r"[A-Za-z]+(?:-[A-Za-z]+)*", clean)
        if len(words) < 2: return False
        alpha = re.sub(r'[^A-Za-z]', '', clean)
        if alpha.isupper(): return True
        upper_words = [w for w in words if w[0].isupper()]
        if len(upper_words) / len(words) >= 0.75: return True
        return False

    bullet_placeholder_lines = []
    for line in processed_lines:
        if is_low_prose_line(line):
            placeholder = get_next_placeholder(line)
            bullet_placeholder_lines.append(placeholder)
            continue

        m = re.match(r'^(\s*)([-*+]|\d+\.)\s+(.*)$', line)
        if m:
            indent, bullet, content = m.groups()
            if is_title_line(content):
                ph = get_next_placeholder(content)
                bullet_placeholder_lines.append(f"{indent}{bullet} {ph}")
            else:
                bph = get_next_placeholder(bullet, prefix="BULLET")
                bullet_placeholder_lines.append(f"{indent}{bph} {content}")
        else:
            bullet_placeholder_lines.append(line)

    # Handle newline preservation
    raw_processed = "\n".join(bullet_placeholder_lines)
    raw_processed = re.sub(r'^\s*---\s*$', lambda m: get_next_placeholder(m.group(0)), raw_processed, flags=re.MULTILINE)

    final_lines = []
    for line in raw_processed.splitlines(keepends=True):
        if line.endswith('\n'):
            content = line.rstrip('\n')
            trailing_newline = True
        else:
            content = line
            trailing_newline = False
        
        newline_placeholder = get_next_placeholder("\n", prefix="PH")
        final_lines.append(content + (newline_placeholder if trailing_newline else ''))

    processed_content = ''.join(final_lines)

    # Merge back-to-back placeholders
    merged_placeholder_map = {}
    pattern = re.compile(r'(?:<<PH\d+>>){2,}')
    
    while True:
        match = pattern.search(processed_content)
        if not match: break
    
        ph_sequence = re.findall(r'<<PH\d+>>', match.group(0))
        keys = [ph.strip('<>') for ph in ph_sequence]
        merged_value = ''.join(placeholder_map.get(k, '') for k in keys)
    
        counter += 1
        new_key = f"PH{counter}"
        new_ph = f"<<{new_key}>>"
        placeholder_map[new_key] = merged_value
    
        processed_content = processed_content[:match.start()] + new_ph + processed_content[match.end():]
        merged_placeholder_map[new_key] = keys

    # Add <<SEP>> marker between prose and placeholders if needed
    processed_content = re.sub(r'(<<PH\d+>>)(?!<<)(?=\w)', r'\1<<SEP>>', processed_content)
    placeholder_map["SEP"] = ""
    
    return placeholder_map, processed_content


def restore_placeholders(corrected_text: str, placeholder_map: Dict[str, str]) -> str:
    """
    Replaces placeholders in the corrected markdown content with their original values.

    Args:
        corrected_text (str): Markdown content containing placeholders.
        placeholder_map (Dict[str, str]): Map of placeholders to original content.

    Returns:
        str: Fully restored markdown with original formatting.
    """
    restored_text = corrected_text

    # Replace placeholder tokens with original values
    for placeholder, original in sorted(placeholder_map.items(), key=lambda x: -len(x[0])):
        if placeholder.startswith("BULLET"):
            restored_text = restored_text.replace(f"[[{placeholder}]]", original)
        else:
            restored_text = restored_text.replace(f"<<{placeholder}>>", original)

    # Remove SEP markers
    restored_text = restored_text.replace('<<SEP>>', '')

    return restored_text

############################################
###--------------PARSER END--------------###
############################################

############################################
###------------CHUNKER  START------------###
############################################

import re
from typing import List

def estimate_token_count(text: str) -> int:
    '''
    Estimate the number of tokens in a given string.

    Args:
        text (str): Input text.

    Returns:
        int: Approximate number of tokens, assuming 4 characters per token
    '''
    return len(text) // 4 

def split_by_top_level_headers(markdown: str) -> List[str]:
    """
    Split markdown into sections using top-level headers (e.g., #, ##, ..., ######).

    Args:
        markdown (str): The input markdown text.

    Returns:
        List[str]: A list of sections split by top-level headers.
    """
    # Find all headers from # to ###### at the beginning of a line
    matches = list(re.finditer(r'^\s*#{1,6}\s.*', markdown, flags=re.MULTILINE))
    if not matches:
        # No headers found, return whole content
        return [markdown]

    sections = []
    for i, match in enumerate(matches):
        start = match.start()
        end = matches[i + 1].start() if i + 1 < len(matches) else len(markdown)
        sections.append(markdown[start:end])

    return sections

def smart_sentence_split(text: str) -> List[str]:
    """
    Split text at sentence boundaries or placeholder boundaries.

    Args: 
        text (str): The input text.

    Returns:
        List[str]: A list of sentence-level parts, preserving placeholder boundaries.
    """
    # Pattern to capture sentence boundaries and placeholder boundaries
    pattern = r'([.!?]\s+(?=[A-Z]))|(__PLACEHOLDER\d+__)'
    matches = re.split(pattern, text)

    # Reconstruct complete sentence or placeholder segments
    parts = []
    buffer = ''
    for chunk in matches:
        if chunk is None:
            continue
        buffer += chunk
        if re.match(r'[.!?]\s+$', chunk) or re.match(r'__PLACEHOLDER\d+__', chunk.strip()):
            parts.append(buffer.strip())
            buffer = ''
    if buffer.strip():
        parts.append(buffer.strip())

    return parts

def chunk_large_section(section: str, max_tokens: int = 1000) -> List[str]:
    """
    Chunk a section while avoiding breaks mid-sentence or mid-placeholder.

    Args:
        section (str): A section of text to be chunked.
        max_tokens (int): Maximum token count per chunk.

    Returns:
        List[str]: A list of token-limited chunks.
    """
    chunks = []
    current_chunk = []
    current_token_count = 0

    # Split section into sentence-safe pieces
    parts = smart_sentence_split(section)

    for part in parts:
        part = part.strip()
        if not part:
            continue

        part_token_count = estimate_token_count(part)

        # Finalize current chunk if adding this part would exceed max tokens
        if current_token_count + part_token_count > max_tokens:
            if current_chunk:
                chunks.append(' '.join(current_chunk).strip())
                current_chunk = []
                current_token_count = 0

        current_chunk.append(part)
        current_token_count += part_token_count

    if current_chunk:
        chunks.append(' '.join(current_chunk).strip())

    return chunks

def chunk_markdown(markdown: str, max_tokens: int = 100) -> List[str]:
    '''
    Chunk a full markdown document into smaller parts based on headers and token limits.

    Args:
        markdown (str): The complete markdown content.
        max_tokens (int): Maximum allowed tokens per chunk.

    Returns:
        List[str]: A list of markdown chunks that are token-limited and structured.
    '''
    # Split by top level headers
    sections = split_by_top_level_headers(markdown)
    final_chunks = []

    for section in sections:
        # If the section is small enough, keep it as is
        if estimate_token_count(section) <= max_tokens:
            final_chunks.append(section.strip())
        else:
            # Otherwise, chunk it further based on sentence boundaries
            final_chunks.extend(chunk_large_section(section, max_tokens=max_tokens))

    def split_long_chunk(chunk: str, max_chars: int = 3500) -> List[str]:
        """
        Split a long chunk into smaller character-limited subchunks, preferring newline or sentence boundaries.

        Args:
            chunk (str): A markdown chunk that may be too long.
            max_chars (int): Maximum number of characters per subchunk.

        Returns:
            List[str]: Subchunks of the input chunk that fit the character limit.
        """
        if len(chunk) <= max_chars:
            return [chunk]

        # Try splitting on newlines first
        parts = re.split(r'(?<=\n)', chunk)
        subchunks = []
        buffer = ""

        for part in parts:
            if len(buffer) + len(part) > max_chars:
                if buffer:
                    subchunks.append(buffer.strip())
                buffer = part
            else:
                buffer += part

        if buffer.strip():
            subchunks.append(buffer.strip())

        # If subchunks are still too long, split on sentence boundary
        final_subchunks = []
        for sub in subchunks:
            if len(sub) <= max_chars:
                final_subchunks.append(sub)
            else:
                sentences = re.split(r'(?<=[.!?])\s+', sub)
                sentence_buffer = ""
                for s in sentences:
                    if len(sentence_buffer) + len(s) > max_chars:
                        final_subchunks.append(sentence_buffer.strip())
                        sentence_buffer = s
                    else:
                        sentence_buffer += (" " if sentence_buffer else "") + s
                if sentence_buffer.strip():
                    final_subchunks.append(sentence_buffer.strip())

        return final_subchunks

    # Apply post-splitting to each chunk to enforce character limits
    adjusted_chunks = []
    for chunk in final_chunks:
        adjusted_chunks.extend(split_long_chunk(chunk))

    return adjusted_chunks

############################################
###-------------CHUNKER  END-------------###
############################################

############################################
###-------------PROMPT START-------------###
############################################

from langchain.prompts import PromptTemplate

# Template for llama3-instruct format
MARKDOWN_CORRECTION_TEMPLATE_LLAMA3 = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a markdown grammar correction assistant. Your job is to correct only grammatical errors in the user's Markdown content.

Strictly follow these rules:
- Do **not** modify any placeholders (e.g., <<PH1>>, <<PH93>>, [[BULLET3]], <<SEP>>). Leave them **exactly as they appear**, including spacing and underscores.
- Do **not** remove, reword, rename, reformat, or relocate any placeholder.
- Do **not** alter Markdown formatting (e.g., headings, links, lists, or indentation).
- Do **not** remove Markdown styling characters (e.g., **, *, _, __, `, [, ]).
- Do **not** add or remove extra content from the original text.
- Only correct grammar **within natural language sentences**, leaving structure unchanged.
- **Always** maintain title case wherever it is is present in the original text. 

If a sentence spans multiple lines or has placeholders in it, correct the grammar but preserve formatting and placeholders **as-is**.

Example:
- Original: "<SEP>We use <<PH4>> to **builds** model **likke** this:<<PH17>><<PH18>>"
- Corrected: "<SEP>We use <<PH4>> to **build** models **like** this:<<PH17>><<PH18_>>"

Example:
- Original: "[[BULLET1]] **It Will Be More Profitablr<PH12>>**"
- Corrected: "[[BULLET1]] **It Will Be More Profitable<PH12>>**"

Example:
- Original: "This methd is **not necessary** the way ti build *AI* agents <<PH32>>"
- Corrected: "This method is **not necessary** the way to build *AI* agents <<PH32>>"

All placeholders are present and stay exactly the same with no additional spaces — only grammar is corrected.

Respond only with the corrected Markdown content. Do not explain anything.<|eot_id|><|start_header_id|>user<|end_header_id|>
Original markdown:
{markdown}

Corrected markdown:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""


def get_markdown_correction_prompt() -> PromptTemplate:
    """
    Get the markdown correction prompt formatted for LLaMA 3 instruct.

    Returns:
        PromptTemplate: Ready to use in LangChain with LLaMA 3 format.
    """
    return PromptTemplate.from_template(MARKDOWN_CORRECTION_TEMPLATE_LLAMA3)

############################################
###--------------PROMPT END--------------###
############################################

############################################
###-------------UTILS  START-------------###
############################################

import os
import yaml
import importlib.util
import multiprocessing
from typing import Dict, Any, Optional, Union, List, Tuple

#Default models to be loaded in our examples:
DEFAULT_MODELS = {
    "local": "/home/jovyan/datafabric/llama2-7b/ggml-model-f16-Q5_K_M.gguf",
    "tensorrt": "",
    "hugging-face-local": "meta-llama/Llama-3.2-3B-Instruct",
    "hugging-face-cloud": "mistralai/Mistral-7B-Instruct-v0.3"
}

# Context window sizes for various models
MODEL_CONTEXT_WINDOWS = {
    # LlamaCpp models
    'ggml-model-f16-Q5_K_M.gguf': 4096,
    'ggml-model-7b-q4_0.bin': 4096,
    'gguf-model-7b-4bit.bin': 4096,

    # HuggingFace models
    'mistralai/Mistral-7B-Instruct-v0.3': 8192,
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B': 4096,
    'meta-llama/Llama-2-7b-chat-hf': 4096,
    'meta-llama/Llama-3-8b-chat-hf': 8192,
    'google/flan-t5-base': 512,
    'google/flan-t5-large': 512,
    'TheBloke/WizardCoder-Python-7B-V1.0-GGUF': 4096,

    # OpenAI models
    'gpt-3.5-turbo': 16385,
    'gpt-4': 8192,
    'gpt-4-32k': 32768,
    'gpt-4-turbo': 128000,
    'gpt-4o': 128000,

    # Anthropic models
    'claude-3-opus-20240229': 200000,
    'claude-3-sonnet-20240229': 180000,
    'claude-3-haiku-20240307': 48000,

    # Other models
    'qwen/Qwen-7B': 8192,
    'microsoft/phi-2': 2048,
    'tiiuae/falcon-7b': 4096,
    "meta-llama/Llama-3.2-3B-Instruct": 128000,
}

def load_config_and_secrets(
    config_path: str = "../../configs/config.yaml",
    secrets_path: str = "../../configs/secrets.yaml"
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Load configuration and secrets from YAML files.

    Args:
        config_path: Path to the configuration YAML file.
        secrets_path: Path to the secrets YAML file.

    Returns:
        Tuple containing (config, secrets) as dictionaries.

    Raises:
        FileNotFoundError: If either the config or secrets file is not found.
    """
    # Convert to absolute paths if needed
    config_path = os.path.abspath(config_path)
    secrets_path = os.path.abspath(secrets_path)

    if not os.path.exists(secrets_path):
        raise FileNotFoundError(f"secrets.yaml file not found in path: {secrets_path}")

    if not os.path.exists(config_path):
        raise FileNotFoundError(f"config.yaml file not found in path: {config_path}")

    with open(config_path) as file:
        config = yaml.safe_load(file)

    with open(secrets_path) as file:
        secrets = yaml.safe_load(file)

    return config, secrets

def initialize_llm(
    model_source: str = "local",
    secrets: Optional[Dict[str, Any]] = None,
    local_model_path: str = DEFAULT_MODELS["local"],
    hf_repo_id: str = ""
) -> Any:
    """
    Initialize a language model based on specified source.

    Args:
        model_source: Source of the model. Options are "local", "hugging-face-local", or "hugging-face-cloud".
        secrets: Dictionary containing API keys for cloud services.
        local_model_path: Path to local model file.

    Returns:
        Initialized language model object.

    Raises:
        ImportError: If required libraries are not installed.
        ValueError: If an unsupported model_source is provided.
    """
    # Check dependencies
    missing_deps = []
    for module in ["langchain_huggingface", "langchain_core.callbacks", "langchain_community.llms"]:
        if not importlib.util.find_spec(module):
            missing_deps.append(module)
    
    if missing_deps:
        raise ImportError(f"Missing required dependencies: {', '.join(missing_deps)}")
    
    # Import required libraries
    from langchain_huggingface import HuggingFacePipeline, HuggingFaceEndpoint
    from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
    from langchain_community.llms import LlamaCpp

    model = None
    context_window = None
    
    # Initialize based on model source
    if model_source == "hugging-face-cloud":
        if hf_repo_id == "":
            repo_id = DEFAULT_MODELS["hugging-face-cloud"]
        else:
            repo_id = hf_repo_id  
        if not secrets or "HUGGINGFACE_API_KEY" not in secrets:
            raise ValueError("HuggingFace API key is required for cloud model access")
            
        huggingfacehub_api_token = secrets["HUGGINGFACE_API_KEY"]
        # Get context window from our lookup table
        if repo_id in MODEL_CONTEXT_WINDOWS:
            context_window = MODEL_CONTEXT_WINDOWS[repo_id]

        model = HuggingFaceEndpoint(
            huggingfacehub_api_token=huggingfacehub_api_token,
            repo_id=repo_id,
        )

    elif model_source == "hugging-face-local":
        from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
        if "HUGGINGFACE_API_KEY" in secrets:
            os.environ["HF_TOKEN"] = secrets["HUGGINGFACE_API_KEY"]
        if hf_repo_id == "":
            model_id = DEFAULT_MODELS["hugging-face-local"]
        else:
            model_id = hf_repo_id        
        # Get context window from our lookup table
        if model_id in MODEL_CONTEXT_WINDOWS:
            context_window = MODEL_CONTEXT_WINDOWS[model_id]
        
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        hf_model = AutoModelForCausalLM.from_pretrained(model_id)

        # If tokenizer has model_max_length, that's our context window
        if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length not in (None, -1):
            context_window = tokenizer.model_max_length

        pipe = pipeline("text-generation", model=hf_model, tokenizer=tokenizer, max_new_tokens=100, device=0)
        model = HuggingFacePipeline(pipeline=pipe)
        
    elif model_source == "tensorrt":
        #If a Hugging Face model is specified, it will be used - otherwise, it will try loading the model from local_path
        try:
            import tensorrt_llm
            sampling_params = tensorrt_llm.SamplingParams(temperature=0.1, top_p=0.95, max_tokens=512) 
            if hf_repo_id != "":
                return TensorRTLangchain(model_path = hf_repo_id, sampling_params = sampling_params)
            else:
                model_config = os.path.join(local_model_path, config.json)
                if os.path.isdir(local_model_path) and os.path.isfile(model_config):
                    return TensorRTLangchain(model_path = local_model_path, sampling_params = sampling_params)
                else:
                    raise Exception("Model format incompatible with TensorRT LLM")
        except ImportError:
            raise ImportError(
                "Could not import tensorrt-llm library. "
                "Please make sure tensorrt-llm is installed properly, or "
                "consider using workspaces based on the NeMo Framework"
            )
    elif model_source == "local":
        callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
        # For LlamaCpp, get the context window from the filename
        model_filename = os.path.basename(local_model_path)
        if model_filename in MODEL_CONTEXT_WINDOWS:
            context_window = MODEL_CONTEXT_WINDOWS[model_filename]
        else:  
            # Default context window for LlamaCpp models (explicitly set)
            context_window = 4096

        model = LlamaCpp(
            model_path=local_model_path,
            n_gpu_layers=-1,                             
            n_batch=512,                                 
            n_ctx=32000,
            max_tokens=1024,
            f16_kv=True,
            use_mmap=False,                             
            low_vram=False,                            
            rope_scaling=None,
            temperature=0.0,
            repeat_penalty=1.0,
            streaming=False,
            stop=None,
            seed=42,
            num_threads=multiprocessing.cpu_count(),
            verbose=False #
        )
    else:
        raise ValueError(f"Unsupported model source: {model_source}")

    # Store context window as model attribute for easy access
    if model and hasattr(model, '__dict__'):
        model.__dict__['_context_window'] = context_window

    return model

############################################
###--------------UTILS  END--------------###
############################################

############################################
###-----------LLM METRIC START-----------###
############################################

import os
import re
import numpy as np
from typing import List, Optional
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from mlflow.metrics import make_metric
from llama_cpp import Llama
import multiprocessing

# Initialize TF-IDF vertorizer for semantic similarity
tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)

# Path to local judge model
LOCAL_LLAMA_JUDGE_PATH = "/home/jovyan/datafabric/llama2-7b/ggml-model-f16-Q5_K_M.gguf" 

class LocalJudgeLlamaClient:
    """Singleton wrapper for local judge-specific LLaMA."""
    _client = None

    @classmethod
    def get_client(cls, model_path: Optional[str] = None) -> Llama:
        """
        Get or initialize the singleton LLaMA client.

        Args:
            model_path (Optional[str]): Path to the local gguf LLaMA model.

        Returns:
            Llama: Loaded LLaMA instance.
        """
        if cls._client is None:
            if model_path is None:
                raise ValueError("Must provide model_path to initialize local LLaMA judge.")

            cls._client = Llama(
                model_path=model_path,
                n_gpu_layers=-1,                             
                n_batch=512,                                 
                n_ctx=32000,
                max_tokens=1024,
                f16_kv=True,
                use_mmap=False,                             
                low_vram=False,                            
                rope_scaling=None,
                temperature=0.0,
                repeat_penalty=1.0,
                streaming=False,
                stop=None,
                seed=42,
                num_threads=multiprocessing.cpu_count(),
                verbose=False
            )
        return cls._client

# Preload model at module load
LocalJudgeLlamaClient.get_client(model_path=LOCAL_LLAMA_JUDGE_PATH)

def simple_grammar_check(text: str) -> int:
    """
    Basic grammar checking without external libraries.
    Detects repeated words, double spaces, and capitalization issues.

    Args:
        text (str): Input sentence to analyze.

    Returns:
        int: Count of potential grammar issues.
    """
    issues = 0
    text = str(text).strip()
    
    # Check for basic issues
    sentences = re.split(r'[.!?]+', text)
    
    for sentence in sentences:
        sentence = sentence.strip()
        if not sentence:
            continue
            
        # Check for sentences not starting with capital letter
        if sentence and not sentence[0].isupper():
            issues += 1
            
        # Check for double spaces
        if '  ' in sentence:
            issues += 1
            
        # Check for common grammar patterns
        words = sentence.lower().split()
        for i, word in enumerate(words):
            # Basic subject-verb agreement checks
            if word == 'i' and i < len(words) - 1:
                if words[i + 1] in ['are', 'were']:
                    issues += 1  
                    
            # Check for repeated words
            if i > 0 and word == words[i-1]:
                issues += 1
                
    return issues

def semantic_similarity_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """
    Compute semantic similarity between predictions and targets using TF-IDF and cosine similarity.

    Args:
        predictions (List[str]): Model-generated texts.
        targets (List[str]): Ground truth references.

    Returns:
        float: Mean cosine similarity between matched pairs.
    """
    # Combine all texts to fit the vectorizer
    all_texts = list(targets) + list(predictions)
    
    # Handle empty texts
    all_texts = [str(text) if text else "" for text in all_texts]
    
    if len(set(all_texts)) < 2:  # All texts are identical or empty
        return 1.0
    
    # Fit and transform
    tfidf_matrix = tfidf_vectorizer.fit_transform(all_texts)
    
    # Split back into targets and predictions
    n_targets = len(targets)
    target_vectors = tfidf_matrix[:n_targets]
    pred_vectors = tfidf_matrix[n_targets:]
    
    # Calculate cosine similarity for each pair
    similarities = []
    for i in range(len(targets)):
        similarity = cosine_similarity(target_vectors[i:i+1], pred_vectors[i:i+1])[0][0]
        similarities.append(similarity)
    
    return np.mean(similarities)

def grammar_error_count_eval_fn(predictions, targets):
    """
    Count grammar issues in the predictions.

    Args:
        predictions (List[str]): Model outputs.
        targets (List[str]): Reference texts (unused here).

    Returns:
        float: Average number of issues per prediction.
    """
    error_counts = []
    for pred in predictions:
        error_count = simple_grammar_check(str(pred))
        error_counts.append(error_count)
    
    return np.mean(error_counts)

def grammar_error_rate_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """
    Calculate grammar error rate (issues per word) for predictions.

    Args:
        predictions (List[str]): Model outputs.
        targets (List[str]): Reference texts (unused here).

    Returns:
        float: Mean error rate.
    """
    error_rates = []
    for pred in predictions:
        error_count = simple_grammar_check(str(pred))
        word_count = len(str(pred).split())
        error_rate = error_count / max(word_count, 1)  # Avoid division by zero
        error_rates.append(error_rate)
    
    return np.mean(error_rates)

def grammar_improvement_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """
    Measure improvement in grammar (fewer errors) from targets to predictions.

    Args:
        predictions (List[str]): Corrected text.
        targets (List[str]): Original input text.

    Returns:
        float: Mean improvement (positive = fewer errors in prediction).
    """
    improvements = []
    for pred, target in zip(predictions, targets):
        input_errors = simple_grammar_check(str(target))
        output_errors = simple_grammar_check(str(pred))
        improvement = input_errors - output_errors  
        improvements.append(improvement)
    
    return np.mean(improvements)

def grammar_score_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """
    Assign grammar score from 0–100 based on number of issues.

    Args:
        predictions (List[str]): Model outputs.
        targets (List[str]): Reference texts (unused here).

    Returns:
        float: Mean score where 100 = perfect grammar.
    """
    scores = []
    for pred in predictions:
        error_count = simple_grammar_check(str(pred))
        word_count = len(str(pred).split())
        if word_count == 0:
            scores.append(0)
        else:
            # Simple scoring: start at 100, subtract points for errors
            error_penalty = min(error_count * 10, 100)  
            score = max(100 - error_penalty, 0)
            scores.append(score)
    
    return np.mean(scores)

def readability_improvement_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """
    Estimate improvement in readability using sentence length as proxy.

    Args:
        predictions (List[str]): Corrected output.
        targets (List[str]): Original input.

    Returns:
        float: Average improvement in readability score.
    """
    def calculate_readability_score(text: str) -> float:
        """
        Estimate a basic readability score for a given text.
    
        Uses average sentence length as a simple heuristic. Shorter sentences are considered more readable.
    
        Args:
            text (str): The input text to evaluate.
    
        Returns:
            float: Readability score where higher is better.
        """
        sentences = text.split('.')
        words = text.split()
        if len(sentences) == 0 or len(words) == 0:
            return 0
        
        avg_sentence_length = len(words) / len(sentences)
        # Simple readability: prefer shorter sentences and common words
        readability = max(20 - avg_sentence_length, 0)  
        return readability
    
    improvements = []
    for pred, target in zip(predictions, targets):
        input_readability = calculate_readability_score(str(target))
        output_readability = calculate_readability_score(str(pred))
        improvement = output_readability - input_readability
        improvements.append(improvement)
    
    return np.mean(improvements)

def llm_judge_eval_fn_local(predictions: List[str]) -> float:
    """
    Use a local LLaMA model to rate grammar of predictions from 1 to 10.

    Args:
        predictions (List[str]): Model outputs.

    Returns:
        float: Average LLaMA-generated grammar rating.
    """
    llama = LocalJudgeLlamaClient.get_client()
    scores = []

    for pred in predictions:
        prompt = f"""Rate the following text solely on grammar. Respond with a single digit from 1 to 10. DO NOT include any explanation, label, or punctuation. Reply with just the number.

Text: I has a apple.
Answer: 3

Text: The dog chased the ball across the yard.
Answer: 9

Text: Him don't know where she is.
Answer: 2

Text: {pred}
Answer:"""

        try:
            result = llama(prompt, stop=["\n"])
            text = result["choices"][0]["text"].strip()

            try:
                import os
                os.makedirs("llm_eval_logs", exist_ok=True)
                with open("llm_eval_logs/local_llama_responses.txt", 'a', encoding='utf-8') as f:
                    f.write(f"Response: {text}\n")
            except:
                pass

            score = float(re.findall(r"\d+", text)[0])
            scores.append(score)
        except Exception as e:
            print(f"[LLaMA judge error]: {e}")
            scores.append(5.0)

    return sum(scores) / len(scores)

# ---- Create all the metric wrappers for MLflow ----
    
semantic_similarity_metric = make_metric(
    eval_fn=semantic_similarity_eval_fn,
    greater_is_better=True,
    name="semantic_similarity"
)

grammar_error_count_metric = make_metric(
    eval_fn=grammar_error_count_eval_fn,
    greater_is_better=False,
    name="grammar_error_count"
)

grammar_error_rate_metric = make_metric(
    eval_fn=grammar_error_rate_eval_fn,
    greater_is_better=False,
    name="grammar_error_rate"
)

grammar_improvement_metric = make_metric(
    eval_fn=grammar_improvement_eval_fn,
    greater_is_better=True,
    name="grammar_improvement"
)

grammar_score_metric = make_metric(
    eval_fn=grammar_score_eval_fn,
    greater_is_better=True,
    name="grammar_score"
)

readability_improvement_metric = make_metric(
    eval_fn=readability_improvement_eval_fn,
    greater_is_better=True,
    name="readability_improvement"
)

llm_judge_metric_local = make_metric(
    eval_fn=llm_judge_eval_fn_local,
    greater_is_better=True,
    name="llm_judge_local_score"
)

############################################
###------------LLM METRIC END------------###
############################################

## Define MLflow Model Class

In [None]:
class MarkdownCorrectorModel(mlflow.pyfunc.PythonModel):
    """An MLflow model that encapsulates the entire markdown correction pipeline."""

    def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None:
        """Initializes the model, loading configurations, secrets, and the LLM."""
        
        # Load configurations and secrets from model artifacts
        config_path = context.artifacts["config"]
        secrets_path = context.artifacts["secrets"]
        self.config, self.secrets = load_config_and_secrets(config_path, secrets_path)

        # Initialize the LLM
        model_source = self.config.get("model_source", "local")
        local_model_path = context.artifacts.get("local_model")
        llm = initialize_llm(model_source, self.secrets, local_model_path)

        # Create the processing chain
        correction_prompt = get_markdown_correction_prompt()
        self.llm_chain = correction_prompt | llm
        logger.info("✅ Model context loaded and LLM chain initialized.")

    @staticmethod
    def _safe_join_chunks(chunks: List[str]) -> str:
        """Rejoins text chunks while preserving formatting."""
        joined = ""
        for i, chunk in enumerate(chunks):
            if i == 0:
                joined += chunk
            else:
                prev = chunks[i - 1].rstrip()
                curr = chunk
                if prev.endswith('.') and re.match(r'^[A-Z\"]', curr.lstrip()):
                    joined += ' ' + curr.lstrip()
                else:
                    joined += curr
        return joined

    def predict(self, context: mlflow.pyfunc.PythonModelContext, model_input: pd.DataFrame) -> pd.Series:
        """Executes the full workflow for each repository URL in the input."""

        results = []
        for repo_url in model_input["repo_url"]:
            logger.info(f"Processing repository: {repo_url}")
            access_token = self.secrets.get("GITHUB_ACCESS_TOKEN")
            
            # 1. Extract Markdown
            processor = GitHubMarkdownProcessor(repo_url=repo_url, access_token=access_token)
            markdowns = processor.run()

            # 2. Parse & Chunk
            parsed_markdowns, placeholder_maps, all_chunks = {}, {}, {}
            for filename, content in markdowns.items():
                placeholder_map, processed_content = parse_md_for_grammar_correction(content)
                parsed_markdowns[filename] = processed_content
                placeholder_maps[filename] = placeholder_map
                all_chunks[filename] = chunk_markdown(processed_content)
            
            # 3. Invoke Model
            corrected_chunks_by_file = defaultdict(list)
            for file_name, chunks in all_chunks.items():
                for chunk in chunks:
                    response = self.llm_chain.invoke({"markdown": chunk})
                    corrected_chunks_by_file[file_name].append(response)

            # 4. Rebuild and Restore
            final_corrected_files = {}
            for file_name, corrected_chunks in corrected_chunks_by_file.items():
                rebuilt_content = self._safe_join_chunks(corrected_chunks)
                placeholder_map = placeholder_maps.get(file_name, {})
                restored_content = restore_placeholders(rebuilt_content, placeholder_map)
                final_corrected_files[file_name] = restored_content
            
            results.append(json.dumps(final_corrected_files, indent=2))
            logger.info(f"✅ Finished processing for {repo_url}.")

        return pd.Series(results)

## Verify Assets and Load Data

In [None]:
def log_asset_status(asset_path: Path, asset_name: str) -> None:
    """Logs the existence status of a given file or directory."""
    if asset_path.exists():
        logger.info(f"✅ {asset_name} is properly configured at: {asset_path}")
    else:
        logger.warning(f"⚠️ {asset_name} not found at: {asset_path}.")

log_asset_status(CONFIG_PATH, "Config file")
log_asset_status(SECRETS_PATH, "Secrets file")
log_asset_status(LOCAL_MODEL_PATH, "Local LLaMA model")
log_asset_status(EVAL_DATA_PATH, "Evaluation data JSON")
log_asset_status(REQUIREMENTS_PATH, "Requirements file")

with open(EVAL_DATA_PATH, "r") as f:
    results = json.load(f)
original_texts = [item["original"] for item in results]
eval_df = pd.DataFrame(original_texts, columns=["markdown"])
logger.info(f"Loaded {len(eval_df)} records for evaluation.")
print(eval_df.head())

## Log and Register Model

In [None]:
%%time

mlflow.set_experiment(EXPERIMENT_NAME)

with mlflow.start_run(run_name=RUN_NAME) as run:
    run_id = run.info.run_id
    logger.info(f"MLflow Run Started. Run ID: {run_id}")

    artifacts = {
        "config": str(CONFIG_PATH),
        "secrets": str(SECRETS_PATH),
        "llm": str(LOCAL_MODEL_PATH),
    }
    
    signature = ModelSignature(
        inputs=Schema([ColSpec("string", "markdown")]),
        outputs=Schema([ColSpec("string", "corrected")]),
    )
    
    mlflow.pyfunc.log_model(
        artifact_path=MODEL_NAME,
        python_model=MarkdownCorrectionService(),
        artifacts=artifacts,
        signature=signature,
        registered_model_name=MODEL_NAME,
        pip_requirements=str(REQUIREMENTS_PATH),
    )
    
    model_uri = f"runs:/{run_id}/{MODEL_NAME}"
    logger.info(f"Model URI: {model_uri}")

# Validate Registered Model
client = MlflowClient()
try:
    latest_version_info = client.get_latest_versions(MODEL_NAME, stages=["None"])[0]
    latest_version = latest_version_info.version
    logger.info(f"Successfully registered model '{MODEL_NAME}' version {latest_version}.")
    
    model_uri_latest = f"models:/{MODEL_NAME}/{latest_version}"
    loaded_model = mlflow.pyfunc.load_model(model_uri_latest)
    
    sample_input = eval_df.head(1)
    logger.info(f"Performing sample prediction on: \n{sample_input['markdown'].iloc[0][:100]}...")
    
    prediction = loaded_model.predict(sample_input)
    logger.info(f"✅ Sample prediction successful. Output:\n{prediction.iloc[0][:100]}...")
except Exception as e:
    logger.error(f"Failed to validate registered model. Error: {e}")

## Log and Register the Model

In [None]:
%%time

# Define artifacts to be packaged with the model
artifacts = {
    "config": str(CONFIG_PATH),
    "secrets": str(SECRETS_PATH),
    "local_model": str(LOCAL_MODEL_PATH)
}

# Define the model's signature
input_schema = Schema([ColSpec("string", "repo_url")])
output_schema = Schema([ColSpec("string")]) # Output is a JSON string
signature = ModelSignature(inputs=input_schema, outputs=output_schema)

# Set up MLflow experiment
mlflow.set_tracking_uri('/phoenix/mlflow') # Adjust if your tracking server is elsewhere
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)

with mlflow.start_run(run_name="MarkdownCorrectorRegistry") as run:
    run_id = run.info.run_id
    logger.info(f"Starting MLflow run with ID: {run_id}")
    
    # Log the model to MLflow
    mlflow.pyfunc.log_model(
        artifact_path=MLFLOW_MODEL_NAME,
        python_model=MarkdownCorrectorModel(),
        artifacts=artifacts,
        pip_requirements="../requirements.txt",
        code_paths=["../src"],
        signature=signature,
        registered_model_name=MLFLOW_MODEL_NAME,
        input_example=pd.DataFrame([{"repo_url": "https://github.com/hp-david/test"}])
    )
    logger.info(f"✅ Model '{MLFLOW_MODEL_NAME}' logged and registered successfully.")

## Prepare Eval Data

In [None]:
'''
logger.info("Preparing data for evaluation...")

# Load the newly registered model
model_uri = f"models:/{MLFLOW_MODEL_NAME}/latest"
loaded_model = mlflow.pyfunc.load_model(model_uri)

# Get original content from the test repo
test_repo_url = "https://github.com/hp-david/test"
config, secrets = load_config_and_secrets(CONFIG_PATH, SECRETS_PATH)
processor = GitHubMarkdownProcessor(repo_url=test_repo_url, access_token=secrets.get("GITHUB_ACCESS_TOKEN"))
original_markdowns = processor.run()

# Get the model's corrected predictions
test_input = pd.DataFrame([{"repo_url": test_repo_url}])
prediction_result = loaded_model.predict(test_input)
corrected_markdowns = json.loads(prediction_result[0])

# Align originals and predictions into a single DataFrame
eval_data = []
for filename, original_content in original_markdowns.items():
    if filename in corrected_markdowns:
        eval_data.append({
            "original": original_content,
            "corrected": corrected_markdowns[filename]
        })

evaluation_df = pd.DataFrame(eval_data)
logger.info(f"✅ Created evaluation DataFrame with {len(evaluation_df)} file(s).")
display(evaluation_df.head())
'''

## Evaluate Model

In [None]:
'''
with mlflow.start_run(run_id=run_id):
    logger.info(f"Starting evaluation for run ID: {run_id}")
    
    results = mlflow.evaluate(
        data=evaluation_df,
        targets="original",
        predictions="corrected",
        extra_metrics=[
            # Standard metrics, now treated as extra metrics
            exact_match(),
            rouge1(),
            rougeL(),
            # Your custom metrics
            semantic_similarity_metric,
            grammar_improvement_metric,
            llm_judge_metric_local,
        ]
    )
    
    logger.info("✅ Evaluation complete.")
    print(json.dumps(results.metrics, indent=2))
'''

## Log Execution Time

In [None]:
end_time: float = time.time()
elapsed_time: float = end_time - start_time
elapsed_minutes: int = int(elapsed_time // 60)
elapsed_seconds: float = elapsed_time % 60

logger.info(f"⏱️ Total execution time: {elapsed_minutes}m {elapsed_seconds:.2f}s")
logger.info("✅ Notebook execution completed successfully.")