# ‚ú® Text Generation

G√©n√©rer du texte : Greedy, Top-k, Top-p, Temperature

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

## Strat√©gies de Sampling

### 1. Greedy Decoding
Toujours choisir le token le plus probable.

In [None]:
def greedy_sample(logits):
    """
    Greedy sampling: choisir le token avec la plus haute probabilit√©.
    
    Args:
        logits: (vocab_size,)
    
    Returns:
        token_id: int
    """
    return np.argmax(logits)

# Test
logits = np.array([1.0, 3.5, 2.0, 0.5])
token = greedy_sample(logits)
print(f"Logits: {logits}")
print(f"Token choisi (greedy): {token}")

### 2. Temperature Sampling
Contr√¥ler la "cr√©ativit√©" du mod√®le.

In [None]:
def temperature_sample(logits, temperature=1.0):
    """
    Sample avec temp√©rature.
    
    Args:
        logits: (vocab_size,)
        temperature: float
            - T < 1.0: Plus conservateur (peaked distribution)
            - T = 1.0: Distribution normale
            - T > 1.0: Plus cr√©atif (flatter distribution)
    
    Returns:
        token_id: int
    """
    # Apply temperature
    logits = logits / temperature
    
    # Softmax
    exp_logits = np.exp(logits - np.max(logits))
    probs = exp_logits / np.sum(exp_logits)
    
    # Sample
    token = np.random.choice(len(probs), p=probs)
    return token

# Test avec diff√©rentes temp√©ratures
logits = np.array([1.0, 3.5, 2.0, 0.5])

print("√âchantillonnage avec diff√©rentes temp√©ratures:")
for temp in [0.5, 1.0, 2.0]:
    samples = [temperature_sample(logits, temp) for _ in range(100)]
    print(f"\nT={temp}:")
    for i in range(4):
        count = samples.count(i)
        print(f"  Token {i}: {count}%")

### 3. Top-k Sampling
Ne consid√©rer que les k tokens les plus probables.

In [None]:
def top_k_sample(logits, k=10, temperature=1.0):
    """
    Top-k sampling.
    
    Args:
        logits: (vocab_size,)
        k: int - nombre de top tokens √† consid√©rer
        temperature: float
    
    Returns:
        token_id: int
    """
    # Get top k indices
    top_k_indices = np.argsort(logits)[-k:]
    top_k_logits = logits[top_k_indices]
    
    # Apply temperature
    top_k_logits = top_k_logits / temperature
    
    # Softmax
    exp_logits = np.exp(top_k_logits - np.max(top_k_logits))
    probs = exp_logits / np.sum(exp_logits)
    
    # Sample from top k
    sampled_index = np.random.choice(len(probs), p=probs)
    token = top_k_indices[sampled_index]
    
    return token

# Test
logits = np.random.randn(100)
token = top_k_sample(logits, k=10)
print(f"Token choisi (top-k): {token}")
print(f"Rang du token: {np.argsort(logits)[::-1].tolist().index(token) + 1}")

### 4. Top-p (Nucleus) Sampling
Consid√©rer les tokens jusqu'√† ce que la probabilit√© cumul√©e atteigne p.

In [None]:
def top_p_sample(logits, p=0.9, temperature=1.0):
    """
    Top-p (nucleus) sampling.
    
    Args:
        logits: (vocab_size,)
        p: float - probabilit√© cumul√©e cible (0 < p <= 1)
        temperature: float
    
    Returns:
        token_id: int
    """
    # Apply temperature
    logits = logits / temperature
    
    # Softmax
    exp_logits = np.exp(logits - np.max(logits))
    probs = exp_logits / np.sum(exp_logits)
    
    # Sort probabilities
    sorted_indices = np.argsort(probs)[::-1]
    sorted_probs = probs[sorted_indices]
    
    # Get cumulative probabilities
    cumsum_probs = np.cumsum(sorted_probs)
    
    # Find cutoff
    cutoff_index = np.searchsorted(cumsum_probs, p) + 1
    
    # Keep only top-p tokens
    nucleus_indices = sorted_indices[:cutoff_index]
    nucleus_probs = sorted_probs[:cutoff_index]
    nucleus_probs = nucleus_probs / np.sum(nucleus_probs)  # Renormalize
    
    # Sample
    sampled_index = np.random.choice(len(nucleus_probs), p=nucleus_probs)
    token = nucleus_indices[sampled_index]
    
    return token

# Test
logits = np.random.randn(100)
token = top_p_sample(logits, p=0.9)
print(f"Token choisi (top-p): {token}")

## G√©n√©ration de Texte Compl√®te

In [None]:
def generate_text(model, prompt, max_length=100, strategy='top_p', **kwargs):
    """
    G√©n√®re du texte de mani√®re autoregressive.
    
    Args:
        model: GPT model
        prompt: str - texte de d√©part
        max_length: int - nombre de tokens √† g√©n√©rer
        strategy: str - 'greedy', 'temperature', 'top_k', ou 'top_p'
        **kwargs: arguments pour la strat√©gie de sampling
    
    Returns:
        generated_text: str
    """
    # Encode prompt
    context = encode(prompt)
    
    # Generate tokens
    for _ in range(max_length):
        # Get logits for next token
        logits = model.forward(np.array([context]))
        next_token_logits = logits[0, -1, :]  # Last token
        
        # Sample next token
        if strategy == 'greedy':
            next_token = greedy_sample(next_token_logits)
        elif strategy == 'temperature':
            next_token = temperature_sample(next_token_logits, **kwargs)
        elif strategy == 'top_k':
            next_token = top_k_sample(next_token_logits, **kwargs)
        elif strategy == 'top_p':
            next_token = top_p_sample(next_token_logits, **kwargs)
        else:
            raise ValueError(f"Unknown strategy: {strategy}")
        
        # Append to context
        context.append(next_token)
        
        # Optional: stop at end of sentence
        if decode([next_token]) in ['.', '!', '?']:
            break
    
    # Decode
    return decode(context)

# Example usage (pseudo-code)
# generated = generate_text(model, "To be or not to be", 
#                          max_length=50, 
#                          strategy='top_p', 
#                          p=0.9, 
#                          temperature=0.8)
# print(generated)

## Comparaison des Strat√©gies

In [None]:
# Visualiser l'impact de la temp√©rature
logits = np.array([1.0, 3.5, 2.0, 0.5, 1.5])
temperatures = [0.5, 1.0, 2.0]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, temp in enumerate(temperatures):
    scaled_logits = logits / temp
    exp_logits = np.exp(scaled_logits - np.max(scaled_logits))
    probs = exp_logits / np.sum(exp_logits)
    
    axes[i].bar(range(len(probs)), probs)
    axes[i].set_title(f'Temperature = {temp}')
    axes[i].set_xlabel('Token ID')
    axes[i].set_ylabel('Probabilit√©')
    axes[i].set_ylim([0, 1])

plt.tight_layout()
plt.show()

print("\nüî• Temp√©rature basse (0.5): Plus conservateur, peaked")
print("üòê Temp√©rature normale (1.0): Distribution originale")
print("üé≤ Temp√©rature haute (2.0): Plus cr√©atif, flat")