1️⃣ Module Core - core_modules.py

In [16]:
%%writefile core_modules.py
# core_modules.py - Configuration optimisée
from dataclasses import dataclass
from typing import Dict, Any

@dataclass
class ClimateConfig:
    """Configuration centralisée et optimisée"""
    model_name: str = "distilbert-base-uncased"
    max_length: int = 128
    batch_size: int = 16
    epochs: int = 3
    learning_rate: float = 2e-5
    lora_r: int = 8
    lora_alpha: int = 16
    output_dir: str = "outputs/final_model"

    # Configuration Q&A
    qa_model: str = "all-MiniLM-L6-v2"
    similarity_threshold: float = 0.3
    max_results: int = 5

Overwriting core_modules.py


2️⃣ Module Data Processing - data_modules.py

In [17]:
%%writefile data_modules.py
import pandas as pd
from datasets import Dataset
from sklearn.model_selection import train_test_split
import numpy as np
import re
from typing import Tuple

class DataProcessor:
    def __init__(self):
        self.text_col = None
        self.label_col = None
        self.label_mapping = {}
        self.reverse_label_mapping = {}

    def detect_columns(self, df: pd.DataFrame) -> Tuple[str, str]:
        text_keywords = ['text', 'content', 'message', 'comment', 'body', 'description', 'self_text']
        label_keywords = ['label', 'sentiment', 'category', 'class', 'target', 'comment_sentiment']
        text_col = next((c for c in df.columns if any(k in str(c).lower() for k in text_keywords)), None)
        label_col = next((c for c in df.columns if any(k in str(c).lower() for k in label_keywords)), None)
        if not text_col:
            text_col = df.select_dtypes(include=['object']).columns[0]
        if not label_col:
            label_col = df.columns[-1]
        return text_col, label_col

    def clean_text(self, text: str) -> str:
        if pd.isna(text) or str(text).strip().lower() in ['nan', 'none', '', 'null']:
            return None
        text = str(text).strip()
        text = re.sub(r'&gt;|&lt;|&amp;', lambda m: {'&gt;': '>', '&lt;': '<', '&amp;': '&'}[m.group()], text)
        text = re.sub(r'http\S+', '', text)
        text = re.sub(r'\s+', ' ', text)
        return text.strip() if text.strip() else None

    def prepare_datasets(self, df: pd.DataFrame, sample_size: int = 8000) -> Tuple[Dataset, Dataset, Dataset]:
        self.text_col, self.label_col = self.detect_columns(df)
        df_clean = df[[self.text_col, self.label_col]].copy()
        df_clean.columns = ['text', 'label']
        df_clean['text'] = df_clean['text'].apply(self.clean_text)
        df_clean['label'] = df_clean['label'].astype(str)
        df_clean = df_clean.dropna().reset_index(drop=True)
        df_clean = df_clean[df_clean['text'].str.len() >= 10]

        if len(df_clean) > sample_size:
            df_clean = df_clean.sample(n=sample_size, random_state=42)

        unique_labels = sorted(df_clean['label'].unique())
        self.label_mapping = {str(l): i for i, l in enumerate(unique_labels)}
        df_clean['label_id'] = df_clean['label'].map(self.label_mapping)

        # Nettoyage final NaN
        df_clean = df_clean.dropna(subset=['label_id'])
        df_clean['label_id'] = df_clean['label_id'].astype(int)

        if df_clean.empty:
            raise ValueError("❌ Aucune donnée valide après nettoyage.")

        train_df, temp = train_test_split(df_clean, test_size=0.4, stratify=df_clean['label_id'], random_state=42)
        val_df, test_df = train_test_split(temp, test_size=0.5, stratify=temp['label_id'], random_state=42)

        return (
            Dataset.from_pandas(train_df[['text', 'label_id']]),
            Dataset.from_pandas(val_df[['text', 'label_id']]),
            Dataset.from_pandas(test_df[['text', 'label_id']])
        )

Overwriting data_modules.py


3️⃣ Module Modèle - model_modules.py

In [18]:
%%writefile model_modules.py
import os
import logging
import warnings
import torch
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
    EarlyStoppingCallback,
)
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

class ModelManager:
    def __init__(self, config):
        self.config = config
        self.tokenizer = None
        self.peft_model = None

    def setup_tokenizer(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        return self.tokenizer

    def setup_model(self, num_labels: int):
        base_model = AutoModelForSequenceClassification.from_pretrained(
            self.config.model_name,
            num_labels=num_labels,
            torch_dtype=torch.float32,
            problem_type="single_label_classification"
        )
        lora_config = LoraConfig(
            task_type=TaskType.SEQ_CLS,
            r=self.config.lora_r,
            lora_alpha=self.config.lora_alpha,
            lora_dropout=0.1,
            target_modules=["q_lin", "v_lin"],
            bias="none",
        )
        self.peft_model = get_peft_model(base_model, lora_config)
        return self.peft_model

    def tokenize_function(self, examples):
        return self.tokenizer(
            examples["text"],
            truncation=True,
            padding=False,
            max_length=self.config.max_length,
        )

    def compute_metrics(self, eval_pred):
        predictions, labels = eval_pred
        preds = np.argmax(predictions, axis=1)
        return {
            "accuracy": accuracy_score(labels, preds),
            "f1_weighted": f1_score(labels, preds, average="weighted", zero_division=0),
            "precision": precision_score(labels, preds, average="weighted", zero_division=0),
            "recall": recall_score(labels, preds, average="weighted", zero_division=0),
        }

    def setup_training_args(self):
        os.makedirs(self.config.output_dir, exist_ok=True)
        return TrainingArguments(
            output_dir=self.config.output_dir,
            num_train_epochs=self.config.epochs,
            per_device_train_batch_size=self.config.batch_size,
            per_device_eval_batch_size=self.config.batch_size * 2,
            learning_rate=self.config.learning_rate,
            warmup_steps=200,
            weight_decay=0.01,
            eval_strategy="epoch",
            save_strategy="epoch",
            logging_strategy="steps",
            logging_steps=50,
            save_steps=500,
            load_best_model_at_end=True,
            metric_for_best_model="eval_accuracy",
            greater_is_better=True,
            fp16=False,
            bf16=False,
            fp16_full_eval=False,
            bf16_full_eval=False,
            save_total_limit=2,
            report_to="none",
            remove_unused_columns=False,
            dataloader_pin_memory=False,
        )

    def setup_trainer(self, train_dataset, val_dataset):
        return Trainer(
            model=self.peft_model,
            args=self.setup_training_args(),
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=self.tokenizer,
            compute_metrics=self.compute_metrics,
            data_collator=DataCollatorWithPadding(tokenizer=self.tokenizer),
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
        )

Overwriting model_modules.py


4. visualization_modules.py

In [19]:
%%writefile visualization_modules.py
# visualization_modules.py
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import streamlit as st
import numpy as np
import os
import json
from sklearn.metrics import classification_report, confusion_matrix
from collections import Counter
import re
from typing import List, Dict

# --- NLTK / BLEU / ROUGE ---
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

# Téléchargement silencieux des ressources NLTK
nltk.download('punkt_tab', quiet=True)
nltk.download('punkt', quiet=True)

plt.style.use('default')

class VisualizationManager:
    """Gestionnaire de visualisations pour Climate Analyzer."""

    # --------------------------------------------------
    # Outils internes BLEU / ROUGE
    # --------------------------------------------------
    _smoothie = SmoothingFunction().method4
    _rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

    @staticmethod
    def _bleu(ref: str, hyp: str) -> float:
        """Calcule le BLEU score entre deux textes."""
        ref_tok = nltk.word_tokenize(ref.lower())
        hyp_tok = nltk.word_tokenize(hyp.lower())
        return sentence_bleu([ref_tok], hyp_tok, smoothing_function=VisualizationManager._smoothie)

    @staticmethod
    def _rouge_score(ref: str, hyp: str) -> dict:
        """Calcule les scores ROUGE entre deux textes."""
        scores = VisualizationManager._rouge_scorer.score(ref.lower(), hyp.lower())
        return {'rouge-1': scores['rouge1'].fmeasure,
                'rouge-l': scores['rougeL'].fmeasure}

    # --------------------------------------------------
    # 1) Courbes d’entraînement
    # --------------------------------------------------
    @staticmethod
    def plot_training_curves(log_dir: str = "outputs/final_model"):
        try:
            log_file = os.path.join(log_dir, "trainer_state.json")
            if not os.path.exists(log_file):
                st.warning("📄 Aucun log d'entraînement trouvé.")
                return

            with open(log_file, 'r', encoding='utf-8') as f:
                logs = json.load(f)

            history = logs.get('log_history', [])
            if not history:
                st.warning("📉 Aucune donnée d'historique trouvée.")
                return

            epochs, train_loss, eval_loss, eval_acc, eval_f1 = [], [], [], [], []

            for entry in history:
                if 'eval_loss' in entry:
                    epochs.append(entry.get('epoch', 0))
                    eval_loss.append(entry.get('eval_loss', 0))
                    eval_acc.append(entry.get('eval_accuracy', 0))
                    eval_f1.append(entry.get('eval_f1_weighted', 0))
                elif 'train_loss' in entry:
                    train_loss.append(entry.get('train_loss', 0))

            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            fig.suptitle("📈 Évolution de l'entraînement", fontsize=16)

            if train_loss and eval_loss:
                train_steps = np.linspace(0, max(epochs) if epochs else 1, len(train_loss))
                axes[0, 0].plot(train_steps, train_loss, 'b-', label='Train Loss', alpha=0.7)
                axes[0, 0].plot(epochs[:len(eval_loss)], eval_loss, 'r-o', label='Eval Loss', markersize=4)
                axes[0, 0].set_title('Loss Evolution')
                axes[0, 0].legend()
                axes[0, 0].grid(True, alpha=0.3)

            if eval_acc:
                axes[0, 1].plot(epochs[:len(eval_acc)], eval_acc, 'g-o', label='Accuracy', markersize=4)
                axes[0, 1].set_title('Accuracy Evolution')
                axes[0, 1].legend()
                axes[0, 1].grid(True, alpha=0.3)

            if eval_f1:
                axes[1, 0].plot(epochs[:len(eval_f1)], eval_f1, 'm-o', label='F1-Weighted', markersize=4)
                axes[1, 0].set_title('F1-Score Evolution')
                axes[1, 0].legend()
                axes[1, 0].grid(True, alpha=0.3)

            if eval_acc and eval_f1:
                final_metrics = ['Accuracy', 'F1-Score']
                final_values = [eval_acc[-1], eval_f1[-1]]
                bars = axes[1, 1].bar(final_metrics, final_values, color=['green', 'purple'], alpha=0.7)
                axes[1, 1].set_title('Final Metrics')
                axes[1, 1].set_ylim(0, 1)
                for bar, value in zip(bars, final_values):
                    axes[1, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                                   f'{value:.3f}', ha='center', va='bottom')

            plt.tight_layout()
            st.pyplot(fig)
        except Exception as e:
            st.error(f"❌ Erreur lors de l'affichage des courbes : {e}")

    # --------------------------------------------------
    # 2) Matrice de confusion
    # --------------------------------------------------
    @staticmethod
    def show_confusion_matrix(trainer, test_dataset, label_names: List[str]):
        try:
            predictions_output = trainer.predict(test_dataset)
            predictions = predictions_output.predictions.argmax(axis=1)
            true_labels = predictions_output.label_ids

            cm = confusion_matrix(true_labels, predictions)

            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                       xticklabels=label_names, yticklabels=label_names, ax=ax1)
            ax1.set_title("Matrice de confusion")
            ax1.set_xlabel("Prédictions")
            ax1.set_ylabel("Vraies valeurs")

            cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                       xticklabels=label_names, yticklabels=label_names, ax=ax2)
            ax2.set_title("Matrice de confusion (normalisée)")
            plt.tight_layout()
            st.pyplot(fig)

            report = classification_report(true_labels, predictions,
                                         target_names=label_names,
                                         output_dict=True, zero_division=0)
            report_df = pd.DataFrame(report).transpose()
            st.subheader("📊 Rapport de classification")
            st.dataframe(report_df.round(3))
        except Exception as e:
            st.error(f"❌ Erreur lors de l'affichage de la matrice de confusion : {e}")

    # --------------------------------------------------
    # 3) Distribution des classes
    # --------------------------------------------------
    @staticmethod
    def plot_class_distribution(labels, label_names: List[str] = None, title: str = "Distribution des classes"):
        try:
            if hasattr(labels, 'tolist'):
                labels = labels.tolist()
            labels = [int(x) for x in labels]
            label_counts = Counter(labels)

            fig, ax = plt.subplots(figsize=(10, 6))
            if label_names:
                x_labels = [label_names[i] if i < len(label_names) else f"Classe {i}" for i in sorted(label_counts.keys())]
                counts = [label_counts[i] for i in sorted(label_counts.keys())]
            else:
                x_labels = [f"Classe {i}" for i in sorted(label_counts.keys())]
                counts = [label_counts[i] for i in sorted(label_counts.keys())]

            bars = ax.bar(range(len(x_labels)), counts, color=plt.cm.Set3(np.linspace(0, 1, len(x_labels))))
            ax.set_title(title, fontsize=14, fontweight='bold')
            ax.set_xlabel("Classes")
            ax.set_ylabel("Nombre d'échantillons")
            ax.set_xticks(range(len(x_labels)))
            ax.set_xticklabels(x_labels, rotation=45 if max(map(len, x_labels)) > 10 else 0)

            for bar, count in zip(bars, counts):
                ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(counts)*0.01,
                       str(count), ha='center', va='bottom')
            total = sum(counts)
            ax.text(0.02, 0.98, f"Total: {total}\nClasses: {len(x_labels)}",
                   transform=ax.transAxes, va='top', ha='left',
                   bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.7))
            plt.tight_layout()
            st.pyplot(fig)
        except Exception as e:
            st.error(f"❌ Erreur lors de l'affichage de la distribution : {e}")

    # --------------------------------------------------
    # 4) Analyse des résultats Q&A
    # --------------------------------------------------
    @staticmethod
    def plot_qa_results_analysis(qa_results: List[Dict], question: str):
        if not qa_results:
            st.info("Aucun résultat à analyser")
            return
        try:
            scores = [r['score'] for r in qa_results]
            ranks = [r['rank'] for r in qa_results]

            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            fig.suptitle(f"Analyse des résultats pour: '{question[:50]}...'", fontsize=14)

            axes[0, 0].hist(scores, bins=min(10, len(scores)), alpha=0.7, color='skyblue', edgecolor='black')
            axes[0, 0].set_title("Distribution des scores de similarité")
            axes[0, 0].axvline(np.mean(scores), color='red', linestyle='--', label=f'Moyenne: {np.mean(scores):.3f}')
            axes[0, 0].legend()

            axes[0, 1].bar(ranks, scores, color='lightcoral', alpha=0.7)
            axes[0, 1].set_title("Scores par rang")
            axes[0, 1].set_xlabel("Rang")

            text_lengths = [len(r['text']) for r in qa_results]
            axes[1, 0].scatter(text_lengths, scores, alpha=0.6, color='green')
            axes[1, 0].set_title("Score vs Longueur du texte")
            axes[1, 0].set_xlabel("Longueur du texte")

            top_scores = scores[:min(5, len(scores))]
            top_ranks = ranks[:min(5, len(ranks))]
            axes[1, 1].barh(range(len(top_scores)), top_scores, color='purple', alpha=0.7)
            axes[1, 1].set_title("Top 5 des scores")
            axes[1, 1].set_yticks(range(len(top_scores)))
            axes[1, 1].set_yticklabels([f"Rang {r}" for r in top_ranks])

            plt.tight_layout()
            st.pyplot(fig)

            st.subheader("📈 Statistiques détaillées")
            stats_df = pd.DataFrame({
                "Métrique": ["Score moyen", "Score médian", "Score max", "Score min", "Écart-type"],
                "Valeur": [np.mean(scores), np.median(scores), np.max(scores), np.min(scores), np.std(scores)]
            })
            st.dataframe(stats_df.round(4))
        except Exception as e:
            st.error(f"❌ Erreur lors de l'analyse des résultats Q&A : {e}")

    # --------------------------------------------------
    # 5) Méthodes BLEU / ROUGE manquantes
    # --------------------------------------------------
    def calculate_bleu_score(self, reference: str, candidate: str) -> float:
        """Calcule le BLEU score entre deux textes."""
        return self._bleu(reference, candidate)

    def calculate_rouge_score(self, reference: str, candidate: str) -> dict:
        """Calcule les scores ROUGE entre deux textes."""
        return self._rouge_score(reference, candidate)

    def visualize_bleu_rouge_scores(self, qa_results, references):
        """Visualisation BLEU & ROUGE pour chaque paire (ref, résultat)."""
        bleus, r1s, rls = [], [], []
        for ref, res in zip(references, qa_results):
            bleus.append(self._bleu(ref, res['text']))
            r1s.append(self._rouge_score(ref, res['text'])['rouge-1'])
            rls.append(self._rouge_score(ref, res['text'])['rouge-l'])

        x = list(range(1, len(bleus)+1))
        plt.figure(figsize=(10, 4))
        plt.bar([i-0.2 for i in x], bleus, 0.4, label='BLEU')
        plt.bar([i+0.2 for i in x], r1s, 0.4, label='ROUGE-1')
        plt.xlabel('Rang')
        plt.ylabel('Score')
        plt.title('BLEU & ROUGE vs références')
        plt.legend()
        plt.tight_layout()
        st.pyplot(plt.gcf())

    def evaluate_qa_performance(self, qa_module, questions, references):
        """Évaluation complète Q-A avec scores BLEU/ROUGE."""
        bleus, r1s, rls = [], [], []
        for q, ref in zip(questions, references):
            res = qa_module.query_with_fallback(q, top_k=1)
            if res:
                cand = res[0]['text']
                bleus.append(self._bleu(ref, cand))
                r1s.append(self._rouge_score(ref, cand)['rouge-1'])
                rls.append(self._rouge_score(ref, cand)['rouge-l'])

        st.write("### 📊 Global Q-A metrics")
        col1, col2, col3 = st.columns(3)
        col1.metric("Avg BLEU", f"{np.mean(bleus):.4f}")
        col2.metric("Avg ROUGE-1", f"{np.mean(r1s):.4f}")
        col3.metric("Avg ROUGE-L", f"{np.mean(rls):.4f}")

Overwriting visualization_modules.py


5. qa_modules.py

In [20]:
%%writefile qa_modules.py
# qa_modules.py - Version corrigée et optimisée
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import re
from typing import List, Dict, Any, Optional
import logging
from functools import lru_cache
import hashlib
import os

class OptimizedQAModule:
    """Version optimisée du module Q&A avec caching et performance améliorée"""

    def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
        try:
            self.model = SentenceTransformer(model_name, device='cpu')
            self.similarity_threshold = 0.3

            # Structures optimisées
            self.corpus_texts: List[str] = []
            self.corpus_labels: List[int] = []
            self.corpus_embeddings: Optional[np.ndarray] = None
            self.corpus_index: Optional[faiss.Index] = None

            # Base de connaissances optimisée
            self.knowledge_base: List[Dict[str, Any]] = []
            self.kb_embeddings: Optional[np.ndarray] = None
            self.kb_index: Optional[faiss.Index] = None

            # Cache pour les requêtes fréquentes
            self._query_cache: Dict[str, List[Dict]] = {}

            self._setup_knowledge_base()
            logging.info("✅ Module Q&A initialisé avec succès")

        except Exception as e:
            logging.error(f"❌ Erreur initialisation Q&A: {e}")
            raise

    def _setup_knowledge_base(self):
        """Initialisation optimisée de la base de connaissances"""
        self.knowledge_base = [
            {
                "text": "Le réchauffement climatique est principalement causé par les émissions de CO2 humaines.",
                "category": "causes",
                "keywords": ["CO2", "émissions", "humaines", "causes"],
                "weight": 1.0
            },
            {
                "text": "Les énergies renouvelables (solaire, éolien, hydro) réduisent drastiquement les émissions.",
                "category": "solutions",
                "keywords": ["renouvelables", "solaire", "éolien", "hydro"],
                "weight": 1.2
            },
            {
                "text": "La déforestation est responsable de 15% des émissions mondiales de CO2.",
                "category": "causes",
                "keywords": ["déforestation", "forêts", "15%"],
                "weight": 0.9
            },
            {
                "text": "Le transport représente 24% des émissions mondiales de gaz à effet de serre.",
                "category": "secteurs",
                "keywords": ["transport", "24%", "véhicules"],
                "weight": 1.1
            },
            {
                "text": "L'isolation thermique peut réduire la consommation énergétique jusqu'à 50%.",
                "category": "solutions",
                "keywords": ["isolation", "thermique", "50%"],
                "weight": 1.0
            }
        ]
        self._rebuild_kb_index()

    def _rebuild_kb_index(self):
        """Reconstruction optimisée de l'index FAISS"""
        if not self.knowledge_base:
            return

        try:
            texts = [item["text"] for item in self.knowledge_base]
            self.kb_embeddings = self.model.encode(
                texts,
                normalize_embeddings=True,
                show_progress_bar=False,
                convert_to_numpy=True
            )

            # Index FAISS optimisé
            d = self.kb_embeddings.shape[1]
            self.kb_index = faiss.IndexFlatIP(d)
            self.kb_index.add(self.kb_embeddings.astype(np.float32))

        except Exception as e:
            logging.error(f"❌ Erreur reconstruction index: {e}")

    def fit(self, dataset: List[Dict[str, Any]]) -> bool:
        """Indexation optimisée du corpus d'entraînement"""
        try:
            if not dataset:
                logging.warning("⚠️ Dataset vide, rien à indexer")
                return False

            self.corpus_texts = [d["text"] for d in dataset]
            self.corpus_labels = [d.get("label_id", 0) for d in dataset]

            # Encodage optimisé avec batch processing
            batch_size = 32
            embeddings = []

            for i in range(0, len(self.corpus_texts), batch_size):
                batch = self.corpus_texts[i:i+batch_size]
                batch_embeddings = self.model.encode(
                    batch,
                    normalize_embeddings=True,
                    show_progress_bar=False,
                    convert_to_numpy=True
                )
                embeddings.append(batch_embeddings)

            self.corpus_embeddings = np.vstack(embeddings)

            # Index FAISS
            d = self.corpus_embeddings.shape[1]
            self.corpus_index = faiss.IndexFlatIP(d)
            self.corpus_index.add(self.corpus_embeddings.astype(np.float32))

            # Mise à jour du cache
            self._query_cache.clear()

            logging.info(f"✅ {len(dataset)} documents indexés avec succès")
            return True

        except Exception as e:
            logging.error(f"❌ Erreur indexation: {e}")
            return False

    @lru_cache(maxsize=100)
    def _get_query_embedding(self, query: str) -> np.ndarray:
        """Cache des embeddings de requêtes fréquentes"""
        return self.model.encode([query], normalize_embeddings=True)[0]

    def _search_index(self, query: str, index: faiss.Index, texts: List[str],
                     source: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """Recherche optimisée dans un index FAISS"""
        if index is None or not texts:
            return []

        try:
            query_embedding = self._get_query_embedding(query)
            scores, indices = index.search(
                query_embedding.reshape(1, -1).astype(np.float32),
                min(top_k, len(texts))
            )

            results = []
            for rank, (score, idx) in enumerate(zip(scores[0], indices[0]), 1):
                if idx < len(texts) and score > self.similarity_threshold:
                    result = {
                        "text": texts[idx],
                        "score": float(score),
                        "rank": rank,
                        "source": source
                    }

                    # Ajout des métadonnées si disponible
                    if source == "knowledge_base" and idx < len(self.knowledge_base):
                        result.update({
                            "category": self.knowledge_base[idx]["category"],
                            "keywords": self.knowledge_base[idx]["keywords"]
                        })
                    elif source == "training_corpus" and idx < len(self.corpus_labels):
                        result["label_id"] = int(self.corpus_labels[idx])

                    results.append(result)

            return results

        except Exception as e:
            logging.error(f"❌ Erreur recherche index: {e}")
            return []

    def query_knowledge_base(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
        """Recherche optimisée dans la base de connaissances"""
        if not self.knowledge_base:
            return []
        return self._search_index(query, self.kb_index,
                                [item["text"] for item in self.knowledge_base],
                                "knowledge_base", top_k)

    def query_corpus(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
        """Recherche optimisée dans le corpus d'entraînement"""
        if not self.corpus_texts:
            return []
        return self._search_index(query, self.corpus_index, self.corpus_texts,
                                "training_corpus", top_k)

    def keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """Recherche par mots-clés optimisée avec scoring avancé"""
        if not query:
            return []

        try:
            query_words = set(re.findall(r'\b\w+\b', query.lower()))
            if not query_words:
                return []

            scored = []

            # Recherche dans la base de connaissances
            for item in self.knowledge_base:
                text_words = set(re.findall(r'\b\w+\b', item["text"].lower()))
                keyword_words = set(item["keywords"])

                # Score combiné TF-IDF like
                text_score = len(query_words & text_words) / max(len(text_words), 1)
                keyword_score = len(query_words & keyword_words) / max(len(keyword_words), 1)
                combined_score = (text_score * 0.7 + keyword_score * 0.3) * item.get("weight", 1.0)

                if combined_score > 0.1:
                    scored.append({
                        "text": item["text"],
                        "category": item["category"],
                        "keywords": item["keywords"],
                        "score": combined_score,
                        "source": "knowledge_base"
                    })

            # Recherche dans le corpus
            for i, text in enumerate(self.corpus_texts):
                text_words = set(re.findall(r'\b\w+\b', text.lower()))
                score = len(query_words & text_words) / max(len(text_words), 1)

                if score > 0.1:
                    scored.append({
                        "text": text,
                        "label_id": int(self.corpus_labels[i]),
                        "score": score,
                        "source": "training_corpus"
                    })

            scored.sort(key=lambda x: x["score"], reverse=True)
            for i, result in enumerate(scored[:top_k], 1):
                result["rank"] = i

            return scored[:top_k]

        except Exception as e:
            logging.error(f"❌ Erreur recherche mots-clés: {e}")
            return []

    def query_with_fallback(self, question: str, top_k: int = 5,
                          search_mode: str = "hybrid") -> List[Dict[str, Any]]:
        """Recherche avec fallback optimisé"""
        # Clé de cache
        cache_key = f"{question}_{search_mode}_{top_k}"
        if cache_key in self._query_cache:
            return self._query_cache[cache_key]

        try:
            # Sélection du mode de recherche
            if search_mode == "knowledge_only":
                results = self.query_knowledge_base(question, top_k)
            elif search_mode == "corpus_only":
                results = self.query_corpus(question, top_k)
            elif search_mode == "keywords":
                results = self.keyword_search(question, top_k)
            else:  # hybrid
                kb_results = self.query_knowledge_base(question, top_k // 2 + 1)
                corpus_results = self.query_corpus(question, top_k // 2 + 1)

                # Fusion et déduplication
                seen_texts = set()
                results = []

                for item in kb_results + corpus_results:
                    if item["text"] not in seen_texts:
                        seen_texts.add(item["text"])
                        results.append(item)

                results.sort(key=lambda x: x["score"], reverse=True)
                results = results[:top_k]

            # Fallback si nécessaire
            if not results or (results and max(r["score"] for r in results) < self.similarity_threshold):
                fallback = self.keyword_search(question, top_k)
                seen = {r["text"] for r in results}
                for item in fallback:
                    if item["text"] not in seen:
                        results.append(item)
                results.sort(key=lambda x: x["score"], reverse=True)
                results = results[:top_k]

            # Mise en cache
            self._query_cache[cache_key] = results
            return results

        except Exception as e:
            logging.error(f"❌ Erreur recherche avec fallback: {e}")
            return []

    def search_by_category(self, category: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """Recherche par catégorie optimisée"""
        try:
            results = [
                {
                    "text": item["text"],
                    "category": item["category"],
                    "keywords": item["keywords"],
                    "source": "knowledge_base",
                    "rank": i + 1,
                    "score": 1.0
                }
                for i, item in enumerate([
                    it for it in self.knowledge_base
                    if it["category"].lower() == category.lower()
                ][:top_k])
            ]
            return results
        except Exception as e:
            logging.error(f"❌ Erreur recherche catégorie: {e}")
            return []

    def add_knowledge(self, text: str, category: str = "custom",
                     keywords: List[str] = None) -> bool:
        """Ajout optimisé de nouvelles connaissances"""
        try:
            if not text or not isinstance(text, str):
                return False

            # Vérification des doublons
            if any(item["text"].strip() == text.strip() for item in self.knowledge_base):
                return False

            new_item = {
                "text": text.strip(),
                "category": category,
                "keywords": keywords or [],
                "weight": 1.0
            }

            self.knowledge_base.append(new_item)
            self._rebuild_kb_index()
            self._query_cache.clear()

            logging.info(f"✅ Nouvelle connaissance ajoutée: {text[:50]}...")
            return True

        except Exception as e:
            logging.error(f"❌ Erreur ajout connaissance: {e}")
            return False

    def get_categories(self) -> List[str]:
        """Retourne les catégories disponibles"""
        try:
            return list({item["category"] for item in self.knowledge_base})
        except:
            return []

    def get_stats(self) -> Dict[str, Any]:
        """Statistiques détaillées et optimisées"""
        try:
            return {
                "knowledge_base_size": len(self.knowledge_base),
                "training_corpus_size": len(self.corpus_texts),
                "total_documents": len(self.knowledge_base) + len(self.corpus_texts),
                "categories": self.get_categories(),
                "avg_kb_length": np.mean([len(item["text"]) for item in self.knowledge_base]) if self.knowledge_base else 0,
                "avg_corpus_length": np.mean([len(t) for t in self.corpus_texts]) if self.corpus_texts else 0,
                "cache_size": len(self._query_cache)
            }
        except Exception as e:
            logging.error(f"❌ Erreur stats: {e}")
            return {}

    def clear_cache(self):
        """Nettoyage manuel du cache"""
        self._query_cache.clear()
        logging.info("🗑️ Cache vidé")

Overwriting qa_modules.py


4️⃣ Module Knowledge Base - knowledge_modules.py

In [21]:
%%writefile knowledge_modules.py

# knowledge_modules.py
import numpy as np
from typing import List, Optional
import re

class KnowledgeBase:
    """Gestion de la base de connaissances sans sentence-transformers"""

    def __init__(self):
        self.knowledge_base = []
        self.setup_knowledge_base()

    def setup_knowledge_base(self):
        """Configuration de la base de connaissances"""
        self.knowledge_base = [
            "Le réchauffement climatique est principalement causé par les émissions de gaz à effet de serre d'origine humaine.",
            "Les énergies renouvelables comme le solaire et l'éolien sont essentielles pour décarboner notre économie.",
            "La déforestation massive contribue significativement au changement climatique.",
            "Le secteur des transports représente environ 24% des émissions mondiales de gaz à effet de serre.",
            "L'amélioration de l'efficacité énergétique des bâtiments peut réduire jusqu'à 50% de leur consommation.",
            "L'agriculture durable et régénératrice peut séquestrer du carbone tout en produisant de la nourriture.",
            "Les océans absorbent 25% du CO2 atmosphérique mais s'acidifient, menaçant les écosystèmes marins.",
            "Les politiques de taxation du carbone incitent les entreprises à réduire leurs émissions.",
            "L'adaptation au changement climatique est aussi cruciale que l'atténuation des émissions.",
            "Les technologies de capture et stockage du carbone pourraient permettre d'atteindre la neutralité carbone."
        ]
        print("✅ Base de connaissances initialisée avec recherche par mots-clés")

    def find_context(self, query: str, top_k: int = 3) -> List[str]:
        """Recherche de contexte pertinent par similarité textuelle simple"""
        if not query or not self.knowledge_base:
            return []

        try:
            # Nettoyage et tokenisation simple
            query_clean = query.lower()
            query_words = set(re.findall(r'\b\w+\b', query_clean))

            # Score de similarité basé sur les mots communs
            scored_docs = []

            for doc in self.knowledge_base:
                doc_clean = doc.lower()
                doc_words = set(re.findall(r'\b\w+\b', doc_clean))

                # Calcul du score Jaccard
                intersection = len(query_words & doc_words)
                union = len(query_words | doc_words)

                if union > 0:
                    jaccard_score = intersection / union
                    scored_docs.append((doc, jaccard_score))

            # Tri par score décroissant
            scored_docs.sort(key=lambda x: x[1], reverse=True)

            # Retour des top_k documents avec score > 0.1
            relevant_docs = []
            for doc, score in scored_docs[:top_k]:
                if score > 0.1:  # Seuil de pertinence
                    relevant_docs.append(doc)

            return relevant_docs

        except Exception as e:
            print(f"⚠️ Erreur recherche contexte: {e}")
            return []

    def add_knowledge(self, new_knowledge: str):
        """Ajouter une nouvelle connaissance"""
        if new_knowledge and new_knowledge not in self.knowledge_base:
            self.knowledge_base.append(new_knowledge)
            print(f"✅ Nouvelle connaissance ajoutée: {new_knowledge[:50]}...")

    def get_stats(self):
        """Statistiques de la base de connaissances"""
        return {
            "total_documents": len(self.knowledge_base),
            "avg_length": np.mean([len(doc) for doc in self.knowledge_base]) if self.knowledge_base else 0,
        }

Overwriting knowledge_modules.py


update_kb_rss.py

In [22]:
%%writefile update_kb_rss.py
import feedparser
from qa_modules import OptimizedQAModule
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
import schedule
import time
import datetime
import logging
import os
import sys

# Configuration du logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Configuration des URLs RSS valides
RSS_FEEDS = [
    "https://www.carbonbrief.org/feed/",
    "https://climate.nasa.gov/news/rss.xml",
    "https://unfccc.int/news/rss.xml"
]

class RSSUpdater:
    def __init__(self, qa_module):
        self.qa = qa_module
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=400,
            chunk_overlap=50,
            separators=["\n", ". ", "? ", "! "]
        )

    def fetch_and_process_feed(self, feed_url, max_entries=5):
        """Récupère et traite un flux RSS"""
        try:
            feed = feedparser.parse(feed_url)
            if feed.bozo:
                logging.warning(f"⚠️ Flux RSS mal formaté: {feed_url}")
                return 0

            processed = 0
            for entry in feed.entries[:max_entries]:
                try:
                    # Extraction des informations
                    title = entry.title if hasattr(entry, 'title') else "Sans titre"
                    summary = entry.summary if hasattr(entry, 'summary') else ""
                    link = entry.link if hasattr(entry, 'link') else "Lien manquant"

                    # Nettoyage du texte
                    content = f"{title}. {summary}".strip()
                    if len(content) < 50:  # Skip trop courts
                        continue

                    # Création du document
                    doc = Document(
                        page_content=content,
                        metadata={
                            "url": link,
                            "title": title,
                            "source": "rss_feed",
                            "date": datetime.datetime.now().isoformat()
                        }
                    )

                    # Découpage en chunks
                    chunks = self.text_splitter.split_documents([doc])

                    # Ajout à la base de connaissances
                    for chunk in chunks:
                        success = self.qa.add_knowledge(
                            text=chunk.page_content,
                            category="rss_news",
                            keywords=["rss", "news", "climate", "update"]
                        )
                        if success:
                            processed += 1
                            logging.info(f"✅ Ajouté: {title[:60]}...")

                except Exception as e:
                    logging.error(f"❌ Erreur traitement article: {e}")
                    continue

            return processed

        except Exception as e:
            logging.error(f"❌ Erreur flux RSS {feed_url}: {e}")
            return 0

    def update_all_feeds(self):
        """Mise à jour de tous les flux RSS"""
        total_added = 0
        logging.info(f"🔄 Début mise à jour RSS - {datetime.datetime.now()}")

        for feed_url in RSS_FEEDS:
            added = self.fetch_and_process_feed(feed_url)
            total_added += added
            logging.info(f"📊 {feed_url}: {added} articles ajoutés")

        logging.info(f"✅ Mise à jour terminée - Total: {total_added} nouveaux articles")
        return total_added

    def start_scheduler(self):
        """Démarre le planificateur RSS"""
        # Planification quotidienne à 9h
        schedule.every().day.at("09:00").do(self.update_all_feeds)

        # Test immédiat
        self.update_all_feeds()

        logging.info("📅 Planificateur RSS démarré - mise à jour quotidienne à 09:00")

        # Boucle d'exécution
        while True:
            schedule.run_pending()
            time.sleep(3600)  # Vérification toutes les heures

# Fonction utilitaire pour Streamlit
def init_rss_updater(qa_module):
    """Initialise le RSS updater pour Streamlit"""
    updater = RSSUpdater(qa_module)
    return updater

def manual_rss_update(qa_module):
    """Mise à jour manuelle via Streamlit"""
    try:
        updater = RSSUpdater(qa_module)
        return updater.update_all_feeds()
    except Exception as e:
        logging.error(f"❌ Erreur mise à jour RSS: {e}")
        return 0

Overwriting update_kb_rss.py


5️⃣ Module Streamlit - streamlit_app.py

In [23]:
%%writefile streamlit_app.py
# streamlit_app_fusion.py - Version fusionnée avec optimisations et RSS
import streamlit as st
import pandas as pd
import os
import torch
import json
import sys
import numpy as np
from datasets import Dataset
from contextlib import contextmanager
from qa_modules import OptimizedQAModule  # Module optimisé

# Configuration optimisée de Streamlit
st.set_page_config(
    page_title="🌍 Climate Analyzer Pro",
    page_icon="🌍",
    layout="wide",
    initial_sidebar_state="expanded"
)

# ---------------------------------------------------------
# Barre de progression persistante
# ---------------------------------------------------------
@contextmanager
def st_progress(title="Progress", max_value=100):
    bar = st.progress(0, text=title)
    try:
        yield bar
    finally:
        bar.empty()

# Cache optimisé pour le module Q&A
@st.cache_resource
def get_optimized_qa():
    """Cache du module Q&A optimisé"""
    return OptimizedQAModule()

# Persistance session_state - Initialisation correcte
if "trainer" not in st.session_state:
    st.session_state.trainer = None
if "label_names" not in st.session_state:
    st.session_state.label_names = None
if "test_ds" not in st.session_state:
    st.session_state.test_ds = None
if "training" not in st.session_state:
    st.session_state.training = False
if "raw_train_data" not in st.session_state:
    st.session_state.raw_train_data = None
if "qa_module" not in st.session_state:
    st.session_state.qa_module = get_optimized_qa()  # Utilisation du module optimisé

sys.path.append('/content')
from core_modules import ClimateConfig
from data_modules import DataProcessor
from model_modules import ModelManager
from visualization_modules import VisualizationManager


class ClimateAnalyzerApp:
    def __init__(self):
        self.config = ClimateConfig()
        self.data_processor = DataProcessor()
        self.model_manager = ModelManager(self.config)

        # Fusion des modules Q&A
        if st.session_state.qa_module is None:
            st.session_state.qa_module = get_optimized_qa()

        self.qa_module = st.session_state.qa_module
        self.visualizer = VisualizationManager()
        self.load_saved_model()

    def load_saved_model(self):
        """Chargement automatique du modèle si déjà présent"""
        if st.session_state.trainer is None and os.path.exists("outputs/final_model/config.json"):
            try:
                self.model_manager.setup_tokenizer()
                num_labels = len(self.data_processor.label_mapping) or 2
                self.model_manager.setup_model(num_labels)
                trainer = self.model_manager.setup_trainer(None, None)
                trainer.model = trainer.model.from_pretrained("outputs/final_model")
                st.session_state.trainer = trainer
                st.session_state.label_names = list(self.data_processor.label_mapping.keys())
                st.success("✅ Modèle chargé depuis le disque.")
            except Exception as e:
                st.warning(f"⚠️ Chargement impossible : {e}")

    def run(self):
        """Menu principal de l'application"""
        st.title("🌍 Climate Sentiment Analyzer Pro")

        # Sidebar optimisée avec statistiques Q&A
        with st.sidebar:
            st.markdown("### 📊 Statistiques Système")

            stats = self.qa_module.get_stats()
            col1, col2 = st.columns(2)
            with col1:
                st.metric("Base de connaissances", stats["knowledge_base_size"])
            with col2:
                st.metric("Total documents", stats["total_documents"])

            # Contrôles rapides
            if st.button("🗑️ Vider le cache Q&A"):
                self.qa_module.clear_cache()
                st.success("Cache vidé!")
                st.rerun()

            # Navigation principale
            mode = st.selectbox(
                "Mode",
                ["🚀 Pipeline Complet", "❓ Q&A Avancée", "📚 Gestion des Connaissances",
                 "📰 Mise à jour RSS", "📈 Visualisations"]
            )

        # Routage vers les différentes sections
        if mode == "🚀 Pipeline Complet":
            self.run_complete_pipeline()
        elif mode == "❓ Q&A Avancée":
            self.run_advanced_qa_interface()
        elif mode == "📚 Gestion des Connaissances":
            self.run_knowledge_management()
        elif mode == "📰 Mise à jour RSS":
            self.run_rss_integration()
        elif mode == "📈 Visualisations":
            self.run_visualizations()

    def run_complete_pipeline(self):
        """Pipeline complet d'entraînement (inchangé mais optimisé)"""
        st.header("🚀 Pipeline Complet")

        uploaded_file = st.file_uploader("Téléchargez votre CSV", type=["csv"])
        if uploaded_file:
            df = pd.read_csv(uploaded_file)
            st.dataframe(df.head())

            # SLIDERS avec valeurs par défaut optimisées
            sample_size = st.slider("Taille échantillon", 1000, 10000, value=4000)
            self.config.epochs = st.slider("Epochs", 1, 5, value=3)

            is_training = st.session_state.get("training", False)

            if st.button(
                "🚀 Lancer l'entraînement",
                type="primary",
                disabled=bool(is_training)
            ):
                st.session_state.training = True
                try:
                    self.train_pipeline(df, sample_size)
                finally:
                    st.session_state.training = False

    def train_pipeline(self, df: pd.DataFrame, sample_size: int):
        """Processus d'entraînement optimisé"""
        try:
            # 1/4 — Analyse des données avec cache
            with st_progress("1/4  Analyse des données …") as bar:
                train_ds, val_ds, test_ds = self.data_processor.prepare_datasets(df, sample_size)
                bar.progress(25)

            # Sauvegarder les données brutes
            raw_train_data = [{"text": item["text"], "label_id": item["label_id"]}
                            for item in train_ds]
            st.session_state.raw_train_data = raw_train_data

            # 2/4 — Tokenizer
            with st_progress("2/4  Chargement du tokenizer …") as bar:
                self.model_manager.setup_tokenizer()
                bar.progress(50)

            # 3/4 — Modèle
            with st_progress("3/4  Initialisation du modèle …") as bar:
                num_labels = len(self.data_processor.label_mapping)
                self.model_manager.setup_model(num_labels)
                bar.progress(75)

            # 4/4 — Tokenisation optimisée
            def prep(ds):
                with st_progress("4/4  Tokenisation …") as bar:
                    ds = ds.map(
                        self.model_manager.tokenize_function,
                        batched=True,
                        desc="Tokenisation"
                    )
                    ds = ds.rename_column("label_id", "labels")
                    keep = {"input_ids", "attention_mask", "labels"}
                    for col in list(ds.column_names):
                        if col not in keep:
                            ds = ds.remove_columns(col)
                    ds.set_format(type="torch", columns=list(keep))
                    bar.progress(100)
                    return ds

            train_ds_processed, val_ds_processed, test_ds_processed = map(prep, (train_ds, val_ds, test_ds))

            trainer = self.model_manager.setup_trainer(train_ds_processed, val_ds_processed)

            with st.spinner("Entraînement en cours …"):
                trainer.train()

            trainer.save_model("outputs/final_model")
            trainer.state.save_to_json("outputs/final_model/trainer_state.json")

            # Indexer les données d'entraînement dans le module Q&A optimisé
            if st.session_state.raw_train_data:
                self.qa_module.fit(st.session_state.raw_train_data)
                st.session_state.qa_module = self.qa_module

            st.session_state.trainer = trainer
            st.session_state.label_names = list(self.data_processor.label_mapping.keys())
            st.session_state.test_ds = test_ds_processed
            st.success("🎉 Entraînement terminé !")
            st.balloons()

        except Exception as e:
            st.error(f"❌ Erreur : {e}")
            import traceback
            st.error(f"Détail: {traceback.format_exc()}")
            st.session_state.training = False

    def run_advanced_qa_interface(self):
        """Interface Q&A avancée fusionnée avec optimisations"""
        st.header("❓ Interface Q&A Avancée")

        # Configuration de recherche avec colonnes optimisées
        col1, col2 = st.columns([3, 1])

        with col1:
            question = st.text_input(
                "Posez votre question sur le climat :",
                placeholder="Ex: Quelles sont les causes du réchauffement climatique ?"
            )

        with col2:
            search_mode = st.selectbox(
                "Mode",
                ["hybrid", "knowledge_only", "corpus_only", "keywords"],
                format_func=lambda x: {
                    "hybrid": "🔀 Hybride",
                    "knowledge_only": "📚 Base de connaissances",
                    "corpus_only": "📊 Corpus d'entraînement",
                    "keywords": "🔍 Mots-clés"
                }[x]
            )

        # Paramètres avancés dans l'expandeur
        with st.expander("⚙️ Paramètres de recherche", expanded=False):
            col1, col2 = st.columns(2)
            with col1:
                top_k = st.slider("Nombre de résultats", 1, 10, value=5)
            with col2:
                show_details = st.checkbox("Afficher les détails", True)

        # Recherche principale avec spinner optimisé
        if question:
            try:
                with st.spinner("🔍 Recherche intelligente en cours..."):
                    results = self.qa_module.query_with_fallback(question, top_k, search_mode)

                self.display_qa_results(results, show_details, f"Résultats pour: '{question}'")

                # Analyse des résultats dans un expander
                if results and len(results) > 1:
                    with st.expander("📊 Analyse des résultats", expanded=False):
                        self.visualizer.plot_qa_results_analysis(results, question)

            except Exception as e:
                st.error(f"❌ Erreur : {str(e)}")
                if st.checkbox("Afficher les détails techniques"):
                    st.exception(e)

        # Questions suggérées avec boutons
        st.markdown("### 💡 Questions suggérées")
        suggested_questions = [
            "Quelles sont les principales causes du réchauffement climatique ?",
            "Comment les énergies renouvelables peuvent-elles aider ?",
            "Quel est l'impact de la déforestation sur le climat ?",
            "Comment réduire les émissions de gaz à effet de serre ?"
        ]

        cols = st.columns(2)
        for i, suggestion in enumerate(suggested_questions[:4]):
            if cols[i % 2].button(
                suggestion[:50] + "..." if len(suggestion) > 50 else suggestion,
                key=f"suggestion_{i}",
                use_container_width=True
            ):
                st.session_state["question"] = suggestion
                st.rerun()

    def display_qa_results(self, results: list, show_details: bool, title: str):
        """Affichage optimisé des résultats Q&A"""
        if not results:
            st.info("🔍 Aucun résultat trouvé.")
            return

        st.markdown(f"### {title}")
        st.caption(f"{len(results)} résultat(s) trouvé(s)")

        for result in results:
            # Emoji selon la source
            emoji = {
                "knowledge_base": "📚",
                "training_corpus": "📊",
                "keywords": "🔍"
            }.get(result.get("source"), "📄")

            # Couleur selon le score
            score = result.get("score", 0)
            score_color = "🟢" if score > 0.7 else "🟡" if score > 0.4 else "🔴"

            with st.expander(
                f"{emoji} Score: {score_color} {score:.3f} - {result.get('source', 'source').replace('_', ' ').title()}",
                expanded=(result.get("rank", 0) == 1)
            ):
                st.write("**Texte:**")
                st.write(result["text"])

                if show_details:
                    st.divider()
                    col1, col2 = st.columns(2)

                    with col1:
                        if "category" in result:
                            st.caption(f"**Catégorie:** `{result['category']}`")
                        if "keywords" in result and result["keywords"]:
                            st.caption("**Mots-clés:** " + " ".join([f"`{kw}`" for kw in result["keywords"][:3]]))

                    with col2:
                        st.caption(f"**Rang:** {result.get('rank', '?')}")

    def run_knowledge_management(self):
        """Interface de gestion optimisée de la base de connaissances"""
        st.header("📚 Gestion des Connaissances")

        tab1, tab2, tab3 = st.tabs(["📖 Consulter", "➕ Ajouter", "📊 Statistiques"])

        with tab1:
            self._render_knowledge_browser()

        with tab2:
            self._render_knowledge_adder()

        with tab3:
            self._render_knowledge_stats()

    def _render_knowledge_browser(self):
        """Sous-composant pour naviguer dans la base de connaissances"""
        st.subheader("📖 Consultation")

        # Filtres avec colonnes
        col1, col2 = st.columns(2)
        with col1:
            categories = self.qa_module.get_categories()
            filter_category = st.selectbox("Filtrer par", ["Toutes"] + categories)
        with col2:
            search_text = st.text_input("Rechercher", placeholder="Mots-clés...")

        # Filtrage et affichage
        knowledge_items = []
        for idx, item in enumerate(self.qa_module.knowledge_base):
            if filter_category == "Toutes" or item["category"] == filter_category:
                if not search_text or search_text.lower() in item["text"].lower():
                    knowledge_items.append((idx, item))

        st.info(f"📄 {len(knowledge_items)} document(s) trouvé(s)")

        # Affichage paginé pour performance
        items_per_page = 5
        page = st.number_input("Page", min_value=1, max_value=max(1, len(knowledge_items)//items_per_page + 1), value=1)

        start_idx = (page - 1) * items_per_page
        end_idx = min(start_idx + items_per_page, len(knowledge_items))

        for idx, (orig_idx, item) in enumerate(knowledge_items[start_idx:end_idx], start=1):
            with st.expander(f"📄 {item['category'].upper()} - {item['text'][:80]}..."):
                st.write(item["text"])
                col1, col2 = st.columns(2)
                with col1:
                    st.caption(f"ID: `{orig_idx}`")
                with col2:
                    if item["keywords"]:
                        st.caption("Keywords: " + ", ".join(item["keywords"][:3]))

    def _render_knowledge_adder(self):
        """Sous-composant pour ajouter des connaissances"""
        st.subheader("➕ Ajouter une connaissance")

        with st.form("add_knowledge_form"):
            new_text = st.text_area(
                "Texte",
                placeholder="Entrez le texte de la nouvelle connaissance...",
                height=100
            )

            col1, col2 = st.columns(2)
            with col1:
                categories = self.qa_module.get_categories()
                category_choice = st.selectbox("Catégorie", ["Nouvelle"] + categories)

                if category_choice == "Nouvelle":
                    new_category = st.text_input("Nouvelle catégorie", placeholder="technologies")
                    final_category = new_category
                else:
                    final_category = category_choice

            with col2:
                keywords = st.text_input(
                    "Mots-clés",
                    placeholder="tech, innovation, futur"
                )

            submitted = st.form_submit_button("✅ Ajouter", use_container_width=True)

            if submitted and new_text and final_category:
                keyword_list = [kw.strip() for kw in keywords.split(",") if kw.strip()]

                if self.qa_module.add_knowledge(new_text, final_category, keyword_list):
                    st.success("✅ Ajouté avec succès!")
                    st.rerun()
                else:
                    st.error("❌ Erreur lors de l'ajout")

    def _render_knowledge_stats(self):
        """Sous-composant pour les statistiques"""
        st.subheader("📊 Statistiques")
        stats = self.qa_module.get_stats()

        # Métriques principales
        col1, col2, col3, col4 = st.columns(4)
        col1.metric("Total", stats["total_documents"])
        col2.metric("KB", stats["knowledge_base_size"])
        col3.metric("Corpus", stats["training_corpus_size"])
        col4.metric("Catégories", len(stats["categories"]))

        # Graphiques si des données existent
        if self.qa_module.knowledge_base:
            # Distribution par catégorie
            category_counts = {}
            for item in self.qa_module.knowledge_base:
                cat = item["category"]
                category_counts[cat] = category_counts.get(cat, 0) + 1

            df_categories = pd.DataFrame([
                {"Catégorie": cat, "Nombre": count}
                for cat, count in category_counts.items()
            ])

            st.bar_chart(df_categories.set_index("Catégorie"))

    def run_rss_integration(self):
        """Section RSS dans l'interface Streamlit"""
        st.header("📰 Mise à jour RSS Automatique")

        col1, col2 = st.columns([3, 1])

        with col1:
            st.info("📡 Flux RSS configurés:")
            for url in [
                "Carbon Brief",
                "NASA Climate",
                "UNFCCC News"
            ]:
                st.write(f"• {url}")

        with col2:
            if st.button("🔄 Mise à jour manuelle", type="primary"):
                with st.spinner("Mise à jour en cours..."):
                    try:
                        from update_kb_rss import manual_rss_update
                        added = manual_rss_update(self.qa_module)
                        if added > 0:
                            st.success(f"✅ {added} articles ajoutés")
                            st.rerun()
                        else:
                            st.info("Aucun nouvel article trouvé")
                    except ImportError:
                        st.error("❌ Module update_kb_rss non trouvé")
                    except Exception as e:
                        st.error(f"❌ Erreur lors de la mise à jour : {str(e)}")

        # Paramètres avancés
        with st.expander("⚙️ Paramètres RSS"):
            st.write("Mise à jour automatique activée")
            st.write("Fréquence: Quotidienne à 09:00")
            st.write("Sources: Carbon Brief, NASA Climate, UNFCCC")

    def run_visualizations(self):
        """Interface des visualisations (inchangée mais avec optimisations)"""
        st.header("📈 Visualisations")

        viz = st.selectbox(
            "Choisir une visualisation",
            ["Distribution des classes", "Matrice de confusion", "Courbes d'entraînement",
             "📊 Métriques BLEU/ROUGE", "🔍 Évaluation Q&A", "📚 Analyse Knowledge Base"]
        )

        test_ds = st.session_state.get("test_ds")
        label_names = st.session_state.get("label_names")

        try:
            if viz == "Distribution des classes" and test_ds:
                self.visualizer.plot_class_distribution(test_ds["labels"], label_names)
            elif viz == "Matrice de confusion" and test_ds and st.session_state.trainer:
                self.visualizer.show_confusion_matrix(st.session_state.trainer, test_ds, label_names)
            elif viz == "Courbes d'entraînement":
                self.visualizer.plot_training_curves("outputs/final_model")
            elif viz == "📊 Métriques BLEU/ROUGE":
                self.run_bleu_rouge_analysis()
            elif viz == "🔍 Évaluation Q&A":
                self.run_qa_evaluation()
            elif viz == "📚 Analyse Knowledge Base":
                self.run_knowledge_base_analysis()
        except Exception as e:
            st.error(f"❌ Erreur : {e}")
            if st.checkbox("Détails techniques"):
                st.exception(e)

    def run_knowledge_base_analysis(self):
        """Analyse optimisée de la base de connaissances"""
        st.subheader("📚 Analyse de la Base de Connaissances")

        if not self.qa_module.knowledge_base:
            st.info("La base de connaissances est vide.")
            return

        # Mots-clés fréquents
        all_keywords = [kw for item in self.qa_module.knowledge_base for kw in item["keywords"]]

        if all_keywords:
            from collections import Counter
            top_keywords = Counter(all_keywords).most_common(10)
            df_keywords = pd.DataFrame(top_keywords, columns=["Mot-clé", "Fréquence"])
            st.bar_chart(df_keywords.set_index("Mot-clé"))

        # Longueurs de texte
        lengths = [len(item["text"]) for item in self.qa_module.knowledge_base]

        if lengths:
            col1, col2, col3 = st.columns(3)
            col1.metric("Min", f"{min(lengths)}")
            col2.metric("Moyenne", f"{np.mean(lengths):.0f}")
            col3.metric("Max", f"{max(lengths)}")

    def run_bleu_rouge_analysis(self):
        """Analyse BLEU/ROUGE (conservée avec optimisations)"""
        st.subheader("📊 Analyse des métriques BLEU et ROUGE")

        # Code inchangé mais avec gestion d'erreurs améliorée
        try:
            # Section test personnalisé
            col1, col2 = st.columns(2)
            with col1:
                reference_text = st.text_area(
                    "Texte de référence:",
                    "Le réchauffement climatique est un phénomène global causé par les activités humaines.",
                    height=100
                )

            with col2:
                candidate_text = st.text_area(
                    "Texte candidat:",
                    "Le changement climatique est un problème mondial dû aux actions humaines.",
                    height=100
                )

            if st.button("Calculer les scores", use_container_width=True):
                bleu_score = self.visualizer.calculate_bleu_score(reference_text, candidate_text)
                rouge_scores = self.visualizer.calculate_rouge_score(reference_text, candidate_text)

                col1, col2, col3 = st.columns(3)
                col1.metric("BLEU", f"{bleu_score:.4f}")
                col2.metric("ROUGE-1", f"{rouge_scores['rouge-1']:.4f}")
                col3.metric("ROUGE-L", f"{rouge_scores['rouge-l']:.4f}")

        except Exception as e:
            st.error(f"Erreur dans l'analyse BLEU/ROUGE : {e}")

    def run_qa_evaluation(self):
        """Évaluation Q&A avec optimisations"""
        st.subheader("🔍 Évaluation du système Q&A")

        # Questions de test avec suggestions
        default_questions = [
            "Quelles sont les principales causes du réchauffement climatique ?",
            "Comment les énergies renouvelables peuvent-elles aider ?",
            "Quel est l'impact de la déforestation sur le climat ?"
        ]

        if st.checkbox("Utiliser questions par défaut", value=True):
            test_questions = default_questions
        else:
            test_questions = st.text_area(
                "Entrez vos questions (une par ligne):",
                height=100
            ).split("\n")

        if test_questions and st.button("🚀 Lancer l'évaluation", type="primary"):
            with st.spinner("Évaluation en cours..."):
                # Génération des références
                reference_answers = []
                for question in test_questions:
                    kb_results = self.qa_module.query_knowledge_base(question, top_k=1)
                    ref = kb_results[0]['text'] if kb_results else "Référence manquante"
                    reference_answers.append(ref)

                self.visualizer.evaluate_qa_performance(
                    self.qa_module,
                    test_questions,
                    reference_answers
                )


if __name__ == "__main__":
    app = ClimateAnalyzerApp()
    app.run()

Overwriting streamlit_app.py


6️⃣ Script d'Installation - setup_pipeline.py

In [24]:
%%writefile setup_pipeline.py
import subprocess
import sys

def install_dependencies():
    packages = [
        "transformers>=4.36.0",
        "datasets>=2.16.0",
        "torch>=2.1.0",
        "peft>=0.7.0",
        "sentence-transformers>=2.2.0",
        "faiss-cpu>=1.7.0",
        "streamlit>=1.29.0",
        "plotly>=5.17.0",
        "scikit-learn>=1.3.0",
        "matplotlib>=3.7.0",
        "seaborn>=0.12.0",
        "pandas>=1.5.0",
        "numpy>=1.24.0",
        "rouge-score",
        "nltk"
    ]

    for package in packages:
        cmd = [sys.executable, "-m", "pip", "install", package]
        try:
            subprocess.check_call(cmd)
            print(f"✅ {package} installé")
        except subprocess.CalledProcessError as e:
            print(f"⚠️ Erreur avec {package}: {e}")

    # Téléchargement des ressources NLTK
    import nltk
    nltk.download('punkt_tab', quiet=True)
    nltk.download('punkt', quiet=True)
    print("✅ Ressources NLTK prêtes")

if __name__ == "__main__":
    install_dependencies()

Overwriting setup_pipeline.py


In [25]:
!python setup_pipeline.py

✅ transformers>=4.36.0 installé
✅ datasets>=2.16.0 installé
✅ torch>=2.1.0 installé
✅ peft>=0.7.0 installé
✅ sentence-transformers>=2.2.0 installé
✅ faiss-cpu>=1.7.0 installé
✅ streamlit>=1.29.0 installé
✅ plotly>=5.17.0 installé
✅ scikit-learn>=1.3.0 installé
✅ matplotlib>=3.7.0 installé
✅ seaborn>=0.12.0 installé
✅ pandas>=1.5.0 installé
✅ numpy>=1.24.0 installé
✅ rouge-score installé
✅ nltk installé
✅ Ressources NLTK prêtes


In [26]:
!pip install streamlit



In [27]:
!pip install pyngrok



In [28]:
!pip install faiss-cpu sentence-transformers



In [29]:
# 🔧 Lancement Streamlit + ngrok (version corrigée)
import subprocess
import time
from pyngrok import ngrok

# 1️⃣ Token ngrok
TOKEN = "30Nciu2LDo3NzmKva2zibt2sCFL_7Ag5r9kUYyBCha12WSZ3"
!ngrok authtoken {TOKEN}

# 2️⃣ Lancer l'application principale
subprocess.Popen(
    ["streamlit", "run", "streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"],
    stdout=subprocess.DEVNULL,
    stderr=subprocess.DEVNULL
)

# 3️⃣ Attendre et créer le tunnel
time.sleep(5)
public_url = ngrok.connect(8501)
print("🚀 Interface Streamlit disponible à :")
print(public_url)

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml
🚀 Interface Streamlit disponible à :
NgrokTunnel: "https://0c1c5f32419e.ngrok-free.app" -> "http://localhost:8501"
