In [27]:
! pip install mistralai

Collecting mistralai
  Downloading mistralai-1.5.0-py3-none-any.whl.metadata (29 kB)
Collecting eval-type-backport>=0.2.0 (from mistralai)
  Using cached eval_type_backport-0.2.2-py3-none-any.whl.metadata (2.2 kB)
Collecting jsonpath-python>=1.0.6 (from mistralai)
  Using cached jsonpath_python-1.0.6-py3-none-any.whl.metadata (12 kB)
Collecting typing-inspect>=0.9.0 (from mistralai)
  Using cached typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting mypy-extensions>=0.3.0 (from typing-inspect>=0.9.0->mistralai)
  Using cached mypy_extensions-1.0.0-py3-none-any.whl.metadata (1.1 kB)
Downloading mistralai-1.5.0-py3-none-any.whl (271 kB)
Using cached eval_type_backport-0.2.2-py3-none-any.whl (5.8 kB)
Using cached jsonpath_python-1.0.6-py3-none-any.whl (7.6 kB)
Using cached typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)
Using cached mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)
Installing collected packages: mypy-extensions, jsonpath-python, eval-type-backport, typing-inspe

In [28]:
from openai import OpenAI
import pandas as pd
from tqdm import tqdm
import os
import time
from mistralai import Mistral

In [9]:
df_chambre_int = pd.read_csv('df_chambre_sample.csv')

In [None]:
api_key = ''
model = "mistral-small-latest"

client = Mistral(api_key=api_key)

def classify_text(text):

    prompt = (
        "Veuillez déterminer si le texte suivant traite des relations entre la France et d'autres pays, "
        "ou d'événements spécifiques de la politique internationale en dehors de la France. Veuillez juger avec rigueur. "
        "Si c'est le cas, retournez 1 ; sinon, ou si le texte ne fait que mentionner des noms liés, retournez 0.\n"
        "Vous devez strictement respecter le format requis : vous ne pouvez répondre que par 0 ou 1, rien d'autre.\n"
        f"Texte : {text}\n"
        "Votre réponse (uniquement 0 ou 1) :"
        )



    max_retries = 3
    retry_delay = 2  
    
    for attempt in range(max_retries):
        try:
            response = client.chat.complete(
                model="mistral-small-latest",
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt}],
                stream=False,
            )
            
            llm_reply = response.choices[0].message.content.strip()
            

            print(f"Text: {text[:50]}...\nLLM Reply: {llm_reply}\n")
            time.sleep(0.5)

            if llm_reply in ["0", "1"]:
                return int(llm_reply)
            else:
                print(f"Error in LLM's response: {llm_reply}, trying to extract numbers")
            
                for char in llm_reply:
                    if char in ["0", "1"]:
                        return int(char)
                        
                print(f"Can't extract number, use -1 to represent anormaly")
                return -1  
                
        except Exception as e:
            if attempt < max_retries - 1:
                print(f"Erro (Retry in {attempt+1}/{max_retries}): {e}，{retry_delay}")
                time.sleep(retry_delay)
            else:
                print(f"Error, max time reached: {e}")
                return -1  

def process_dataframe_with_checkpoint(df, text_column, result_column, checkpoint_file="classification_checkpoint.csv"):
    """Save and load checkpoint"""
    

    if os.path.exists(checkpoint_file):
        print(f"Find checkppoint: {checkpoint_file}, Resuming")
        df_checkpoint = pd.read_csv(checkpoint_file)
        

        if 'index' in df_checkpoint.columns:
            processed_indices = df_checkpoint['index'].tolist()
            for idx in processed_indices:
                if idx < len(df):
                    df.at[idx, result_column] = df_checkpoint.loc[df_checkpoint['index'] == idx, result_column].values[0]
            

            total_indices = set(range(len(df)))
            unprocessed_indices = list(total_indices - set(processed_indices))
            unprocessed_indices.sort() 
        else:

            print("No index column, start over")
            unprocessed_indices = list(range(len(df)))
    else:

        print("No checkpoint, start over")
        df[result_column] = -2 
        unprocessed_indices = list(range(len(df)))

    for idx in tqdm(unprocessed_indices, desc="Processing"):
        text = df.at[idx, text_column]
        result = classify_text(text)
        df.at[idx, result_column] = result
        

        df_checkpoint = pd.DataFrame({
            'index': [idx],
            text_column: [text],
            result_column: [result]
        })
        

        mode = 'a' if os.path.exists(checkpoint_file) and idx != unprocessed_indices[0] else 'w'
        header = not (os.path.exists(checkpoint_file) and idx != unprocessed_indices[0])
        df_checkpoint.to_csv(checkpoint_file, mode=mode, header=header, index=False)
        
    return df


df_chambre_int = process_dataframe_with_checkpoint(
    df_chambre_int, 
    text_column="Texte", 
    result_column="is_international_politics",
    checkpoint_file="chambre_int_checkpoint.csv"
)


df_chambre_int.to_csv("chambre_int_processed.csv", index=False)



In [None]:
print("\Abstract:")
print(df_chambre_int["is_international_politics"].value_counts())
print("\nFive first rows:")
print(df_chambre_int.head())