In [None]:
"""
## 12. 简单神经网络 (Simple Neural Network)

### 网络结构
- 输入层 -> 隐藏层（Sigmoid激活）-> 输出层（Sigmoid激活）
- 适用于二分类问题

### 前向传播
1. 隐藏层: Z1 = W1*X + b1, A1 = sigmoid(Z1)
2. 输出层: Z2 = W2*A1 + b2, A2 = sigmoid(Z2)

### 反向传播（链式法则）
1. 输出层梯度: dZ2 = A2 - y
2. 隐藏层梯度: dZ1 = (W2^T * dZ2) * sigmoid'(Z1)
3. 权重更新: W -= learning_rate * dW

### 应用场景
- XOR问题（需隐藏层）
- 二分类问题
- 神经网络基础学习

### 复杂度
- 训练: O(epochs * samples * (features + hidden_size))
- 预测: O(features * hidden_size)
"""

import numpy as np

class SimpleNN:
    def __init__(self, input_size, hidden_size, output_size):
        """初始化网络参数"""
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        # 权重初始化 (Xavier初始化)
        self.W1 = np.random.randn(hidden_size, input_size) * 0.1
        self.b1 = np.zeros((hidden_size, 1))
        self.W2 = np.random.randn(output_size, hidden_size) * 0.1
        self.b2 = np.zeros((output_size, 1))
        
    def sigmoid(self, z):
        """Sigmoid激活函数"""
        return 1 / (1 + np.exp(-np.clip(z, -500, 500)))  # 防止溢出
    
    def sigmoid_derivative(self, z):
        """Sigmoid导数: σ'(z) = σ(z) * (1 - σ(z))"""
        s = self.sigmoid(z)
        return s * (1 - s)

    def forward(self, X):
        """前向传播"""
        # 隐藏层
        self.Z1 = np.dot(self.W1, X) + self.b1
        self.A1 = self.sigmoid(self.Z1)
        
        # 输出层
        self.Z2 = np.dot(self.W2, self.A1) + self.b2
        self.A2 = self.sigmoid(self.Z2)
        
        return self.A2

    def backward(self, X, y):
        """反向传播"""
        num_samples = X.shape[1]
        
        # 输出层梯度
        dZ2 = self.A2 - y  # 二元交叉熵 + Sigmoid的导数简化形式
        self.dW2 = (1 / num_samples) * np.dot(dZ2, self.A1.T)
        self.db2 = (1 / num_samples) * np.sum(dZ2, axis=1, keepdims=True)
        
        # 隐藏层梯度
        dA1 = np.dot(self.W2.T, dZ2)
        dZ1 = dA1 * self.sigmoid_derivative(self.Z1)
        self.dW1 = (1 / num_samples) * np.dot(dZ1, X.T)
        self.db1 = (1 / num_samples) * np.sum(dZ1, axis=1, keepdims=True)

    def update_parameters(self, learning_rate):
        """使用梯度下降更新参数"""
        self.W1 -= learning_rate * self.dW1
        self.b1 -= learning_rate * self.db1
        self.W2 -= learning_rate * self.dW2
        self.b2 -= learning_rate * self.db2

    def train(self, X, y, epochs, learning_rate, print_cost=True):
        """完整的训练循环"""
        for i in range(epochs):
            # 前向传播
            predictions = self.forward(X)
            
            # 计算损失 (二元交叉熵)
            cost = -np.mean(y * np.log(predictions + 1e-8) + (1 - y) * np.log(1 - predictions + 1e-8))
            
            # 反向传播
            self.backward(X, y)
            
            # 更新参数
            self.update_parameters(learning_rate)
            
            if print_cost and i % 1000 == 0:
                print(f"Cost after epoch {i}: {cost:.6f}")

    def predict(self, X):
        """预测"""
        predictions = self.forward(X)
        return (predictions > 0.5).astype(int)

# 测试：解决XOR问题
if __name__ == "__main__":
    # XOR数据集
    X_train = np.array([[0, 0, 1, 1], [0, 1, 0, 1]])
    y_train = np.array([[0, 1, 1, 0]])
    
    # 创建并训练模型
    nn = SimpleNN(input_size=2, hidden_size=4, output_size=1)
    nn.train(X_train, y_train, epochs=10000, learning_rate=0.5)
    
    # 测试
    predictions = nn.predict(X_train)
    accuracy = np.mean(predictions == y_train) * 100
    print(f"Accuracy: {accuracy:.2f}%")


In [None]:
"""
## 11. Transformer 注意力机制详解

### 四种注意力机制对比

#### 1. 多头注意力 (MHA - Multi-Head Attention)
- **特点**: 完整的自注意力，每个Query对应一个完整的K-V序列
- **参数**: num_heads = 8（标准）
- **计算**: Q = [B,L,D], K = [B,L,D], V = [B,L,D] -> Output = [B,L,D]
- **复杂度**: O(n²) 其中n是序列长度

#### 2. 多查询注意力 (MQA - Multi-Query Attention)
- **特点**: 所有Query共享一套K-V
- **参数**: 只有一组K和V（不分头）
- **优势**: 内存占用少，KV缓存小，适合生成任务
- **应用**: LLaMA 2, Falcon等

#### 3. 分组查询注意力 (GQA - Grouped-Query Attention)
- **特点**: Query分组，每组共享一套K-V
- **参数**: num_kv_heads < num_q_heads（如n_heads=8, kv_heads=2）
- **优势**: 平衡MHA和MQA，介于两者之间
- **应用**: Gemini, Llama 3.1等

#### 4. 线性注意力 (Linear Attention)
- **特点**: 将Softmax改为其他激活函数（如ELU），实现O(n)复杂度
- **原理**: Q*V通过数学变形避免显式计算注意力矩阵
- **应用**: 长序列处理

#### 5. Flash Attention
- **特点**: 优化的CUDA内存访问模式
- **优势**: 不改变算法，只改进内存利用，速度快3-4倍
- **核心**: 分块处理和在线Softmax
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MHA(nn.Module):
    """多头注意力 (Multi-Head Attention)"""
    def __init__(self, embed_dim, num_heads):
        super(MHA, self).__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.w_q = nn.Linear(embed_dim, embed_dim)
        self.w_k = nn.Linear(embed_dim, embed_dim)
        self.w_v = nn.Linear(embed_dim, embed_dim)
        self.w_o = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, attn_mask=None):
        batch_size, seq_len, _ = query.shape
        
        # 投影后分头 -> [B, H, L, D_h]
        query = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力 -> [B, H, L, L]
        attention = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if attn_mask is not None:
            attention += attn_mask.unsqueeze(1)
        
        # Softmax和加权求和
        attn_weights = self.softmax(attention)
        output = torch.matmul(attn_weights, value)  # [B, H, L, D_h]
        
        # 合并头 -> [B, L, D]
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        output = self.w_o(output)
        return output

class MQA(nn.Module):
    """多查询注意力 (Multi-Query Attention)"""
    def __init__(self, embed_dim, num_heads):
        super(MQA, self).__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.w_q = nn.Linear(embed_dim, embed_dim)  # Query仍然分头
        self.w_k = nn.Linear(embed_dim, self.head_dim)  # Key只有一份
        self.w_v = nn.Linear(embed_dim, self.head_dim)  # Value只有一份
        self.w_o = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, attn_mask=None):
        batch_size, seq_len, _ = query.shape
        
        # Q分头，K-V不分头
        query = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.w_k(key).unsqueeze(1)  # [B, 1, L, D_h]
        value = self.w_v(value).unsqueeze(1)  # [B, 1, L, D_h]
        
        # 注意力计算 -> [B, H, L, L]
        attention = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if attn_mask is not None:
            attention += attn_mask.unsqueeze(1)
        
        attn_weights = self.softmax(attention)
        output = torch.matmul(attn_weights, value)  # [B, H, L, D_h]
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        output = self.w_o(output)
        return output

class GQA(nn.Module):
    """分组查询注意力 (Grouped-Query Attention)"""
    def __init__(self, embed_dim, num_q_heads, num_kv_heads):
        super(GQA, self).__init__()
        assert embed_dim % num_q_heads == 0
        assert num_q_heads % num_kv_heads == 0
        self.embed_dim = embed_dim
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_q_heads // num_kv_heads  # 每个KV头服务的Q头数
        self.head_dim = embed_dim // num_q_heads
        
        self.w_q = nn.Linear(embed_dim, embed_dim)
        self.w_k = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.w_v = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.w_o = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, attn_mask=None):
        batch_size, seq_len, _ = query.shape
        
        query = self.w_q(query).view(batch_size, seq_len, self.num_q_heads, self.head_dim).transpose(1, 2)
        key = self.w_k(key).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        value = self.w_v(value).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        # 复制KV使其与Q保持相同的头数
        key = key.unsqueeze(2).repeat(1, 1, self.num_groups, 1, 1).view(batch_size, self.num_q_heads, seq_len, self.head_dim)
        value = value.unsqueeze(2).repeat(1, 1, self.num_groups, 1, 1).view(batch_size, self.num_q_heads, seq_len, self.head_dim)
        
        attention = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if attn_mask is not None:
            attention += attn_mask.unsqueeze(1)
        
        attn_weights = self.softmax(attention)
        output = torch.matmul(attn_weights, value).transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        output = self.w_o(output)
        return output


In [None]:
"""
## 13. Transformer 位置编码 (Positional Encoding)

### 为什么需要位置编码？
- Transformer没有循环或卷积，无法捕捉序列顺序信息
- 需要显式地编码位置信息

### 两种位置编码方式对比

#### 1. 绝对位置编码 (APE - Additive Positional Encoding)
- **方式**: PE = PE + word_embedding
- **公式**:
  - PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
  - PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
- **特点**: 
  - 静态的位置编码
  - 加法融合
  - 每个位置都有一个固定的向量
- **应用**: 原始Transformer, BERT

#### 2. 旋转位置编码 (RoPE - Rotary Position Embedding)
- **方式**: 通过旋转矩阵作用于Q和K
- **优势**:
  - 相对位置信息编码
  - 距离衰减自然出现
  - 长序列外推能力强
- **应用**: GPT-3.5, LLaMA等最新模型

### 公式解释
- θ_i = 10000^(-2i/d_model): 不同维度的基频率
- m * θ_i: 位置m与基频率的乘积
- 使用sin/cos对进行编码：易于捕捉周期性模式
"""

import numpy as np

class TransformerPositionalEncodings:
    """位置编码类：对比APE和RoPE"""

    def __init__(self, d_model, max_position=1024):
        """
        初始化位置编码
        
        Args:
            d_model: 模型维度（必须是偶数）
            max_position: 最大序列长度
        """
        if d_model % 2 != 0:
            raise ValueError("d_model必须是偶数")
        self.d_model = d_model
        self.max_position = max_position
        
        # 预计算频率 (APE和RoPE的共同基础)
        self.frequencies = self._precompute_frequencies()
        self.freqs_cos = np.cos(self.frequencies)
        self.freqs_sin = np.sin(self.frequencies)

    def _precompute_frequencies(self):
        """
        预计算核心频率矩阵: m * θ_i
        
        Returns:
            np.array: shape=[max_position, d_model/2]
        """
        # 计算 θ_i = 1 / (10000^(2i/d_model))
        theta = 1.0 / (10000 ** (np.arange(0, self.d_model, 2, dtype=np.float32) / self.d_model))
        
        # 位置索引 m = [0, 1, ..., max_position-1]
        position = np.arange(self.max_position, dtype=np.float32)
        
        # 外积: m * θ_i -> [max_position, d_model/2]
        freqs_matrix = np.outer(position, theta)
        return freqs_matrix

    def get_absolute_encoding(self):
        """
        ### APE (加性位置编码)
        
        生成静态位置编码矩阵
        - 偶数维度: sin(m*θ_i)
        - 奇数维度: cos(m*θ_i)
        
        Returns:
            np.array: shape=[1, max_position, d_model]
        """
        pe = np.zeros((self.max_position, self.d_model))
        
        # 偶数维度应用sin
        pe[:, 0::2] = self.freqs_sin
        
        # 奇数维度应用cos
        pe[:, 1::2] = self.freqs_cos
        
        return pe[np.newaxis, :, :]

    def apply_rotary_encoding(self, x):
        """
        ### RoPE (旋转位置编码)
        
        将旋转矩阵应用于Q和K张量
        
        原理:
        1. 将向量 x = [x_0, x_1, x_2, x_3, ...] 分组为 (x_0,x_1), (x_2,x_3), ...
        2. 对每对应用2D旋转矩阵:
           [cos θ  -sin θ] [x_0]
           [sin θ   cos θ] [x_1]
        3. 等价于: x * cos θ + (-x_1, x_0) * sin θ
        
        Args:
            x: [batch_size, seq_len, d_model]
            
        Returns:
            x with rotary encoding applied
        """
        seq_len = x.shape[1]
        
        # 重复cos/sin使其覆盖所有维度 (2i, 2i+1共用)
        cos_emb = np.repeat(self.freqs_cos, 2, axis=-1)
        sin_emb = np.repeat(self.freqs_sin, 2, axis=-1)
        
        # 截取序列长度内的编码
        cos = cos_emb[np.newaxis, :seq_len, :]
        sin = sin_emb[np.newaxis, :seq_len, :]
        
        # 旋转变换: 构造 (-x_1, x_0, -x_3, x_2, ...)
        x_reshaped = x.reshape(x.shape[:-1] + (-1, 2))
        x_rotated = np.stack([-x_reshaped[..., 1], x_reshaped[..., 0]], axis=-1)
        x_rotated = x_rotated.reshape(x.shape)
        
        # 应用旋转公式: x*cos + rotated*sin
        return x * cos + x_rotated * sin

# 对比演示
if __name__ == "__main__":
    d_model = 128
    batch_size = 2
    seq_len = 50
    
    pos_encoder = TransformerPositionalEncodings(d_model, max_position=1024)
    
    # APE演示
    print("=== 绝对位置编码 (APE) ===")
    ape_matrix = pos_encoder.get_absolute_encoding()
    word_embeddings = np.random.rand(batch_size, seq_len, d_model)
    final_input_ape = word_embeddings + ape_matrix[:, :seq_len, :]
    print(f"APE矩阵形状: {ape_matrix.shape}")
    print(f"融合后形状: {final_input_ape.shape}")
    
    # RoPE演示
    print("\\n=== 旋转位置编码 (RoPE) ===")
    query = np.random.rand(batch_size, seq_len, d_model)
    key = np.random.rand(batch_size, seq_len, d_model)
    query_with_rope = pos_encoder.apply_rotary_encoding(query)
    key_with_rope = pos_encoder.apply_rotary_encoding(key)
    print(f"RoPE后Query形状: {query_with_rope.shape}")
    print(f"RoPE后Key形状: {key_with_rope.shape}")


In [None]:
"""
## 9. 字节对编码 (Byte Pair Encoding, BPE)

### 概述
BPE是一种数据压缩技术，经常用于NLP的分词。它通过迭代合并最频繁的连续字节对来构建词汇表。

### 算法步骤
1. **初始化**: 词汇表包含所有单个字符
2. **统计**: 计算所有相邻字符对的频率
3. **合并**: 选择频率最高的字符对进行合并
4. **重复**: 重复步骤2-3直到达到目标词汇表大小

### 关键概念
- **字符对**: 文本中相邻的两个符号
- **合并规则**: 记录所有的合并操作，用于编码新文本
- **贪心策略**: 总是选择当前频率最高的字符对

### 应用
- GPT-2, GPT-3 使用BPE分词
- 处理OOV（词表外）问题
- 平衡词汇量和处理效率
"""

from collections import defaultdict, Counter
import re

def get_vocabulary(corpus):
    """
    从语料库生成初始词汇表
    
    Args:
        corpus: 可迭代的字符串
        
    Returns:
        dict: {token_sequence: frequency}
    """
    vocab = Counter()
    for line in corpus:
        for word in line.strip().split():
            # 将单词拆分为字符，并加上结束符'</w>'
            tokens = list(word) + ['</w>']
            vocab[' '.join(tokens)] += 1
    return dict(vocab)

def get_stats(vocab):
    """统计词汇表中所有相邻符号对的频率"""
    pairs = defaultdict(int)
    for token_seq, freq in vocab.items():
        symbols = token_seq.split()
        for i in range(len(symbols)-1):
            pairs[(symbols[i], symbols[i+1])] += freq
    return pairs

def merge_pair(pair, v_in):
    """
    在词汇表中合并指定的字符对
    
    Args:
        pair: (symbol1, symbol2) 要合并的对
        v_in: 输入词汇表
        
    Returns:
        dict: 合并后的词汇表
    """
    v_out = {}
    bigram = re.escape(' '.join(pair))
    pattern = re.compile(r'(?<!\S)' + bigram + r'(?<!\S)')
    for token_seq, freq in v_in.items():
        new_token = pattern.sub(''.join(pair), token_seq)
        v_out[new_token] = freq
    return v_out

def learn_bpe(corpus, num_merges):
    """
    学习BPE合并规则
    
    Args:
        corpus: 文本语料库
        num_merges: 合并操作次数
        
    Returns:
        tuple: (最终词汇表, 合并规则列表)
    """
    vocab = get_vocabulary(corpus)
    merges = []
    
    for i in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        
        # 选择频率最高的字符对
        best = max(pairs, key=pairs.get)
        merges.append(best)
        vocab = merge_pair(best, vocab)
        print(f"Merge {i+1}: {best[0]} + {best[1]} -> {''.join(best)}")
    
    return vocab, merges

def encode_word(word, merges):
    """
    使用学到的合并规则对单词进行编码
    
    Args:
        word: 要编码的单词
        merges: 合并规则列表
        
    Returns:
        list: 编码后的token列表
    """
    symbols = list(word) + ['</w>']
    token_seq = ' '.join(symbols)
    
    for a, b in merges:
        pair = a + ' ' + b
        token_seq = token_seq.replace(pair, a + b)
    
    return token_seq.split()

# 测试示例
if __name__ == '__main__':
    corpus = ['low lower lowest', 'newer wider']
    vocab, merges = learn_bpe(corpus, num_merges=10)
    
    # 编码测试
    for w in ['low', 'lower', 'newest']:
        enc = encode_word(w, merges)
        print(f"{w} -> {enc}")


In [None]:
"""
## 10. WordPiece 分词算法

### 概述
WordPiece 是 Google 开发的分词技术，用于 BERT 等预训练模型。
与 BPE 不同，WordPiece 使用**概率分数**而非频率选择合并对。

### 关键区别（vs BPE）
- **合并评分**: 
  - BPE: 按频率 count(pair) 排序
  - WordPiece: 按 count(pair) / (count(token1) * count(token2)) 排序
  
- **编码方式**:
  - BPE: 顺序应用所有合并规则
  - WordPiece: 贪心最长匹配（从左到右，选择最长的有效token）

### 算法步骤
1. 初始化词汇表为所有字符
2. 计算所有token对的概率分数
3. 选择分数最高的对进行合并
4. 重复直到达到目标词汇表大小

### 应用
- BERT, ALBERT等模型
- 处理未知词的标准方法
- 非起始位置token加##前缀（如##ing）
"""

import collections

def get_stats(word_freqs):
    """统计单个token和相邻token对的频率"""
    token_counts = collections.defaultdict(int)
    pair_counts = collections.defaultdict(int)
    
    for word_tokens, freq in word_freqs.items():
        for token in word_tokens:
            token_counts[token] += freq
        for i in range(len(word_tokens) - 1):
            pair_counts[(word_tokens[i], word_tokens[i+1])] += freq
    
    return token_counts, pair_counts

def merge_vocab(best_pair, word_freqs):
    """在词汇表中执行合并操作"""
    new_word_freqs = collections.defaultdict(int)
    new_token = "".join(best_pair)
    
    for word_tokens, freq in word_freqs.items():
        new_word_tokens = []
        i = 0
        while i < len(word_tokens):
            try:
                if word_tokens[i] == best_pair[0] and word_tokens[i+1] == best_pair[1]:
                    new_word_tokens.append(new_token)
                    i += 2
                else:
                    new_word_tokens.append(word_tokens[i])
                    i += 1
            except IndexError:
                new_word_tokens.append(word_tokens[i])
                i += 1
        
        new_word_freqs[tuple(new_word_tokens)] = freq
    
    return new_word_freqs

def train_wordpiece(corpus, vocab_size):
    """
    训练 WordPiece 模型
    
    Args:
        corpus: {word: frequency} 格式的语料库
        vocab_size: 目标词汇表大小
        
    Returns:
        set: 最终词汇表
    """
    # 初始化词汇表
    vocab = {'[UNK]'}
    for word in corpus.keys():
        vocab.update(list(word))
    
    # 预分词
    word_freqs = collections.defaultdict(int)
    for word, freq in corpus.items():
        word_tokens = tuple(word) + ('</w>',)
        word_freqs[word_tokens] = freq
    
    # 迭代合并
    num_merges = vocab_size - len(vocab)
    for i in range(num_merges):
        token_counts, pair_counts = get_stats(word_freqs)
        
        if not pair_counts:
            break
        
        # 计算概率分数：count(pair) / (count(token1) * count(token2))
        best_pair = None
        max_score = -1
        for pair, count in pair_counts.items():
            if token_counts[pair[0]] > 0 and token_counts[pair[1]] > 0:
                score = count / (token_counts[pair[0]] * token_counts[pair[1]])
                if score > max_score:
                    max_score = score
                    best_pair = pair
        
        if best_pair is None:
            break
        
        # 执行合并
        word_freqs = merge_vocab(best_pair, word_freqs)
        new_token = "".join(best_pair)
        vocab.add(new_token)
        
        print(f"Merge {i+1}/{num_merges}: {best_pair} -> {new_token} (Score: {max_score:.4f})")
    
    return vocab

def encode_wordpiece(text, vocab):
    """
    使用WordPiece进行编码（贪心最长匹配）
    """
    if text == "":
        return []
    
    tokens = []
    start = 0
    while start < len(text):
        # 从长到短寻找词汇表中的最长子词
        match = ""
        for end in range(len(text), start, -1):
            subword = text[start:end]
            if subword in vocab:
                match = subword
                break
        
        if not match:
            tokens.append('[UNK]')
            start += 1
            continue
        
        # 非起始位置的token加##前缀
        if start > 0:
            tokens.append('##' + match)
        else:
            tokens.append(match)
        
        start += len(match)
    
    return tokens

# 测试
if __name__ == "__main__":
    corpus = {'low': 5, 'lower': 2, 'newest': 6, 'widest': 3}
    final_vocab = train_wordpiece(corpus, vocab_size=30)
    
    clean_vocab = {token.replace('</w>', '') for token in final_vocab}
    clean_vocab.add('[UNK]')
    
    test_words = ["lowest", "newer"]
    for word in test_words:
        encoded = encode_wordpiece(word, clean_vocab)
        print(f"{word} -> {encoded}")
