# It's time to...T-T-T TRAIN!

We've got a rudimentary pipeline we think will be relatively effective. 

It's time we explored some training.

Before that though, let's get a benchmark of how the program/module performs untrained.

In [1]:
# local imports
from modules import ClassifyByInclusionExclusion
from data import get_synergy_data, create_batched_devset, NON_BIOMED_SRs
from metrics import batch_sr_eval, f1_evaluate, validate_all_criteria

# pkg imports
import dspy

from imblearn.under_sampling import RandomUnderSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# we just need our development set to make sure things are working
non_biomeds_df = get_synergy_data(NON_BIOMED_SRs)
non_biomeds_df

Unnamed: 0,doi,SR_id,SR_title,title,abstract,label_included,relevant
0,https://doi.org/10.1109/indcon.2010.5712716,Hall_2012,A Systematic Literature Review on Fault Predic...,Computer vision based offset error computation...,The use of computer vision based approach has ...,0,False
1,https://doi.org/10.1109/induscon.2010.5740045,Hall_2012,A Systematic Literature Review on Fault Predic...,Design and development of a software for fault...,This paper presents an on-line fault diagnosis...,0,False
2,https://doi.org/10.1109/tpwrd.2005.848672,Hall_2012,A Systematic Literature Review on Fault Predic...,Analytical Approach to Internal Fault Simulati...,A new method for simulating faulted transforme...,0,False
3,https://doi.org/10.1109/icelmach.2008.4799852,Hall_2012,A Systematic Literature Review on Fault Predic...,Nonlinear equivalent circuit model of a tracti...,The paper presents the development of an equiv...,0,False
4,https://doi.org/10.1109/ipdps.2006.1639408,Hall_2012,A Systematic Literature Review on Fault Predic...,Fault tolerance with real-time Java,After having drawn up a state of the art on th...,0,False
...,...,...,...,...,...,...,...
66482,https://doi.org/10.1109/ictai.2010.27,Radjenovic_2013,Software fault prediction metrics: A systemati...,Attribute Selection and Imbalanced Data: Probl...,The data mining and machine learning community...,0,False
66483,https://doi.org/10.1109/acc.2001.945656,Radjenovic_2013,Software fault prediction metrics: A systemati...,Benchmarking of advanced technologies for proc...,Global competition is forcing industrial plant...,0,False
66484,https://doi.org/10.1109/icsess.2010.5552438,Radjenovic_2013,Software fault prediction metrics: A systemati...,Queueing models based performance evaluation a...,Since queueing is a common behavior in compute...,0,False
66485,https://doi.org/10.1109/wicom.2011.6040617,Radjenovic_2013,Software fault prediction metrics: A systemati...,A New Face Detection Method with GA-BP Neural ...,"In this paper, the BP neural network improved ...",0,False


In [3]:
devset = create_batched_devset(non_biomeds_df, size=75) 

In [4]:
# configuring our local gemma3 model
lm = dspy.LM('ollama_chat/gemma3:4b-it-qat', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)
# testing out the LM
lm("Say 'Hello world!'", temperature=0.7) 

['Hello world!\n']

In [None]:
# testing the old model on the devset
# TODO we need to implement a Batch Module and a normal one for Training 
batch_sr_eval(ClassifyByInclusionExclusion(), devset)

Batch: Hall_2012


F1 Score: 0.500 100%|██████████████████████████████████████████████████████████████'| 75/75 [05:38<00:00, ' ' 4.51s/it]'


Confusion Matrix: Counter({'TN': 36, 'FN': 24, 'TP': 13, 'FP': 2})
Precision: 0.867
Recall: 0.351
F1: 0.500
MCC: 0.373
Specificity: 0.973
Batch: Smid_2020


F1 Score: 0.286 100%|██████████████████████████████████████████████████████████████'| 75/75 [05:08<00:00, ' ' 4.12s/it]'


Confusion Matrix: Counter({'TN': 69, 'FN': 3, 'FP': 2, 'TP': 1})
Precision: 0.333
Recall: 0.250
F1: 0.286
MCC: 0.254
Specificity: 0.973
Batch: Radjenovic_2013


F1 Score: 0.400 100%|██████████████████████████████████████████████████████████████'| 75/75 [03:34<00:00, ' ' 2.86s/it]'


Confusion Matrix: Counter({'TN': 71, 'FP': 2, 'TP': 1, 'FN': 1})
Precision: 0.333
Recall: 0.500
F1: 0.400
MCC: 0.389
Specificity: 0.973
Batch: Sep_2021


F1 Score: nan  16%|██████████▏                                                     '| 12/75 [00:46<03:34, ' ' 3.40s/it]'

## Now, we train.

Let's get the dataset together before looking at optimisation strategies.

In [6]:
# SR_ids of biomed systematic reviews
biomed_srs = {
    'Appenzeller-Herzog_2019', 'Bos_2018',
    'Brouwer_2019', 'Chou_2003',
    'Chou_2004', 'Donners_2021',
    'Jeyaraman_2020', 'Leenaars_2019',
    'Leenaars_2020', 'Meijboom_2021',
    'Menon_2022', 'Moran_2021',
    'Muthu_2021', 'Nelson_2002',
    'Oud_2018', 'Walker_2018',
    'Wassenaar_2017', 'Wolters_2018',
    'van_Dis_2020', 'van_der_Valk_2021',
    'van_der_Waal_2022', 'van_de_Schoot_2018',

}

In [7]:
# get the synergy biomed SRs' data
biomeds_df = get_synergy_data(biomed_srs)

In [8]:
trainset = create_batched_devset(biomeds_df, size=50)

### BootstrapFewShotWithRandomSearch

This is where we start before moving on to MIPROVv2 and potentially finetuning.

In [15]:
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
from signatures import InclusionExclusionCriteria, CheckCriteria

In [22]:
config = dict(max_bootstrapped_demos=4, max_labeled_demos=4, num_candidate_programs=8)

In [23]:
def validate_all_criteria_match(example,
                                pred,
                                trace=None) -> bool:
    return example.relevant == all(pred.satisfied)

In [24]:
class NeoInclusionExclusionCriteria(dspy.Signature):
    """
    Output a set of inclusion/exclusion criteria for the screening of a systematic review.
    """

    systematic_review_title: str = dspy.InputField()
    criteria: list[str] = dspy.OutputField(desc="Inclusion/exclusion criteria and their descrptions.")

class NeoCheckCriteria(dspy.Signature):
    """Verify which criteria are satisfied by the title and abstract of a candidate citation."""

    criteria: list[str] = dspy.InputField()
    citation_title: str = dspy.InputField()
    citation_abstract: str = dspy.InputField()
    satisfied: list[bool] = dspy.OutputField(desc="Whether each criteria is satisfied or not.")

In [25]:
class NeoClassifyByInclusionExclusion(dspy.Module):
    def __init__(self):
        self.generate_criteria = dspy.ChainOfThought(NeoInclusionExclusionCriteria)
        self.evaluate_criteria = dspy.ChainOfThought(NeoCheckCriteria)

    def forward(self, sr_title: str, citation_title: str, citation_abstract: str):
        criteria = self.generate_criteria(
            systematic_review_title=sr_title
        ).criteria
        return self.evaluate_criteria(criteria=criteria,
                                      citation_title=citation_title,
                                      citation_abstract=citation_abstract)

In [26]:
neotrainset = []
for sr_id, data in trainset.items():
    sr_title, data = data[0], data[1:][0]
    neotrainset += [dspy.Example(sr_title=sr_title,  
                                 citation_title=e.citation_title, 
                                 citation_abstract=e.citation_abstract, 
                                 relevant=e.relevant)\
                    .with_inputs('sr_title', 'citation_title', 'citation_abstract')
                    for e in data]

In [27]:
teleprompter = BootstrapFewShotWithRandomSearch(metric=validate_all_criteria_match, **config)

Going to sample between 1 and 4 traces per predictor.
Will attempt to bootstrap 8 candidate sets.


In [28]:
optimised_program = teleprompter.compile(NeoClassifyByInclusionExclusion(), trainset=neotrainset)

Average Metric: 763.00 / 1100 (69.4%): 100%|████████████████████████████████████████| 1100/1100 [00:11<00:00, 92.51it/s]

2025/04/22 16:22:04 INFO dspy.evaluate.evaluate: Average Metric: 763 / 1100 (69.4%)



New best score: 69.36 for seed -3
Scores so far: [np.float64(69.36)]
Best score so far: 69.36
Average Metric: 763.00 / 1100 (69.4%): 100%|███████████████████████████████████████| 1100/1100 [00:10<00:00, 104.84it/s]

2025/04/22 16:22:15 INFO dspy.evaluate.evaluate: Average Metric: 763 / 1100 (69.4%)



Scores so far: [np.float64(69.36), np.float64(69.36)]
Best score so far: 69.36


  0%|▎                                                                                 | 4/1100 [00:03<17:28,  1.05it/s]


Bootstrapped 4 full traces after 4 examples for up to 1 rounds, amounting to 4 attempts.
Average Metric: 814.00 / 1100 (74.0%): 100%|██████████████████████████████████████| 1100/1100 [1:31:31<00:00,  4.99s/it]

2025/04/22 17:53:50 INFO dspy.evaluate.evaluate: Average Metric: 814 / 1100 (74.0%)



New best score: 74.0 for seed -1
Scores so far: [np.float64(69.36), np.float64(69.36), np.float64(74.0)]
Best score so far: 74.0


  0%|▎                                                                                | 5/1100 [00:00<00:02, 393.82it/s]


Bootstrapped 4 full traces after 5 examples for up to 1 rounds, amounting to 5 attempts.
Average Metric: 794.00 / 1100 (72.2%): 100%|██████████████████████████████████████| 1100/1100 [1:07:45<00:00,  3.70s/it]

2025/04/22 19:01:35 INFO dspy.evaluate.evaluate: Average Metric: 794 / 1100 (72.2%)



Scores so far: [np.float64(69.36), np.float64(69.36), np.float64(74.0), np.float64(72.18)]
Best score so far: 74.0


  0%|▏                                                                                | 2/1100 [00:00<00:05, 197.02it/s]


Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.
Average Metric: 735.00 / 1100 (66.8%): 100%|████████████████████████████████████████| 1100/1100 [47:23<00:00,  2.59s/it]

2025/04/22 19:48:59 INFO dspy.evaluate.evaluate: Average Metric: 735 / 1100 (66.8%)



Scores so far: [np.float64(69.36), np.float64(69.36), np.float64(74.0), np.float64(72.18), np.float64(66.82)]
Best score so far: 74.0


  0%|▏                                                                                | 2/1100 [00:00<00:03, 311.59it/s]


Bootstrapped 1 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.
Average Metric: 782.00 / 1100 (71.1%): 100%|██████████████████████████████████████| 1100/1100 [1:22:08<00:00,  4.48s/it]

2025/04/22 21:11:07 INFO dspy.evaluate.evaluate: Average Metric: 782 / 1100 (71.1%)



Scores so far: [np.float64(69.36), np.float64(69.36), np.float64(74.0), np.float64(72.18), np.float64(66.82), np.float64(71.09)]
Best score so far: 74.0


  0%|▏                                                                                | 2/1100 [00:00<00:04, 231.05it/s]

Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.





Average Metric: 822.00 / 1100 (74.7%): 100%|██████████████████████████████████████| 1100/1100 [2:18:03<00:00,  7.53s/it]

2025/04/22 23:29:11 INFO dspy.evaluate.evaluate: Average Metric: 822 / 1100 (74.7%)



New best score: 74.73 for seed 3
Scores so far: [np.float64(69.36), np.float64(69.36), np.float64(74.0), np.float64(72.18), np.float64(66.82), np.float64(71.09), np.float64(74.73)]
Best score so far: 74.73


  0%|▏                                                                                | 2/1100 [00:00<00:05, 217.01it/s]

Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.





Average Metric: 762.00 / 1100 (69.3%): 100%|██████████████████████████████████████| 1100/1100 [1:22:14<00:00,  4.49s/it]

2025/04/23 00:51:25 INFO dspy.evaluate.evaluate: Average Metric: 762 / 1100 (69.3%)



Scores so far: [np.float64(69.36), np.float64(69.36), np.float64(74.0), np.float64(72.18), np.float64(66.82), np.float64(71.09), np.float64(74.73), np.float64(69.27)]
Best score so far: 74.73


  0%|▎                                                                                | 4/1100 [00:00<00:03, 274.67it/s]

Bootstrapped 3 full traces after 4 examples for up to 1 rounds, amounting to 4 attempts.





Average Metric: 762.00 / 1100 (69.3%): 100%|██████████████████████████████████████| 1100/1100 [1:36:55<00:00,  5.29s/it]

2025/04/23 02:28:22 INFO dspy.evaluate.evaluate: Average Metric: 762 / 1100 (69.3%)



Scores so far: [np.float64(69.36), np.float64(69.36), np.float64(74.0), np.float64(72.18), np.float64(66.82), np.float64(71.09), np.float64(74.73), np.float64(69.27), np.float64(69.27)]
Best score so far: 74.73


  0%|                                                                                 | 1/1100 [00:00<00:03, 319.40it/s]

Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.





Average Metric: 809.00 / 1100 (73.5%): 100%|██████████████████████████████████████| 1100/1100 [1:24:59<00:00,  4.64s/it]

2025/04/23 03:53:22 INFO dspy.evaluate.evaluate: Average Metric: 809 / 1100 (73.5%)



Scores so far: [np.float64(69.36), np.float64(69.36), np.float64(74.0), np.float64(72.18), np.float64(66.82), np.float64(71.09), np.float64(74.73), np.float64(69.27), np.float64(69.27), np.float64(73.55)]
Best score so far: 74.73


  0%|▏                                                                                | 3/1100 [00:00<00:02, 365.90it/s]


Bootstrapped 3 full traces after 3 examples for up to 1 rounds, amounting to 3 attempts.
Average Metric: 756.00 / 1100 (68.7%): 100%|██████████████████████████████████████| 1100/1100 [1:28:11<00:00,  4.81s/it]

2025/04/23 05:21:33 INFO dspy.evaluate.evaluate: Average Metric: 756 / 1100 (68.7%)



Scores so far: [np.float64(69.36), np.float64(69.36), np.float64(74.0), np.float64(72.18), np.float64(66.82), np.float64(71.09), np.float64(74.73), np.float64(69.27), np.float64(69.27), np.float64(73.55), np.float64(68.73)]
Best score so far: 74.73
11 candidate programs found.


In [33]:
optimised_program.save("first_classify_by_inclusion_exclusion_train", save_program=True)