<a href="https://colab.research.google.com/github/Harshal292004/KGPDSH/blob/master/StormArchitecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
from langchain.prompts.chat import ChatPromptTemplate
from pydantic import BaseModel, Field, validator, PrivateAttr
from langchain_groq import ChatGroq
from langchain.chains.combine_documents import create_stuff_documents_chain
from google.colab import userdata
from typing import List, Dict

class ScoreDetails(BaseModel):
    score: float = Field(description="Score of the aspect (0-10)")
    justification: str = Field(description="Explanation for the score")

    @validator('score')
    def validate_score(cls, v: float) -> float:
        if v < 0 or v > 10:
            raise ValueError('Score must be between 0 and 10')
        return round(v, 1)

# Evaluation Categories
class Significance(BaseModel):
    theoretical_contribution: ScoreDetails = Field(description="Assessment of theoretical contribution to the field")
    practical_impact: ScoreDetails = Field(description="Evaluation of practical applications and real-world impact")
    novelty: ScoreDetails = Field(description="Originality of the research and its contribution")

class MethodologyEvaluation(BaseModel):
    research_design: ScoreDetails = Field(description="Appropriateness and quality of research design")
    reproducibility: ScoreDetails = Field(description="Clarity and completeness of methods for reproduction")

class PresentationQuality(BaseModel):
    writing_clarity: ScoreDetails = Field(description="Clarity and effectiveness of writing")
    organization: ScoreDetails = Field(description="Logical flow and structure of the paper")
    literature_review: ScoreDetails = Field(description="Comprehensiveness of literature review")

class PaperEvaluation(BaseModel):
    paper_id: str = Field(description="Unique identifier for the paper")
    significance: Significance = Field(description="Evaluation of research significance")
    methodology: MethodologyEvaluation = Field(description="Detailed methodology assessment")
    presentation: PresentationQuality = Field(description="Quality of presentation")
    confidence_level: float = Field(description="Overall confidence in this review (0-10)")
    _overall_score: float = PrivateAttr(default=0.0)
    major_strengths: List[str] = Field(description="Major strengths of the paper")
    major_weaknesses: List[str] = Field(description="Major weaknesses of the paper")
    detailed_feedback: str = Field(description="Comprehensive feedback and suggestions")
    summary_of_paper: str = Field(description="Relevant summary for retrieval tasks")

    @property
    def overall_score(self) -> float:
        if self._overall_score == 0.0:
            weights = {
                'significance': 0.3,
                'methodology': 0.5,
                'presentation': 0.2,
            }

            sig_score = (self.significance.theoretical_contribution.score +
                         self.significance.practical_impact.score +
                         self.significance.novelty.score) / 3

            meth_score = (self.methodology.research_design.score +
                          self.methodology.reproducibility.score) / 2

            pres_score = (self.presentation.writing_clarity.score +
                          self.presentation.organization.score +
                          self.presentation.literature_review.score) / 3

            self._overall_score = round(
                weights['significance'] * sig_score +
                weights['methodology'] * meth_score +
                weights['presentation'] * pres_score ,
                3
            )
        return self._overall_score

    @validator('confidence_level')
    def validate_confidence(cls, v: float) -> float:
        if v < 0 or v > 10:
            raise ValueError('Confidence level must be between 0 and 10')
        return round(v, 1)

    def generate_summary_report(self) -> str:
        return f"""
        Paper Review Summary (ID: {self.paper_id})
        ==========================================
        Review Confidence: {self.confidence_level}/10
        Overall Score: {self.overall_score}/10
        Key Strengths:
        {chr(10).join(f"- {strength}" for strength in self.major_strengths)}

        Areas for Improvement:
        {chr(10).join(f"- {weakness}" for weakness in self.major_weaknesses)}

        Detailed Feedback:
        {self.detailed_feedback}

        Summary of the Paper:
        {self.summary_of_paper}
        """

# Implementing STORM Architecture with LangGraph
class ConferenceAgent:
    def __init__(self, name: str, chat_model: ChatGroq, target_conference: str):
        self.name = name
        self.chat_model = chat_model
        self.target_conference = target_conference

    def evaluate_paper(self, paper_summary: str) -> Dict:
        prompt = ChatPromptTemplate.from_template(
            f"""
            You are a representative for the conference {self.target_conference}. Evaluate the research paper based on:
            - Relevance to conference themes
            - Quality of methodology
            - Novelty of contribution
            Provide a score (0-10) and a justification for why the paper fits this conference.
            Paper summary: {paper_summary}
            """
        )
        response = self.chat_model.generate(prompt)
        return {
            "score": float(response["score"]),
            "justification": response["justification"],
        }

class STORMSystem:
    def __init__(self, agents: List[ConferenceAgent]):
        self.agents = agents

    def discuss_and_decide(self, paper_summary: str) -> Dict:
        evaluations = []
        for agent in self.agents:
            evaluation = agent.evaluate_paper(paper_summary)
            evaluations.append({
                "conference": agent.target_conference,
                "score": evaluation["score"],
                "justification": evaluation["justification"],
            })

        # Determine the best conference based on scores
        best_conference = max(evaluations, key=lambda x: x["score"])
        return {
            "best_conference": best_conference["conference"],
            "justification": best_conference["justification"],
        }

# Example Usage
if __name__ == "__main__":
    chat_model = ChatGroq(groq_api_key = userdata.get("GROQ_API_KEY"), model="llama-3.3-70b-specdec")

    # Define agents for each conference
    agents = [
        ConferenceAgent(name="Agent CVPR", chat_model=chat_model, target_conference="CVPR"),
        ConferenceAgent(name="Agent NeurIPS", chat_model=chat_model, target_conference="NeurIPS"),
        ConferenceAgent(name="Agent EMNLP", chat_model=chat_model, target_conference="EMNLP"),
        ConferenceAgent(name="Agent KDD", chat_model=chat_model, target_conference="KDD"),
    ]

    storm_system = STORMSystem(agents=agents)

    # Input a research paper summary
    paper_summary = "paper"
    decision = storm_system.discuss_and_decide(paper_summary)

    print(f"Best Conference: {decision['best_conference']}")
    print(f"Justification: {decision['justification']}")


<ipython-input-15-d49af6add6a1>:12: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  @validator('score')
<ipython-input-15-d49af6add6a1>:73: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  @validator('confidence_level')


TypeError: Got unknown type name

In [7]:
pip install langchain langchain_community langgraph langchain_groq langgraph.agents

[31mERROR: Ignored the following versions that require a different python version: 0.1.0 Requires-Python <4.0,>=3.11[0m[31m
[0m[31mERROR: Could not find a version that satisfies the requirement langgraph.agents (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for langgraph.agents[0m[31m
[0m

In [6]:
pip install --upgrade langgraph

