In [1]:
import torch
import torch.nn as nn
from transformer_layers import MultiHeadAttention, FeedForward

class GPTBlock(nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout=0.1):
        super(GPTBlock, self).__init__()
        
        # Pre-LayerNorm 架構（GPT-2 風格）
        self.ln_1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, nhead, dropout)
        self.ln_2 = nn.LayerNorm(d_model)
        self.mlp = FeedForward(d_model, d_ff, dropout, activation='gelu')
        self.residual_scale = 0.1 

        
    def forward(self, x, mask=None, key_padding_mask=None):
        # Pre-LayerNorm + 殘差連接（含 residual scaling）
        x_ln = self.ln_1(x)
        x = x + self.residual_scale * self.attn(x_ln, x_ln, x_ln, mask, key_padding_mask)
        x = x + self.residual_scale * self.mlp(self.ln_2(x))
        return x

class GPTEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model, max_seq_len, dropout):
        super(GPTEmbeddings, self).__init__()
        self.wte = nn.Embedding(vocab_size, d_model)         # Token Embedding
        self.wpe = nn.Embedding(max_seq_len, d_model)        # Position Embedding
        self.drop = nn.Dropout(dropout)
    
    def forward(self, input_ids, position_ids=None):
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        if position_ids is None:
            position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
        token_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        return self.drop(token_embeds + position_embeds)

class GPTOutputHead(nn.Module):
    def __init__(self, d_model, vocab_size, embedding_weight):
        super(GPTOutputHead, self).__init__()
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = embedding_weight
    
    def forward(self, hidden_states):
        hidden_states = self.ln_f(hidden_states)
        logits = self.lm_head(hidden_states)
        return logits, hidden_states

In [2]:
from transformers import GPT2Tokenizer

class GPT2Model(nn.Module):
    def __init__(self, 
                 vocab_size,           # 詞彙表大小
                 d_model=768,          # 模型維度
                 nhead=12,             # 注意力頭數
                 num_layers=12,        # 層數
                 d_ff=3072,            # 前饋網路隱藏層維度
                 max_seq_len=1024,     # 最大序列長度
                 dropout=0.1,          # dropout 比率
                 pad_token_id=0):      # padding token id
        super(GPT2Model, self).__init__()
        
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.pad_token_id = pad_token_id
        
        # 嵌入層
        self.embeddings = GPTEmbeddings(vocab_size, d_model, max_seq_len, dropout)
        
        # Transformer 塊
        self.h = nn.ModuleList([
            GPTBlock(d_model, nhead, d_ff, dropout) 
            for _ in range(num_layers)
        ])
        
        # 輸出層
        self.output_head = GPTOutputHead(d_model, vocab_size, self.embeddings.wte.weight)
        
        # 初始化權重
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
    
    def _generate_causal_mask(self, seq_len, device):
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
        return mask.bool()
    
    def _create_padding_mask(self, input_ids):
        return (input_ids == self.pad_token_id)
    
    def forward(self, input_ids, attention_mask=None, position_ids=None, return_dict=False):
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        
        # 嵌入層
        hidden_states = self.embeddings(input_ids, position_ids)
        
        # 創建因果遮罩（自回歸遮罩）
        causal_mask = self._generate_causal_mask(seq_len, device)
        
        # 創建 key padding mask
        if attention_mask is not None:
            key_padding_mask = ~attention_mask.bool()
        else:
            key_padding_mask = self._create_padding_mask(input_ids)
        
        # 通過所有 Transformer 塊
        for block in self.h:
            hidden_states = block(
                hidden_states, 
                mask=causal_mask, 
                key_padding_mask=key_padding_mask
            )
        
        # 輸出層
        logits, hidden_states = self.output_head(hidden_states)
        
        if return_dict:
            return {
                'logits': logits,
                'hidden_states': hidden_states
            }
        return logits
    
    def generate(self, input_ids, attention_mask=None, max_new_tokens=50, 
                pad_token_id=None, eos_token_id=None):
        self.eval()
        generated = input_ids.clone()
        if attention_mask is not None:
            attn_mask = attention_mask.clone()
        else:
            attn_mask = torch.ones_like(generated)

        if pad_token_id is None:
            pad_token_id = self.pad_token_id

        with torch.no_grad():
            for _ in range(max_new_tokens):
                # Forward pass to get logits
                outputs = self.forward(generated, attention_mask=attn_mask)

                # Get logits of the last token position
                next_token_logits = outputs[:, -1, :]

                # Greedy decoding: pick the token with the highest probability
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

                # Append the new token
                generated = torch.cat([generated, next_token], dim=1)

                # Update attention mask
                new_mask = torch.ones((generated.shape[0], 1), dtype=attn_mask.dtype, device=attn_mask.device)
                attn_mask = torch.cat([attn_mask, new_mask], dim=1)

                # Stop if EOS token is generated for all sequences
                if eos_token_id is not None and (next_token == eos_token_id).all():
                    break

                # Prevent exceeding max sequence length
                if generated.shape[1] >= self.max_seq_len:
                    break

        return generated





# 初始化 GPT-2 Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 設定 pad token
tokenizer.pad_token = tokenizer.eos_token

# 準備輸入文字（批次測試 + padding）
texts = ["Once upon a time", "A quick brown fox jumps"]
encodings = tokenizer(texts, return_tensors='pt', padding=True)

input_ids_tensor = encodings['input_ids']
attention_mask_tensor = encodings['attention_mask']

# 初始化模型
model = GPT2Model(
    vocab_size=50257,
    d_model=1600,
    nhead=25,
    num_layers=48,
    d_ff=6400,
    max_seq_len=1024,
    dropout=0.1,
    pad_token_id=tokenizer.pad_token_id
)

# 測試前向傳播
with torch.no_grad():
    outputs = model(input_ids_tensor, attention_mask=attention_mask_tensor, return_dict=True)
    print("Logits shape:", outputs['logits'].shape)
    print("Hidden states shape:", outputs['hidden_states'].shape)

# 使用模型產生文字（只使用 greedy）
output_ids = model.generate(
    input_ids=input_ids_tensor,
    attention_mask=attention_mask_tensor,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id
)

# 解碼並輸出結果
for ids in output_ids:
    output_text = tokenizer.decode(ids, skip_special_tokens=True)
    print(output_text)

  from .autonotebook import tqdm as notebook_tqdm


Logits shape: torch.Size([2, 5, 50257])
Hidden states shape: torch.Size([2, 5, 1600])
Once upon a timeinkiinkiinkiinkiinkiinkiinkiinkiDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDEC Carl ric ric ric ric ric ric ric ric ric ric ric ric ric ric ric ric ric
A quick brown fox jumps ranc ranc ranc ranc rancDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDECDEC Apply walked walked walked walked walked gland gland gland gland gland gland gland gland gland gland gland gland
