# Part 2: SwiGLU Activation & RMSNorm

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## RMSNorm Module

In [22]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x / rms * self.weight

## SwiGLU FFN

In [23]:
class SwiGLUFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff*2, bias=False)    # Gate, Value
        self.w2 = nn.Linear(d_ff, d_model, bias=False)    # Output

    def forward(self, x):
        # gate, value = torch.split(self.w1(x), dim=-1)
        gate, value = self.w1(x).chunk(2, dim=-1)
        output = self.w2(F.silu(gate) * value)
        return output

## Experiment

In [24]:
class StandardFFN(nn.Module):
    '''FFN with GeLU'''
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        
    def forward(self, x):
        return self.w2(F.gelu(self.w1(x)))
    
class CausalSelfAttention(nn.Module):
    '''Attention Module'''
    def __init__(self, d_model, n_head):
        super().__init__()
        assert d_model % n_head == 0
        self.n_head = n_head
        self.d_head = d_model // n_head
        
        self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)
        
    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        
        q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)
        
        attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn = attn.masked_fill(torch.tril(torch.ones(T, T)).to(x.device) == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        
        out = attn @ v
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(out)
    
class TransformerBlock(nn.Module):
    '''Transformer Block Basic Class'''
    def __init__(self, d_model, n_head, d_ff):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_ff = d_ff
        
    def forward(self, x):
        raise NotImplementedError
    
class StandardTransformerBlock(TransformerBlock):
    '''Config 1: GeLU + Layernorm - nanoGPT'''
    def __init__(self, d_model, n_head, d_ff):
        super().__init__(d_model, n_head, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_head)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = StandardFFN(d_model, d_ff)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x
    
class SwiGLUTransformerBlock(TransformerBlock):
    '''Config 2: SwiGLU + RMSNorm - PaLM/Llama'''
    def __init__(self, d_model, n_head, d_ff):
        super().__init__(d_model, n_head, d_ff)
        self.norm1 = RMSNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_head)
        self.norm2 = RMSNorm(d_model)
        self.ffn = SwiGLUFFN(d_model, d_ff)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x
    
class HybridTransformerBlock(TransformerBlock):
    '''Config 3: SwiGLU + LayerNorm'''
    def __init__(self, d_model, n_head, d_ff):
        super().__init__(d_model, n_head, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_head)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = SwiGLUFFN(d_model, d_ff)
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm1(x))
        return x
    
class TinyTransformer(nn.Module):
    '''Tiny Transformer Module'''
    def __init__(self, vocab_size, d_model, n_head, n_layer, d_ff, 
                 block_type='standard', max_seq_len=1024):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model

        # Token & Postion Embeding
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)

        # Transformer Block
        block_classes = {
            'standard': StandardTransformerBlock,
            'swiglu': SwiGLUTransformerBlock,
            'hybrid': HybridTransformerBlock
        }
        assert block_type in block_classes, f"Unknown block type: {block_type}"
        block_class = block_classes[block_type]

        self.blocks = nn.Sequential(*[
            block_class(d_model, n_head, d_ff) for _ in range(n_layer)
        ])

        # Norm and Linear Layer
        self.norm_f = nn.LayerNorm(d_model) if block_type != 'swiglu' else RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, RMSNorm):
            torch.nn.init.ones_(module.weight)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # Tokenizer
        token_emb = self.token_emb(idx)     # (B, T, d_model)
        pos = torch.arange(T, device=idx.device)
        pos_emb = self.pos_emb(pos)         # (T, d_model)
        x = token_emb + pos_emb

        # Transformer Block
        x = self.blocks(x)

        # Norm & Linear
        x = self.norm_f(x)
        logits = self.lm_head(x)

        # Loss
        loss = None
        if targets is not None:
            B, T, C = logits.shape
            inputs = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(inputs, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            # Truncate to max seq length
            idx_cond = idx[:, -self.pos_emb.num_embeddings:]

            # Forward
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]/temperature

            # Top-K
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1)
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Sample
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

            idx = torch.cat((idx, idx_next), dim=1)
        
        return idx

### Test in Tiny Transformer

In [25]:
# 测试模型
def test_models():
    # 参数
    vocab_size = 1000
    d_model = 256
    n_head = 8
    n_layer = 6
    d_ff = 512
    batch_size = 4
    seq_len = 32
    
    # 测试三种配置
    configs = ['standard', 'swiglu', 'hybrid']
    
    for config in configs:
        print(f"\n=== 测试 {config} 配置 ===")
        model = TinyTransformer(
            vocab_size=vocab_size,
            d_model=d_model,
            n_head=n_head,
            n_layer=n_layer,
            d_ff=d_ff,
            block_type=config
        )
        
        # 创建随机输入
        x = torch.randint(0, vocab_size, (batch_size, seq_len))
        
        # 前向传播
        logits, loss = model(x, x)  # 使用输入作为目标（仅用于测试）
        
        print(f"输入形状: {x.shape}")
        print(f"Logits形状: {logits.shape}")
        print(f"损失值: {loss.item():.4f}")
        
        # 参数计数
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"总参数: {total_params:,}")
        print(f"可训练参数: {trainable_params:,}")

if __name__ == "__main__":
    test_models()


=== 测试 standard 配置 ===
输入形状: torch.Size([4, 32])
Logits形状: torch.Size([4, 32, 1000])
损失值: 6.9787
总参数: 3,926,528
可训练参数: 3,926,528

=== 测试 swiglu 配置 ===
输入形状: torch.Size([4, 32])
Logits形状: torch.Size([4, 32, 1000])
损失值: 6.9252
总参数: 4,709,632
可训练参数: 4,709,632

=== 测试 hybrid 配置 ===
输入形状: torch.Size([4, 32])
Logits形状: torch.Size([4, 32, 1000])
损失值: 6.9384
总参数: 4,709,888
可训练参数: 4,709,888


### Training Comparison

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import os

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# 数据集类（模拟 OpenWebText 子集）
class TextDataset(Dataset):
    def __init__(self, num_samples=10000, seq_len=128, vocab_size=1000):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        
        # 生成随机文本数据
        self.data = torch.randint(0, vocab_size, (num_samples, seq_len + 1))
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        x = self.data[idx, :-1]
        y = self.data[idx, 1:]
        return x, y


# 训练函数
def train_model(config_name, block_type, num_steps=10000, 
                batch_size=32, lr=1e-3, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """训练单个配置的模型"""
    
    print(f"\n{'='*50}")
    print(f"开始训练配置: {config_name}")
    print(f"{'='*50}")
    
    # 模型参数
    vocab_size = 1000
    d_model = 256
    n_head = 8
    n_layer = 6
    d_ff = 512
    seq_len = 128
    
    # 创建模型
    model = TinyTransformer(
        vocab_size=vocab_size,
        d_model=d_model,
        n_head=n_head,
        n_layer=n_layer,
        d_ff=d_ff,
        block_type=block_type,
        max_seq_len=seq_len
    ).to(device)
    
    # 创建数据集和DataLoader
    dataset = TextDataset(num_samples=5000, seq_len=seq_len, vocab_size=vocab_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # 优化器
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    # 学习率调度器
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
    
    # 训练记录
    losses = []
    step_times = []
    throughputs = []
    
    # 训练循环
    model.train()
    data_iter = iter(dataloader)
    global_step = 0
    
    pbar = tqdm(total=num_steps, desc=f"训练 {config_name}")
    
    while global_step < num_steps:
        try:
            x, y = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            x, y = next(data_iter)
        
        x, y = x.to(device), y.to(device)
        
        # 前向传播
        start_time = time.time()
        optimizer.zero_grad()
        _, loss = model(x, y)
        
        # 反向传播
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # 优化步骤
        optimizer.step()
        scheduler.step()
        
        # 记录时间
        step_time = time.time() - start_time
        throughput = batch_size / step_time
        
        # 记录指标
        losses.append(loss.item())
        step_times.append(step_time)
        throughputs.append(throughput)
        
        # 更新进度条
        pbar.update(1)
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'throughput': f'{throughput:.1f} samples/sec'
        })
        
        global_step += 1
    
    pbar.close()
    
    # 计算平均指标
    avg_loss = np.mean(losses[-100:])  # 最后100步的平均损失
    avg_throughput = np.mean(throughputs[-100:])
    
    print(f"\n训练完成 - {config_name}:")
    print(f"最终损失: {losses[-1]:.4f}")
    print(f"平均损失 (最后100步): {avg_loss:.4f}")
    print(f"平均吞吐量: {avg_throughput:.1f} samples/sec")
    
    return {
        'config_name': config_name,
        'losses': losses,
        'throughputs': throughputs,
        'step_times': step_times,
        'final_loss': losses[-1],
        'avg_loss': avg_loss,
        'avg_throughput': avg_throughput,
        'model': model
    }


# 主训练比较函数
def run_training_comparison():
    # 训练配置
    configs = [
        ('GeLU+LayerNorm', 'standard'),
        ('SwiGLU+RMSNorm', 'swiglu'),
        ('SwiGLU+LayerNorm (hybrid)', 'hybrid')
    ]
    
    # 训练参数
    num_steps = 10000
    batch_size = 32
    lr = 1e-3
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"使用设备: {device}")
    print(f"训练步数: {num_steps}")
    print(f"批次大小: {batch_size}")
    print(f"学习率: {lr}")
    
    # 训练所有配置
    results = {}
    for config_name, block_type in configs:
        result = train_model(
            config_name=config_name,
            block_type=block_type,
            num_steps=num_steps,
            batch_size=batch_size,
            lr=lr,
            device=device
        )
        results[config_name] = result
    
    # 绘制损失曲线
    plt.figure(figsize=(15, 5))
    
    # 子图1: 损失曲线
    plt.subplot(1, 2, 1)
    for config_name in results:
        losses = results[config_name]['losses']
        # 平滑损失曲线
        window = 50
        smoothed_losses = np.convolve(losses, np.ones(window)/window, mode='valid')
        steps = np.arange(len(smoothed_losses))
        
        plt.plot(steps, smoothed_losses, label=config_name, linewidth=2)
    
    plt.xlabel('训练步数')
    plt.ylabel('损失 (平滑)')
    plt.title('训练损失曲线 (滑动窗口=50)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 子图2: 吞吐量比较
    plt.subplot(1, 2, 2)
    config_names = []
    avg_throughputs = []
    
    for config_name in results:
        config_names.append(config_name)
        avg_throughputs.append(results[config_name]['avg_throughput'])
    
    bars = plt.bar(config_names, avg_throughputs, color=['skyblue', 'lightcoral', 'lightgreen'])
    plt.xlabel('模型配置')
    plt.ylabel('平均吞吐量 (samples/sec)')
    plt.title('不同配置的吞吐量比较')
    plt.grid(True, alpha=0.3, axis='y')
    
    # 在柱状图上添加数值标签
    for bar, throughput in zip(bars, avg_throughputs):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 5,
                f'{throughput:.1f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig('training_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # 打印详细比较结果
    print("\n" + "="*60)
    print("训练结果总结")
    print("="*60)
    
    print(f"{'配置':<30} {'最终损失':<12} {'平均损失':<12} {'吞吐量':<12}")
    print("-"*60)
    
    for config_name in configs:
        config_name = config_name[0]
        result = results[config_name]
        print(f"{config_name:<30} {result['final_loss']:<12.4f} {result['avg_loss']:<12.4f} {result['avg_throughput']:<12.1f}")
    
    # 分析内存使用
    print("\n" + "="*60)
    print("内存使用分析")
    print("="*60)
    
    for config_name in configs:
        config_name = config_name[0]
        model = results[config_name]['model']
        
        # 计算参数数量
        total_params = sum(p.numel() for p in model.parameters())
        
        # 估计内存使用
        param_memory = total_params * 4 / (1024**2)  # 假设float32，转换为MB
        print(f"{config_name:<30} {total_params:,} 参数 | {param_memory:.1f} MB")
    
    # 生成示例文本
    print("\n" + "="*60)
    print("生成示例文本")
    print("="*60)
    
    for config_name in configs:
        config_name = config_name[0]
        model = results[config_name]['model']
        
        # 设置模型为评估模式
        model.eval()
        
        # 生成起始token
        start_tokens = torch.tensor([[1, 2, 3, 4, 5]], device=device)
        
        # 生成文本
        with torch.no_grad():
            generated = model.generate(start_tokens, max_new_tokens=20, temperature=0.8)
        
        print(f"\n{config_name}:")
        print(f"输入: {start_tokens[0].cpu().numpy()}")
        print(f"生成: {generated[0].cpu().numpy()}")
        
        # 切换回训练模式
        model.train()
    
    return results


# 运行训练比较
if __name__ == "__main__":
    results = run_training_comparison()
    
    # 保存结果
    print("\n保存结果到文件...")
    torch.save(results, 'training_results.pth')
    print("完成！")