In [2]:
import os
import csv
import math
import random
import logging
import argparse
import itertools
from collections import namedtuple

import torch
import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm
import time
from sklearn.metrics import precision_score
import numpy as np
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import (
    BertModel, BertConfig, BertPreTrainedModel,
    BertTokenizer, AdamW, get_linear_schedule_with_warmup
)

In [3]:
# Configuration des journaux
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Structures de données
GlossSelectionRecord = namedtuple("GlossSelectionRecord", ["guid", "sentence", "sense_keys", "glosses", "targets"])
BertInput = namedtuple("BertInput", ["input_ids", "input_mask", "segment_ids", "label_id"])

In [4]:
class WSDDataset(Dataset):
    def __init__(self, features):
        self.features = features

    def __getitem__(self, index):
        return self.features[index]

    def __len__(self):
        return len(self.features)

In [5]:
class BertWSD(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.ranking_linear = torch.nn.Linear(config.hidden_size, 1)
        self.init_weights()

def _compute_weighted_loss(loss, weighting_factor):
    """Calcul d'une perte pondérée"""
    squared_factor = weighting_factor ** 2
    return 1 / (2 * squared_factor) * loss + math.log(1 + squared_factor)

def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Tronque une paire de séquences à la longueur maximale"""
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

In [6]:
def load_dataset(
    csv_path, 
    tokenizer, 
    max_sequence_length, 
    max_samples=None
):
    """
    Charge le jeu de données à partir d'un fichier CSV avec option de sous-échantillonnage.
    
    Args:
        csv_path (str): Chemin vers le fichier CSV
        tokenizer (BertTokenizer): Tokenizer BERT
        max_sequence_length (int): Longueur maximale des séquences
        max_samples (int, optional): Nombre maximal d'échantillons à charger
    
    Returns:
        WSDDataset: Jeu de données pour l'entraînement ou l'évaluation
    """
    def _deserialize_csv_record(row):
        return GlossSelectionRecord(
            row[0],  # guid
            row[1],  # sentence
            eval(row[2]),  # sense_keys
            eval(row[3]),  # glosses
            [int(t) for t in eval(row[4])]  # targets
        )

    def _create_records_from_csv(csv_path, deserialize_fn, max_samples=None):
        """
        Crée des enregistrements à partir d'un fichier CSV avec sous-échantillonnage.
        
        Args:
            csv_path (str): Chemin du fichier CSV
            deserialize_fn (callable): Fonction de désérialisation
            max_samples (int, optional): Nombre maximal d'échantillons
        
        Returns:
            list: Liste d'enregistrements
        """
        records = []
        with open(csv_path, 'r', encoding='utf-8', newline='') as f:
            reader = csv.reader(f)
            next(reader)  # Ignorer l'en-tête
            
            # Utiliser itertools pour limiter les échantillons
            for row in itertools.islice(reader, max_samples):
                records.append(deserialize_fn(row))
        
        return records

    # Charger les enregistrements avec limitation optionnelle
    records = _create_records_from_csv(
        csv_path, 
        _deserialize_csv_record, 
        max_samples
    )
    
    # Convertir en features
    features = _create_features_from_records(
        records, 
        max_sequence_length, 
        tokenizer
    )
    
    # Log du nombre d'échantillons chargés
    logger.info(f"Chargé {len(features)} échantillons depuis {csv_path}")
    
    return WSDDataset(features)

In [7]:
def _create_features_from_records(records, max_seq_length, tokenizer):
    """Convertit les enregistrements en features pour BERT"""
    features = []
    for record in tqdm(records, desc="Conversion des données"):
        tokens_a = tokenizer.tokenize(record.sentence)
        sequences = [(gloss, 1 if i in record.targets else 0) for i, gloss in enumerate(record.glosses)]

        pairs = []
        for seq, label in sequences:
            tokens_b = tokenizer.tokenize(seq)
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)

            tokens = tokens_a + ['[SEP]']
            segment_ids = [0] * len(tokens)

            tokens += tokens_b + ['[SEP]']
            segment_ids += [1] * (len(tokens_b) + 1)

            tokens = ['[CLS]'] + tokens
            segment_ids = [0] + segment_ids

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(input_ids)

            padding_length = max_seq_length - len(input_ids)
            input_ids += [0] * padding_length
            input_mask += [0] * padding_length
            segment_ids += [0] * padding_length

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            pairs.append(
                BertInput(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label)
            )

        features.append(pairs)

    return features

def collate_batch(batch):
    """Regroupe les lots de données"""
    max_seq_length = len(batch[0][0].input_ids)

    collated = []
    for sub_batch in batch:
        batch_size = len(sub_batch)
        sub_collated = [torch.zeros([batch_size, max_seq_length], dtype=torch.long) for _ in range(3)] + \
                       [torch.zeros([batch_size], dtype=torch.long)]

        for i, bert_input in enumerate(sub_batch):
            sub_collated[0][i] = torch.tensor(bert_input.input_ids, dtype=torch.long)
            sub_collated[1][i] = torch.tensor(bert_input.input_mask, dtype=torch.long)
            sub_collated[2][i] = torch.tensor(bert_input.segment_ids, dtype=torch.long)
            sub_collated[3][i] = torch.tensor(bert_input.label_id, dtype=torch.long)

        collated.append(sub_collated)

    return collated

def forward_gloss_selection(model, batches, device):
    """Effectue une passe avant pour la sélection de gloses"""
    batch_loss = 0
    logits_list = []
    loss_fn = torch.nn.CrossEntropyLoss()

    for batch in batches:
        batch = tuple(t.to(device) for t in batch)
        outputs = model.bert(input_ids=batch[0], attention_mask=batch[1], token_type_ids=batch[2])
        hidden_state = model.dropout(outputs[1])

        logits = model.ranking_linear(hidden_state).squeeze(-1)
        labels = torch.max(batch[3], -1).indices.detach()
        batch_loss += loss_fn(logits.unsqueeze(dim=0), labels.unsqueeze(dim=-1))
        logits_list.append(logits)

    loss = batch_loss / len(batches)
    return loss, logits_list

In [8]:
def train_wsd(train_path, eval_path, output_dir='./results',max_train_samples=None, max_eval_samples=None):
    """
    Entraînement du modèle de sélection de gloses avec sous-échantillonnage.
    
    Args:
        train_path (str): Chemin du fichier CSV d'entraînement
        eval_path (str): Chemin du fichier CSV d'évaluation
        output_dir (str, optional): Répertoire de sauvegarde du modèle
        max_train_samples (int, optional): Nombre maximal d'échantillons d'entraînement
        max_eval_samples (int, optional): Nombre maximal d'échantillons d'évaluation
    """
    # Configuration
    max_seq_length = 128
    batch_size = 8
    num_train_epochs = 3
    learning_rate = 5e-5
    seed = 42

    # Configuration du seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Périphérique
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Utilisation du périphérique : {device}")

    # Modèle et Tokenizer
    model_name = 'bert-base-cased'
    config = BertConfig.from_pretrained(model_name, num_labels=2)
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertWSD.from_pretrained(model_name, config=config)

    # Ajout du token spécial
    if '[TGT]' not in tokenizer.additional_special_tokens:
        tokenizer.add_special_tokens({'additional_special_tokens': ['[TGT]']})
        model.resize_token_embeddings(len(tokenizer))

    model.to(device)
    
    # Chargement des données avec sous-échantillonnage optionnel
    train_dataset = load_dataset(
        train_path, 
        tokenizer, 
        max_sequence_length=max_seq_length, 
        max_samples=max_train_samples
    )

    train_sampler = RandomSampler(train_dataset)
    
    train_dataloader = DataLoader(
        train_dataset, 
        sampler=train_sampler, 
        batch_size=batch_size, 
        collate_fn=collate_batch
    )
    
    # Optionnel : charger le jeu de données d'évaluation
    if max_eval_samples is not None:
        eval_dataset = load_dataset(
            eval_path, 
            tokenizer, 
            max_sequence_length=max_seq_length, 
            max_samples=max_eval_samples
        )
    
    # Préparation de l'optimiseur
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 
            'weight_decay': 0.01
        },
        {
            'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 
            'weight_decay': 0.0
        }
    ]

    total_steps = len(train_dataloader) * num_train_epochs
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=0, 
        num_training_steps=total_steps
    )
    
    logger.info("🚀 Entrainement Model Word Sense Disambiguation ")
    logger.info(f"Nombre total de lots de formatio (Batches): {len(train_dataloader)}")
    logger.info(f"Device: {device}")
    
    # Boucle d'entraînement
    for epoch in range(num_train_epochs):
        model.train()
        total_loss = 0

        # Track predictions and labels for precision calculation
        all_preds = []
        all_labels = []

        # Create epoch progress bar
        with tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_train_epochs}", 
                          unit="batch", colour="green") as epoch_iterator:
        
            start_time = time.time()
    
            for step, batches in enumerate(epoch_iterator):
                loss, logits_list = forward_gloss_selection(model, batches, device)

                # Collect predictions and true labels
                for batch_logits, batch in zip(logits_list, batches):
                    # Convert logits to predictions
                    preds = (batch_logits > 0.5).cpu().numpy().astype(int)
                    
                    # Get true labels
                    labels = batch[3].cpu().numpy()
                    
                    all_preds.extend(preds)
                    all_labels.extend(labels)
    
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
                optimizer.step()
                scheduler.step()
    
                total_loss += loss.item()

                # Calculate precision
                try:
                    precision = precision_score(all_labels, all_preds, zero_division=0)
                except:
                    precision = 0
    
                 # Update progress bar with real-time metrics
                epoch_iterator.set_postfix({
                    'Loss': f'{loss.item():.4f}', 
                    'Avg Loss': f'{total_loss/(step+1):.4f}',
                    'Precision': f'{precision:.4f}',
                    'Learning Rate': f'{scheduler.get_last_lr()[0]:.6f}'
                })
    
                # End of epoch summary
                epoch_duration = time.time() - start_time
                # End of epoch summary
                final_precision = precision_score(all_labels, all_preds, zero_division=0)
                logger.info(f"Epoch {epoch+1} completed in {epoch_duration:.2f} seconds. "
                            f"Average Loss: {total_loss/len(train_dataloader):.4f}"
                            f"Precision: {final_precision:.4f}")
    
    
                if step % 100 == 0:
                    logger.info(f"Époque {epoch}, Étape {step}, Perte : {loss.item()}")

    logger.info("✅ Entrainement terminé avec succès")
    # Sauvegarde du modèle
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    logger.info(f"Modèle entraîné et sauvegardé dans {output_dir}")

In [None]:
# Charger seulement 5000 échantillons pour un test rapide
train_wsd(
    train_path='/kaggle/input/dataset/corpus_dir-max_num_gloss5-augmented.csv', 
    eval_path='/kaggle/input/dataset/semeval2007-max_num_gloss5-augmented.csv', 
    max_train_samples=50000, #40000==4heure      # Limiter à 5000 échantillons
    max_eval_samples=10000  #8000==20%      # Limiter à 1000 échantillons d'évaluation
)

# Charger tous les échantillons (comportement par défaut)
# train_wsd(train_path='/chemin/vers/train.csv', eval_path='/chemin/vers/eval.csv')