In [None]:
import csv
import random

def reservoir_sample(file_path, sample_size):
    """Sélectionne aléatoirement sample_size lignes d'un gros fichier CSV."""
    reservoir = []
    with open(file_path, mode='r', encoding='utf-8') as file:
        reader = csv.reader(file)
        header = next(reader)  # Sauvegarder l'en-tête
        for i, row in enumerate(reader):
            if i < sample_size:
                reservoir.append(row)
            else:
                j = random.randint(0, i)
                if j < sample_size:
                    reservoir[j] = row
    return header, reservoir

# Paramètres
input_file = '../data/raw/train.csv'
output_file = '../data/raw/new_dataset.csv'
sample_size = 50_000

# Exécution
header, sample_rows = reservoir_sample(input_file, sample_size)

# Écriture dans le nouveau fichier
with open(output_file, mode='w', newline='', encoding='utf-8') as f_out:
    writer = csv.writer(f_out)
    writer.writerow(header)
    writer.writerows(sample_rows)

print(f"✅ {sample_size} lignes aléatoires écrites dans {output_file}")


✅ 50000 lignes aléatoires écrites dans ../data/raw/new_dataset.csv


In [None]:
# Installer les dépendances en utilisant un miroir PyPI pour éviter les erreurs réseau
!pip install pyyaml accelerate>=0.26.0 transformers[torch] --index-url https://mirrors.aliyun.com/pypi/simple/

# Importer les bibliothèques
import pandas as pd
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset
import yaml
import logging
import os
from sklearn.model_selection import train_test_split

# Configurer le logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler("/content/drive/MyDrive/MLOps/output.log")
    ]
)
logger = logging.getLogger(__name__)

# Vérifier le GPU
logger.info(f"GPU disponible : {torch.cuda.is_available()}")
if torch.cuda.is_available():
    logger.info(f"Nom du GPU : {torch.cuda.get_device_name(0)}")

# Tester la connectivité réseau
logger.info("Test de la connectivité réseau")
!ping pypi.org

# Définir les classes
class SentimentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

class DataLoader:
    def __init__(self, config):
        self.raw_path = config['data']['raw_path']
        self.processed_path = config['data']['processed_path']
        self.chunk_size = config.get('data', {}).get('chunk_size', 100000)

    def load_raw_data(self):
        logger.info("Chargement des données brutes")
        df = pd.read_csv(self.raw_path)
        return df

    def preprocess_chunk(self, df_chunk):
        logger.info("Prétraitement d'un chunk de données")
        df_chunk = df_chunk.dropna(subset=['text', 'label'])
        df_chunk['text'] = df_chunk['text'].astype(str)
        # Ajuster les labels (0 et 1 au lieu de 1 et 2)
        if df_chunk['label'].min() == 1:
            logger.info("Ajustement des labels de 1,2 à 0,1")
            df_chunk['label'] = df_chunk['label'] - 1
        return df_chunk

    def process_and_save_chunks(self): 
        logger.info("Début du prétraitement des données")
        if os.path.exists(self.processed_path):
            logger.info("Fichier traité existe déjà")
            return
        chunks = pd.read_csv(self.raw_path, chunksize=self.chunk_size)
        processed_chunks = []
        for i, chunk in enumerate(chunks):
            logger.info(f"Traitement du chunk {i+1}")
            processed_chunk = self.preprocess_chunk(chunk)
            processed_chunks.append(processed_chunk)
        processed_df = pd.concat(processed_chunks, ignore_index=True)
        os.makedirs(os.path.dirname(self.processed_path), exist_ok=True)
        processed_df.to_csv(self.processed_path, index=False)
        logger.info("Prétraitement terminé, fichier sauvegardé")

class BertTextClassifier:
    def __init__(self, config_path):
        with open(config_path, 'r') as file:
            self.config = yaml.safe_load(file)
        self.model_name = self.config['model']['bert']['model_name']
        self.num_labels = self.config['model']['bert']['num_labels']
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name)
        self.model = BertForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=self.num_labels
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        logger.info(f"Modèle BERT chargé sur {self.device}")

    def train(self, df):
        logger.info(f"Entraînement BERT sur {len(df)} lignes")
        # Ajuster les labels si nécessaire
        labels = df['label'].values
        if labels.min() == 1:
            logger.info("Ajustement des labels de 1,2 à 0,1")
            labels = labels - 1
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            df['text'].values, labels, test_size=0.2, random_state=42
        )
        train_dataset = SentimentDataset(train_texts, train_labels, self.tokenizer)
        val_dataset = SentimentDataset(val_texts, val_labels, self.tokenizer)

        training_args = TrainingArguments(
            output_dir='/content/drive/MyDrive/MLOps/results',
            num_train_epochs=self.config['model']['bert']['epochs'],
            per_device_train_batch_size=self.config['model']['bert']['batch_size'],
            per_device_eval_batch_size=self.config['model']['bert']['batch_size'],
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir='/content/drive/MyDrive/MLOps/logs',
            logging_steps=10,
            eval_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            dataloader_pin_memory=torch.cuda.is_available(),
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=lambda p: {'accuracy': (p.predictions.argmax(-1) == p.label_ids).mean()}
        )

        trainer.train()
        logger.info("Entraînement BERT terminé")

    def save_model(self, output_dir):
        os.makedirs(output_dir, exist_ok=True)
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        logger.info(f"Modèle BERT sauvegardé à {output_dir}")

# Vérifier les données
logger.info("Vérification du chemin des données")
# !ls /kaggle/input/your-dataset
# logger.info("Vérification des premières lignes")
# !head /kaggle/input/your-dataset/raw.csv

# Définir la configuration
config = {
    'data': {
        'raw_path': data_path,  # Remplacer par le chemin réel
        'processed_path': '/content/drive/MyDrive/MLOps/data/processed.csv',
        'chunk_size': 100000
    },
    'model': {
        'bert': {
            'model_name': 'bert-base-uncased',
            'num_labels': 2,
            'epochs': 3,
            'batch_size': 16
        }
    }
}

# Sauvegarder la configuration
os.makedirs('/content/drive/MyDrive/MLOps/config', exist_ok=True)
with open('/content/drive/MyDrive/MLOps/config/config.yaml', 'w') as file:
    yaml.safe_dump(config, file)

# Exécuter l'entraînement
try:
    logger.info("Début du prétraitement des données")
    data_loader = DataLoader(config)
    data_loader.process_and_save_chunks()

    logger.info("Début de l'entraînement BERT")
    bert_classifier = BertTextClassifier('/content/drive/MyDrive/MLOps/config/config.yaml')
    df = pd.read_csv(data_loader.processed_path).sample(n=10000, random_state=42)
    bert_classifier.train(df)
    bert_classifier.save_model('/content/drive/MyDrive/MLOps/models/bert_v1')
    logger.info("Entraînement terminé")
except Exception as e:
    logger.error(f"Erreur pendant l'exécution : {str(e)}", exc_info=True)
    raise

In [None]:
import sys
import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="_distutils_hack")
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

import yaml
import mlflow
import mlflow.pytorch
from sklearn.metrics import accuracy_score
from sklearn.utils.class_weight import compute_class_weight
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import logging
from sklearn.model_selection import train_test_split
import re
import emoji

# Configurer le logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Définir les classes
class SentimentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

class TextClassificationModel:
    def __init__(self, config_path: str):
        with open(config_path, 'r') as file:
            self.config = yaml.safe_load(file)
        self.data_loader = DataLoader(config_path)
    
    def save_model(self, model, path: str):
        raise NotImplementedError
    
    def train(self, df: pd.DataFrame):
        raise NotImplementedError

class BertTextClassifier(TextClassificationModel):
    def __init__(self, config_path: str):
        super().__init__(config_path)
        self.model_name = self.config['model']['bert']['model_name']
        self.num_labels = self.config['model']['bert']['num_labels']
        self.max_length = self.config['model']['bert']['max_length']
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name)
        self.model = BertForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=self.num_labels
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        logger.info(f"Modèle BERT chargé ({self.model_name}) sur {self.device}")

    def preprocess_data(self, df: pd.DataFrame):
        """Prétraitement spécifique pour BERT."""
        logger.info(f"Prétraitement de {len(df)} lignes")
        
        # Nettoyage
        df = df.dropna(subset=['text', 'label'])
        df = df[df['text'].str.strip() != '']
        df = df[df['text'].str.split().str.len() >= 3]  # Minimum 3 mots
        df['text'] = df['text'].apply(lambda x: re.sub(r'\s+', ' ', x.strip()))

        # Ajustement des labels (de 1,2 à 0,1)
        if df['label'].min() == 1:
            logger.info("Ajustement des labels de 1,2 à 0,1")
            df['label'] = df['label'] - 1

        # Analyse des longueurs
        lengths = [len(self.tokenizer.encode(text, add_special_tokens=True)) for text in df['text']]
        logger.info(f"Longueur moyenne : {np.mean(lengths)}, Max : {np.max(lengths)}, 90e percentile : {np.percentile(lengths, 90)}")
        mlflow.log_metric("mean_text_length", np.mean(lengths))
        mlflow.log_metric("max_text_length", np.max(lengths))

        # Distribution des labels
        label_dist = df['label'].value_counts().to_dict()
        logger.info(f"Distribution des labels : {label_dist}")
        mlflow.log_metric("label_0_count", label_dist.get(0, 0))
        mlflow.log_metric("label_1_count", label_dist.get(1, 0))

        return df

    @timer_decorator
    def train(self, df: pd.DataFrame):
        logger.info(f"Entraînement BERT sur {len(df)} lignes")
        
        # Prétraitement
        df = self.preprocess_data(df)
        
        # Calcul des poids des classes
        class_weights = compute_class_weight('balanced', classes=np.array([0, 1]), y=df['label'])
        class_weights = torch.tensor(class_weights, dtype=torch.float).to(self.device)
        logger.info(f"Poids des classes : {class_weights.tolist()}")

        # Séparation train/test
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            df['text'].values, df['label'].values, test_size=0.2, random_state=42
        )
        train_dataset = SentimentDataset(train_texts, train_labels, self.tokenizer, self.max_length)
        val_dataset = SentimentDataset(val_texts, val_labels, self.tokenizer, self.max_length)

        # Configuration de l'entraînement
        training_args = TrainingArguments(
            output_dir=self.config['training']['output_dir'],
            num_train_epochs=self.config['model']['bert']['epochs'],
            per_device_train_batch_size=self.config['model']['bert']['batch_size'],
            per_device_eval_batch_size=self.config['model']['bert']['batch_size'],
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir=self.config['training']['logging_dir'],
            logging_steps=10,
            eval_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            dataloader_pin_memory=torch.cuda.is_available(),
            fp16=True,  # Précision mixte pour optimiser la mémoire
        )

        # Trainer personnalisé pour la pondération des pertes
        class CustomTrainer(Trainer):
            def compute_loss(self, model, inputs, return_outputs=False):
                labels = inputs.pop('labels')
                outputs = model(**inputs)
                logits = outputs.logits
                loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
                loss = loss_fct(logits, labels)
                return (loss, outputs) if return_outputs else loss

        trainer = CustomTrainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=lambda p: {'accuracy': (p.predictions.argmax(-1) == p.label_ids).mean()}
        )

        # MLflow
        experiment_name = "TextClassificationExperiment"
        experiment = mlflow.get_experiment_by_name(experiment_name)
        if experiment is None:
            experiment_id = mlflow.create_experiment(experiment_name)
        else:
            experiment_id = experiment.experiment_id

        with mlflow.start_run(experiment_id=experiment_id):
            mlflow.log_param("model_type", "BERT")
            mlflow.log_param("model_name", self.model_name)
            mlflow.log_param("epochs", self.config['model']['bert']['epochs'])
            mlflow.log_param("batch_size", self.config['model']['bert']['batch_size'])
            mlflow.log_param("max_length", self.max_length)

            # Entraînement
            trainer.train()

            # Évaluation
            predictions = trainer.predict(val_dataset).predictions.argmax(-1)
            accuracy = accuracy_score(val_labels, predictions)
            mlflow.log_metric("val_accuracy", accuracy)
            mlflow.pytorch.log_model(self.model, "bert_model")
            logger.info(f"Précision sur validation : {accuracy:.4f}")

    def save_model(self, output_dir: str):
        os.makedirs(output_dir, exist_ok=True)
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        logger.info(f"Modèle BERT sauvegardé à {output_dir}")

if __name__ == "__main__":
    config_path = os.path.join(os.path.dirname(__file__), "../../config/config.yaml")
    config = {
        'data': {
            'raw_path': '/path/to/raw.csv',
            'processed_path': '/content/drive/MyDrive/MLOps/data/processed.parquet',
            'chunk_size': 100000
        },
        'model': {
            'bert': {
                'model_name': 'bert-base-uncased',
                'num_labels': 2,
                'epochs': 3,
                'batch_size': 16,
                'max_length': 128
            }
        },
        'training': {
            'output_dir': '/content/drive/MyDrive/MLOps/results',
            'logging_dir': '/content/drive/MyDrive/MLOps/logs'
        }
    }
    
    # Sauvegarder la configuration
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    with open(config_path, 'w') as file:
        yaml.safe_dump(config, file)

    try:
        logger.info("Début du prétraitement des données")
        data_loader = DataLoader(config_path)
        data_loader.process_and_save_chunks()

        logger.info("Début de l'entraînement BERT")
        bert_classifier = BertTextClassifier(config_path)
        df = pd.read_parquet(data_loader.processed_path)  # Charger toutes les 100 000 lignes
        bert_classifier.train(df)
        bert_classifier.save_model('/content/drive/MyDrive/MLOps/models/bert_v1')
        logger.info("Entraînement terminé")
    except Exception as e:
        logger.error(f"Erreur pendant l'exécution : {str(e)}", exc_info=True)
        raise

In [None]:
# ====================== IMPORTS ======================
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, TrainerCallback
import yaml
import os
import re
import emoji
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# ====================== CONFIGURATION ======================
# data_path = "/content/new_dataset.csv"  # <-- À adapter selon ton chemin
config = {
    'data': {
        'raw_path': '/content/new_dataset.csv',
        'processed_path': '/content/processed.csv',
        'chunk_size': 10_000
    },
    'model': {
        'bert': {
            'model_name': 'bert-base-uncased',
            'num_labels': 2,
            'max_length': 128,
            'epochs': 5,
            'weight_decay': 0.01,
            'batch_size': 16
        }
    }
}
os.makedirs('/content', exist_ok=True)
with open('/content/config.yaml', 'w') as file:
    yaml.safe_dump(config, file)

# ====================== DATA LOADER ======================
class DataLoader:
    def __init__(self, config):
        self.raw_path = config['data']['raw_path']
        self.processed_path = config['data']['processed_path']
        self.chunk_size = config['data']['chunk_size']

    def preprocess_chunk(self, df_chunk):
        df_chunk = df_chunk.dropna(subset=['text', 'label'])
        df_chunk['text'] = df_chunk['text'].astype(str)
        df_chunk['text'] = df_chunk['text'].apply(lambda x: re.sub(r'http\S+|www\S+|@\w+|#\w+', '', x))
        df_chunk['text'] = df_chunk['text'].apply(lambda x: emoji.replace_emoji(x, replace=''))
        df_chunk['text'] = df_chunk['text'].apply(lambda x: re.sub(r'[^\w\s!?]', '', x.lower()))
        df_chunk = df_chunk[df_chunk['text'].str.len() > 10]
        df_chunk = df_chunk[df_chunk['text'].str.split().str.len() > 2]
        if df_chunk['label'].min() == 1:
            df_chunk['label'] = df_chunk['label'] - 1
        return df_chunk

    def chunk_generator(self):
        for chunk in pd.read_csv(self.raw_path, chunksize=self.chunk_size):
            yield self.preprocess_chunk(chunk)

# ====================== FEATURE ENGINEER ======================
class FeatureEngineer:
    def __init__(self, config_path):
        self.config = yaml.safe_load(open(config_path, 'r'))
        self.tokenizer = BertTokenizer.from_pretrained(self.config['model']['bert']['model_name'])

    def transform(self, texts, max_length=128):
        return self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_token_type_ids=False,
            return_tensors='pt'
        )

# ====================== DATASET ======================
class SentimentDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {k: v[idx] for k, v in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

# ====================== MODEL ======================
global_metrics = []
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions)
    return {
        'accuracy': acc,
        'f1': f1
    }

class BertTrainer:
    def __init__(self, config_path):
        self.config_path = config_path
        with open(config_path, 'r') as f:
            self.config = yaml.safe_load(f)

        self.model = BertForSequenceClassification.from_pretrained(
            self.config['model']['bert']['model_name'],
            num_labels=self.config['model']['bert']['num_labels']
        )
        self.feature_engineer = FeatureEngineer(config_path)
        self.output_dir = "/content/checkpoint"

    def train_on_chunk(self, df, chunk_num):
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            df['text'].tolist(),
            df['label'].tolist(),
            test_size=0.2,
            random_state=42
        )

        train_enc = self.feature_engineer.transform(train_texts)
        val_enc = self.feature_engineer.transform(val_texts)

        train_dataset = SentimentDataset(train_enc, train_labels)
        val_dataset = SentimentDataset(val_enc, val_labels)

        training_args = TrainingArguments(
            output_dir=self.output_dir,
            num_train_epochs=self.config['model']['bert']['epochs'],
            per_device_train_batch_size=self.config['model']['bert']['batch_size'],
            per_device_eval_batch_size=32,
            weight_decay=self.config['model']['bert']['weight_decay'],
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            logging_dir='./logs',
            logging_steps=10,
            save_total_limit=2,
            resume_from_checkpoint=os.path.exists(os.path.join(self.output_dir, "checkpoint-last")),
            overwrite_output_dir=False
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            # compute_metrics=lambda p: {'accuracy': (p.predictions.argmax(-1) == p.label_ids).mean()}
            compute_metrics=compute_metrics
        )

        checkpoint_path = os.path.join(self.output_dir, "checkpoint-last") if os.path.exists(os.path.join(self.output_dir, "checkpoint-last")) else None
        trainer.train(resume_from_checkpoint=checkpoint_path)

        eval_results = trainer.evaluate()
        eval_results["chunk"] = chunk_num
        global_metrics.append(eval_results)


# ====================== MAIN ======================
if __name__ == "__main__":
    config_path = '/content/config.yaml'
    data_loader = DataLoader(config)
    trainer = BertTrainer(config_path)

    for i, chunk_df in enumerate(data_loader.chunk_generator()):
        print(f"=== Entraînement sur le chunk {i+1} ===")
        trainer.train_on_chunk(chunk_df, chunk_num=i+1)

    # Sauvegarde finale
    trainer.model.save_pretrained("/content/models/final_bert")
    trainer.feature_engineer.tokenizer.save_pretrained("/content/models/final_bert")


In [None]:
#  code used to train model import sys
import os
import warnings
import yaml
import mlflow
import mlflow.sklearn
import mlflow.pytorch
import pandas as pd
import numpy as np
import logging
import torch
from torch.utils.data import Dataset
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.utils.class_weight import compute_class_weight
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from src.data.load_data import DataLoader
from src.features.feature_engineering import FeatureEngineer
from src.utils.helper_functions import timer_decorator

# Configurer le logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Ignorer les avertissements spécifiques
warnings.filterwarnings("ignore", category=UserWarning, module="_distutils_hack")

# Ajouter le chemin du projet au sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))


# Définir les classes
class SentimentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

class TextClassificationModel:
    def __init__(self, config_path: str):
        with open(config_path, 'r') as file:
            self.config = yaml.safe_load(file)
        self.data_loader = DataLoader(config_path)
    
    def save_model(self, model, path: str):
        raise NotImplementedError
    
    def train(self, df: pd.DataFrame):
        raise NotImplementedError

class RandomForestTextClassifier(TextClassificationModel):
    def __init__(self, config_path: str):
        super().__init__(config_path)
        self.feature_engineer = FeatureEngineer(config_path)
        self.model = SGDClassifier(loss='log_loss', max_iter=1000, tol=1e-3)

    @timer_decorator
    def train(self, df: pd.DataFrame):
        logging.info(f"Entraînement sur un lot de {len(df)} lignes")
        texts, labels = df['text'].tolist(), df['label'].tolist()
        
        # Vérifier les classes présentes
        unique_labels = set(labels)
        if len(unique_labels) < 2:
            logging.warning(f"Lot ignoré : contient seulement les labels {unique_labels}")
            return
        
        # Ajuster TF-IDF sur un échantillon si trop grand
        sample_size = min(10000, len(texts))
        self.feature_engineer.fit_tfidf(texts[:sample_size])
        
        # Calculer les poids des classes sur un échantillon
        sample_labels = labels[:sample_size]
        class_weights = compute_class_weight('balanced', classes=np.array([1, 2]), y=sample_labels)
        class_weight_dict = {1: class_weights[0], 2: class_weights[1]}
        
        # Entraînement par lots
        batch_size = 10000
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            batch_labels = np.array(labels[i:i + batch_size])
            X_batch = self.feature_engineer.transform_tfidf(batch_texts)
            sample_weights = np.array([class_weight_dict[label] for label in batch_labels])
            self.model.partial_fit(X_batch, batch_labels, classes=np.array([1, 2]), sample_weight=sample_weights)
            logging.info(f"Lot {i//batch_size + 1} entraîné")
        
        # Évaluation sur un sous-ensemble
        eval_texts = texts[:1000]
        eval_labels = labels[:1000]
        X_eval = self.feature_engineer.transform_tfidf(eval_texts)
        predictions = self.model.predict(X_eval)
        accuracy = accuracy_score(eval_labels, predictions)
        logging.info(f"Précision sur le sous-ensemble d'évaluation : {accuracy:.4f}")
        
        # Créer ou récupérer l'expérience MLflow
        experiment_name = "TextClassificationExperiment"
        experiment = mlflow.get_experiment_by_name(experiment_name)
        if experiment is None:
            experiment_id = mlflow.create_experiment(experiment_name)
        else:
            experiment_id = experiment.experiment_id
        
        with mlflow.start_run(experiment_id=experiment_id):
            mlflow.log_param("model_type", "SGDClassifier")
            mlflow.log_param("batch_size", batch_size)
            mlflow.log_metric("accuracy", accuracy)
            mlflow.sklearn.log_model(self.model, "sgd_model")
            logging.info("Métriques et modèle loggés dans MLflow")

    def save_model(self, path):
        import pickle
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'wb') as f:
            pickle.dump({'model': self.model, 'vectorizer': self.feature_engineer.tfidf}, f)
        logging.info(f"Modèle sauvegardé à {path}")

class BertTextClassifier(TextClassificationModel):
    def __init__(self, config_path: str):
        super().__init__(config_path)
        self.model_name = self.config['model']['bert']['model_name']
        self.num_labels = self.config['model']['bert']['num_labels']
        self.max_length = self.config['model']['bert']['max_length']
        self.feature_engineer = FeatureEngineer(config_path)  # Utilisation de FeatureEngineer
        self.model = BertForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=self.num_labels
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        logging.info(f"Modèle BERT chargé ({self.model_name}) sur {self.device}")

    def train(self, df):
        logging.info(f"Entraînement BERT sur {len(df)} lignes")
        # Ajuster les labels si nécessaire
        labels = df['label'].values
        if labels.min() == 1:
            logging.info("Ajustement des labels de 1,2 à 0,1")
            labels = labels - 1
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            df['text'].values, labels, test_size=0.2, random_state=42
        )
        # Utiliser FeatureEngineer pour la tokenisation
        train_encodings = self.feature_engineer.transform_bert(train_texts.tolist(), max_length=self.max_length)
        val_encodings = self.feature_engineer.transform_bert(val_texts.tolist(), max_length=self.max_length)
        train_dataset = SentimentDataset(train_texts, train_labels, self.feature_engineer.bert_tokenizer, self.max_length)
        val_dataset = SentimentDataset(val_texts, val_labels, self.feature_engineer.bert_tokenizer, self.max_length)

        training_args = TrainingArguments(
            output_dir='/results',
            num_train_epochs=self.config['model']['bert']['epochs'],
            per_device_train_batch_size=self.config['model']['bert']['batch_size'],
            per_device_eval_batch_size=self.config['model']['bert']['batch_size'],
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir='/logs',
            logging_steps=10,
            eval_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            dataloader_pin_memory=torch.cuda.is_available(),
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=lambda p: {'accuracy': (p.predictions.argmax(-1) == p.label_ids).mean()}
        )

        # MLflow
        experiment_name = "TextClassificationExperiment"
        experiment = mlflow.get_experiment_by_name(experiment_name)
        if experiment is None:
            experiment_id = mlflow.create_experiment(experiment_name)
        else:
            experiment_id = experiment.experiment_id

        with mlflow.start_run(experiment_id=experiment_id):
            mlflow.log_param("model_type", "BERT")
            mlflow.log_param("model_name", self.model_name)
            mlflow.log_param("epochs", self.config['model']['bert']['epochs'])
            mlflow.log_param("batch_size", self.config['model']['bert']['batch_size'])
            mlflow.log_param("max_length", self.max_length)

            # Entraînement
            trainer.train()

            # Évaluation
            predictions = trainer.predict(val_dataset).predictions.argmax(-1)
            accuracy = accuracy_score(val_labels, predictions)
            mlflow.log_metric("val_accuracy", accuracy)
            mlflow.pytorch.log_model(self.model, "bert_model")
            logging.info(f"Précision sur validation : {accuracy:.4f}")

    def save_model(self, output_dir):
        os.makedirs(output_dir, exist_ok=True)
        self.model.save_pretrained(output_dir)
        self.feature_engineer.bert_tokenizer.save_pretrained(output_dir)
        logging.info(f"Modèle BERT sauvegardé à {output_dir}")

if __name__ == "__main__":
    config_path = os.path.join(os.path.dirname(__file__), "../../config/config.yaml")
    data_loader = DataLoader(config_path)
    
    try:
        # logging.info("Début du prétraitement des données")
        # data_loader.process_and_save_chunks()
        
        # logging.info("Début de l'entraînement Random Forest (SGDClassifier)")
        # rf_classifier = RandomForestTextClassifier(config_path)
        # for texts, labels in data_loader.data_generator(batch_size=10000):
        #     df_chunk = pd.DataFrame({'text': texts, 'label': labels})
        #     rf_classifier.train(df_chunk)
        
        # rf_classifier.save_model("models/random_forest_v1.pkl")
        
        logging.info("Début de l'entraînement BERT")
        bert_classifier = BertTextClassifier(config_path)
        for texts, labels in data_loader.data_generator(batch_size=10000):
            df_chunk = pd.DataFrame({'text': texts, 'label': labels})
            bert_classifier.train(df_chunk)
    
        bert_classifier.save_model("models/bert_v1")
        logging.info("Entraînement terminé")
    except Exception as e:
        logging.error(f"Erreur pendant l'exécution : {str(e)}", exc_info=True)
        raise