## BERT Embedding

In [1]:
import torch.nn as nn
import torch
from transformer_layers import Encoder
 
class BertInputEmbedding(nn.Module):
    def __init__(self, d_model, dropout=0.1, vocab_size=30522, max_position_embeddings=512, type_vocab_size=2):
        super(BertInputEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)                 # 詞嵌入：將 token id 轉換為向量
        self.position = nn.Embedding(max_position_embeddings, d_model)     # 位置嵌入：表示 token 在序列中的相對位置
        self.token_type = nn.Embedding(type_vocab_size, d_model)           # 句段嵌入：用來區分句子 A/B 的 segment embeddings
        self.norm = nn.LayerNorm(d_model, eps=1e-12)                       # LayerNorm：正規化處理穩定訓練
        self.dropout = nn.Dropout(dropout)                                 # dropout：防止過擬合

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids)

        # 加總詞嵌入、位置嵌入與句段嵌入
        x = self.embedding(input_ids) + self.position(position_ids) + self.token_type(token_type_ids)
        x = self.norm(x)
        x = self.dropout(x)
        return x

# BertPooler

In [None]:
class BertPooler(nn.Module):
    def __init__(self, d_model):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(d_model, d_model)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # 只取[CLS] token (序列的第一個token)
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

# BERT 本身

In [None]:
class BertModel(nn.Module):
    def __init__(self, d_model=768, nhead=12, d_ff=3072, num_layers=12, dropout=0.1, 
                 vocab_size=30522, max_position_embeddings=512, type_vocab_size=2):
        super(BertModel, self).__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        
        # 核心組件
        self.embedding = BertInputEmbedding(d_model, dropout, vocab_size, max_position_embeddings, type_vocab_size)
        self.encoder = Encoder(d_model, nhead, d_ff, num_layers, dropout)
        self.pooler = BertPooler(d_model)
        
        # 初始化權重
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # 使用截斷正態分佈初始化線性層權重
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, return_dict=False):
        # 輸入嵌入
        embedding_output = self.embedding(input_ids, token_type_ids)
        
        # 處理attention mask
        if attention_mask is not None:
            src_key_padding_mask = attention_mask == 0
        else:
            src_key_padding_mask = None
        
        # 編碼器處理
        sequence_output = self.encoder(embedding_output, mask=None, src_key_padding_mask=src_key_padding_mask)
        
        # Pooler處理
        pooled_output = self.pooler(sequence_output)
        
        if return_dict:
            return {
                'last_hidden_state': sequence_output,
                'pooler_output': pooled_output
            }
        else:
            return sequence_output, pooled_output

# 預訓練頭

In [None]:
class BertNSPHead(nn.Module):
    def __init__(self, d_model):
        super(BertNSPHead, self).__init__()
        self.seq_relationship = nn.Linear(d_model, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score
    
class BertMLMHead(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(BertMLMHead, self).__init__()
        self.dense = nn.Linear(d_model, d_model)
        self.activation = nn.GELU()
        self.norm = nn.RMSNorm(d_model, eps=1e-12)
        self.decoder = nn.Linear(d_model, vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(vocab_size))
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.norm(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states

# 預訓練模型

In [None]:
class BertPreTrainingHeads(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(BertPreTrainingHeads, self).__init__()
        self.predictions = BertMLMHead(d_model, vocab_size)
        self.seq_relationship = BertNSPHead(d_model)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score
    

class BertForPreTraining(nn.Module):
    def __init__(self, d_model=768, nhead=12, d_ff=3072, num_layers=12, dropout=0.1,
                 vocab_size=30522, max_position_embeddings=512, type_vocab_size=2):
        super(BertForPreTraining, self).__init__()
        self.bert = BertModel(d_model, nhead, d_ff, num_layers, dropout, 
                             vocab_size, max_position_embeddings, type_vocab_size)
        self.cls = BertPreTrainingHeads(d_model, vocab_size)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, 
                labels=None, next_sentence_label=None):
        outputs = self.bert(input_ids, attention_mask, token_type_ids, return_dict=True)
        sequence_output = outputs['last_hidden_state']
        pooled_output = outputs['pooler_output']
        
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
        
        total_loss = None
        if labels is not None and next_sentence_label is not None:
            # MLM Loss
            loss_fct = nn.CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.bert.vocab_size), labels.view(-1))
            
            # NSP Loss
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
            
            total_loss = masked_lm_loss + next_sentence_loss

        return {
            'loss': total_loss,
            'prediction_logits': prediction_scores,
            'seq_relationship_logits': seq_relationship_score,
            'hidden_states': sequence_output,
            'pooler_output': pooled_output
        }

# 分類模型

In [6]:
class BertForSequenceClassification(nn.Module):
    def __init__(self, d_model=768, nhead=12, d_ff=3072, num_layers=12, dropout=0.1,
                 vocab_size=30522, max_position_embeddings=512, type_vocab_size=2, num_labels=2):
        super(BertForSequenceClassification, self).__init__()
        self.num_labels = num_labels
        self.bert = BertModel(d_model, nhead, d_ff, num_layers, dropout,
                             vocab_size, max_position_embeddings, type_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(d_model, num_labels)
        self.loss_fct = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(input_ids, attention_mask, token_type_ids, return_dict=True)
        pooled_output = outputs['pooler_output']
        
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        return {
            'loss': loss,
            'logits': logits,
            'hidden_states': outputs['last_hidden_state'],
            'pooler_output': pooled_output
        }

# 測試結果

In [None]:
from transformers import BertTokenizer

# 測試用範例
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 測試基本BERT模型
print("=== 測試基本BERT模型 ===")
sentence = "this is a test"
tokens = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True)
input_ids = tokens['input_ids']
attention_mask = tokens['attention_mask']
token_type_ids = tokens['token_type_ids']

model = BertModel()
sequence_output, pooled_output = model(**tokens)
print(f"Sequence output shape: {sequence_output.shape}")  # [batch_size, seq_len, d_model]
print(f"Pooled output shape: {pooled_output.shape}")      # [batch_size, d_model]

# 測試預訓練模型
print("\n=== 測試預訓練模型 ===")
pretraining_model = BertForPreTraining()
outputs = pretraining_model(**tokens)
print(f"Prediction logits shape: {outputs['prediction_logits'].shape}")      # [batch_size, seq_len, vocab_size]
print(f"NSP logits shape: {outputs['seq_relationship_logits'].shape}")       # [batch_size, 2]

# 測試分類模型
print("\n=== 測試分類模型 ===")
classification_model = BertForSequenceClassification(num_labels=3)
outputs = classification_model(**tokens)
print(f"Classification logits shape: {outputs['logits'].shape}")             # [batch_size, num_labels]

=== 測試基本BERT模型 ===
Sequence output shape: torch.Size([1, 6, 768])
Pooled output shape: torch.Size([1, 768])

=== 測試預訓練模型 ===
Prediction logits shape: torch.Size([1, 6, 30522])
NSP logits shape: torch.Size([1, 2])

=== 測試分類模型 ===
Classification logits shape: torch.Size([1, 3])
