In [None]:
import os
from dataclasses import dataclass
from typing import Literal

import dspy
import pandas as pd
from dspy import ChainOfThought, InputField, OutputField, Signature
from dspy import Example
from dspy import LM, configure
from dspy.evaluate import Evaluate
from dspy.evaluate.metrics import answer_exact_match
from dspy.teleprompt import BootstrapFewShot
from pydantic import BaseModel, Field
from sklearn.model_selection import train_test_split
from dspy import Prediction

In [None]:
llm = LM(
    "deepseek/deepseek-chat",
    api_key=os.getenv("DEEPSEEK_API_KEY"),
    api_base="https://api.deepseek.com",
)
configure(lm=llm)


class QA(Signature):
    question = InputField()
    answer = OutputField(desc="concise answer")


qa = ChainOfThought(QA)

In [None]:
response = qa(question="Explain photosynthesis in simple terms")
# print(f"Question: {response.question}")
print(f"Answer: {response.answer}")

# # Another example
# response2 = qa_bot(question="Explain photosynthesis in simple terms")
# print(f"\nQuestion: {response2.question}")
# print(f"Answer: {response2.answer}")

Answer: Photosynthesis is how plants make their own food. They use sunlight, water from the soil, and carbon dioxide from the air. Inside their leaves, a green substance called chlorophyll captures sunlight energy to turn these ingredients into sugar (which is food for the plant) and oxygen, which is released into the air.


In [None]:
class MultiQueryGenerator(Signature):
    """Generate multiple research queries from a single user query."""

    query = InputField(desc="Original user query")
    num_queries = InputField(desc="Number of queries to generate")
    research_queries = OutputField(
        desc="Numbered list of 6-12 word research queries that preserve specific details"
    )


# Create the module
query_optimizer = ChainOfThought(MultiQueryGenerator)

original_query = "Gaza war impact on children mental health 2019-2024"
num_queries = 5

result = query_optimizer(query=original_query, num_queries=num_queries)

print(f"Original query: {original_query}")
print(f"Optimized queries:\n{result.research_queries}")

Original query: Gaza war impact on children mental health 2019-2024
Optimized queries:
1. Gaza children mental health trauma effects 2019-2024.
2. PTSD prevalence in Gaza children during war 2019-2024.
3. Long-term psychological impacts on Gaza youth 2019-2024.
4. Humanitarian interventions for child mental health Gaza 2019-2024.
5. Educational disruptions and mental health in Gaza children 2019-2024.


In [None]:
llm.inspect_history(n=1)





[34m[2025-09-01T21:35:03.441228][0m

[31mSystem message:[0m

Your input fields are:
1. `query` (str): Original user query
2. `num_queries` (str): Number of queries to generate
Your output fields are:
1. `reasoning` (str): 
2. `research_queries` (str): Numbered list of 6-12 word research queries that preserve specific details
All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## query ## ]]
{query}

[[ ## num_queries ## ]]
{num_queries}

[[ ## reasoning ## ]]
{reasoning}

[[ ## research_queries ## ]]
{research_queries}

[[ ## completed ## ]]
In adhering to this structure, your objective is: 
        Generate multiple research queries from a single user query.


[31mUser message:[0m

[[ ## query ## ]]
Gaza war impact on children mental health 2019-2024

[[ ## num_queries ## ]]
5

Respond with the corresponding output fields, starting with the field `[[ ## reasoning ## ]]`, then `[[ ## research_queries ## ]]`, and then ending wit

In [None]:
class COT(Signature):
    question = InputField()
    answer = OutputField(desc="concise answer")


multistep_question = "What is the birth place of the person who provided the assist to Man united's only goal in the uefa champions league final?"
multistep = ChainOfThought(COT)
ans = multistep(question=multistep_question)

In [None]:
ans

Prediction(
    reasoning='The question asks for the birthplace of the person who provided the assist for Manchester United\'s only goal in the UEFA Champions League final. Manchester United has won the Champions League (formerly European Cup) three times: in 1968, 1999, and 2008. The mention of "only goal" suggests a final where they scored exactly one goal, which applies to the 2008 final against Chelsea, which ended 1-1 in regular time and was won on penalties. The only goal in regular time was scored by Cristiano Ronaldo, assisted by Wes Brown. Wes Brown was born in Longsight, Manchester, England.',
    answer='Longsight, Manchester, England'
)

# Modules

In [None]:
@dataclass
class DoubleCOT(dspy.Module):
    cot1 = ChainOfThought("question -> step_by_step_thought")
    cot2 = ChainOfThought("question, thought -> concise_and_accurate_answer")

    def forward(self, question):
        thought = self.cot1(question=question).step_by_step_thought
        answer = self.cot2(
            question=question, thought=thought
        ).concise_and_accurate_answer
        return dspy.Prediction(answer=answer, thought=thought)

In [None]:
multistep_question = """How many trailing zeros does 100! have in base 12 and whats the general method to cmopute trailing zeros of n! in base b?"""
doublecot = DoubleCOT()
ans = doublecot(question=multistep_question)

In [None]:
ans.answer

'100! has 48 trailing zeros in base 12. The general method to compute trailing zeros of n! in base b is:\n1. Factorize b into primes: b = p1^e1 * p2^e2 * ... * pk^ek.\n2. For each prime pi, compute its exponent in n! as sum_{k=1}^{∞} floor(n / (pi^k)).\n3. For each pi, compute vi = floor( (exponent of pi in n!) / ei ).\n4. The number of trailing zeros is min(v1, v2, ..., vk).'

In [None]:
llm.inspect_history(n=2)





[34m[2025-09-01T21:35:04.088933][0m

[31mSystem message:[0m

Your input fields are:
1. `question` (str):
Your output fields are:
1. `reasoning` (str): 
2. `step_by_step_thought` (str):
All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## question ## ]]
{question}

[[ ## reasoning ## ]]
{reasoning}

[[ ## step_by_step_thought ## ]]
{step_by_step_thought}

[[ ## completed ## ]]
In adhering to this structure, your objective is: 
        Given the fields `question`, produce the fields `step_by_step_thought`.


[31mUser message:[0m

[[ ## question ## ]]
How many trailing zeros does 100! have in base 12 and whats the general method to cmopute trailing zeros of n! in base b?

Respond with the corresponding output fields, starting with the field `[[ ## reasoning ## ]]`, then `[[ ## step_by_step_thought ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`.


[31mResponse:[0m

[32m[[ ## reasoning ## ]]
To find the num

# Enforcing typed output

In [None]:
# from dspy.functional import TypedChainOfThought
class StructuredOutput(BaseModel):
    answer: str = Field(description="Answer in 3-5 words")
    confidence: float = Field(
        description="Confidence score between 0 and 1", ge=0, le=1
    )


class QAwithConfidence(Signature):
    question = InputField()
    answer: StructuredOutput = OutputField()


predict = ChainOfThought(QAwithConfidence)

In [None]:
question = "What is the capital of France?"
response = predict(question=question)

In [None]:
response.answer

StructuredOutput(answer='Paris', confidence=1.0)

In [None]:
llm.inspect_history(n=1)





[34m[2025-09-01T21:35:04.663880][0m

[31mSystem message:[0m

Your input fields are:
1. `question` (str):
Your output fields are:
1. `reasoning` (str): 
2. `answer` (StructuredOutput):
All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## question ## ]]
{question}

[[ ## reasoning ## ]]
{reasoning}

[[ ## answer ## ]]
{answer}        # note: the value you produce must adhere to the JSON schema: {"type": "object", "properties": {"answer": {"type": "string", "description": "Answer in 3-5 words", "title": "Answer"}, "confidence": {"type": "number", "description": "Confidence score between 0 and 1", "maximum": 1, "minimum": 0, "title": "Confidence"}}, "required": ["answer", "confidence"], "title": "StructuredOutput"}

[[ ## completed ## ]]
In adhering to this structure, your objective is: 
        Given the fields `question`, produce the fields `answer`.


[31mUser message:[0m

[[ ## question ## ]]
What is the capital of France?

R

# RAG

In [None]:
colbertv2_wiki17_abstracts = dspy.ColBERTv2(
    url="http://20.102.90.50:2017/wiki17_abstracts"
)
dspy.settings.configure(rm=colbertv2_wiki17_abstracts)
# Define Retrieve Module
retriever = dspy.Retrieve(k=3)

query = "Portugal"

# Call the retriever on a particular query.
topK_passages = retriever(query).passages

for idx, passage in enumerate(topK_passages):
    print(f"{idx + 1}]", passage, "\n")

1] Portugal | Portugal (Portuguese: ] ), officially the Portuguese Republic (Portuguese: "República Portuguesa" ] ), is a sovereign state located on the Iberian Peninsula in southwestern Europe. It is the westernmost country of mainland Europe, being bordered to the west and south by the Atlantic Ocean and to the north and east by Spain. The Portugal–Spain border is 1214 km long, making it the longest uninterrupted border within the European Union. The republic also includes the Atlantic archipelagos of the Azores and Madeira, both autonomous regions with their own regional governments. 

2] Portugal (disambiguation) | Portugal is a country in southwestern Europe. 

3] Geography of Portugal | Portugal is a coastal nation in southwestern Europe, located at the western end of the Iberian Peninsula, bordering Spain (on its northern and eastern frontiers: a total of 1214 km ). The Portuguese territory also includes a series of archipelagos in the Atlantic Ocean (the Açores and Madeira), wh

In [None]:
class GenerateAnswer(Signature):
    """Represents the structure and functionality for generating an answer.

    This class is used to define the input fields related to context and
    question, and an output field to hold the generated answer. It serves
    as a framework for handling the flow of data involved in generating
    answers based on an input question and its related context.

    Attributes:
        context: May contain relevant information.
        question: Question to answer.
        answer: Answer to the question.
    """

    context = InputField(desc="May contain relevant information")
    question = InputField(desc="Question to answer")
    answer = OutputField(desc="Answer to the question often between 2-7 words")


class RAG(dspy.Module):
    """
    A class representing a Retrieval-Augmented Generation (RAG) model.

    This class provides a modular framework for retrieval-augmented generation,
    where a retrieval component fetches relevant documents or passages based
    on a query, and a generative component produces an output by combining
    retrieved information with the query. The purpose of this class is to
    combine retrieval and generation capabilities into a cohesive, easy-to-use
    pipeline.

    Attributes:
        num_messages (int): The number of top messages or passages to retrieve,
            which indirectly dictates the size of the retrieved context.
    """

    def __init__(self, num_messages=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=num_messages)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)

    def forward(self, question):
        context = self.retrieve(question).passages
        prediction = self.generate_answer(context=context, question=question)
        return prediction

In [None]:
question = "What country did Pele play for?"
rag = RAG()
output = rag(question=question)

In [None]:
output

Prediction(
    reasoning="The question asks about the country Pele played for, but there are multiple footballers named Pelé in the context. The first Pelé (born 1991) plays for the Guinea-Bissau national team. The second mention is about Waldemar de Brito, who discovered Pelé (the famous one, not detailed here). The third Pelé (born 1978) is Cape Verdean. Since the question does not specify which Pelé, and the context provides two distinct national teams (Guinea-Bissau and Cape Verde), it is ambiguous. However, the most famous Pelé (Edson Arantes do Nascimento) is not in this context, so based solely on the provided information, both are valid but the question is unclear. Given the context entries, the answer should reflect the ambiguity or note the possibilities. Since the instruction is to provide an answer between 2-7 words, I'll list the countries mentioned for the Pelés in the context.",
    answer='Guinea-Bissau or Cape Verde.'
)

In [None]:
llm.inspect_history()





[34m[2025-09-01T21:35:05.243732][0m

[31mSystem message:[0m

Your input fields are:
1. `context` (str): May contain relevant information
2. `question` (str): Question to answer
Your output fields are:
1. `reasoning` (str): 
2. `answer` (str): Answer to the question often between 2-7 words
All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## context ## ]]
{context}

[[ ## question ## ]]
{question}

[[ ## reasoning ## ]]
{reasoning}

[[ ## answer ## ]]
{answer}

[[ ## completed ## ]]
In adhering to this structure, your objective is: 
        Represents the structure and functionality for generating an answer.
        
        This class is used to define the input fields related to context and
        question, and an output field to hold the generated answer. It serves
        as a framework for handling the flow of data involved in generating
        answers based on an input question and its related context.
        
        A

# Multihop RAG

In [None]:
def deduplicate(passages):
    seen = set()
    unique = []
    for p in passages:
        key = p.strip().lower()
        if key not in seen:
            unique.append(p)
            seen.add(key)
    return unique

In [None]:
class GenerateQuery(Signature):
    """Represents the structure and functionality for generating a query."""

    context = InputField(desc="May contain relevant information")
    question = InputField(desc="Question to answer")
    query = OutputField()


@dataclass
class MultihopRAG(dspy.Module):
    passages_per_hop: int = 3
    max_hops: int = 3

    def __init__(self):
        super().__init__()
        self.generate_query = [
            dspy.ChainOfThought(GenerateQuery) for _ in range(self.max_hops)
        ]
        self.retrieve = dspy.Retrieve(k=self.passages_per_hop)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
        self.rag = RAG(num_messages=self.passages_per_hop)

    def forward(self, question):
        context = []
        for hop in range(self.max_hops):
            query = self.generate_query[hop](context=context, question=question).query
            passages = self.retrieve(query).passages
            context = deduplicate(context + passages)
            pred = self.generate_answer(context=context, question=question)
        return dspy.Prediction(answer=pred.answer, context=context)

In [None]:
question = "What is the capital city of the country that won the 2014 world cup?"
multihop_rag = MultihopRAG()
output = multihop_rag(question=question)

In [None]:
output

Prediction(
    answer='Berlin',
    context=["2014 FIFA World Cup | The 2014 FIFA World Cup was the 20th FIFA World Cup, the quadrennial world championship for men's national football teams organized by FIFA. It took place in Brazil from 12 June to 13 July 2014, after the country was awarded the hosting rights in 2007. It was the second time that Brazil staged the competition, the first being in 1950, and the fifth time that it was held in South America.", "2014 FIFA World Cup Final | The 2014 FIFA World Cup Final was a football match that took place on 13 July 2014 at the Maracanã Stadium in Rio de Janeiro, Brazil to determine the 2014 FIFA World Cup champion. Germany defeated Argentina 1–0 in extra time, with the only goal being scored by Mario Götze, who collected André Schürrle's cross from the left on his chest before volleying a high left-footed shot into the net. The match was the third final between the two countries, a World Cup record, after their 1986 and 1990 matches, and 

In [None]:
llm.inspect_history()





[34m[2025-09-01T21:35:31.255446][0m

[31mSystem message:[0m

Your input fields are:
1. `context` (str): May contain relevant information
2. `question` (str): Question to answer
Your output fields are:
1. `reasoning` (str): 
2. `answer` (str): Answer to the question often between 2-7 words
All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## context ## ]]
{context}

[[ ## question ## ]]
{question}

[[ ## reasoning ## ]]
{reasoning}

[[ ## answer ## ]]
{answer}

[[ ## completed ## ]]
In adhering to this structure, your objective is: 
        Represents the structure and functionality for generating an answer.
        
        This class is used to define the input fields related to context and
        question, and an output field to hold the generated answer. It serves
        as a framework for handling the flow of data involved in generating
        answers based on an input question and its related context.
        
        A

# Zero Shot Prompting

In [None]:
df = pd.read_csv("football_rules_dataset.csv")

In [None]:
df.head()

Unnamed: 0,Question,Answer
0,Player spits on an opponent,Red card
1,Player deliberately handles the ball to preven...,Red card and penalty
2,Player commits a reckless tackle,Yellow card
3,Player shows dissent by word or action,Yellow card
4,Player delays the restart of play,Yellow card


In [None]:
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
trainset, testset = train_test_split(df, test_size=0.8, random_state=42)

In [None]:
train_set = [
    Example(question=x["Question"], answer=x["Answer"]).with_inputs("question")
    for x in trainset.to_dict(orient="records")
]
test_set = [
    Example(question=x["Question"], answer=x["Answer"]).with_inputs("question")
    for x in testset.to_dict(orient="records")
]

In [None]:
train_set

[Example({'question': "Player commits flagrant elbow to opponent's face", 'answer': 'Red card and criminal charges'}) (input_keys={'question'}),
 Example({'question': 'Player throws punch at opponent after final whistle', 'answer': 'Red card and criminal charges'}) (input_keys={'question'}),
 Example({'question': 'Player commits violent conduct off the ball', 'answer': 'Red card'}) (input_keys={'question'}),
 Example({'question': 'Player impedes an opponent without playing the ball', 'answer': 'Indirect free kick'}) (input_keys={'question'}),
 Example({'question': 'Player commits soft contact during aerial duel', 'answer': 'Play on'}) (input_keys={'question'}),
 Example({'question': 'Player appeals for penalty without aggression', 'answer': 'Play on'}) (input_keys={'question'}),
 Example({'question': 'Player makes death threats to opponent', 'answer': 'Red card and criminal charges'}) (input_keys={'question'}),
 Example({'question': 'Player deliberately handles the ball to prevent a go

In [None]:
RefereeOptions = Literal[
    "Red card and criminal charges",
    "Yellow card",
    "Red card",
    "Warning",
    "Warning and retake",
    "Warning or yellow card",
    "Play on",
    "Red card and lengthy ban",
    "Indirect free kick",
    "Free kick",
    "Penalty and red card",
    "Yellow card or penalty if in box",
    "Red card and penalty",
    "Red card and penalty if in box",
]


class RefereeAnswer(Signature):
    """Represents a referee's answer for a specific question.

    This class is derived from the `Signature` class and is used to
    provide functionality for managing a referee's response to a question.
    It defines the input question and the expected output answer, which
    should involve selecting an appropriate action strictly without any
    paraphrasing.

    Attributes:
        question (InputField): The input question that requires the referee's
            decision or action.
        answer (OutputField): The referee's selected action, defined by the
            `RefereeOptions` type, ensuring it is chosen without paraphrasing.
    """

    question = InputField()
    answer = OutputField(
        desc="MANDATORY: Select one exact option from this list without any modification: ['Red card and criminal charges', 'Yellow card', 'Red card', 'Warning', 'Warning and retake', 'Warning or yellow card', 'Play on', 'Red card and lengthy ban', 'Indirect free kick', 'Free kick', 'Penalty and red card', 'Yellow card or penalty if in box', 'Red card and penalty', 'Red card and penalty if in box']",
        type=RefereeOptions,
    )

In [None]:
class PredictModel(dspy.Module):
    def __init__(self):
        self.predict = ChainOfThought(RefereeAnswer)

    def forward(self, question):
        prediction = self.predict(question=question)
        return prediction

In [None]:
predict = PredictModel()

In [None]:
evaluate_program = Evaluate(
    devset=test_set,
    metric=answer_exact_match,  # More flexible than exact match
    display_table=10,
    display_progress=True,
    num_threads=1,
)

In [None]:
eval = evaluate_program(predict)
eval

Average Metric: 46.00 / 83 (55.4%): 100%|██████████| 83/83 [00:03<00:00, 24.99it/s]

2025/09/01 21:35:36 INFO dspy.evaluate.evaluate: Average Metric: 46 / 83 (55.4%)





Unnamed: 0,question,example_answer,reasoning,pred_answer,answer_exact_match
0,Player commits sexual assault during play,Red card and criminal charges,Sexual assault is a severe criminal offense that goes far beyond t...,Red card and criminal charges,✔️ [True]
1,Player makes questionable appeal for foul,Play on,"The scenario describes a player making a ""questionable appeal for ...",Yellow card,
2,Player throws bottle at referee from bench,Red card and criminal charges,"Throwing an object, especially a bottle, at the referee from the b...",Red card and criminal charges,✔️ [True]
3,Player blocks goalkeeper's release of the ball,Yellow card,Blocking the goalkeeper's release of the ball is considered an off...,Yellow card,✔️ [True]
4,Player pulls opponent's hair during play,Yellow card,Pulling an opponent's hair is a serious offense that is considered...,Red card,
5,Player simulates contact to deceive referee,Yellow card,Simulating contact to deceive the referee is an act of unsporting ...,Yellow card,✔️ [True]
6,Player pushes an opponent during corner kick,Warning or yellow card,Pushing an opponent is a direct physical foul that is unsporting a...,Yellow card,
7,Player headbutts an opponent,Red card and criminal charges,A headbutt is a violent and aggressive act that is considered seri...,Red card,
8,Player commits a late tackle after ball has gone,Yellow card,"A late tackle, especially after the ball has gone, is considered r...",Yellow card,✔️ [True]
9,Player enters opponent's penalty area before penalty kick,Warning and retake,"According to the Laws of the Game (Law 14), if an attacking player...",Warning and retake,✔️ [True]


EvaluationResult(score=55.42, results=<list of 83 results>)

In [None]:
df = pd.DataFrame(
    [
        {
            "question": ex.question,
            "expected_answer": ex.answer,
            "predicted_answer": pred.answer,
            "reasoning": pred.reasoning,
            "correct": correct,
        }
        for ex, pred, correct in eval.results
    ]
)

In [None]:
df.tail()

Unnamed: 0,question,expected_answer,predicted_answer,reasoning,correct
78,Player engages in time-wasting tactics,Yellow card,Yellow card,Time-wasting tactics are considered unsporting behavior under the ...,True
79,Player makes obscene gesture to family section,Red card and criminal charges,Red card,"Making an obscene gesture, especially directed toward a family sec...",False
80,Player makes contact with referee accidentally,Warning,Play on,The Laws of the Game (IFAB) state that physical contact with a mat...,False
81,Player removes shirt during goal celebration,Yellow card,Yellow card,"According to the Laws of the Game (Law 12), removing the shirt dur...",True
82,Player slightly delays throw-in restart,Play on,Warning,A slight delay in taking a throw-in is considered unsporting behav...,False


# Few shot Prompting

In [None]:
teleprompter = BootstrapFewShot(metric=answer_exact_match, max_labeled_demos=10)
compiled_predictor = teleprompter.compile(predict, trainset=train_set)

 55%|█████▌    | 11/20 [00:02<00:02,  3.96it/s]

Bootstrapped 4 full traces after 11 examples for up to 1 rounds, amounting to 11 attempts.





In [None]:
evaluate_program = Evaluate(
    devset=test_set,
    metric=answer_exact_match,
    display_table=10,
    num_threads=8,
    display_progress=True,
)

In [None]:
eval_compiled = evaluate_program(compiled_predictor)

Average Metric: 50.00 / 83 (60.2%): 100%|██████████| 83/83 [00:04<00:00, 18.43it/s]

2025/09/01 21:35:45 INFO dspy.evaluate.evaluate: Average Metric: 50 / 83 (60.2%)





Unnamed: 0,question,example_answer,reasoning,pred_answer,answer_exact_match
0,Player commits sexual assault during play,Red card and criminal charges,Sexual assault is an extremely serious criminal act that has no pl...,Red card and criminal charges,✔️ [True]
1,Player makes questionable appeal for foul,Play on,Questionable appeals for a foul are considered acts of unsporting ...,Warning,
2,Player throws bottle at referee from bench,Red card and criminal charges,Throwing a bottle at the referee is an act of violent conduct and ...,Red card and criminal charges,✔️ [True]
3,Player blocks goalkeeper's release of the ball,Yellow card,Blocking the goalkeeper's release of the ball is an offense that p...,Indirect free kick,
4,Player pulls opponent's hair during play,Yellow card,"Pulling an opponent's hair is an act of violent conduct, as it is ...",Red card,
5,Player simulates contact to deceive referee,Yellow card,Simulating contact (diving) to deceive the referee is an act of un...,Yellow card,✔️ [True]
6,Player pushes an opponent during corner kick,Warning or yellow card,"Pushing an opponent during a corner kick is a foul, as it is a for...",Warning,
7,Player headbutts an opponent,Red card and criminal charges,"A headbutt is an act of violent conduct, which is a serious foul p...",Red card and criminal charges,✔️ [True]
8,Player commits a late tackle after ball has gone,Yellow card,A late tackle after the ball has gone is a reckless challenge that...,Red card,
9,Player enters opponent's penalty area before penalty kick,Warning and retake,Entering the opponent's penalty area before a penalty kick is take...,Warning and retake,✔️ [True]


# Refine x BestofN

In [None]:
class PredictModel(dspy.Module):
    def __init__(self):
        self.predict = ChainOfThought(RefereeAnswer)

    def forward(self, question):
        output = self.predict(question=question)
        return dspy.Prediction(answer=output.answer)


predict_model = PredictModel()
predict_model(
    question="Player makes questionable appeal during foul"
)  # Shoulda been yellow card

Prediction(
)

In [None]:
"""
By default, Refine will try to run the module up to N times until the threshold is met. If the module encounters an error, it will keep going up to N failed attempts
"""


# Reward function: check if answer is valid
def reward_fn(example, prediction, trace=None):
    allowed_answers = [
        "Red card and criminal charges",
        "Yellow card",
        "Red card",
        "Warning",
        "Warning and retake",
        "Warning or yellow card",
        "Play on",
        "Red card and lengthy ban",
        "Indirect free kick",
        "Free kick",
        "Penalty and red card",
        "Yellow card or penalty if in box",
        "Red card and penalty",
        "Red card and penalty if in box",
    ]
    return 1.0 if prediction.answer in allowed_answers else 0.0


class PredictModelRefine(dspy.Module):
    def __init__(self):
        super().__init__()

        # Refine wraps the predictor
        self.predict = dspy.Refine(
            module=ChainOfThought(RefereeAnswer),
            N=5,  # generate 5 candidates
            reward_fn=reward_fn,  # use custom reward check
            threshold=0.5,  # accept if score ≥ 0.5
        )

    def forward(self, question):
        output = self.predict(question=question)
        return Prediction(answer=output.answer)

In [None]:
predict_model = PredictModelRefine()
predict_model(
    question="Player makes questionable appeal during foul"
)  # Outputs the right thing

Prediction(
    answer='Yellow card'
)

In [None]:
"""
Runs a module up to N times with different rollout IDs (bypassing cache) and returns the best prediction, as defined by the reward_fn, or the first prediction that passes the threshold.
"""


class PredictModelBestofN(dspy.Module):
    def __init__(self):
        super().__init__()

        # Refine wraps the predictor
        self.predict = dspy.BestOfN(
            module=ChainOfThought(RefereeAnswer),
            N=5,  # generate 5 candidates
            reward_fn=reward_fn,  # use custom reward check
            threshold=0.5,  # accept if score ≥ 0.5
        )

    def forward(self, question):
        output = self.predict(question=question)
        return Prediction(answer=output.answer)


predict_model = PredictModelBestofN()
predict_model(
    question="Player makes questionable appeal during foul"
)  # Outputs the right thing

Prediction(
    answer='Yellow card'
)