In [None]:
!pip install groq tiktoken
!wget https://www.gutenberg.org/files/1342/1342-0.txt -O pride_and_prejudice.txt
#download and leave like two chapters for tests

import re
from groq import Groq
from typing import Dict, List, Tuple
import json
from collections import defaultdict
import tiktoken
import os
from google.colab import userdata

class LLMCharacterTracker:
    def __init__(self, api_key: str):
        self.client = Groq(api_key=api_key)
        # Using cl100k_base tokenizer which is compatible with many LLMs
        self.tokenizer = tiktoken.get_encoding("cl100k_base")
        self.chunk_size = 500  # tokens per chunk

    def preprocess_text(self, input_file: str, output_file: str) -> str:
            """
            Preprocess the book text by removing the introduction, table of contents, ending sections,
            text with leading spaces, and text enclosed in square brackets or between underscores.
            Removes unwanted punctuation but keeps sentence-ending punctuation marks.
            """
            with open(input_file, 'r', encoding='utf-8') as f:
                text = f.read()

            # Remove text enclosed in square brackets, including multiline content (e.g., [Illustration: ...])
            text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL)

            # Remove unwanted punctuation but keep sentence-ending punctuation (., !, ?)
            text = re.sub(r"[\"'“”‘’\]\[\(\){}]", '', text)

            # Print and remove lines with leading spaces
            removed_lines = [line for line in text.splitlines() if '      ' in line]
            print("Lines being removed due to leading spaces:")
            for line in removed_lines:
                print(line)

            # Remove lines with leading spaces (non-content or formatted text)
            text = '\n'.join([line for line in text.splitlines() if '      ' not in line])

            # Identify the start of the main content by looking for the first meaningful paragraph
            paragraphs = text.split('\n\n')
            main_content_index = -1

            for i, paragraph in enumerate(paragraphs):
                # Check if the paragraph is a valid start (contains more than one word and looks like a complete sentence)
                if re.match(r'^[A-Z][^?!.]*[.?!]$', paragraph.strip(), re.MULTILINE) and len(paragraph.split()) > 5:
                    main_content_index = i
                    break

            if main_content_index != -1:
                # Retain only the content from the first main paragraph onward
                text = '\n\n'.join(paragraphs[main_content_index:])

            # Identify the ending marker using non-content blocks (e.g., multiple empty lines or formatting markers)
            # Ensure these blocks do not include valid content (like "chapter")
            match = re.search(r'(\n\s*\n\s*){3,}', text, flags=re.DOTALL)
            if match:
                surrounding_text = text[max(0, match.start() - 100):match.start() + 100].lower()
                if 'chapter' not in surrounding_text:
                    text = text[:match.start()] # Remove content after the non-content block

            # Remove specific project markers (e.g., "*** END OF THE PROJECT GUTENBERG")
            text = re.sub(r'\*\*\*.*?\*\*\*', '', text, flags=re.DOTALL)

            # Trim empty lines at the start and end of the book
            text = text.strip()

            # Clean up any extra newlines or spaces for a cleaner output
            text = re.sub(r'\n\s*\n', '\n\n', text) # Maintain paragraphs with double newlines
            text = re.sub(r'-', ' ', text)  # Remove hyphens
            text = re.sub(r'[ ]+', ' ', text)  # Normalize spaces
            text = re.sub(r'_', '', text) # Remove just the "_" symbols while keeping the words intact

            with open(output_file, 'w', encoding='utf-8') as out_f:
                out_f.write(text)

            return output_file

    def chunk_text(self, text: str) -> List[Tuple[int, str]]:
        """
        Split text into chunks while preserving paragraph boundaries
        """
        paragraphs = text.split('\n\n')
        chunks = []
        current_chunk = []
        current_tokens = 0
        chunk_number = 0

        for para in paragraphs:
            para_tokens = len(self.tokenizer.encode(para))

            # If adding this paragraph would exceed chunk size, finalize current chunk
            if current_tokens + para_tokens > self.chunk_size:
                if current_chunk:
                    chunk_text = '\n\n'.join(current_chunk)
                    chunks.append((chunk_number, chunk_text))
                    current_chunk = []
                    current_tokens = 0
                    chunk_number += 1

            current_chunk.append(para)
            current_tokens += para_tokens

        # Add the final chunk if not empty
        if current_chunk:
            chunks.append((chunk_number, '\n\n'.join(current_chunk)))

        return chunks

    def extract_last_sentences(self, text: str, num_sentences: int = 3) -> str:
        """
        Extract the last few sentences from a text while maintaining structural integrity
        """
        paragraphs = text.split('\n\n')
        collected_sentences = []

        for paragraph in reversed(paragraphs):
            # Split paragraph into sentences
            sentences = re.split(r'(?<=[.!?])\s+', paragraph.strip())

            # Add sentences from the end of this paragraph
            for sentence in reversed(sentences):
                if sentence.strip():
                    collected_sentences.insert(0, sentence.strip())

                    # Stop if we've collected enough sentences
                    if len(collected_sentences) >= num_sentences:
                        break

            # Stop searching if we've collected enough sentences
            if len(collected_sentences) >= num_sentences:
                break

        return ' '.join(collected_sentences[:num_sentences])

    def create_prompt(self, chunk: str, previous_context: str = "") -> str:
        """
        Create a prompt for the LLM to identify characters and resolve references
        """
        return f"""Given the following text chunk from a book, identify all character mentions, including pronouns and aliases. Resolve each mention to the character's full name. Consider the previous context if provided.

Previous context (if any):
{previous_context}

Text chunk:
{chunk}

For each paragraph and sentence in the chunk, list all character mentions in the following format:
[Paragraph NUMBER]
[Sentence NUMBER]: CharacterMention1[FullName1], CharacterMention2[FullName2], ...

Rules:
1. Resolve all pronouns (he/she/I/we) to the correct character
2. Resolve all aliases and nicknames to the character's full name
3. If a mention is already the full name, use the same name for both mention and full name
4. Number paragraphs and sentences sequentially
5. Include only character mentions, ignore other entities

Example output:
[Paragraph 1]
[Sentence 1]: Lizzy[Elizabeth Bennet], she[Elizabeth Bennet]
[Sentence 2]: Mr. Darcy[Fitzwilliam Darcy], they[Fitzwilliam Darcy, Elizabeth Bennet]

Please process the text chunk now:"""

    def process_llm_response(self, response: str) -> Dict:
        """
        Parse the LLM response into a structured format with proper paragraph numbering
        """
        chronological_mentions = defaultdict(lambda: defaultdict(list))
        current_paragraph = 1  # Start with paragraph 1 instead of None

        for line in response.strip().split('\n'):
            if line.startswith('[Paragraph'):
                # Extract paragraph number, defaulting to incrementing if not found
                paragraph_match = re.search(r'\d+', line)
                if paragraph_match:
                    current_paragraph = int(paragraph_match.group())
            elif line.startswith('[Sentence'):
                sentence_match = re.search(r'\[Sentence (\d+)\]:\s*(.*)', line)
                if sentence_match:
                    sentence_num = int(sentence_match.group(1))
                    mentions_text = sentence_match.group(2)

                    # Parse mentions
                    mentions = re.findall(r'([^,\[\]]+)\[([^\[\]]+)\]', mentions_text)
                    for mention, full_name in mentions:
                        mention = mention.strip()
                        full_name = full_name.strip()
                        chronological_mentions[current_paragraph][sentence_num].append({
                            'character': full_name,
                            'mention': mention
                        })

        return chronological_mentions

    def merge_mentions(self, all_chunks_mentions: List[Dict]) -> Dict:
        """
        Merge mentions from multiple chunks while maintaining continuous paragraph numbering
        """
        merged_mentions = defaultdict(lambda: defaultdict(list))
        paragraph_offset = 0  # Track the cumulative paragraph count

        for chunk_index, chunk_mentions in enumerate(all_chunks_mentions):
            # Find the max paragraph number in this chunk
            max_para_num = max(int(para_num) for para_num in chunk_mentions.keys()) if chunk_mentions else 0

            for para_num, sentences in chunk_mentions.items():
                if para_num is not None:
                    # Adjust paragraph number based on previous chunks
                    adjusted_para_num = int(para_num) + paragraph_offset

                    for sent_num, mentions in sentences.items():
                        for mention in mentions:
                            merged_mentions[adjusted_para_num][sent_num].append(mention)

            # Update the offset for the next chunk
            paragraph_offset += max_para_num

        return merged_mentions

    def track_sequential_mentions(self, input_file: str, output_dir: str):
        """
        Process the book using LLM and create chronological tracking of character mentions
        """
        os.makedirs(output_dir, exist_ok=True)

        # Preprocess the text
        preprocessed_file = self.preprocess_text(input_file, f"preprocessed_{input_file}")

        with open(preprocessed_file, 'r', encoding='utf-8') as f:
            text = f.read()

        # Split into chunks
        chunks = self.chunk_text(text)

        # Process each chunk
        all_chunks_mentions = []
        previous_context = ""

        for chunk_num, chunk_text in chunks:
            prompt = self.create_prompt(chunk_text, previous_context)
            # Call Groq API
            response = self.client.chat.completions.create(
                messages=[
                    {
                        "role": "system",
                        "content": "You are a literary analysis assistant specialized in identifying and tracking characters in narrative text."
                    },
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                model="mixtral-8x7b-32768",
                temperature=0.0
            )
            print(response.choices[0].message.content)
            # Process the response
            chunk_mentions = self.process_llm_response(response.choices[0].message.content)
            all_chunks_mentions.append(chunk_mentions)

            # Update context for next chunk
            previous_context = self.extract_last_sentences(chunk_text,  num_sentences=5)

        # Merge all mentions and write output
        merged_mentions = self.merge_mentions(all_chunks_mentions)
        self.write_output(merged_mentions, f"{output_dir}sequential_mentions.txt")

    def write_output(self, chronological_mentions: Dict, output_file: str):
        """
        Write the chronological mentions to the output file
        """
        # Convert defaultdict to regular dict and ensure all keys are integers
        mentions_dict = {
            int(para_num): {
                int(sent_num): mentions
                for sent_num, mentions in sentences.items()
            }
            for para_num, sentences in chronological_mentions.items()
            if para_num is not None  # Skip any None keys
        }

        with open(output_file, 'w', encoding='utf-8') as f:
            for paragraph in sorted(mentions_dict.keys()):
                f.write(f"\n[Paragraph {paragraph}]\n")
                for sentence in sorted(mentions_dict[paragraph].keys()):
                    f.write(f"[Sentence {sentence}]: ")
                    mentions = mentions_dict[paragraph][sentence]
                    mention_texts = []
                    for mention in mentions:
                        if mention['mention'] == mention['character']:
                            mention_texts.append(f"{mention['mention']}")
                        else:
                            mention_texts.append(f"{mention['mention']}[{mention['character']}]")
                    f.write(", ".join(mention_texts))
                    f.write("\n")


if __name__ == "__main__":
    tracker = LLMCharacterTracker(api_key=userdata.get("GROQ_API_KEY"))
    tracker.track_sequential_mentions(
        input_file="pride_and_prejudice.txt",
        output_dir="pride_and_prejudice_llm/"
    )