In [2]:
from typing import Dict, List, Optional, Type, Tuple, Union
import prompts
from pydantic import BaseModel, Field, ValidationError
from abc import ABC, abstractmethod
from dataclasses import dataclass
import json

from logger_module import logger
from utils import Chunker, Tokenizer
import requests
import asyncio
from prompts import (
    MAX_VALIDATION_ERROR_TRY,
    SUMMARY_ROLE,
    SUMMARY_VALIDATION_RESOLVE_ROLE,
    PROMPT_ROLE,
    PROMPT_VALIDATION_RESOLVE_ROLE,
)
from json.decoder import JSONDecodeError
import os
from dotenv import load_dotenv
import time



In [3]:
load_dotenv()


## HeadersSchema
class HeadersSchema(BaseModel):
    authorization: str = Field(..., alias="Authorization")
    content_type: str = Field(default="application/json", alias="Content-Type")

    @classmethod
    def create(cls, api_key: str) -> "HeadersSchema":
        return cls(Authorization=f"Bearer {api_key}")

    class Config:
        populate_by_name = True


class MessageSchema(BaseModel):
    role: str
    content: str


##SummaryScehma
class SummaryPayloadSchema(BaseModel):
    model: str
    messages: List[MessageSchema]
    temperature: float
    stream: bool
    max_completion_tokens: int = 2048
    response_format: Dict[str, str] = {"type": "json_object"}
    top_p: float = 0.8
    frequency_penalty: float = 1.0
    presence_penalty: float = 1.5


class SummaryResponseSchema(BaseModel):
    summary: str
    characters: Dict[str, str]
    places: Dict[str, str]


class SummaryOutputSchema(SummaryResponseSchema):
    id: str


class SummaryContentSchema(BaseModel):
    past_context: str
    current_chapter: str
    character_list: Dict[str, str]
    places_list: Dict[str, str]


##PromptSchema
class PromptPayloadSchema(BaseModel):
    model: str
    messages: List[MessageSchema]
    temperature: float
    stream: bool
    max_completion_tokens: int = 2048
    response_format: Dict[str, str] = {"type": "json_object"}
    top_p: float = 0.8
    frequency_penalty: float = 1.0
    presence_penalty: float = 1.5


class PromptResponseSchema(BaseModel):
    scene_title: str = ""
    prompt: str = ""


class PromptOutputSchema(PromptResponseSchema):
    id: Optional[str] = None

    def __eq__(self, value: str) -> bool:
        return self.id == value


class PromptContentSchema(BaseModel):
    input_text: str
    character_list: Dict[str, str]
    places_list: Dict[str, str]



In [4]:

##Requests


class LLM_API(ABC):
    @abstractmethod
    def get_messages(
        self, content: str, character: Dict[str, str], places: Dict[str, str]
    ) -> List[MessageSchema]:
        pass

    @abstractmethod
    def validation_messages(self, input_text: str) -> List[MessageSchema]:
        pass

    @abstractmethod
    def get(self, messages: List[MessageSchema]) -> Tuple[int, str]:
        pass

    @abstractmethod
    def validate_json(
        self, raw_data: str, schema: Type[BaseModel]
    ) -> Optional[Type[BaseModel]]:
        pass



In [30]:

@dataclass
class Summary:
    api_key: str
    url: str = "https://api.groq.com/openai/v1/chat/completions"
    role: str = f"{SUMMARY_ROLE} follow given schema: {SummaryResponseSchema.model_json_schema()}"
    validation_role: str = f"{SUMMARY_VALIDATION_RESOLVE_ROLE} Schema :{SummaryResponseSchema.model_json_schema()}"
    model: str = "llama-3.1-8b-instant"
    temperature: float = 0.4
    stream: bool = False
    repetition_penalty: float = 1.5
    max_tokens: int = 6000

    def get_messages(
        self,
        content: str,
        previous_summary: str,
        characters: Dict[str, str],
        places: Dict[str, str],
    ) -> List[MessageSchema]:
        return [
            MessageSchema(role="system", content=self.role),
            MessageSchema(
                role="user",
                content=SummaryContentSchema(
                    past_context=previous_summary,
                    current_chapter=content,
                    character_list=characters,
                    places_list=places,
                ).model_dump_json(by_alias=True),
            ),
        ]

    def validation_messages(self, input_text: str) -> List[MessageSchema]:
        return [
            MessageSchema(role="system", content=self.validation_role),
            MessageSchema(
                role="user",
                content=input_text,
            ),
        ]

    def get(self, messages: List[MessageSchema]) -> Tuple[int, str]:
        payload = SummaryPayloadSchema(
            model=self.model,
            messages=messages,
            temperature=self.temperature,
            stream=self.stream,
        ).model_dump(by_alias=True)

        headers = HeadersSchema.create(api_key=self.api_key).model_dump(by_alias=True)
        response = requests.post(url=self.url, headers=headers, json=payload)
        code = response.status_code
        if code == 200:
            response_data = response.json()
            assistant_message = response_data["choices"][0]["message"]["content"]
            logger.info(
                f"200 : Input_Tokens={response_data['usage']['prompt_tokens']} | Output_Tokens={response_data['usage']['completion_tokens']}  | Time={response_data['usage']['total_time']}"
            )
            return code, assistant_message
        else:
            try:
                if response.json()["error"]["code"] == "json_validate_failed":
                    return 422, response.json()["error"]["failed_generation"]
            except:
                logger.warning(f"Error: {response.text()}")
                return code, "ERROR_API_CALL"

        return code, "ERROR_API_CALL"

    def validate_json(
        self, raw_data: str, schema: Type[SummaryResponseSchema]
    ) -> Union[SummaryResponseSchema, bool]:
        """
        Validates JSON data against a provided Pydantic schema.

        :param data: JSON string to be validated.
        :param schema: A Pydantic model class to validate against.
        :return: A tuple where the first element is a boolean indicating if there was an error,
                 and the second element is either the validated data or a list of error details.
        """
        try:
            parsed_data = json.loads(raw_data)
            validated_data = schema.model_validate(parsed_data)
            return validated_data
        except (ValidationError, JSONDecodeError):
            logger.warning("ValidationError")
            return False


@dataclass
class Prompt:
    api_key: str
    url: str = "https://api.groq.com/openai/v1/chat/completions"
    role: str = (
        f"{PROMPT_ROLE} follow given schema: {PromptResponseSchema.model_json_schema()}"
    )
    validation_role: str = f"{PROMPT_VALIDATION_RESOLVE_ROLE} Schema :{PromptResponseSchema.model_json_schema()}"
    model: str = "llama-3.1-8b-instant"
    temperature: float = 0.4
    stream: bool = False
    repetition_penalty: float = 1.5
    max_tokens: int = 6000

    def get_messages(
        self,
        input_text: str,
        characters: Dict[str, str],
        places: Dict[str, str],
    ) -> List[MessageSchema]:
        return [
            MessageSchema(role="system", content=self.role),
            MessageSchema(
                role="user",
                content=PromptContentSchema(
                    input_text=input_text,
                    character_list=characters,
                    places_list=places,
                ).model_dump_json(by_alias=True),
            ),
        ]

    def validation_messages(self, input_text: str) -> List[MessageSchema]:
        return [
            MessageSchema(role="system", content=self.validation_role),
            MessageSchema(
                role="user",
                content=input_text,
            ),
        ]

    def get(self, messages: List[MessageSchema]) -> Tuple[int, str]:
        payload = PromptPayloadSchema(
            model=self.model,
            messages=messages,
            temperature=self.temperature,
            stream=self.stream,
        ).model_dump(by_alias=True)

        headers = HeadersSchema.create(api_key=self.api_key).model_dump(by_alias=True)
        response = requests.post(url=self.url, headers=headers, json=payload)
        code = response.status_code
        if code == 200:
            response_data = response.json()
            assistant_message = response_data["choices"][0]["message"]["content"]
            logger.info(
                f"200 : Input_Tokens={response_data['usage']['prompt_tokens']} | Output_Tokens={response_data['usage']['completion_tokens']}  | Time={response_data['usage']['total_time']}"
            )
            return code, assistant_message
        else:
            try:
                if response.json()["error"]["code"] == "json_validate_failed":
                    return 422, response.json()["error"]["failed_generation"]
            except:
                logger.warning(f"Error: {response.json()}")
                return code, "ERROR_API_CALL"

        return code, "ERROR_API_CALL"

    def validate_json(
        self, raw_data: str, schema: Type[PromptResponseSchema]
    ) -> Union[PromptResponseSchema, bool]:
        """
        Validates JSON data against a provided Pydantic schema.

        :param data: JSON string to be validated.
        :param schema: A Pydantic model class to validate against.
        :return: A tuple where the first element is a boolean indicating if there was an error,
                 and the second element is either the validated data or a list of error details.
        """
        try:
            parsed_data = json.loads(raw_data)
            validated_data = schema.model_validate(parsed_data)
            return validated_data
        except (ValidationError, JSONDecodeError):
            logger.warning("ValidationError")
            return False


class SummaryLoop(BaseModel):
    content: List[Tuple[str, str]]
    summary: Summary
    summary_pool: List[SummaryOutputSchema] = Field(default_factory=list)
    chunked_content: List[Tuple[str, str, str]] = Field(default_factory=list)

    def initialize(self) -> Optional["SummaryLoop"]:
        hf_api = os.environ.get("HF_API")
        if not hf_api:
            logger.error("[SummaryLoop] HF_API Not defined ")
            return None

        self.summary_pool = [
            SummaryOutputSchema(
                summary="This is The first chapeter There is No context",
                places={},
                characters={},
                id="",
            ),
        ]
        tokenizer = Tokenizer(api_key=hf_api)
        logger.trace("tokenizer set")
        chunker = Chunker(max_len=self.summary.max_tokens, tokenizer=tokenizer)
        logger.trace("Chunker set")
        self.chunked_content = chunker.chunk(content=self.content)
        logger.trace("Chapters Chunked")

        return self

    def run(self) -> None:
        """
        Assuming book comes in the form of ((id,title_chapter,chapter_content),..)
        """
        for idx, (id, title, content) in enumerate(self.chunked_content):
            past_context = self.summary_pool[idx]
            message = self.summary.get_messages(
                content=content,
                previous_summary=past_context.summary,
                characters=past_context.characters,
                places=past_context.places,
            )
            status_code, response = self.summary.get(messages=message)

            if status_code == 200:
                validated_response = self.summary.validate_json(
                    response, SummaryResponseSchema
                )

                if validated_response:
                    logger.trace(f"Chunk_{id=} Done")
                    self.summary_pool.append(
                        SummaryOutputSchema(
                            **(validated_response.model_dump(by_alias=True)), id=id
                        )
                    )
                    continue
                else:
                    output = self.handle_validation_error(response)
                    if output:
                        past_context = output
            elif status_code == 422:
                output = self.handle_validation_error(response)
                if output:
                    past_context = output

            self.summary_pool.append(past_context)
            logger.warning(f"{status_code=} error getting{id=}")

    def handle_validation_error(self, input_text):
        message = self.summary.validation_messages(input_text)
        for idx in range(MAX_VALIDATION_ERROR_TRY):
            status_code, response = self.summary.get(messages=message)
            if status_code == 200:
                validated_response = self.summary.validate_json(
                    response, SummaryOutputSchema
                )
                if validated_response:
                    logger.info("Validation error resolved")
                    return validated_response
            elif status_code == 422:
                message = self.summary.validation_messages(response)
            logger.warning(f"Validation Unresolved on try {idx + 1}")
        logger.error("COULDNT VALIDATE THE CHUNK, SKIPPING...")
        return None

    @property
    def get_summary_pool(self):
        return self.summary_pool


class PromptLoop(BaseModel):
    summary:SummaryLoop
    content: List[Tuple[str, str]]
    prompt: Prompt
    prompt_pool: List[PromptOutputSchema] = Field(default_factory=list)
    input_prompt_list:List[PromptResponseSchema]= Field(init=False)
    def initialize(self):

        chunks=self.summary.chunked_content
        sum_pool=self.summary.summary_pool[1:]
        self.input_prompt_list=[PromptContentSchema(input_text=chunk[2],places_list=sum.places,character_list=sum.characters) for chunk,sum in zip(chunks,sum_pool)]
                
    def run(self) -> None:
        """
        Assuming book comes in the form of ((id,title_chapter,chapter_content),..)
        """
        for prompt_input in self.input_prompt_list:
            message = self.prompt.get_messages(
                input_text=prompt_input.content, characters=prompt_input.characters, places=prompt_input.places
            )

            status_code, response = self.prompt.get(messages=message)

            if status_code == 200:
                validated_response = self.summary.validate_json(
                    response, PromptResponseSchema
                )

                if validated_response:
                    logger.trace(f"Chunk_{id=} Done")
                    self.prompt_pool.append(
                        PromptOutputSchema(
                            **(validated_response.model_dump(by_alias=True)), id=id
                        )
                    )
                    continue
                else:
                    output = self.handle_validation_error(response)
                    if output:
                        prompt_out = output
            elif status_code == 422:
                output = self.handle_validation_error(response)
                if output:
                    prompt_out = output

            self.prompt_pool.append(prompt_out)
            logger.warning(f"{status_code=} error getting{id=}")

    def handle_validation_error(self, input_text):
        message = self.prompt.validation_messages(input_text)
        for idx in range(MAX_VALIDATION_ERROR_TRY):
            status_code, response = self.prompt.get(messages=message)
            if status_code == 200:
                validated_response = self.prompt.validate_json(
                    response, PromptResponseSchema
                )
                if validated_response:
                    logger.info("Validation error resolved")
                    return validated_response
            elif status_code == 422:
                message = self.summary.validation_messages(response)
            logger.warning(f"Validation Unresolved on try {idx + 1}")
        logger.error("COULDNT VALIDATE THE CHUNK, SKIPPING...")
        return None

    @property
    def get_prompt_pool(self):
        return self.prompt_pool




In [None]:
from reader import ebook
api = os.environ.get("GROQ_API")
if not api:
    raise Exception("API NOT SET IN .env, HF_API=None")

book = ebook("./exp_book/stranger.pdf")

chapter_content = book.get_chapters()[1:6]

sum = Summary(
    api_key=api,
)
looper = SummaryLoop(content=chapter_content, summary=sum).initialize()
if not looper:
    raise Exception("Looper is none")

looper.run()



--------------------------------------------------
Chapter:$0#1
Summary
	 - The narrator, Meursault, travels to Marengo to attend his mother's funeral. He meets the warden and the keeper, who explain the arrangements for the funeral. Meursault visits his mother's body in the mortuary and meets the old people who are keeping vigil with him. They are quiet and seem to be in a state of shock. Meursault is struck by their appearance and their behavior. The keeper brings them coffee and they drink it in silence. Meursault falls asleep and wakes up to find the old people sleeping. He wakes up again and sees the old men coughing and spitting. The keeper tells them it's time to leave and they all get up and shake hands with Meursault. The narrator reflects on the strange and uncomfortable atmosphere of the vigil. He then meets the warden, who explains that there will be a funeral procession to the church in the village. Meursault, the warden, and the nurse will be the only mourners. They are j

In [None]:
prompt

In [11]:
for i in looper.get_summary_pool:
    if not i.places or not i.characters:
        continue
    print("-" * 50)
    print(f"Chapter:{i.id}")
    print("Summary")
    print(f"\t - {i.summary}")
    print()
    print("characters")
    for k, v in i.characters.items():
        print(f"\t - {k} : {v}")
    print()
    print("places")
    for k, v in i.places.items():
        print(f"\t - {k} : {v}")
    print("\n\n")


--------------------------------------------------
Chapter:$0#1
Summary
	 - The narrator, Meursault, travels to Marengo to attend his mother's funeral. He meets the warden and the keeper, who explain the arrangements for the funeral. Meursault visits his mother's body in the mortuary and meets the old people who are keeping vigil with him. They are quiet and seem to be in a state of shock. Meursault is struck by their appearance and their behavior. The keeper brings them coffee and they drink it in silence. Meursault falls asleep and wakes up to find the old people sleeping. He wakes up again and sees the old men coughing and spitting. The keeper tells them it's time to leave and they all get up and shake hands with Meursault. The narrator reflects on the strange and uncomfortable atmosphere of the vigil. He then meets the warden, who explains that there will be a funeral procession to the church in the village. Meursault, the warden, and the nurse will be the only mourners. They are j

In [29]:
chunks=looper.chunked_content
sum_pool=looper.summary_pool[1:]
# sum_pool
prm=[PromptContentSchema(input_text=chunk[2],places_list=sum.places,character_list=sum.characters) for chunk,sum in zip(chunks,sum_pool)]
prm

[PromptContentSchema(input_text='Part One I MOTHER died today. Or, maybe, yesterday; I canâ\x80\x99t be sure. The telegram from the Home says: YOUR MOTHER PASSED AWAY. FUNERAL TOMORROW. DEEP SYMPATHY. Which leaves the matter doubtful; it could have been yesterday. The Home for Aged Persons is at Marengo, some fifty miles from Algiers. With the two oâ\x80\x99clock bus I should get there well before nightfall. Then I can spend the night there, keeping the usual vigil beside the body, and be back here by tomorrow evening. I have fixed up with my employer for two daysâ\x80\x99 leave; obviously, under the circumstances, he couldnâ\x80\x99t refuse. Still, I had an idea he looked annoyed, and I said, without thinking: â\x80\x9cSorry, sir, but itâ\x80\x99s not my fault, you know.â\x80\x9d Afterwards it struck me I neednâ\x80\x99t have said that. I had no reason to excuse myself; it was up to him to express his sympathy and so forth. Probably he will do so the day after tomorrow, when he sees m