<h1 style="text-align: center; font-size: 50px;"> Register Model </h1>

# Notebook Overview

- Configure the Environment
- Define Constants and Paths
- Define Necessary Modules
- Define MLflow Model Class
- Log and Register Model
- Prepare Eval Data
- Evaluate Model
- Log Execution Time

## Configure the Environment

In [1]:
import logging
import time

# Configure logger
logger: logging.Logger = logging.getLogger("register_model_logger")
logger.setLevel(logging.INFO)
logger.propagate = False  # Prevent duplicate logs

# Set formatter
formatter: logging.Formatter = logging.Formatter(
    fmt="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

# Configure and attach stream handler
stream_handler: logging.StreamHandler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

In [2]:
start_time = time.time()
logger.info("✅ Notebook execution started.")

2025-08-08 00:18:52 - INFO - ✅ Notebook execution started.


In [3]:
%%time

%pip install -r ../requirements.txt --quiet

Note: you may need to restart the kernel to use updated packages.
CPU times: user 38.4 s, sys: 17.2 s, total: 55.6 s
Wall time: 30min 56s


In [4]:
# Standard & Third-Party Libraries
import os
import sys
import json
import re
import yaml
from pathlib import Path
from collections import defaultdict
import pandas as pd
import mlflow
from mlflow.models import ModelSignature, evaluate
from mlflow.types.schema import Schema, ColSpec
from typing import List
import mlflow
import pandas as pd

# Add src directory to system path
sys.path.append(str(Path("..").resolve() / "src"))

# Import standard MLflow metrics
from mlflow.metrics import exact_match, rouge1, rougeL

In [5]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Define Constants and Paths

In [6]:
# Configuration Paths
CONFIG_PATH = Path("../configs/configs.yaml")
REQUIREMENTS_PATH = Path("../requirements.txt")
LOCAL_MODEL_PATH = Path("/home/jovyan/datafabric/meta-llama3.1-8b-Q8/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf")
SECRETS_PATH = Path("../configs/secrets.yaml")

# MLflow Configuration
MLFLOW_EXPERIMENT_NAME = "Markdown_Correction_Service"
MLFLOW_MODEL_NAME = "markdown-corrector"

## Define Necessary Modules

In [7]:
############################################
###--------GITHUB EXTRACTOR START--------###
############################################

import requests
from urllib.parse import urlparse
import sys
import base64 
import re 
from markdown_it import MarkdownIt
from markdown_it.token import Token 
import os
from bs4 import BeautifulSoup
from typing import Optional, Tuple, Dict, Union

class GitHubMarkdownProcessor:
    """
    Processor for extracting and parsing Markdown files from GitHub repositories.

    This class fetches `.md` files from a public or private GitHub repository,
    replaces structural and inline components with labeled placeholders, and 
    saves the resulting structure locally.

    Attributes:
        repo_url (str): GitHub repository URL.
        repo_owner (str): Owner of the GitHub repository.
        repo_name (str): Name of the GitHub repository.
        access_token (Optional[str]): GitHub Personal Access Token (if needed).
        save_dir (str): Directory to save the processed files.
        api_base_url (str): Base GitHub API URL for the repository.
    """

    def __init__(self, repo_url: str, access_token: Optional[str] = None, save_dir: str = './parsed_repo'):
        """
        Initializes the Markdown processor with a GitHub repo URL.

        Args:
            repo_url (str): Full URL to the GitHub repository.
            access_token (Optional[str]): GitHub token for private repo access.
            save_dir (str): Output directory to store processed Markdown files.
        """
        self.repo_url = repo_url
        self.access_token = access_token
        self.save_dir = save_dir

        owner, repo, error = self.parse_url()
        if error:
            raise ValueError(error)
        
        self.repo_owner = owner
        self.repo_name = repo
        self.api_base_url = f"https://api.github.com/repos/{self.repo_owner}/{self.repo_name}"

    def parse_url(self) -> Tuple[Optional[str], Optional[str], Optional[str]] :
        """
        Parses a GitHub URL and extracts the repository owner and name.
    
        Args:
            github_url (str): The full GitHub URL to parse.
    
        Returns:
            Tuple[Optional[str], Optional[str], Optional[str]]:
                - The repository owner (e.g., 'openai')
                - The repository name (e.g., 'whisper')
                - An error message string if parsing fails; otherwise, None
        """
        parsed_url = urlparse(self.repo_url)
        path_parts = parsed_url.path.strip("/").split("/")
    
        # Validate URL format
        if len(path_parts) < 2:
            return None, None, "Invalid GitHub URL format."
        
        # Return owner and repo
        return path_parts[0], path_parts[1], None
    
    def check_repo(self) -> str:
        """
        Determines the visibility (public or private) of a GitHub repository.
    
        Parses the provided GitHub URL, queries the GitHub API to fetch repository details,
        and returns a string indicating its visibility or an error message if access fails.
    
        Args:
            github_url (str): The URL of the GitHub repository to check.
            access_token (Optional[str], optional): GitHub personal access token for authenticated requests. Defaults to None.
    
        Returns:
            str: "private" or "public" if the repository is accessible;
                 otherwise, a descriptive error message.
        """
        # Parse url into components
        owner, name, error = self.parse_url()
        if error:
            return error
        
        # Build GitHub URL
        url = f"https://api.github.com/repos/{owner}/{name}"
    
        # Build authentication header
        headers = {}
        if self.access_token:
            headers["Authorization"] = f"Bearer {self.access_token}"
    
        response = requests.get(url, headers=headers)
    
        # Determine privacy of repo
        if response.status_code == 200:
            repo_data = response.json()
            return "private" if repo_data.get("private") else "public"
        elif response.status_code == 404:
            return "Repository is inaccessible. Please authenticate."
        else:
            return f"Error: {response.status_code}, {response.text}"
        
    def extract_md_files(self) -> Tuple[Optional[Dict], Optional[str]]:
        """
        Traverses a GitHub repository to extract all Markdown (.md) files and organize them in a nested directory structure.
    
        Connects to the GitHub API, retrieves the file tree of the default branch, downloads any Markdown files,
        and reconstructs their paths locally in a dictionary format. Supports optional authentication with a personal access token.
    
        Args:
            github_url (str): The GitHub repository URL (e.g., "https://github.com/user/repo").
            access_token (Optional[str], optional): GitHub personal access token for authenticated requests. Defaults to None.
    
        Returns:
            Tuple[Optional[Dict], Optional[str]]:
                - A nested dictionary representing the directory structure and Markdown file contents.
                - An error message string if any step fails; otherwise, None.
        """
        
        # Parse url into components
        owner, name, error = self.parse_url()
        if error:
            return None, error
        
        # Build GitHub URL
        url = f"https://api.github.com/repos/{owner}/{name}"
    
        # Build authentication header
        headers = {}
        if self.access_token:
            headers["Authorization"] = f"Bearer {self.access_token}"
    
        response = requests.get(url, headers=headers)
        if not response.ok:
            return None, f"Error: {response.status_code}, {response.text}"
    
        # Get default branch
        default_branch = response.json().get("default_branch", "main")
    
        # Build GitHub URL for file tree
        tree_url = f"https://api.github.com/repos/{owner}/{name}/git/trees/{default_branch}?recursive=1"
        tree_response = requests.get(tree_url, headers=headers)
        if not tree_response.ok:
            return None, f"Error: {tree_response.status_code}, {tree_response.text}"
        
        # Dictionary to hold directory structure
        dir_structure = {}
    
        # Iterate through repo tree structure
        for item in tree_response.json().get("tree", []):
            path=item["path"]
        
            # Skip all non md files
            if item["type"] != "blob" or not path.endswith(".md"):
                continue
    
            # Fetch md file content
            content_url = f"https://api.github.com/repos/{owner}/{name}/contents/{path}"
            content_response = requests.get(content_url, headers=headers)
            if not content_response.ok:
                continue
    
            # Decode content response
            file_data = content_response.json()
            try:
                content = base64.b64decode(file_data["content"]).decode("utf-8")
            except Exception as e:
                content = f"Error decoding content: {e}"
    
            # Build directory structure
            parts = path.split("/")
            current = dir_structure 
            for part in parts[:-1]:
                current = current.setdefault(part, {})
            current[parts[-1]] = content 
    
        return dir_structure, None
    
    def run(self) -> Dict[str, str]:
        """
        High-level method to:
        1. Check repository access
        2. Extract markdown files
        3. Return raw markdown content by file path
    
        Returns:
            Dict[str, str]: Mapping from file paths to raw markdown content.
        """
        visibility = self.check_repo()
        logger.info(f"Repository visibility: {visibility}")
    
        if visibility == "Repository is inaccessible. Please authenticate.":
            raise PermissionError("Cannot access repository. Check your access token.")
    
        structure, error = self.extract_md_files()
        if error:
            raise RuntimeError(f"Markdown extraction failed: {error}")
    
        raw_data = {}
    
        def process_structure(structure: Dict[str, Union[str, dict]], path: str = "") -> None:
            """
            Recursively flattens a nested directory structure of markdown files.
        
            Args:
                structure (Dict[str, Union[str, dict]]): Nested dictionary representing directories and markdown file contents.
                path (str, optional): Current path used to build the full file path during traversal. Defaults to "".
        
            Returns:
                None: Updates the outer `raw_data` dictionary in-place with path-to-content mappings.
            """
            for name, content in structure.items():
                current_path = os.path.join(path, name)
                if isinstance(content, dict):
                    process_structure(content, current_path)
                else:
                    raw_data[current_path] = content
    
        process_structure(structure)
        logger.info("Raw markdown extraction complete.")
        return raw_data

############################################
###---------GITHUB EXTRACTOR END---------###
############################################

############################################
###-------------PARSER START-------------###
############################################

import re
from typing import Dict, Tuple
from markdown_it import MarkdownIt


def parse_md_for_grammar_correction(md_content: str) -> Tuple[Dict[str, str], str]:
    """
    Parses Markdown content by replacing non-prose elements with placeholders.

    Args:
        md_content (str): Raw markdown content to process.

    Returns:
        Tuple[Dict[str, str], str]:
            - A dictionary mapping placeholder keys to original markdown blocks.
            - The transformed markdown with placeholders in place of structure.
    """

    md = MarkdownIt()

    placeholder_map = {}
    counter = 0

    def get_next_placeholder(value: str, prefix="PH") -> str:
        """
        Generates a unique placeholder token for the given value and stores the mapping.

        If the input contains any previously assigned placeholders, they are unwrapped and reassigned
        as a new single placeholder. This helps deduplicate and simplify complex inline structures.

        Args:
            value (str): The text content to be replaced by a placeholder.
            prefix (str): The placeholder prefix (e.g., "PH", "BULLET").

        Returns:
            str: The generated placeholder token (e.g., "<<PH3>>" or "[[BULLET2]]").
        """
        nonlocal counter
        value = re.sub(
            r"<<(PH|BULLET|SEP)\d*>>|\[\[BULLET\d+\]\]",
            lambda m: placeholder_map.get(m.group(0).strip("<>[]"), m.group(0)),
            value,
        )
        counter += 1
        key = f"{prefix}{counter}"
        placeholder_map[key] = value

        if prefix == "BULLET":
            return f"[[{key}]]"
        else:
            return f"<<{key}>>"

    def protect_front_matter(content: str) -> str:
        """
        Detects and replaces a YAML front matter block with a single placeholder.
        The front matter must be at the very beginning of the string.

        Args:
            content (str): Input markdown content.

        Returns:
            str: Markdown with front matter replaced by a placeholder.
        """
        front_matter_pattern = re.compile(r'\A---\s*\n.*?\n---\s*\n?', re.DOTALL)
        
        match = front_matter_pattern.search(content)
        if match:
            front_matter_block = match.group(0)
            placeholder = get_next_placeholder(front_matter_block)
            return front_matter_pattern.sub(placeholder, content, count=1)
        
        return content

    def protect_tables(content: str) -> str:
        """
        Detects Markdown tables and replaces them with placeholders.

        Args:
            content (str): Input markdown content.

        Returns:
            str: Markdown with tables replaced by placeholders.
        """
        lines = content.splitlines()
        protected_lines = []
        in_table = False
        table_buffer = []

        for line in lines:
            if (
                "|" in line
                and line.strip().startswith("|")
                and line.strip().endswith("|")
            ):
                if not in_table:
                    in_table = True
                    table_buffer = [line]
                else:
                    table_buffer.append(line)
            elif in_table and re.match(r"^\s*\|[\s\-\|:]+\|\s*$", line):
                table_buffer.append(line)
            elif in_table:
                if len(table_buffer) >= 2:
                    table_content = "\n".join(table_buffer)
                    placeholder = get_next_placeholder(table_content)
                    protected_lines.append(placeholder)
                else:
                    protected_lines.extend(table_buffer)

                table_buffer = []
                in_table = False
                protected_lines.append(line)
            else:
                protected_lines.append(line)

        # Handle final table
        if in_table and len(table_buffer) >= 2:
            table_content = "\n".join(table_buffer)
            placeholder = get_next_placeholder(table_content)
            protected_lines.append(placeholder)
        elif in_table:
            protected_lines.extend(table_buffer)

        return "\n".join(protected_lines)

    def is_low_prose_line(line: str, threshold: float = 0.1) -> bool:
        """
        Determines if a line contains mostly structure and little natural language prose.

        Args:
            line (str): A single markdown line.
            threshold (float): Minimum proportion of prose content.

        Returns:
            bool: True if it's mostly non-prose.
        """
        if not line.strip():
            return True
        no_placeholders = re.sub(r"<<(PH|BULLET|SEP)\d*>>|\[\[BULLET\d+\]\]", "", line)
        no_markdown = re.sub(r"[*_`[\]()#|-]", "", no_placeholders)
        prose_content = no_markdown.strip()
        if len(line) > 0 and (len(prose_content) / len(line)) < threshold:
            return True
        return False
    
    md_content = protect_front_matter(md_content)
    md_content = protect_tables(md_content)

    tokens = md.parse(md_content)
    lines = md_content.splitlines()
    block_replacements = []

    i = 0
    while i < len(tokens):
        token = tokens[i]

        if (
            token.type == "html_block"
            and i + 3 < len(tokens)
            and tokens[i + 1].type == "heading_open"
        ):
            html_start, _ = token.map
            _, heading_end = tokens[i + 1].map
            raw_block = "\n".join(lines[html_start:heading_end])
            placeholder = get_next_placeholder(raw_block)
            block_replacements.append((html_start, heading_end, placeholder))
            i += 4
            continue
        elif token.type == "fence" or token.type == "html_block":
            start, end = token.map
            raw_block = "\n".join(lines[start:end])
            placeholder = get_next_placeholder(raw_block)
            block_replacements.append((start, end, placeholder))
            i += 1
            continue
        elif token.type == "heading_open":
            level = int(token.tag[1])
            inline_token = tokens[i + 1]
            header_text = inline_token.content.strip()
            placeholder = get_next_placeholder(header_text)
            start, end = token.map
            markdown_prefix = "#" * level
            block_replacements.append((start, end, f"{markdown_prefix} {placeholder}"))
            i += 3
            continue
        elif token.type == "blockquote_open":
            start = token.map[0] if token.map else i
            j = i + 1
            blockquote_content = []
            while j < len(tokens) and tokens[j].type != "blockquote_close":
                if (
                    tokens[j].type == "paragraph_open"
                    and j + 1 < len(tokens)
                    and tokens[j + 1].type == "inline"
                ):
                    blockquote_content.append(tokens[j + 1].content)
                j += 1
            if j < len(tokens):
                end = tokens[j].map[1] if tokens[j].map else start + 1
                if blockquote_content:
                    text_content = " ".join(blockquote_content)
                    placeholder = get_next_placeholder(text_content)
                    block_replacements.append((start, end, f"> {placeholder}"))
                i = j + 1
            else:
                i += 1
            continue
        i += 1

    for start, end, replacement in sorted(block_replacements, reverse=True):
        lines[start:end] = [replacement]

    def replace_md_links(match):
        """
        Replaces standard Markdown links with placeholders for the URL target.
        """
        text, url = match.group(1), match.group(2)
        if re.match(r"^<<PH\d+>>$", url):
            return match.group(0)
        return f"[{text}]({get_next_placeholder(url)})"

    def replace_internal_links(match):
        """
        Replaces internal anchor links with placeholders for the anchor target.
        """
        text, anchor = match.group(1), f"#{match.group(2)}"
        if re.match(r"^<<PH\d+>>$", anchor):
            return match.group(0)
        return f"[{text}]({get_next_placeholder(anchor)})"

    processed_lines = []
    for line in lines:
        # Protect inline code first, as it can contain any character.
        line = re.sub(
            r"`([^`]+)`", lambda m: f"`{get_next_placeholder(m.group(1))}`", line
        )
        
        # Protect Markdown images 
        line = re.sub(
            r'!\[([^\]]*)\]\(([^)]+)\)',
            lambda m: get_next_placeholder(m.group(0)),
            line
        )

        # Protect standard Markdown links []() and internal links [](#).
        line = re.sub(r"\[([^\]]+)]\(([^)]+)\)", replace_md_links, line)
        line = re.sub(r"\[([^\]]+)]\(#([^)]+)\)", replace_internal_links, line)
        
        # Protect raw URLs last
        line = re.sub(
            r"https?://[^\s)\]}]+", lambda m: get_next_placeholder(m.group(0)), line
        )
        
        processed_lines.append(line)

    def is_title_line(content: str) -> bool:
        """Heuristic check to determine if a line is a title-style phrase."""
        clean = re.sub(r"<<[^>]+>>", "", content)
        clean = re.sub(r"[*_`[\]\(\)]", "", clean).strip()
        words = re.findall(r"[A-Za-z]+(?:-[A-Za-z]+)*", clean)
        if len(words) < 2:
            return False
        alpha = re.sub(r"[^A-Za-z]", "", clean)
        if alpha.isupper():
            return True
        upper_words = [w for w in words if w[0].isupper()]
        if len(words) > 0 and len(upper_words) / len(words) >= 0.75:
            return True
        return False

    bullet_placeholder_lines = []
    for line in processed_lines:
        if is_low_prose_line(line):
            placeholder = get_next_placeholder(line)
            bullet_placeholder_lines.append(placeholder)
            continue

        m = re.match(r"^(\s*)([-*+]|\d+\.)\s+(.*)$", line)
        if m:
            indent, bullet, content = m.groups()
            if is_title_line(content):
                ph = get_next_placeholder(content)
                bullet_placeholder_lines.append(f"{indent}{bullet} {ph}")
            else:
                bph = get_next_placeholder(bullet, prefix="BULLET")
                bullet_placeholder_lines.append(f"{indent}{bph} {content}")
        else:
            bullet_placeholder_lines.append(line)

    raw_processed = "\n".join(bullet_placeholder_lines)
    raw_processed = re.sub(
        r"^\s*---\s*$",
        lambda m: get_next_placeholder(m.group(0)),
        raw_processed,
        flags=re.MULTILINE,
    )

    final_lines = []
    for line in raw_processed.splitlines(keepends=True):
        if line.endswith("\n"):
            content = line.rstrip("\n")
            trailing_newline = True
        else:
            content = line
            trailing_newline = False
        newline_placeholder = get_next_placeholder("\n", prefix="PH")
        final_lines.append(content + (newline_placeholder if trailing_newline else ""))

    processed_content = "".join(final_lines)

    merged_placeholder_map = {}
    pattern = re.compile(r"(?:<<PH\d+>>){2,}")

    while True:
        match = pattern.search(processed_content)
        if not match:
            break
        ph_sequence = re.findall(r"<<PH\d+>>", match.group(0))
        keys = [ph.strip("<>") for ph in ph_sequence]
        merged_value = "".join(placeholder_map.get(k, "") for k in keys)
        counter += 1
        new_key = f"PH{counter}"
        new_ph = f"<<{new_key}>>"
        placeholder_map[new_key] = merged_value
        processed_content = (
            processed_content[: match.start()]
            + new_ph
            + processed_content[match.end() :]
        )
        merged_placeholder_map[new_key] = keys

    processed_content = re.sub(
        r"(<<PH\d+>>)(?!<<)(?=\w)", r"\1<<SEP>>", processed_content
    )
    placeholder_map["SEP"] = ""

    return placeholder_map, processed_content

def restore_placeholders(corrected_text: str, placeholder_map: Dict[str, str]) -> str:
    """
    Replaces placeholders in the corrected markdown content with their original values.

    Args:
        corrected_text (str): Markdown content containing placeholders.
        placeholder_map (Dict[str, str]): Map of placeholders to original content.

    Returns:
        str: Fully restored markdown with original formatting.
    """
    restored_text = corrected_text

    # Replace placeholder tokens with original values
    # Sort by length descending to handle nested placeholders correctly
    for placeholder, original in sorted(
        placeholder_map.items(), key=lambda x: -len(x[0])
    ):
        if placeholder.startswith("BULLET"):
            restored_text = restored_text.replace(f"[[{placeholder}]]", original)
        else:
            restored_text = restored_text.replace(f"<<{placeholder}>>", original)

    # Remove SEP markers
    restored_text = restored_text.replace("<<SEP>>", "")

    return restored_text
    
############################################
###--------------PARSER END--------------###
############################################

############################################
###------------CHUNKER  START------------###
############################################

import re
from typing import List

def estimate_token_count(text: str) -> int:
    '''
    Estimate the number of tokens in a given string.

    Args:
        text (str): Input text.

    Returns:
        int: Approximate number of tokens, assuming 4 characters per token
    '''
    return len(text) // 4 

def split_by_top_level_headers(markdown: str) -> List[str]:
    """
    Split markdown into sections using top-level headers (e.g., #, ##, ..., ######).

    Args:
        markdown (str): The input markdown text.

    Returns:
        List[str]: A list of sections split by top-level headers.
    """
    # Find all headers from # to ###### at the beginning of a line
    matches = list(re.finditer(r'^\s*#{1,6}\s.*', markdown, flags=re.MULTILINE))
    if not matches:
        # No headers found, return whole content
        return [markdown]

    sections = []
    for i, match in enumerate(matches):
        start = match.start()
        end = matches[i + 1].start() if i + 1 < len(matches) else len(markdown)
        sections.append(markdown[start:end])

    return sections

def smart_sentence_split(text: str) -> List[str]:
    """
    Split text at sentence boundaries or placeholder boundaries.

    Args: 
        text (str): The input text.

    Returns:
        List[str]: A list of sentence-level parts, preserving placeholder boundaries.
    """
    # Pattern to capture sentence boundaries and placeholder boundaries
    pattern = r'([.!?]\s+(?=[A-Z]))|(__PLACEHOLDER\d+__)'
    matches = re.split(pattern, text)

    # Reconstruct complete sentence or placeholder segments
    parts = []
    buffer = ''
    for chunk in matches:
        if chunk is None:
            continue
        buffer += chunk
        if re.match(r'[.!?]\s+$', chunk) or re.match(r'__PLACEHOLDER\d+__', chunk.strip()):
            parts.append(buffer.strip())
            buffer = ''
    if buffer.strip():
        parts.append(buffer.strip())

    return parts

def chunk_large_section(section: str, max_tokens: int = 1000) -> List[str]:
    """
    Chunk a section while avoiding breaks mid-sentence or mid-placeholder.

    Args:
        section (str): A section of text to be chunked.
        max_tokens (int): Maximum token count per chunk.

    Returns:
        List[str]: A list of token-limited chunks.
    """
    chunks = []
    current_chunk = []
    current_token_count = 0

    # Split section into sentence-safe pieces
    parts = smart_sentence_split(section)

    for part in parts:
        part = part.strip()
        if not part:
            continue

        part_token_count = estimate_token_count(part)

        # Finalize current chunk if adding this part would exceed max tokens
        if current_token_count + part_token_count > max_tokens:
            if current_chunk:
                chunks.append(' '.join(current_chunk).strip())
                current_chunk = []
                current_token_count = 0

        current_chunk.append(part)
        current_token_count += part_token_count

    if current_chunk:
        chunks.append(' '.join(current_chunk).strip())

    return chunks

def chunk_markdown(markdown: str, max_tokens: int = 100) -> List[str]:
    '''
    Chunk a full markdown document into smaller parts based on headers and token limits.

    Args:
        markdown (str): The complete markdown content.
        max_tokens (int): Maximum allowed tokens per chunk.

    Returns:
        List[str]: A list of markdown chunks that are token-limited and structured.
    '''
    # Split by top level headers
    sections = split_by_top_level_headers(markdown)
    final_chunks = []

    for section in sections:
        # If the section is small enough, keep it as is
        if estimate_token_count(section) <= max_tokens:
            final_chunks.append(section.strip())
        else:
            # Otherwise, chunk it further based on sentence boundaries
            final_chunks.extend(chunk_large_section(section, max_tokens=max_tokens))

    def split_long_chunk(chunk: str, max_chars: int = 3500) -> List[str]:
        """
        Split a long chunk into smaller character-limited subchunks, preferring newline or sentence boundaries.

        Args:
            chunk (str): A markdown chunk that may be too long.
            max_chars (int): Maximum number of characters per subchunk.

        Returns:
            List[str]: Subchunks of the input chunk that fit the character limit.
        """
        if len(chunk) <= max_chars:
            return [chunk]

        # Try splitting on newlines first
        parts = re.split(r'(?<=\n)', chunk)
        subchunks = []
        buffer = ""

        for part in parts:
            if len(buffer) + len(part) > max_chars:
                if buffer:
                    subchunks.append(buffer.strip())
                buffer = part
            else:
                buffer += part

        if buffer.strip():
            subchunks.append(buffer.strip())

        # If subchunks are still too long, split on sentence boundary
        final_subchunks = []
        for sub in subchunks:
            if len(sub) <= max_chars:
                final_subchunks.append(sub)
            else:
                sentences = re.split(r'(?<=[.!?])\s+', sub)
                sentence_buffer = ""
                for s in sentences:
                    if len(sentence_buffer) + len(s) > max_chars:
                        final_subchunks.append(sentence_buffer.strip())
                        sentence_buffer = s
                    else:
                        sentence_buffer += (" " if sentence_buffer else "") + s
                if sentence_buffer.strip():
                    final_subchunks.append(sentence_buffer.strip())

        return final_subchunks

    # Apply post-splitting to each chunk to enforce character limits
    adjusted_chunks = []
    for chunk in final_chunks:
        adjusted_chunks.extend(split_long_chunk(chunk))

    return adjusted_chunks

############################################
###-------------CHUNKER  END-------------###
############################################

############################################
###-------------PROMPT START-------------###
############################################

from langchain.prompts import PromptTemplate

# Template for llama3-instruct format
MARKDOWN_CORRECTION_TEMPLATE_LLAMA3 = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a markdown grammar correction assistant. Your job is to correct only grammatical errors in the user's Markdown content.

Strictly follow these rules:
- Do **NOT** modify any placeholders (e.g., <<PH1>>, <<PH93>>, [[BULLET3]], <<SEP>>). Leave them **EXACTLY as they appear**, including spacing and underscores.
- Do **NOT** remove, reword, rename, reformat, or relocate any placeholder.
- Do **NOT** add any extra markdown or other symbols (e.g., #, >)
- Do **NOT** add or remove any extra space around placeholders.
- Do **NOT** remove Markdown styling characters (e.g., **, *, _, __, `, [, ]).
- Do **NOT** add or remove extra content from the original text.
- Do **NOT** change or swap markdown syntax ([], (), *, `). They should remain **EXACTLY** as they are in the original text. 
- Only correct grammar **within natural language sentences**, leaving structure unchanged.
- **ALWAYS** maintain title case wherever it is is present in the original text. 

Example:
- Original: "<SEP>We use <<PH4>> to **builds** 2 model **likke** this:<<PH17>><<PH18>>"
- Corrected: "<SEP>We use <<PH4>> to **build** 2 models **like** this:<<PH17>><<PH18_>>"

Example:
- Original: "[[BULLET1]] **It Will Be More Profitablr<PH12>>**"
- Corrected: "[[BULLET1]] **It Will Be More Profitable<PH12>>**"

Example:
- Original: "<<PH27>><<PH28>><<SEP>>This methd is **not necessary** the way ti build *AI* agents <<PH32>>"
- Corrected: "<<PH27>><<PH28>><<SEP>>This method is **not necessary** the way to build *AI* agents <<PH32>>"

All placeholders are present and stay exactly the same with no additional spaces — only grammar is corrected.

Respond only with the corrected Markdown content. Do not explain anything.<|eot_id|><|start_header_id|>user<|end_header_id|>
Original markdown:
{markdown}

Corrected markdown:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""


def get_markdown_correction_prompt() -> PromptTemplate:
    """
    Get the markdown correction prompt formatted for LLaMA 3 instruct.

    Returns:
        PromptTemplate: Ready to use in LangChain with LLaMA 3 format.
    """
    return PromptTemplate.from_template(MARKDOWN_CORRECTION_TEMPLATE_LLAMA3)

############################################
###--------------PROMPT END--------------###
############################################

############################################
###-------------UTILS  START-------------###
############################################

import os
import yaml
import importlib.util
import multiprocessing
from typing import Dict, Any, Optional, Union, List, Tuple

#Default models to be loaded in our examples:
DEFAULT_MODELS = {
    "local": "/home/jovyan/datafabric/llama2-7b/ggml-model-f16-Q5_K_M.gguf",
    "tensorrt": "",
    "hugging-face-local": "meta-llama/Llama-3.2-3B-Instruct",
    "hugging-face-cloud": "mistralai/Mistral-7B-Instruct-v0.3"
}

# Context window sizes for various models
MODEL_CONTEXT_WINDOWS = {
    # LlamaCpp models
    'ggml-model-f16-Q5_K_M.gguf': 4096,
    'ggml-model-7b-q4_0.bin': 4096,
    'gguf-model-7b-4bit.bin': 4096,

    # HuggingFace models
    'mistralai/Mistral-7B-Instruct-v0.3': 8192,
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B': 4096,
    'meta-llama/Llama-2-7b-chat-hf': 4096,
    'meta-llama/Llama-3-8b-chat-hf': 8192,
    'google/flan-t5-base': 512,
    'google/flan-t5-large': 512,
    'TheBloke/WizardCoder-Python-7B-V1.0-GGUF': 4096,

    # OpenAI models
    'gpt-3.5-turbo': 16385,
    'gpt-4': 8192,
    'gpt-4-32k': 32768,
    'gpt-4-turbo': 128000,
    'gpt-4o': 128000,

    # Anthropic models
    'claude-3-opus-20240229': 200000,
    'claude-3-sonnet-20240229': 180000,
    'claude-3-haiku-20240307': 48000,

    # Other models
    'qwen/Qwen-7B': 8192,
    'microsoft/phi-2': 2048,
    'tiiuae/falcon-7b': 4096,
    "meta-llama/Llama-3.2-3B-Instruct": 128000,
}

def load_config_and_secrets(
    config_path: str = "../../configs/config.yaml",
    secrets_path: str = "../../configs/secrets.yaml"
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Load configuration and secrets from YAML files.

    Args:
        config_path: Path to the configuration YAML file.
        secrets_path: Path to the secrets YAML file.

    Returns:
        Tuple containing (config, secrets) as dictionaries.

    Raises:
        FileNotFoundError: If either the config or secrets file is not found.
    """
    # Convert to absolute paths if needed
    config_path = os.path.abspath(config_path)
    secrets_path = os.path.abspath(secrets_path)

    if not os.path.exists(secrets_path):
        raise FileNotFoundError(f"secrets.yaml file not found in path: {secrets_path}")

    if not os.path.exists(config_path):
        raise FileNotFoundError(f"config.yaml file not found in path: {config_path}")

    with open(config_path) as file:
        config = yaml.safe_load(file)

    with open(secrets_path) as file:
        secrets = yaml.safe_load(file)

    return config, secrets

def initialize_llm(
    model_source: str = "local",
    secrets: Optional[Dict[str, Any]] = None,
    local_model_path: str = DEFAULT_MODELS["local"],
    hf_repo_id: str = ""
) -> Any:
    """
    Initialize a language model based on specified source.

    Args:
        model_source: Source of the model. Options are "local", "hugging-face-local", or "hugging-face-cloud".
        secrets: Dictionary containing API keys for cloud services.
        local_model_path: Path to local model file.

    Returns:
        Initialized language model object.

    Raises:
        ImportError: If required libraries are not installed.
        ValueError: If an unsupported model_source is provided.
    """
    # Check dependencies
    missing_deps = []
    for module in ["langchain_huggingface", "langchain_core.callbacks", "langchain_community.llms"]:
        if not importlib.util.find_spec(module):
            missing_deps.append(module)
    
    if missing_deps:
        raise ImportError(f"Missing required dependencies: {', '.join(missing_deps)}")
    
    # Import required libraries
    from langchain_huggingface import HuggingFacePipeline, HuggingFaceEndpoint
    from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
    from langchain_community.llms import LlamaCpp

    model = None
    context_window = None
    
    # Initialize based on model source
    if model_source == "hugging-face-cloud":
        if hf_repo_id == "":
            repo_id = DEFAULT_MODELS["hugging-face-cloud"]
        else:
            repo_id = hf_repo_id  
        if not secrets or "HUGGINGFACE_API_KEY" not in secrets:
            raise ValueError("HuggingFace API key is required for cloud model access")
            
        huggingfacehub_api_token = secrets["HUGGINGFACE_API_KEY"]
        # Get context window from our lookup table
        if repo_id in MODEL_CONTEXT_WINDOWS:
            context_window = MODEL_CONTEXT_WINDOWS[repo_id]

        model = HuggingFaceEndpoint(
            huggingfacehub_api_token=huggingfacehub_api_token,
            repo_id=repo_id,
        )

    elif model_source == "hugging-face-local":
        from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
        if "HUGGINGFACE_API_KEY" in secrets:
            os.environ["HF_TOKEN"] = secrets["HUGGINGFACE_API_KEY"]
        if hf_repo_id == "":
            model_id = DEFAULT_MODELS["hugging-face-local"]
        else:
            model_id = hf_repo_id        
        # Get context window from our lookup table
        if model_id in MODEL_CONTEXT_WINDOWS:
            context_window = MODEL_CONTEXT_WINDOWS[model_id]
        
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        hf_model = AutoModelForCausalLM.from_pretrained(model_id)

        # If tokenizer has model_max_length, that's our context window
        if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length not in (None, -1):
            context_window = tokenizer.model_max_length

        pipe = pipeline("text-generation", model=hf_model, tokenizer=tokenizer, max_new_tokens=100, device=0)
        model = HuggingFacePipeline(pipeline=pipe)
        
    elif model_source == "tensorrt":
        #If a Hugging Face model is specified, it will be used - otherwise, it will try loading the model from local_path
        try:
            import tensorrt_llm
            sampling_params = tensorrt_llm.SamplingParams(temperature=0.1, top_p=0.95, max_tokens=512) 
            if hf_repo_id != "":
                return TensorRTLangchain(model_path = hf_repo_id, sampling_params = sampling_params)
            else:
                model_config = os.path.join(local_model_path, config.json)
                if os.path.isdir(local_model_path) and os.path.isfile(model_config):
                    return TensorRTLangchain(model_path = local_model_path, sampling_params = sampling_params)
                else:
                    raise Exception("Model format incompatible with TensorRT LLM")
        except ImportError:
            raise ImportError(
                "Could not import tensorrt-llm library. "
                "Please make sure tensorrt-llm is installed properly, or "
                "consider using workspaces based on the NeMo Framework"
            )
    elif model_source == "local":
        callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
        # For LlamaCpp, get the context window from the filename
        model_filename = os.path.basename(local_model_path)
        if model_filename in MODEL_CONTEXT_WINDOWS:
            context_window = MODEL_CONTEXT_WINDOWS[model_filename]
        else:  
            # Default context window for LlamaCpp models (explicitly set)
            context_window = 4096

        model = LlamaCpp(
            model_path=local_model_path,
            n_gpu_layers=-1,                             
            n_batch=512,                                 
            n_ctx=32000,
            max_tokens=1024,
            f16_kv=True,
            use_mmap=False,                             
            low_vram=False,                            
            rope_scaling=None,
            temperature=0.0,
            repeat_penalty=1.0,
            streaming=False,
            stop=None,
            seed=42,
            num_threads=multiprocessing.cpu_count(),
            verbose=False #
        )
    else:
        raise ValueError(f"Unsupported model source: {model_source}")

    # Store context window as model attribute for easy access
    if model and hasattr(model, '__dict__'):
        model.__dict__['_context_window'] = context_window

    return model

############################################
###--------------UTILS  END--------------###
############################################

############################################
###-----------LLM METRIC START-----------###
############################################

import os
import re
import numpy as np
import pandas as pd
import math
from typing import List, Optional
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import multiprocessing

# Initialize TF-IDF vertorizer for semantic similarity
tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=5000) 

def semantic_similarity_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """
    Compute semantic similarity between predictions and targets using TF-IDF and cosine similarity.

    Args:
        predictions (List[str]): Model-generated texts.
        targets (List[str]): Ground truth references.

    Returns:
        float: Mean cosine similarity between matched pairs.
    """
    all_texts = list(targets) + list(predictions)
    
    all_texts = [str(text) if text else "" for text in all_texts]
    
    if len(set(all_texts)) < 2:  
        return 10.0
    
    tfidf_matrix = tfidf_vectorizer.fit_transform(all_texts)
    
    n_targets = len(targets)
    target_vectors = tfidf_matrix[:n_targets]
    pred_vectors = tfidf_matrix[n_targets:]
    
    similarities = []
    for i in range(len(targets)):
        similarity = cosine_similarity(target_vectors[i:i+1], pred_vectors[i:i+1])[0][0]
        similarities.append(similarity)
    
    return np.mean(similarities)*10

def _count_syllables(word: str) -> int:
    """Rudimentary English syllable counter."""
    word = word.lower()
    groups = re.findall(r'[aeiouy]+', word)
    count = len(groups)
    if word.endswith('e'):
        count = max(1, count - 1)
    return max(count, 1)

def flesch_reading_ease(text: str) -> float:
    """Calculates Flesch Reading Ease score."""
    sentences = re.split(r'[.!?]+', text)
    sentences = [s for s in sentences if s.strip()]
    words = re.findall(r'\w+', text)
    if not sentences or not words:
        return 0.0
    syllables = sum(_count_syllables(w) for w in words)
    W = len(words)
    S = len(sentences)
    score = 206.835 - 1.015 * (W / S) - 84.6 * (syllables / W)
    return score

def readability_improvement_eval_fn(predictions: List[str], targets: List[str]) -> float:
    """
    Calculates the average absolute change in Flesch Reading Ease.
    A score of 0.0 means no change in readability on average.
    Positive values indicate improvement, negative values indicate a decrease.
    """
    deltas = []
    for pred, target in zip(predictions, targets):
        # Ensure inputs are strings and handle if they are empty
        orig_text = str(target)
        pred_text = str(pred)

        if not orig_text.strip():
            orig_score = 0.0
        else:
            orig_score = flesch_reading_ease(orig_text)

        if not pred_text.strip():
            new_score = 0.0
        else:
            new_score = flesch_reading_ease(pred_text)
        
        # Calculate the simple difference (delta)
        deltas.append(new_score - orig_score)
        
    # If the list is empty for any reason, return 0
    if not deltas:
        return 0.0

    # Return the average of all the deltas directly.
    return float(np.mean(deltas))

def llm_judge_eval_fn_local(predictions: pd.Series, targets: pd.Series, llm) -> float:
    """
    Use the main Llama 3 model to rate the grammar of predictions from 1 to 10.
    This function correctly handles Pandas Series as input from MLflow.
    """
    # Use the correct method to check if a Pandas Series is empty
    if predictions.empty:
        return 0.0

    scores = []
    # Iterating over a Series works as expected
    for pred in predictions:
        # This prompt is formatted for Llama 3 Instruct
        prompt = f"""<|start_header_id|>system<|end_header_id|>
You are a grammar judge. Rate the grammar of the provided text on a scale from 1 (very poor) to 10 (perfect).
Respond with a single integer number only. Do not add any explanation or punctuation.

Keep in mind the text is markdown text, so do not penalize the use of markdown syntax such as *, #, `, [].

Example:
Text: I has a apple.
Answer: 7

Text: The dog chased **the ball** across the yard.
Answer: 10<|eot_id|><|start_header_id|>user<|end_header_id|>
Text: {pred}
Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

        try:
            result = llm.invoke(prompt)
            text = result.strip()
            match = re.search(r"\d+", text)
            score = float(match.group(0)) if match else 0.0
            scores.append(score)
        except Exception as e:
            print(f"[LLM judge runtime error]: {e}")
            scores.append(5.0)

    return float(np.mean(scores)) if scores else 0.0

## Define MLflow Model Class

In [8]:
class MarkdownCorrectorModel(mlflow.pyfunc.PythonModel):
    """An MLflow model that encapsulates the entire markdown correction pipeline."""

    def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None:
        """Initializes the model, loading configurations, secrets, and the LLM."""
        config_path = context.artifacts["config"]
        with open(config_path, "r") as f:
            self.config = yaml.safe_load(f)
        
        github_token = os.getenv("AIS_GITHUB_ACCESS_TOKEN")
        if github_token:
            self.secrets = {"AIS_GITHUB_ACCESS_TOKEN": github_token}
            logger.info("Loaded GITHUB_ACCESS_TOKEN from environment variable.")
        else:
            secrets_path = context.artifacts.get("secrets")
            if secrets_path and os.path.exists(secrets_path):
                with open(secrets_path, "r") as f:
                    self.secrets = yaml.safe_load(f)
                logger.info(f"Loaded secrets from {secrets_path}.")
            else:
                # If no token is found anywhere, initialize with an empty dict or handle as an error
                self.secrets = {}
                logger.warning("No GITHUB_ACCESS_TOKEN found in environment or secrets.yaml.")
        
        model_source = self.config.get("model_source", "local")
        local_model_path = context.artifacts.get("local_model")
        self.llm = initialize_llm(model_source, self.secrets, local_model_path)
        correction_prompt = get_markdown_correction_prompt()
        self.llm_chain = correction_prompt | self.llm
        
        # Define which metrics will be calculated
        self.metric_functions = {
            "semantic_similarity": semantic_similarity_eval_fn,
            "readability_improvement": readability_improvement_eval_fn,
            "grammar_quality_score": lambda predictions, targets: llm_judge_eval_fn_local(
                predictions, targets, llm=self.llm
            ),
        }

        logger.info("✅ Model context loaded with only instantaneous evaluation metrics.")

    @staticmethod
    def _safe_join_chunks(chunks: List[str]) -> str:
        joined = ""
        for i, chunk in enumerate(chunks):
            if i == 0: joined += chunk
            else:
                prev = chunks[i - 1].rstrip()
                curr = chunk
                if prev.endswith('.') and re.match(r'^[A-Z\"]', curr.lstrip()):
                    joined += ' ' + curr.lstrip()
                else:
                    joined += curr
        return joined

    def predict(self, context, model_input: pd.DataFrame) -> pd.DataFrame:
        """Runs correction and evaluates the output using only the fastest metrics."""
        start_time = time.time()
        
        # Get the file content string directly from the input dataframe.
        file_content = model_input["files"].iloc[0]
        
        # Create a dictionary for the single file. 
        markdowns = {"corrected_file.md": file_content}
            
        original_files = markdowns.copy()

        # Correction logic
        parsed_markdowns, placeholder_maps, all_chunks = {}, {}, {}
        for fn, content in markdowns.items():
            placeholder_map, proc = parse_md_for_grammar_correction(content)
            parsed_markdowns[fn], placeholder_maps[fn] = proc, placeholder_map
            all_chunks[fn] = chunk_markdown(proc)
        
        corrected_chunks_by_file = defaultdict(list)
        for fn, chunks in all_chunks.items():
            for chunk in chunks:
                resp = self.llm_chain.invoke({"markdown": chunk})
                content_to_append = resp.content if hasattr(resp, 'content') else resp
                corrected_chunks_by_file[fn].append(content_to_append)
        
        final_corrected = {}
        for fn, chunk_list in corrected_chunks_by_file.items():
            joined = self._safe_join_chunks(chunk_list)
            final_corrected[fn] = restore_placeholders(joined, placeholder_maps[fn])
            
        # Metric calculation loop
        evaluation_metrics = {}
        eval_data = [{"predictions": final_corrected.get(fn, ""), "targets": content} for fn, content in original_files.items()]
        eval_df = pd.DataFrame(eval_data)
        
        for name, metric_func in self.metric_functions.items():
            try:
                raw_score = metric_func(
                    predictions=eval_df["predictions"],
                    targets=eval_df["targets"]
                )
                evaluation_metrics[name] = float(raw_score)
            except Exception as e:
                logger.warning(f"Could not calculate metric '{name}'. Error: {e}")
                evaluation_metrics[name] = "N/A"

                
        response_time = time.time() - start_time

        row = {
            "corrected": final_corrected,
            "originals": original_files,
            "response_time": response_time,
            "evaluation_metrics": evaluation_metrics,
        }
        return pd.DataFrame([row])

## Log and Register Model

In [9]:
%%time

# Define artifacts to be packaged with the model
artifacts = {
    "config": str(CONFIG_PATH),
    "local_model": str(LOCAL_MODEL_PATH),
}

if SECRETS_PATH.exists():
    artifacts["secrets"] = str(SECRETS_PATH)

# Define the model's signature
input_schema = Schema([
    ColSpec("string", "repo_url"),
    ColSpec("string", "files")
])
output_schema = Schema([
    ColSpec("string", "corrected"),
    ColSpec("string", "originals"),
    ColSpec("double", "response_time"),
    ColSpec("string", "evaluation_metrics")
])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)

# Set up MLflow experiment
mlflow.set_tracking_uri('/phoenix/mlflow')
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)

with mlflow.start_run(run_name="MarkdownCorrector_FastEval") as run:
    run_id = run.info.run_id
    logger.info(f"Starting MLflow run with ID: {run_id}")
    
    # Log the model to MLflow
    mlflow.pyfunc.log_model(
        artifact_path=MLFLOW_MODEL_NAME,
        python_model=MarkdownCorrectorModel(),
        artifacts=artifacts,
        pip_requirements="../requirements.txt",
        signature=signature,
        registered_model_name=MLFLOW_MODEL_NAME,
        input_example=pd.DataFrame([{"repo_url": None, "files": "This is a testt."}])
    )
    logger.info(f"✅ Model '{MLFLOW_MODEL_NAME}' logged successfully (fast eval version).")

2025-08-08 00:49:56 - INFO - Starting MLflow run with ID: 8a7223d2332e4a7eb2fb889b18e432d7
2025/08/08 00:49:56 INFO mlflow.pyfunc: Validating input example against model signature


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

                low_vram was transferred to model_kwargs.
                Please confirm that low_vram is what you intended.
  python_model.load_context(context=context)
                rope_scaling was transferred to model_kwargs.
                Please confirm that rope_scaling is what you intended.
  python_model.load_context(context=context)
                num_threads was transferred to model_kwargs.
                Please confirm that num_threads is what you intended.
  python_model.load_context(context=context)
2025-08-08 00:53:12 - INFO - ✅ Model context loaded with only instantaneous evaluation metrics.
Registered model 'markdown-corrector' already exists. Creating a new version of this model...
Created version '37' of model 'markdown-corrector'.
2025-08-08 00:55:54 - INFO - ✅ Model 'markdown-corrector' logged successfully (fast eval version).


CPU times: user 18 s, sys: 39.2 s, total: 57.2 s
Wall time: 5min 58s


## Prepare Eval Data

In [10]:
%%time

logger.info("Preparing data for evaluation…")

# Load the newly registered model
model_uri    = f"models:/{MLFLOW_MODEL_NAME}/latest"
loaded_model = mlflow.pyfunc.load_model(model_uri)

# Get original content from the test repo
test_repo_url = "https://github.com/hp-david/grammarbptest"
with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

# Load secrets
github_token = os.getenv("AIS_GITHUB_ACCESS_TOKEN")
if github_token:
    secrets = {"AIS_GITHUB_ACCESS_TOKEN": github_token}
    logger.info("Loaded GITHUB_ACCESS_TOKEN from environment variable.")
else:
    try:
        secrets_path = context.artifacts.get("secrets")
    except NameError:
        secrets_path = "../configs/secrets.yaml" # Fallback for standalone execution

    if secrets_path and os.path.exists(secrets_path):
        with open(secrets_path, "r") as f:
            secrets = yaml.safe_load(f)
        logger.info(f"Loaded secrets from {secrets_path}.")
    else:
        secrets = {}
        logger.warning("No GITHUB_ACCESS_TOKEN found in environment or secrets.yaml.")

# Instantiate and fetch raw markdowns
processor = GitHubMarkdownProcessor(
    repo_url=     test_repo_url,
    access_token= secrets.get("AIS_GITHUB_ACCESS_TOKEN") # Use .get() for safety
)
original_markdowns = processor.run()

# Initialize a dictionary to store all corrected results
all_corrected = {}
logger.info(f"Starting prediction for {len(original_markdowns)} files...")

# Loop through each file and call predict, just like the frontend
for filename, content in original_markdowns.items():
    logger.info(f"  - Predicting for: {filename}")
    
    # Create a dataframe with a SINGLE file's content string
    test_input = pd.DataFrame([{"repo_url": None, "files": content}])
    
    # Run prediction for this single file
    prediction_df = loaded_model.predict(test_input)
    
    # Extract the corrected content dictionary from the prediction
    corrected_dict = prediction_df.loc[0, "corrected"]
    
    # Store the corrected text using the original filename as the key
    all_corrected[filename] = corrected_dict.get("corrected_file.md", content)

logger.info("✅ Prediction complete for all files.")

eval_data = []
for fn, orig in original_markdowns.items():
    if fn in all_corrected:
        eval_data.append({
            "original":  orig,
            "corrected": all_corrected[fn]
        })

evaluation_df = pd.DataFrame(eval_data)
logger.info(f"✅ Created evaluation DataFrame with {len(evaluation_df)} file(s).")
display(evaluation_df.head())

2025-08-08 00:55:54 - INFO - Preparing data for evaluation…
                low_vram was transferred to model_kwargs.
                Please confirm that low_vram is what you intended.
  python_model.load_context(context=context)
                rope_scaling was transferred to model_kwargs.
                Please confirm that rope_scaling is what you intended.
  python_model.load_context(context=context)
                num_threads was transferred to model_kwargs.
                Please confirm that num_threads is what you intended.
  python_model.load_context(context=context)
2025-08-08 00:57:15 - INFO - ✅ Model context loaded with only instantaneous evaluation metrics.
2025-08-08 00:57:15 - INFO - Repository visibility: public
2025-08-08 00:57:16 - INFO - Raw markdown extraction complete.
2025-08-08 00:57:16 - INFO - Starting prediction for 2 files...
2025-08-08 00:57:16 - INFO -   - Predicting for: README1.md
2025-08-08 01:02:13 - INFO -   - Predicting for: README2.md
2025-08-08 01:

Unnamed: 0,original,corrected
0,"<h1 style=""text-align: center; font-size: 40px...","<h1 style=""text-align: center; font-size: 40px..."
1,# 📜💬 Shakespeare text generation with RNN\n\n<...,# 📜💬 Shakespeare text generation with RNN\n\n<...


CPU times: user 18min 58s, sys: 20.8 s, total: 19min 18s
Wall time: 20min 29s


## Log Evaluation Metrics

In [11]:
%%time

with mlflow.start_run(run_id=run_id):
    logger.info(f"Logging metrics for run ID: {run_id}")

    metrics_to_log = prediction_df.loc[0, "evaluation_metrics"]

    # Log the dictionary of metrics to MLflow
    mlflow.log_metrics(metrics_to_log)

    logger.info("✅ Metrics logged successfully.")
    logger.info(json.dumps(metrics_to_log, indent=2))

2025-08-08 01:16:24 - INFO - Logging metrics for run ID: 8a7223d2332e4a7eb2fb889b18e432d7
2025-08-08 01:16:24 - INFO - ✅ Metrics logged successfully.
2025-08-08 01:16:24 - INFO - {
  "semantic_similarity": 9.96603608439988,
  "readability_improvement": -0.3435385604734904,
  "grammar_quality_score": 10.0
}


CPU times: user 30.6 ms, sys: 18.1 ms, total: 48.7 ms
Wall time: 413 ms


## Log Execution Time

In [12]:
end_time: float = time.time()
elapsed_time: float = end_time - start_time
elapsed_minutes: int = int(elapsed_time // 60)
elapsed_seconds: float = elapsed_time % 60

logger.info(f"⏱️ Total execution time: {elapsed_minutes}m {elapsed_seconds:.2f}s")
logger.info("✅ Notebook execution completed successfully.")

2025-08-08 01:16:24 - INFO - ⏱️ Total execution time: 57m 31.53s
2025-08-08 01:16:24 - INFO - ✅ Notebook execution completed successfully.
