# TP : LSTM pour Génération de Texte

Pipeline : Tokenisation → Vocabulaire → Dataset → Entraînement → Génération → Évaluation

## Imports et Configuration

In [10]:
import math
import json
import re
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cpu


## Tokenisation

In [11]:
def clean_text(text):
    if not isinstance(text, str):
        return ""
    text = text.replace('@-@', '-')
    text = re.sub(r'=+\s*.*?\s*=+', '', text)
    text = re.sub(r'http[s]?://\S+|www\.\S+|\S+@\S+', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def tokenize(text):
    text = clean_text(text)
    punctuations = ['.', ',', ';', ':', '!', '?', ')', '(', '[', ']', '{', '}']
    all_words = []
    
    for word in text.split():
        word = word.lower()
        while word and word[-1] in punctuations:
            if word[:-1]:
                all_words.append(word[:-1])
            word = word[-1]
            if word in punctuations:
                all_words.append(word)
                word = ''
        if word:
            all_words.append(word)
    
    return [w for w in all_words if w]

## 3. Vocabulaire

In [12]:
class Vocabulary:
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.vocab_size = 0
        
    def build_vocab(self, texts, min_freq=2):
        word_freq = {}
        for text in texts:
            for word in tokenize(text):
                word_freq[word] = word_freq.get(word, 0) + 1
        
        self.all_words = sorted([w for w, freq in word_freq.items() if freq >= min_freq])
        
        # Tokens spéciaux
        for idx, token in enumerate(['<pad>', '<unk>', '<bos>', '<eos>']):
            self.word2idx[token] = idx
            self.idx2word[idx] = token
        
        for i, word in enumerate(self.all_words, start=4):
            self.word2idx[word] = i
            self.idx2word[i] = word

        self.vocab_size = len(self.all_words) + 4
        print(f"Vocab: {self.vocab_size} mots (min_freq={min_freq})")

    def encode(self, word):
        return self.word2idx.get(word, self.word2idx['<unk>'])
    
    def decode(self, idx):
        return self.idx2word.get(idx, '<unk>')

## Dataset et Modèle

In [13]:
class TextDataset(Dataset):
    def __init__(self, sentences, vocab, seq_len):
        self.vocab = vocab
        self.seq_len = seq_len
        self.pairs = []

        for sentence in sentences:
            words = ['<bos>'] + tokenize(sentence) + ['<eos>']
            encoded_words = [vocab.encode(word) for word in words]
            for i in range(len(encoded_words) - seq_len):
                self.pairs.append((encoded_words[i:i+seq_len], encoded_words[i+seq_len]))
                
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        input_seq, target = self.pairs[idx]
        return torch.tensor(input_seq), torch.tensor(target)

In [14]:
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.dropout_layer = nn.Dropout(p=dropout)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, 
                           dropout=dropout if num_layers > 1 else 0)
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        x = self.dropout_layer(self.embed(x))
        output, _ = self.lstm(x)
        return self.linear(self.dropout_layer(output[:, -1, :]))

## Entraînement

In [15]:
def train(model, dataset, epochs=10, batch_size=32, lr=0.001):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
    
    best_loss = float('inf')
    patience_counter = 0
    early_stop_patience = 7
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for inputs, targets in dataloader:
            optimizer.zero_grad()
            loss = criterion(model(inputs), targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        perp = math.exp(min(avg_loss, 10))  # Cap pour éviter overflow
        
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f} - Perplexité: {perp:.2f}")
        
        scheduler.step(avg_loss)
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_epoch = epoch + 1
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= early_stop_patience:
            print(f"Early stopping (epoch {epoch+1})")
            break
    
    final_perp = math.exp(min(best_loss, 10))
    print(f"\nMeilleur loss: {best_loss:.4f} - Perplexité: {final_perp:.2f} (epoch {best_epoch})")

## Génération

In [16]:
def generate_words(model, vocab, prompt, max_length=100, mode='sampling', temp=0.7, top_k=50, seq_len=20):
    model.eval()
    prompt_tokens = tokenize(prompt)
    tokens = [vocab.encode('<bos>')] + [vocab.encode(word) for word in prompt_tokens]
    prompt_length = len(tokens)
    unk_id = vocab.encode('<unk>')
    
    with torch.no_grad():
        for _ in range(max_length):
            if len(tokens) >= seq_len:
                recent = tokens[-seq_len:]
            else:
                recent = [vocab.encode('<pad>')] * (seq_len - len(tokens)) + tokens
                
            logits = model(torch.tensor([recent]))
            logits[0, unk_id] -= 10.0
        
            if mode == 'greedy':
                next_token = torch.argmax(logits, dim=1).item()
            else:  # sampling
                if top_k > 0:
                    top_vals, top_ids = torch.topk(logits, top_k, dim=1)
                    probs = torch.softmax(top_vals / temp, dim=1)
                    next_token = top_ids[0, torch.multinomial(probs, 1).item()].item()
                else:
                    probs = torch.softmax(logits / temp, dim=1)
                    next_token = torch.multinomial(probs, 1).item()
        
            tokens.append(next_token)
            if next_token == vocab.encode('<eos>'):
                break
    
    return [vocab.decode(t) for t in tokens[prompt_length:]]

def clean_generated_text(words):
    return ' '.join([w for w in words if w not in ['<bos>', '<eos>', '<pad>', '<unk>']])

## Chargement et Configuration

In [17]:
def load_wiki_tokens(file_path, num_samples=None):
    sentences = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line and len(line) > 20:
                sentences.append(line)
            if num_samples and len(sentences) >= num_samples:
                break
    print(f"Chargé: {len(sentences)} phrases")
    return sentences

sentences = load_wiki_tokens('wiki.train.tokens', num_samples=None)

# Configuration optimisée pour perplexité < 80
config = {
    'embed_dim': 400,          # Augmenté de 256 à 400
    'hidden_dim': 800,         # Augmenté de 512 à 800
    'num_layers': 3,
    'epochs': 40,              # Augmenté de 25 à 40
    'seq_len': 35,             # Augmenté de 20 à 35
    'batch_size': 128,         # Augmenté de 64 à 128
    'min_word_freq': 5,        # Augmenté de 2 à 5 pour réduire vocab
    'max_gen_length': 150,
    'temperature': 0.7,
    'top_k': 50,
    'lr': 0.001,
    'dropout': 0.5             # Augmenté de 0.3 à 0.5
}

print(f"Configuration optimisée pour perplexité < 80:")
print(f"  Vocab min_freq: {config['min_word_freq']} (vocabulaire réduit)")
print(f"  Embed: {config['embed_dim']}, Hidden: {config['hidden_dim']}, Layers: {config['num_layers']}")
print(f"  Seq_len: {config['seq_len']}, Batch: {config['batch_size']}, Epochs: {config['epochs']}")

Chargé: 20741 phrases
Configuration optimisée pour perplexité < 80:
  Vocab min_freq: 5 (vocabulaire réduit)
  Embed: 400, Hidden: 800, Layers: 3
  Seq_len: 35, Batch: 128, Epochs: 40


## Construction et Entraînement

In [18]:
vocab = Vocabulary()
vocab.build_vocab(sentences, min_freq=config['min_word_freq'])

dataset = TextDataset(sentences, vocab, seq_len=config['seq_len'])
print(f"Dataset: {len(dataset)} paires")

model = LSTMModel(vocab.vocab_size, config['embed_dim'], config['hidden_dim'], 
                config['num_layers'], dropout=config['dropout'])
print(f"Modèle: {sum(p.numel() for p in model.parameters()):,} paramètres\n")

train(model, dataset, epochs=config['epochs'], batch_size=config['batch_size'], lr=config['lr'])

Vocab: 20397 mots (min_freq=5)
Dataset: 1491998 paires
Modèle: 38,595,997 paramètres

Epoch 1/40 - Loss: 6.4792 - Perplexité: 651.42
Epoch 2/40 - Loss: 5.9009 - Perplexité: 365.36
Epoch 3/40 - Loss: 5.6555 - Perplexité: 285.85
Epoch 4/40 - Loss: 5.4960 - Perplexité: 243.72
Epoch 5/40 - Loss: 5.3817 - Perplexité: 217.40
Epoch 6/40 - Loss: 5.2910 - Perplexité: 198.54
Epoch 7/40 - Loss: 5.2156 - Perplexité: 184.13
Epoch 8/40 - Loss: 5.1509 - Perplexité: 172.58
Epoch 9/40 - Loss: 5.0951 - Perplexité: 163.22
Epoch 10/40 - Loss: 5.0448 - Perplexité: 155.21
Epoch 11/40 - Loss: 4.9991 - Perplexité: 148.29
Epoch 12/40 - Loss: 4.9582 - Perplexité: 142.34
Epoch 13/40 - Loss: 4.9208 - Perplexité: 137.11
Epoch 14/40 - Loss: 4.8873 - Perplexité: 132.59
Epoch 15/40 - Loss: 4.8579 - Perplexité: 128.75
Epoch 16/40 - Loss: 4.8299 - Perplexité: 125.20
Epoch 17/40 - Loss: 4.8061 - Perplexité: 122.25
Epoch 18/40 - Loss: 4.7865 - Perplexité: 119.88
Epoch 19/40 - Loss: 4.7654 - Perplexité: 117.38
Epoch 20/40

## Test de Génération

In [19]:
for prompt in ["The government announced", "In recent years", "According to reports"]:
    print(f"\nPrompt: '{prompt}'")
    
    gen = generate_words(model, vocab, prompt, max_length=config['max_gen_length'], 
                        mode='sampling', temp=config['temperature'], 
                        top_k=config['top_k'], seq_len=config['seq_len'])
    print(f"Génération: {clean_generated_text(gen)[:200]}...")


Prompt: 'The government announced'
Génération: that the decision to be a " more than one " of the a game . the team 's first international team , which included a number of different players , was at a time in the development of his final appearan...

Prompt: 'In recent years'
Génération: ....

Prompt: 'According to reports'
Génération: of the game . the game was originally created on the next two - game games ....


## Évaluation sur Datasets

In [20]:
def load_eval_prompts(file_path, num_samples=50):
    prompts = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i >= num_samples:
                    break
                try:
                    data = json.loads(line)
                    item = data[0] if isinstance(data, list) else data
                    prompts.append({'prompt': item.get('prompt', ''), 'gold_ref': item.get('gold_ref', '')})
                except:
                    continue
    except FileNotFoundError:
        print(f"Fichier non trouvé: {file_path}")
    return prompts

EVAL_DATASETS = {
    'wikinews': 'wikinews_typical-0.95_gpt2-xl_256.jsonl',
    'wikitext': 'wikitext_typical-0.95_gpt2-xl_256.jsonl',
    'book': 'book_typical-0.95_gpt2-xl_256.jsonl'
}

for dataset_name, fpath in EVAL_DATASETS.items():
    prompts = load_eval_prompts(fpath, num_samples=50)
    print(f"\n[{dataset_name}] {len(prompts)} prompts")
    
    results = []
    for idx, item in enumerate(prompts):
        generated_results = {}
        for pred_idx in range(5):
            gen = generate_words(model, vocab, item['prompt'], 
                               max_length=config['max_gen_length'], 
                               mode='sampling', temp=config['temperature'],
                               top_k=config['top_k'], seq_len=config['seq_len'])
            generated_results[str(pred_idx)] = ' '.join(gen)
        
        results.append({
            'prefix_text': item['prompt'],
            'reference_text': item['gold_ref'],
            'generated_result': generated_results
        })
        
        if (idx + 1) % 10 == 0:
            print(f"  {idx + 1}/{len(prompts)}")
    
    output_file = f'{dataset_name}_lstm_predictions2.json'
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"  Sauvegardé: {output_file}")

print("\nFichiers prêts pour évaluation")


[wikinews] 50 prompts
  10/50
  20/50
  30/50
  40/50
  50/50
  Sauvegardé: wikinews_lstm_predictions2.json

[wikitext] 50 prompts
  10/50
  20/50
  30/50
  40/50
  50/50
  Sauvegardé: wikitext_lstm_predictions2.json

[book] 50 prompts
  10/50
  20/50
  30/50
  40/50
  50/50
  Sauvegardé: book_lstm_predictions2.json

Fichiers prêts pour évaluation


In [22]:
# Sauvegarder uniquement les poids (state_dict)
torch.save(model.state_dict(), 'mon_modele_weights.pth')

# Charger les poids
model.load_state_dict(torch.load('mon_modele_weights.pth'))
model.eval()


LSTMModel(
  (embed): Embedding(20397, 400)
  (dropout_layer): Dropout(p=0.5, inplace=False)
  (lstm): LSTM(400, 800, num_layers=3, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=800, out_features=20397, bias=True)
)

In [24]:
# Sauvegarder votre modèle déjà entraîné
torch.save({
    'model_state_dict': model.state_dict(),
    'vocab_size': vocab.vocab_size,
    'embedding_dim': 400,
    'hidden_dim': 800,
    'num_layers': 3,
    'dropout': 0.5
}, 'lstm_wikitext_model.pth')

print("Modèle sauvegardé avec succès!")


Modèle sauvegardé avec succès!


In [25]:
# Charger le modèle
checkpoint = torch.load('lstm_wikitext_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Modèle chargé avec succès!")


Modèle chargé avec succès!
