# English Correction with Langchain

# Overview

## Step 0: Configuring the Environment 

In [None]:
# Required Libraries
%pip install -r ../requirements.txt --quiet

In [None]:
# Standard Libraries
import os
os.environ['LLAMA_CPP_LOG_LEVEL'] = '0'
import sys
import logging
from collections import defaultdict

# Add src directory to system path
sys.path.append(os.path.abspath('../src'))

# Internal Modules
from github_extractor import GitHubMarkdownProcessor
from utils import load_config_and_secrets
from utils import (
    load_config_and_secrets,
    initialize_llm,
)
from parser import parse_md_for_grammar_correction, restore_placeholders
from chunker import chunk_markdown
from core.prompt_templates import get_markdown_correction_prompt
from core.markdown_correction_service import MarkdownCorrectionService

# Other modules
import mlflow
from mlflow.models import evaluate

### Define Constants and Paths

In [None]:
CONFIG_PATH = "../configs/configs.yaml"
SECRETS_PATH = "../configs/secrets.yaml"
LOCAL_MODEL_PATH = "/home/jovyan/datafabric/llama3.1-8b-instruct/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf" #"/home/jovyan/datafabric/llama2-7b/ggml-model-f16-Q5_K_M.gguf" 

### Configuration and Secrets Loading

In [None]:
config, secrets = load_config_and_secrets(CONFIG_PATH, SECRETS_PATH)

In [None]:
# Create Logger
logger = logging.getLogger("english-correction-notebook")
logger.setLevel(logging.INFO)

formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", 
                             datefmt="%Y-%m-%d %H:%M:%S") 

stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False

## Step 1: Extracting and Parsing Markdown Files From GitHub Repositories

### Extract Markdown Files

In [None]:
# Repo URL and token
repo_url = "https://github.com/hp-david/test"
access_token = secrets.get("GITHUB_ACCESS_TOKEN")

# Create processor instance
processor = GitHubMarkdownProcessor(repo_url=repo_url, access_token=access_token)

# Run preprocessing workflow
markdowns = processor.run()

### Parse Markdown Files with Placeholders

In [None]:
parsed_markdowns = {}
placeholder_maps = {}

for filename, content in markdowns.items():
    # Parse the content and get placeholder map
    placeholder_map, processed_content = parse_md_for_grammar_correction(content)
    
    # Store the processed content (maintains dictionary structure for chunker)
    parsed_markdowns[filename] = processed_content
    
    # Store the placeholder map for later restoration
    placeholder_maps[filename] = placeholder_map

logger.info(f"Parsed {len(parsed_markdowns)} files successfully")

### Chunk Markdown Content

In [None]:
all_chunks = {}  

for file_name, content in parsed_markdowns.items():
    chunks = chunk_markdown(content, max_tokens=100)
    all_chunks[file_name] = chunks

# Print chunks during testing

for file_name, chunks in all_chunks.items():
    logger.info(f"\n===== {file_name} =====\n")
    for i, chunk in enumerate(chunks):
        logger.info(f"\n--- Chunk {i+1} ---\n")
        logger.info(chunk)
        logger.info("\n" + "-" * 40 + "\n")

## Step 2: Correct Markdown Files with LLM

In [None]:
# Get markdown correction prompt from prompt_templates module
correction_prompt = get_markdown_correction_prompt()

### Initialize Mode

In [None]:
#from core.prompt_templates import get_response_from_llm, MARKDOWN_CORRECTION_SYSTEM_PROMPT, MARKDOWN_CORRECTION_USER_PROMPT

if "model_source" in config:
    model_source = config["model_source"]

# Initialize llm 
llm = initialize_llm(model_source, secrets, LOCAL_MODEL_PATH)

# Create the LLM chain with the correction prompt
llm_chain = correction_prompt | llm

### Invoke Model on Each Chunk

In [None]:

results = []
count = 0

for file_name, chunks in all_chunks.items():  
    for chunk in chunks:
        response = llm_chain.invoke({"markdown": chunk})
        results.append({
            "file": file_name,
            "original": chunk,
            "corrected": response
        })
        print(f"chunk {count} done")
        count += 1
'''
results = []
count = 0

for file_name, chunks in all_chunks.items():
    for chunk in chunks:
        # Use the general function directly with formatted user prompt
        user_prompt = MARKDOWN_CORRECTION_USER_PROMPT.format(markdown_text=chunk)
        response = get_response_from_llm(llm, MARKDOWN_CORRECTION_SYSTEM_PROMPT, user_prompt)
        
        results.append({
            "file": file_name,
            "original": chunk,
            "corrected": response.content
        })
        
        print(f"chunk {count} done")
        count += 1
'''

In [None]:
# Print results during testing
for result in results:
    original_text = result["original"]
    corrected_text = result["corrected"]
    
    original_tokens = len(llm.client.tokenize(original_text.encode("utf-8")))
    corrected_tokens = len(llm.client.tokenize(corrected_text.encode("utf-8")))

    print(f"\n===== {result['file']} =====\n")
    print(f"--- Original ({original_tokens} tokens) ---\n")
    print(original_text)
    print(f"\n--- Corrected ({corrected_tokens} tokens) ---\n")
    print(corrected_text)
    print("\n" + "=" * 60 + "\n")


In [None]:
from collections import defaultdict
import os
import re

# Helper: Safe chunk joiner
def safe_join_chunks(chunks):
    joined = ""
    for i, chunk in enumerate(chunks):
        if i == 0:
            joined += chunk
        else:
            prev = chunks[i - 1].rstrip()
            curr = chunk

            # Heuristic: Add space only between sentences if needed
            if prev.endswith('.') and re.match(r'^[A-Z\"]', curr.lstrip()):
                joined += ' ' + curr.lstrip()
            else:
                joined += curr  # Don't strip indentation!
    return joined


# Group corrected chunks by file
corrected_chunks_by_file = defaultdict(list)

for result in results:
    corrected_chunks_by_file[result["file"]].append(result["corrected"])

# Rebuild each file from its corrected chunks with smart joining
rebuilt_corrected_files = {
    file_name: safe_join_chunks(chunks)
    for file_name, chunks in corrected_chunks_by_file.items()
}

# Create output directory
output_dir = "corrected"
os.makedirs(output_dir, exist_ok=True)

# Restore placeholders and write final output
for file_name, corrected_content in rebuilt_corrected_files.items():
    placeholder_map = placeholder_maps.get(file_name, {})
    restored_content = restore_placeholders(corrected_content, placeholder_map)

    # Create subdirectories as needed under 'corrected/'
    output_path = os.path.join(output_dir, file_name)
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    with open(output_path, "w", encoding="utf-8") as f:
        f.write(restored_content)


In [None]:
import difflib

diff_output_dir = "corrected_diffs"
os.makedirs(diff_output_dir, exist_ok=True)

for file_name, corrected_content in rebuilt_corrected_files.items():
    placeholder_map = placeholder_maps.get(file_name, {})
    restored_content = restore_placeholders(corrected_content, placeholder_map)

    # Get original content from markdowns dict
    original_content = markdowns.get(file_name)
    if original_content is None:
        print(f"Warning: No original content for file {file_name}")
        continue

    # Create unified diff view (HTML side-by-side)
    differ = difflib.HtmlDiff(tabsize=4, wrapcolumn=80)
    diff_html = differ.make_file(
        original_content.splitlines(),
        restored_content.splitlines(),
        fromdesc=f"Original: {file_name}",
        todesc=f"Corrected: {file_name}",
        context=True,
        numlines=3
    )

    # Write diff HTML file
    diff_path = os.path.join(diff_output_dir, file_name + ".html")
    os.makedirs(os.path.dirname(diff_path), exist_ok=True)
    with open(diff_path, "w", encoding="utf-8") as f:
        f.write(diff_html)

In [None]:
import json

with open("results.json", "w") as f:
    json.dump(results, f)

## ML Flow Logging and Eval

### Register the Model with ML Flow

In [None]:
'''
mlflow.set_experiment("markdown-correction-experiment")

with mlflow.start_run(run_name="markdown-correction-run") as run:
    MarkdownCorrectionService.log_model(
        llm_artifact=LOCAL_MODEL_PATH,
        config_yaml=CONFIG_PATH,
        secrets_yaml=SECRETS_PATH,
    )

    model_uri = f"runs:/{run.info.run_id}/markdown_corrector"
    mlflow.register_model(model_uri, "MarkdownCorrector")

    logger.info(f"Model registered: MarkdownCorrector")
'''

### ML Flow LLM Evaluation

In [None]:
'''
import pandas as pd

from mlflow.metrics import (
    ari_grade_level,
    flesch_kincaid_grade_level,
    exact_match,
    rouge1,
    rougeL
)
from core.llm_metrics import (
    semantic_similarity_metric,
    grammar_error_count_metric,
    grammar_error_rate_metric,
    grammar_improvement_metric,
    grammar_score_metric,
    readability_improvement_metric,
    llm_judge_metric,
    llm_judge_metric_local,
    generate_gpt_gold_standards
)

# Generate GPT gold standards
print("Generating GPT gold standards...")
original_texts = [item["original"] for item in results]

# Pass API key to the function
api_key = secrets.get("OPEN_AI_API_KEY") if secrets else None
gpt_gold_standards = generate_gpt_gold_standards(original_texts, api_key)

# Create evaluation DataFrame
eval_df = pd.DataFrame([
    {
        "markdown": original,
        "gpt_corrected": gpt_gold  # GPT's correction as gold standard
    }
    for original, gpt_gold in zip(original_texts, gpt_gold_standards)
])

# Run evaluation
results = mlflow.evaluate(
    model=model_uri,
    data=eval_df,
    targets="gpt_corrected",
    feature_names=["markdown"],
    extra_metrics=[
        ari_grade_level(),
        flesch_kincaid_grade_level(),
        exact_match(),
        rouge1(),
        rougeL(),
        semantic_similarity_metric,
        grammar_error_count_metric,
        grammar_error_rate_metric,
        grammar_improvement_metric,
        grammar_score_metric,
        readability_improvement_metric,
        llm_judge_metric,
        llm_judge_metric_local
    ]
)

logger.info("Evaluation results:")
logger.info(results.metrics)
mlflow.log_metrics(results.metrics)
'''