# Generate Test Set

In this notebook, we generate a test set using model version v0.1.0. The goal is to create a representative and balanced dataset. To achieve this, we classify a large number of patents, shuffle the results, and then sample a few instances from each class. Finally, we manually verify and correct the labels for each class to ensure accuracy.

## Step 1: Generate Initial Predictions

We classify all patents using the model from `02-classify_patent_v0.1.0.ipynb`.  
For each patent description, we store the following information:

- `num_patent`: Patent number  
- `num_desc`: Description number  
- `desc`: Patent description text  
- `sdg_pred`: Predicted SDG (Sustainable Development Goal) class

The results are saved in `classified_patents_raw.jsonl`.


### 🔹 Step 1.1. Load all Patents

In [None]:
from tqdm.notebook import tqdm
from api.services.patent_service import get_all_patents

def get_all_patents_number():
    """
    Retrieve all patent numbers from the database in batches of 100.

    Returns:
        list: A list containing all patent numbers.
    """
    all_patents_number = []

    # Step 1: Get total number of patents
    patents = get_all_patents()
    total_patents = patents.total_count

    # Step 2: Iterate through all patents in batches of 100 with progress tracking
    for i in tqdm(range(0, total_patents, 100), desc="Fetching patent numbers by batch"):
        # Step 3: Fetch a batch of patents
        patents_batch = get_all_patents(first=i, last=i+100)

        # Step 4: Extract patent numbers from the current batch
        for patent in patents_batch.patents:
            all_patents_number.append(patent.number)

    return all_patents_number

# Example usage
all_patents_number = get_all_patents_number()

In [None]:
# Step 3: Compare and report the result
patents_metadata = get_all_patents()
total_expected = patents_metadata.total_count

print(f"Expected total patents: {total_expected}")
print(f"Total patent numbers collected: {len(all_patents_number)}")
print(f"Retrieved unique patent numbers: {len(list(set(all_patents_number)))}")

### 🔹 Step 1.2. Analyse patents


#### 1.2.1. Define model

In [None]:
from tqdm.notebook import tqdm
from transformers import pipeline
from api.config.ai_config import ai_huggingface_token
from api.models.Patent import FullPatent
from api.repositories.patent_repository import update_full_patent

# Dict of SDG candidate labels
sdg_labels_dict = {
    "SDG1": "End poverty in all its forms everywhere", 
    "SDG2": "End hunger, achieve food security and improved nutrition and promote sustainable agriculture", 
    "SDG3": "Ensure healthy lives and promote well-being for all at all ages", 
    "SDG4": "Ensure inclusive and equitable quality education and promote lifelong learning opportunities for all", 
    "SDG5": "Achieve gender equality and empower all women and girls", 
    "SDG6": "Ensure availability and sustainable management of water and sanitation for all", 
    "SDG7": "Ensure access to affordable, reliable, sustainable and modern energy for all", 
    "SDG8": "Promote sustained, inclusive and sustainable economic growth, full and productive employment and decent work for all", 
    "SDG9": "Build resilient infrastructure, promote inclusive and sustainable industrialization and foster innovation", 
    "SDG10": "Reduce inequality within and among countries", 
    "SDG11": "Make cities and human settlements inclusive, safe, resilient and sustainable", 
    "SDG12": "Ensure sustainable consumption and production patterns", 
    "SDG13": "Take urgent action to combat climate change and its impacts", 
    "SDG14": "Conserve and sustainably use the oceans, seas and marine resources for sustainable development", 
    "SDG15": "Protect, restore and promote sustainable use of terrestrial ecosystems, sustainably manage forests, combat desertification, and halt and reverse land degradation and halt biodiversity loss", 
    "SDG16": "Promote peaceful and inclusive societies for sustainable development, provide access to justice for all and build effective, accountable and inclusive institutions at all levels", 
    "SDG17": "Strengthen the means of implementation and revitalize the Global Partnership for Sustainable Development"
}

candidate_label_values = list(sdg_labels_dict.values())

# Initialize the classifier
classifier = pipeline(model="facebook/bart-large-mnli", token=ai_huggingface_token)


In [None]:
def get_sdg_code_from_label(label: str, label_dict: dict) -> str:
    """Reverse lookup SDG code from full label text."""
    for code, text in label_dict.items():
        if label == text:
            return code
    return "None"


def classify_full_patent_description(patent: FullPatent,
                                     classifier=classifier,
                                     candidate_labels=candidate_label_values,
                                     label_dict=sdg_labels_dict,
                                     treshold: float = 0.18) -> FullPatent:
    """
    Classify all description blocks in a FullPatent and enrich them with SDG labels.

    Args:
        patent (FullPatent): The patent to analyze.
        classifier: HuggingFace classifier.
        candidate_labels (list): SDG label texts.
        label_dict (dict): Map from SDG label text to SDG code.
        treshold (float): Minimum score to accept prediction.

    Returns:
        FullPatent: Enriched object.
    """

    # Step 1: Filter descriptions with enough length
    valid_descriptions = [(desc, desc.description_text) 
                          for desc in patent.description 
                          if len(desc.description_text.split()) > 20]

    # Step 2: Extract just the text for classification
    texts_to_classify = [text for _, text in valid_descriptions]

    # Step 3: Run classifier on batch
    results = classifier(texts_to_classify, candidate_labels=candidate_labels)

    # Step 4: Assign results back to descriptions
    for (desc, _), result in zip(valid_descriptions, results):

        try:
            top_score = result["scores"][0]
            if top_score >= treshold:
                label_text = result["labels"][0]
                desc.sdg = get_sdg_code_from_label(label_text, label_dict)
            else:
                desc.sdg = "None"
                top_score = -1

            # print(f"[{desc.description_number}] Label: {desc.sdg} | Score: {top_score:.3f} | Text: {desc.description_text}")
        except Exception as e:
            print(f"Error on description {desc.description_number}: {e}")
            desc.sdg = "Error"

    # Step 5: Handle short descriptions (not classified)
    for desc in patent.description:
        if len(desc.description_text.split()) <= 20:
            desc.sdg = "None"

    patent.is_analyzed = True

    # Update the Patent in Database
    update_full_patent(patent.model_dump())

    return patent

#### 1.2.2. Run analyse

In [14]:
import os
import json
import random
from typing import List, Dict
from api.services.patent_service import get_full_patent_by_number

def load_already_classified_patents(file_path: str) -> set:
    """Charge les brevets déjà enregistrés dans le fichier JSONL."""
    if not os.path.exists(file_path):
        return set()

    classified = set()
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                item = json.loads(line.strip())
                classified.add(item["patent_number"])
            except Exception as e:
                print(f"Error reading line: {e}")
    return classified


def save_classified_descriptions(data: List[Dict], file_path: str):
    """Ajoute les nouvelles descriptions à la fin du fichier JSONL."""
    with open(file_path, "a", encoding="utf-8") as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")


def analyze_patents_and_save_descriptions(patent_numbers: List[str], export_file: str = "classified_patents_raw_test.jsonl"):
    already_classified = load_already_classified_patents(export_file)
    to_process = [pn for pn in patent_numbers if pn not in already_classified]

    print(f"Total patents to process: {len(to_process)}")
    random.shuffle(to_process)

    for patent_number in to_process:
        try:
            print(f"\nProcessing patent {patent_number}...")
            patent = get_full_patent_by_number(patent_number)
            enriched_patent = classify_full_patent_description(patent)

            output_data = []
            for desc in enriched_patent.description:
                output_data.append({
                    "patent_number": patent_number,
                    "description_number": desc.description_number,
                    "description_text": desc.description_text,
                    "sdg": desc.sdg
                })

            save_classified_descriptions(output_data, export_file)

        except Exception as e:
            print(f"Error processing {patent_number}: {e}")

    print("\nProcessing completed and data saved.")

In [None]:
# Analyse all patents
analyze_patents_and_save_descriptions(all_patents_number, export_file="classified_patents_raw.jsonl" )

#### 1.2.3. Remove Duplicates

To ensure the quality of the dataset, we remove duplicate entries (doublons) from the classified results.

In [None]:
import json

def analyze_and_save_clean_jsonl(file_path, output_path):
    seen_pairs = set()
    total_lines = 0
    duplicate_lines = 0
    duplicate_patents = []

    with open(file_path, "r", encoding="utf-8") as infile, \
         open(output_path, "w", encoding="utf-8") as outfile:

        for line in infile:
            try:
                total_lines += 1
                item = json.loads(line.strip())
                pair = (item.get("patent_number"), item.get("description_text"))

                if pair in seen_pairs:
                    duplicate_lines += 1
                    duplicate_patents.append(item.get("patent_number"))
                else:
                    seen_pairs.add(pair)
                    outfile.write(json.dumps(item, ensure_ascii=False) + "\n")
            except Exception as e:
                print(f"Erreur parsing JSON : {e}")

    print(f"Total lines: {total_lines}")
    print(f"Duplicate (patent_number, description_text) pairs: {duplicate_lines}")
    print(f"Unique (patent_number, description_text) pairs saved: {len(seen_pairs)}")

    return total_lines, duplicate_lines, duplicate_patents


input_path = "classified_patents_raw.jsonl"
output_path = "classified_patents_raw_clean.jsonl"
analyze_and_save_clean_jsonl(input_path, output_path)

#### 1.2.4. Check if duplicates

In [None]:
def analyze_jsonl_by_patent_and_description(file_path):
    seen_pairs = set()
    total_lines = 0
    duplicate_lines = 0
    duplicate_patents = []


    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                total_lines += 1
                item = json.loads(line.strip())
                pair = (item.get("patent_number"), item.get("description_text"))

                if pair in seen_pairs:
                    duplicate_lines += 1
                    duplicate_patents.append(item.get("patent_number"))
                else:
                    seen_pairs.add(pair)
            except Exception as e:
                print(f"Erreur parsing JSON : {e}")

    print(f"Total lines: {total_lines}")
    print(f"Duplicate (patent_number, description_number) pairs: {duplicate_lines}")
    print(f"Unique (patent_number, description_number) pairs: {len(seen_pairs)}")

    return total_lines, duplicate_lines, duplicate_patents

# Exemple d'utilisation
analyze_jsonl_by_patent_and_description('classified_patents_raw_clean.jsonl')

## Step 2 – Create a balanced test set

To ensure class balance:
- Group data by `sdg_pred`
- Shuffle each group
- Sample 10 items per class

The selected items are saved in testset_[version]_[language].jsonl.

#### 1.3.1. Generate

In [None]:
import json
from collections import defaultdict
import random
from typing import List
from langdetect import detect, DetectorFactory
from langdetect.lang_detect_exception import LangDetectException

DetectorFactory.seed = 0  # Résultats reproductibles

def detect_language(text: str) -> str:
    try:
        lang = detect(text)
        if lang in {"fr", "en", "de"}:
            return lang
    except LangDetectException:
        pass
    return None

def generate_classified_patents_raw(
    jsonl_file: str,
    version: str = "v0",
    max_per_sdg: int = 10
):
    """
    Optimized: Detect language once per unique patent_number, applied to all its entries.
    Also prints statistics per SDG.
    """
    
    # 1. Group items by patent_number (to ensure uniqueness and batch language detection)
    items_by_patent = defaultdict(list)
    with open(jsonl_file, "r", encoding="utf-8") as f:
        for line in f:
            try:
                item = json.loads(line.strip())
                patent_number = item.get("patent_number")
                if patent_number:
                    items_by_patent[patent_number].append(item)
            except Exception as e:
                print(f"Erreur parsing JSON : {e}")

    # 2. Detect language once per patent_number, then group items by language and SDG
    grouped_by_lang = {
        "fr": defaultdict(list),
        "en": defaultdict(list),
        "de": defaultdict(list)
    }

    for patent_number, items in items_by_patent.items():
        # Find first non-empty description text to detect language
        for item in items:
            text = item.get("description_text", "").strip()
            if text:
                lang = detect_language(text)
                break
        else:
            lang = None

        # If language detected and supported, group items by SDG under that language
        if lang in grouped_by_lang:
            for item in items:
                sdg = item.get("sdg")
                if sdg is not None:
                    grouped_by_lang[lang][sdg].append(item)

    # 3. For each language, shuffle and select up to max_per_sdg items per SDG, then write to file
    for lang, grouped in grouped_by_lang.items():
        testset: List[dict] = []
        per_class_count = {}
        for sdg, items in grouped.items():
            random.shuffle(items)
            selected = items[:max_per_sdg]
            testset.extend(selected)
            per_class_count[sdg] = len(selected)

        output_file = f"testset_{version}_{lang}_raw.jsonl"
        with open(output_file, "w", encoding="utf-8") as out:
            for item in testset:
                out.write(json.dumps(item, ensure_ascii=False) + "\n")


        print(f"\n{output_file} généré avec {len(testset)} éléments.")
        print("Répartition par SDG :")
        for sdg in sorted(per_class_count):
            print(f"  SDG {sdg:>2}: {per_class_count[sdg]} éléments")

In [None]:
generate_classified_patents_raw(jsonl_file="classified_patents_raw_clean.jsonl", version="v1", max_per_sdg=10)


testset_v3_fr_raw.jsonl généré avec 98 éléments.
Répartition par SDG :
  SDG None: 10 éléments
  SDG SDG1: 1 éléments
  SDG SDG10: 10 éléments
  SDG SDG12: 10 éléments
  SDG SDG13: 6 éléments
  SDG SDG14: 10 éléments
  SDG SDG15: 10 éléments
  SDG SDG16: 3 éléments
  SDG SDG17: 6 éléments
  SDG SDG3: 10 éléments
  SDG SDG4: 2 éléments
  SDG SDG6: 10 éléments
  SDG SDG9: 10 éléments

testset_v3_en_raw.jsonl généré avec 122 éléments.
Répartition par SDG :
  SDG None: 10 éléments
  SDG SDG10: 10 éléments
  SDG SDG11: 10 éléments
  SDG SDG12: 10 éléments
  SDG SDG13: 10 éléments
  SDG SDG14: 8 éléments
  SDG SDG15: 10 éléments
  SDG SDG16: 4 éléments
  SDG SDG17: 7 éléments
  SDG SDG3: 10 éléments
  SDG SDG4: 10 éléments
  SDG SDG6: 10 éléments
  SDG SDG8: 3 éléments
  SDG SDG9: 10 éléments

testset_v3_de_raw.jsonl généré avec 78 éléments.
Répartition par SDG :
  SDG None: 10 éléments
  SDG SDG10: 10 éléments
  SDG SDG12: 10 éléments
  SDG SDG13: 2 éléments
  SDG SDG14: 9 éléments
  SDG S

#### 1.3.2. Check duplicates

In [None]:
import json
from collections import defaultdict

def trouver_doublons_jsonl(fichier_jsonl):
    """
    Find duplicates in a JSONL file based on 'patent_number' and 'description_text'.

    Args:
        jsonl_file (str): Path to the JSONL file.

    Returns:
        List[dict]: List of duplicate entries (each as a dictionary).
    """
    vus = defaultdict(list)
    doublons = []

    with open(fichier_jsonl, 'r', encoding='utf-8') as f:
        for ligne in f:
            try:
                item = json.loads(ligne.strip())
                cle = (item.get('patent_number'), item.get('description_text'))
                vus[cle].append(item)
            except json.JSONDecodeError as e:
                print(f"Skipped line due to parsing error : {e}")

    for items in vus.values():
        if len(items) > 1:
            doublons.extend(items)

    return doublons


In [None]:
trouver_doublons_jsonl("testset_v1_en_raw.jsonl")

[]

In [None]:
trouver_doublons_jsonl("testset_v1_fr_raw.jsonl")

In [None]:
trouver_doublons_jsonl("testset_v1_de_raw.jsonl")

## Optional Functions

### Retrieve all patent numbers from the database

If you want to retrieve the labeled data from the database instead of regenerating it, you can use the function below.

In [3]:
from tqdm.notebook import tqdm
from api.services.patent_service import get_all_patents

def get_all_patents_number_if_analysed():
    """
    Retrieve all patent numbers from the database in batches of 100.

    Returns:
        list: A list containing all patent numbers.
    """
    all_patents_number = []

    # Step 1: Get total number of patents
    patents = get_all_patents()
    total_patents = patents.total_count

    # Step 2: Iterate through all patents in batches of 100 with progress tracking
    for i in tqdm(range(0, total_patents, 100), desc="Fetching patent numbers by batch"):
        # Step 3: Fetch a batch of patents
        patents_batch = get_all_patents(first=i, last=i+100)

        # Step 4: Extract patent numbers from the current batch
        for patent in patents_batch.patents:
            if patent.is_analyzed:
                all_patents_number.append(patent.number)

    return all_patents_number

# Example usage
all_patents_number_if_analysed = get_all_patents_number_if_analysed()

Fetching patent numbers by batch:   0%|          | 0/234 [00:00<?, ?it/s]

In [5]:
from typing import List

def db_get_patents_and_save_descriptions(patent_numbers: List[str], export_file: str = "classified_patents_raw_test.jsonl"):
    already_classified = load_already_classified_patents(export_file)
    to_process = [pn for pn in patent_numbers if pn not in already_classified]

    print(f"Total patents to process: {len(to_process)}")


    for patent_number in tqdm(to_process, desc="Patent recovered"):
        try:
            patent = get_full_patent_by_number(patent_number)
            # If analysed patent save it
            if patent.is_analyzed:
                output_data = []
                for desc in patent.description:
                    output_data.append({
                        "patent_number": patent_number,
                        "description_number": desc.description_number,
                        "description_text": desc.description_text,
                        "sdg": desc.sdg
                    })

                save_classified_descriptions(output_data, export_file)

        except Exception as e:
            print(f"Error processing {patent_number}: {e}")

    print("\nProcessing completed and data saved.")

In [None]:
db_get_patents_and_save_descriptions(all_patents_number_if_analysed, export_file="classified_patents_raw_db.jsonl" )

Total patents to process: 0


Patent recovered: 0it [00:00, ?it/s]


Processing completed and data saved.


In [None]:
#  Remove duplicates
input_path = "classified_patents_raw_db.jsonl"
output_path = "classified_patents_raw_db_clean.jsonl"
analyze_and_save_clean_jsonl(input_path, output_path)

In [None]:
analyze_jsonl_by_patent_and_description('classified_patents_raw_clean.jsonl')

In [None]:
generate_classified_patents_raw(jsonl_file="classified_patents_raw_db_clean.jsonl", version="db_v1", max_per_sdg=10)


testset_v2_fr_raw.json généré avec 98 éléments.
Répartition par SDG :
  SDG None: 10 éléments
  SDG SDG1: 1 éléments
  SDG SDG10: 10 éléments
  SDG SDG12: 10 éléments
  SDG SDG13: 6 éléments
  SDG SDG14: 10 éléments
  SDG SDG15: 10 éléments
  SDG SDG16: 3 éléments
  SDG SDG17: 6 éléments
  SDG SDG3: 10 éléments
  SDG SDG4: 2 éléments
  SDG SDG6: 10 éléments
  SDG SDG9: 10 éléments

testset_v2_en_raw.json généré avec 122 éléments.
Répartition par SDG :
  SDG None: 10 éléments
  SDG SDG10: 10 éléments
  SDG SDG11: 10 éléments
  SDG SDG12: 10 éléments
  SDG SDG13: 10 éléments
  SDG SDG14: 8 éléments
  SDG SDG15: 10 éléments
  SDG SDG16: 4 éléments
  SDG SDG17: 7 éléments
  SDG SDG3: 10 éléments
  SDG SDG4: 10 éléments
  SDG SDG6: 10 éléments
  SDG SDG8: 3 éléments
  SDG SDG9: 10 éléments

testset_v2_de_raw.json généré avec 78 éléments.
Répartition par SDG :
  SDG None: 10 éléments
  SDG SDG10: 10 éléments
  SDG SDG12: 10 éléments
  SDG SDG13: 2 éléments
  SDG SDG14: 9 éléments
  SDG SDG1

In [None]:
trouver_doublons_jsonl("testset_db_v1_en_raw.json")

[]

In [None]:
trouver_doublons_jsonl("testset_db_v1_fr_raw.json")

[]

In [None]:
trouver_doublons_jsonl("testset_db_v1_de_raw.json")

[]

### Generate SDG justification summaries for a FullPatent.

This function allows you to generate summary justifications for the patents that have already been classified. These summaries can then be used to update the database, ensuring that the front-end displays the most up-to-date and relevant information.

In [28]:
from api.models.Patent import FullPatent
from api.models.SDGSummary import SDGSummary

def generate_summary(patent: FullPatent,
                     ai_client,
                     ai_model: str,
                     sdg_labels_dict: dict) -> list[SDGSummary]:
    """
    Generate SDG justification summaries for a FullPatent.

    This function groups description blocks by their assigned SDG labels 
    (excluding "None" or "Error"), then uses an AI model to generate a 
    summary explaining how the content supports the respective SDG.

    Args:
        patent (FullPatent): The patent object containing the description blocks.
        ai_client: AI client used to generate summaries.
        ai_model (str): The identifier of the AI model to be used.
        sdg_labels_dict (dict): Mapping from SDG codes to their textual descriptions.

    Returns:
        list[SDGSummary]: A list of SDGSummary objects created for the patent.
    """

    summaries = []

    # Group descriptions by SDG label, excluding "None" and "Error"
    sdg_to_descriptions = {}
    for desc in patent.description:
        if desc.sdg not in ["None", "Error"]:
            sdg_to_descriptions.setdefault(desc.sdg, []).append(desc)

    for sdg_code, descriptions in tqdm(sdg_to_descriptions.items(), desc=f"Generating summaries for patent {patent.number}"):
        try:
            sdg_description = sdg_labels_dict[sdg_code]
            combined_text = "\n".join(desc.description_text for desc in descriptions)

            system_prompt = f"""
            You are an AI specialized in sustainable development and patents. Read the following patent excerpt and explain how it contributes to this Sustainable Development Goal (SDG): {sdg_code} - {sdg_description}.

            Focus on:
            - The main innovation or idea in the patent.
            - How it supports the targets of the SDG.
            - Any positive impact (social, environmental, or economic) it may have.

            Patent text:
            {combined_text}

            Write a short, clear summary showing the link between the patent and the SDG.
            """

            user_prompt = f"""
            Summarize how this patent helps achieve the SDG: {sdg_code} - {sdg_description}."""

            response = ai_client.chat(
                model=ai_model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}],
                options = {"num_predict":512}
            )

            summary_text = response["message"]["content"].strip()

            summary = SDGSummary(
                patent_number=str(patent.number),
                sdg=str(sdg_code),
                sdg_description=summary_text
            )

            summaries.append(summary)

            # print(f"[{patent.number}] SDG: {sdg_code} | Summary:\n{summary_text}\n")

        except Exception as e:
            print(f"Error generating summary for SDG {sdg_code} in patent {patent.number}: {e}")

    return summaries


sdg_labels_dict = {
    "SDG1": "End poverty in all its forms everywhere", 
    "SDG2": "End hunger, achieve food security and improved nutrition and promote sustainable agriculture", 
    "SDG3": "Ensure healthy lives and promote well-being for all at all ages", 
    "SDG4": "Ensure inclusive and equitable quality education and promote lifelong learning opportunities for all", 
    "SDG5": "Achieve gender equality and empower all women and girls", 
    "SDG6": "Ensure availability and sustainable management of water and sanitation for all", 
    "SDG7": "Ensure access to affordable, reliable, sustainable and modern energy for all", 
    "SDG8": "Promote sustained, inclusive and sustainable economic growth, full and productive employment and decent work for all", 
    "SDG9": "Build resilient infrastructure, promote inclusive and sustainable industrialization and foster innovation", 
    "SDG10": "Reduce inequality within and among countries", 
    "SDG11": "Make cities and human settlements inclusive, safe, resilient and sustainable", 
    "SDG12": "Ensure sustainable consumption and production patterns", 
    "SDG13": "Take urgent action to combat climate change and its impacts", 
    "SDG14": "Conserve and sustainably use the oceans, seas and marine resources for sustainable development", 
    "SDG15": "Protect, restore and promote sustainable use of terrestrial ecosystems, sustainably manage forests, combat desertification, and halt and reverse land degradation and halt biodiversity loss", 
    "SDG16": "Promote peaceful and inclusive societies for sustainable development, provide access to justice for all and build effective, accountable and inclusive institutions at all levels", 
    "SDG17": "Strengthen the means of implementation and revitalize the Global Partnership for Sustainable Development"
}

In [34]:
from api.repositories.sdg_summary_repository import get_sdg_summary_by_patent_number
from api.services.patent_service import get_full_patent_by_number
from api.repositories.sdg_summary_repository import create_sdg_summary
from api.config.ai_config import ai_model, ai_client
from tqdm.notebook import tqdm


def generate_summaries_for_patent_in_db_analysed(patent_numbers: List[str]):

    for patent_number in tqdm(patent_numbers, desc="Patent recovered"):
        try:
            # We check if a summary already exists
            summaries = get_sdg_summary_by_patent_number(patent_number)

            if not summaries:
                patent = get_full_patent_by_number(patent_number)
                summaries = generate_summary(patent, ai_client, ai_model, sdg_labels_dict)
                
                # Save in bdd
                for summary in summaries:
                    create_sdg_summary(summary.model_dump())

        except Exception as e:
            print(f"Error processing {patent_number}: {e}")

    print("\nProcessing completed and summaries generated.")

In [None]:
generate_summaries_for_patent_in_db_analysed(all_patents_number_if_analysed)