In [None]:
import csv
from typing import List, Optional
from pathlib import Path
from pydantic import BaseModel, Field
from tqdm.auto import tqdm
from datetime import date

from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from langchain.output_parsers.fix import OutputFixingParser
from langchain_core.runnables import RunnableLambda


In [5]:
# Define the output models using Pydantic v2
class MICClassificationResult(BaseModel):
    """Schema for the Militarized Interstate Confrontation classification result."""
    is_mic: bool = Field(description="True if the article describes a Militarized Interstate Confrontation (MIC), False otherwise")
    explanation: str = Field(description="A brief explanation of why the article was classified as MIC or not")

class MICDetailedInfo(BaseModel):
    """Schema for detailed information about a Militarized Interstate Confrontation."""
    MICdate: date = Field(description="The date when the MIC occurred (YYYY-MM-DD format, or as specific as possible)")
    fatality_min: int = Field(description="The minimum number of fatalities (use same number as max if precise)")
    fatality_max: int = Field(description="The maximum number of fatalities (use same number as min if precise)")
    countries_involved: List[str] = Field(description="List of countries involved in the confrontation")
    initiator_country: Optional[str] = Field(description="The country that initiated the confrontation, if identifiable")
    target_country: Optional[str] = Field(description="The country that was targeted in the confrontation, if identifiable")

def init_llm() -> ChatOpenAI:
    """Initialize and return the LLM model using LangChain 0.3."""
    return ChatOpenAI(
        base_url="http://localhost:1234/v1",
        api_key="LMStudio",
        model_name="qwen2.5-7b-instruct-1m",
        temperature=0.1
    )

def create_classification_chain():
    """Create a LangChain classification chain with error fixing"""
    llm = init_llm()
    parser = PydanticOutputParser(pydantic_object=MICClassificationResult)
    # Add output fixing parser to handle potential errors
    fixing_parser = OutputFixingParser.from_llm(parser=parser, llm=llm)
    
    format_instructions = parser.get_format_instructions()
    
    prompt = PromptTemplate(
        template="""You are an expert analyst of international relations and military conflicts.
        Your task is to determine whether a news article describes a Militarized Interstate Confrontation (MIC).

        A Militarized Interstate Confrontation (MIC) is defined as:
        - A direct confrontation between two or more countries
        - Involving military forces (army, navy, air force, etc.)
        - Where there is a threat, display, or use of military force

        The article must describe an actual military interaction, not just diplomatic tensions or discussions about potential conflicts.

        Analyze the following article carefully and determine if it describes a MIC.
        Output your answer in the specified JSON format with two fields:
        1. is_mic: true if it's a MIC, false if it's not
        2. explanation: A short explanation of your reasoning

        Here are some examples to guide you:
        Example 1:
        Article: Russian troops opened fire on Ukrainian soldiers near the border, killing three and wounding seven others. The Ukrainian government condemned the attack as a violation of its sovereignty.
        Output: {{"is_mic": true, "explanation": "This article describes a direct military confrontation between Russian and Ukrainian forces with fatalities, which is a clear case of a Militarized Interstate Confrontation."}}

        Example 2:
        Article: China and Taiwan held diplomatic talks aimed at easing tensions in the region. Both sides agreed to maintain open lines of communication to prevent misunderstandings.
        Output: {{"is_mic": false, "explanation": "This article describes diplomatic talks rather than a military confrontation. No military forces were involved, and there was no threat or use of force."}}

        Example 3:
        Article: North Korean forces fired artillery shells into South Korean waters as a show of force during joint US-South Korean military exercises. No casualties were reported.
        Output: {{"is_mic": true, "explanation": "This article describes a militarized action (artillery fire) by North Korea directed at South Korea, which constitutes a Militarized Interstate Confrontation even without casualties."}}

        Example 4:
        Article: The United Nations Security Council met to discuss increasing tensions between India and Pakistan but no military actions were reported.
        Output: {{"is_mic": false, "explanation": "This article only mentions diplomatic discussions about tensions. It does not describe any actual military confrontation, threat, or use of force between countries."}}

        Article:
        {article}

        {format_instructions}
        """,
        input_variables=['article'],
        partial_variables={'format_instructions': format_instructions}
    )
    
    # Build the LCEL chain with fixing parser
    chain = prompt | llm | fixing_parser
    
    return chain

def create_extraction_chain():
    """Create a LangChain chain for extracting detailed MIC information with error fixing."""
    llm = init_llm()
    parser = PydanticOutputParser(pydantic_object=MICDetailedInfo)
    # Add output fixing parser to handle potential errors
    fixing_parser = OutputFixingParser.from_llm(parser=parser, llm=llm)
    
    format_instructions = parser.get_format_instructions()
    
    prompt = PromptTemplate(
        template="""You are an expert analyst of international relations and military conflicts.
        
        This article has been identified as describing a Militarized Interstate Confrontation (MIC).
        
        Extract the following specific details about this confrontation:
        1. The date when the confrontation occurred (be as precise as possible, use YYYY-MM-DD format if date is known else return 0000-00-00) 
        2. The number of fatalities (provide a range with minimum and maximum values; use the same number for both if precise)
        3. All countries involved in the confrontation
        4. If possible, identify which country initiated the confrontation and which was the target
        
        If any information is not explicitly stated in the article, make your best estimate based on context clues.
        If you cannot determine a piece of information at all, use null for that field.

        Example 1:
        Article: "On March 15, 2022, tensions escalated between Nation A and Nation B, leading to armed skirmishes. Reports confirm at least 50 casualties."
        
        Extracted Details:
        ```
        {{
            "MICdate": "2022-03-15",
            "fatality_min": 50,
            "fatality_max": 50,
            "countries_involved": ["Nation A", "Nation B"],
            "initiator_country": "Nation A",
            "target_country": "Nation B"
        }}
        ```
        
        Example 2:
        Article: "In early 1998, a naval conflict arose between Country X and Country Y. The exact number of casualties remains unknown."
        
        Extracted Details:
        ```
        {{
            "MICdate": "1998-01-01",
            "fatality_min": 0,
            "fatality_max": 0,
            "countries_involved": ["Country X", "Country Y"],
            "initiator_country": null,
            "target_country": null
        }}
        ```
        
        Article:
        {article}
        
        {format_instructions}
        """,
        input_variables=['article'],
        partial_variables={'format_instructions': format_instructions}
    )
    
    # Build the LCEL chain with fixing parser
    chain = prompt | llm | fixing_parser
    
    return chain

def read_article_file(file_path: Path) -> Optional[str]:
    """Read an article file with proper error handling for encodings."""
    try:
        # Try UTF-8 first
        return file_path.read_text(encoding="utf-8")
    except UnicodeDecodeError:
        try:
            # Fall back to Latin-1
            return file_path.read_text(encoding="latin-1")
        except Exception as e:
            print(f"Error reading {file_path.name}: {e}")
            return None

def create_unified_chain():
    """Create a unified chain that handles both classification and extraction based on classification result."""
    classification_chain = create_classification_chain()
    extraction_chain = create_extraction_chain()
    
    # Branch based on classification result
    def process_classification_and_branch(inputs):
        article_text = inputs["article_text"]
        file_index = inputs["file_index"]
        
        # Classify the article
        classification_result = classification_chain.invoke({"article": article_text})
        
        # If it's a MIC, extract details
        if classification_result.is_mic:
            extraction_result = extraction_chain.invoke({"article": article_text})
            return {
                "file_index": file_index,
                "classification_result": classification_result,
                "extraction_result": extraction_result,
                "is_mic": True
            }
        
        # If not a MIC, return only classification result
        return {
            "file_index": file_index,
            "classification_result": classification_result,
            "extraction_result": None,
            "is_mic": False
        }
    
    # Create the unified chain using a lambda function
    chain = RunnableLambda(process_classification_and_branch)
    
    return chain

def save_to_csv(file_index, result, csv_file, is_extraction=False):
    """Save results to CSV file, handling both classification and extraction results."""
    # Check if file exists to determine if headers are needed
    file_exists = csv_file.exists()
    
    # Check if entry already exists in the file
    existing_entries = set()
    if file_exists:
        with open(csv_file, mode="r", newline="", encoding="utf-8") as f:
            reader = csv.reader(f)
            next(reader, None)  # Skip header
            for row in reader:
                if row:  # Skip empty rows
                    existing_entries.add(row[0])  # Add index to set
    
    # Skip if entry already exists
    if file_index in existing_entries:
        return False
    
    # Open file in appropriate mode
    mode = "a" if file_exists else "w"
    with open(csv_file, mode=mode, newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        
        # Write headers if it's a new file
        if not file_exists:
            if is_extraction:
                writer.writerow([
                    'Index', 'MICdate', 'Fatality_Min', 'Fatality_Max', 
                    'Countries_Involved', 'Initiator_Country', 'Target_Country'
                ])
            else:
                writer.writerow(['Index', 'Label', 'Explanation'])
        
        # Write data row
        if is_extraction:
            writer.writerow([
                file_index,
                result.MICdate,
                result.fatality_min,
                result.fatality_max,
                ', '.join(result.countries_involved) if result.countries_involved else '',
                result.initiator_country if result.initiator_country else "null",
                result.target_country if result.target_country else "null"
            ])
        else:
            writer.writerow([
                file_index,
                int(result.is_mic),
                result.explanation
            ])
    
    return True

def process_articles(base_dir: Path, output_dir: Path, years_to_process: list[str]) -> None:
    """Process articles using a unified approach that classifies and conditionally extracts details."""
    # Create output directories
    classified_dir = output_dir / "classified_files"
    detailed_dir = output_dir / "detailed_files"
    classified_dir.mkdir(parents=True, exist_ok=True)
    detailed_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize unified chain
    unified_chain = create_unified_chain()
    
    # Process each year
    for year in years_to_process:
        year_dir = base_dir / year
        
        if not year_dir.is_dir():
            print(f"Directory for year {year} not found. Skipping.")
            continue
        
        # Set up output files
        classification_file = classified_dir / f"{year}_classification.csv"
        extraction_file = detailed_dir / f"{year}_mic_details.csv"
        
        # Get all article files and sort them by name
        article_files = sorted(list(year_dir.glob("**/*.txt")), key=lambda x: x.name)
        total_articles = len(article_files)
        print(f"Found {total_articles} articles in {year}")
        
        # Check for existing processed entries
        processed_entries = set()
        if classification_file.exists():
            with open(classification_file, 'r', newline='', encoding='utf-8') as f:
                reader = csv.reader(f)
                next(reader, None)  # Skip header
                for row in reader:
                    if row:
                        processed_entries.add(row[0])
            print(f"Found {len(processed_entries)} already processed articles")
        
        # Process articles with progress bar
        with tqdm(total=total_articles, desc=f"Processing {year}", position=0) as pbar:
            new_mic_count = 0
            new_processed_count = 0
            
            for article_file in article_files:
                # Create index for this file
                file_index = f"{year}_{article_file.name}"
                
                # Skip if already processed
                if file_index in processed_entries:
                    pbar.update(1)
                    continue
                
                # Read the article
                content = read_article_file(article_file)
                if content is None:
                    pbar.update(1)
                    continue
                
                try:
                    # Process through unified chain
                    result = unified_chain.invoke({
                        "article_text": content,
                        "file_index": file_index
                    })
                    
                    # Save classification result
                    classification_saved = save_to_csv(
                        file_index, 
                        result["classification_result"], 
                        classification_file
                    )
                    
                    if classification_saved:
                        new_processed_count += 1
                    
                    # If it's a MIC, save extraction result
                    if result["is_mic"] and result["extraction_result"]:
                        extraction_saved = save_to_csv(
                            file_index,
                            result["extraction_result"],
                            extraction_file,
                            is_extraction=True
                        )
                        
                        if extraction_saved:
                            new_mic_count += 1
                
                except Exception as e:
                    print(f"Error processing {file_index}: {str(e)[:100]}...")
                
                # Update progress
                pbar.update(1)
                
                # Save progress every 10 articles
                if new_processed_count % 10 == 0 and new_processed_count > 0:
                    pbar.set_postfix({"MICs": new_mic_count, "Processed": new_processed_count})
        
        print(f"Year {year} processing complete. New articles processed: {new_processed_count}, New MICs identified: {new_mic_count}")

In [8]:
base_dir = Path.cwd().parent / "processed_files"
output_dir = Path.cwd().parent

print(f"Starting unified MIC classification and extraction pipeline")
process_articles(base_dir, output_dir, ["2008"])
print("Processing complete.")

Starting unified MIC classification and extraction pipeline
Found 10218 articles in 2008
Found 1472 already processed articles


Processing 2008:   0%|          | 0/10218 [00:00<?, ?it/s]

Error processing 2008_article_2584.txt: Error code: 400 - {'error': 'Trying to keep the first 7122 tokens when context the overflows. Howeve...


KeyboardInterrupt: 