In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
# %pip install wandb

In [3]:
import wandb
import os
WANDB_API_KEY = 'your_api_key'
wandb.login(
    key=WANDB_API_KEY
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjinkaitao-comm[0m ([33mjinkaitao-comm-central-university-of-finance-and-economics[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## Load Dataset

In [4]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, AutoTokenizer
import torch.nn.functional as F
from tqdm import tqdm

# 加载 wmt17 数据集（中文-英文）
ds = load_dataset("wmt/wmt17", "zh-en")

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = len(tokenizer.vocab)
# 'bert-base-uncased' vocab_size:

# model_checkpoint = "Helsinki-NLP/opus-mt-zh-en"
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# vocab_size 
# vocab_size = tokenizer.vocab_size
# print(f"vocab size:{vocab_size}")

class TranslationDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        source = self.dataset[idx]['translation']['zh']
        target = self.dataset[idx]['translation']['en']
        
        # Tokenization
        source_encodings = self.tokenizer(source, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        target_encodings = self.tokenizer(target, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
                        
        # We want to return input_ids and attention_mask for both source and target
        return {
            'source_text': source,  # original source text
            'target_text': target,  # original target text
            'input_ids': source_encodings['input_ids'].squeeze(0),  # Remove the batch dimension
            'attention_mask': source_encodings['attention_mask'].squeeze(0),
            'labels': target_encodings['input_ids'].squeeze(0)  # Labels are the target sequence
        }

README.md:   0%|          | 0.00/11.6k [00:00<?, ?B/s]

train-00000-of-00013.parquet:   0%|          | 0.00/286M [00:00<?, ?B/s]

train-00001-of-00013.parquet:   0%|          | 0.00/272M [00:00<?, ?B/s]

train-00002-of-00013.parquet:   0%|          | 0.00/281M [00:00<?, ?B/s]

train-00003-of-00013.parquet:   0%|          | 0.00/278M [00:00<?, ?B/s]

train-00004-of-00013.parquet:   0%|          | 0.00/277M [00:00<?, ?B/s]

train-00005-of-00013.parquet:   0%|          | 0.00/281M [00:00<?, ?B/s]

train-00006-of-00013.parquet:   0%|          | 0.00/282M [00:00<?, ?B/s]

train-00007-of-00013.parquet:   0%|          | 0.00/281M [00:00<?, ?B/s]

train-00008-of-00013.parquet:   0%|          | 0.00/294M [00:00<?, ?B/s]

train-00009-of-00013.parquet:   0%|          | 0.00/272M [00:00<?, ?B/s]

train-00010-of-00013.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

train-00011-of-00013.parquet:   0%|          | 0.00/327M [00:00<?, ?B/s]

train-00012-of-00013.parquet:   0%|          | 0.00/254M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/394k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/362k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25134743 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2002 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2001 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [5]:
from torch.utils.data import DataLoader, RandomSampler

def get_dataloader(dataset, tokenizer, batch_size, is_train=True, max_length=128, num_samples=None):
    if is_train:
        full_dataset = dataset['train']
    else:
        full_dataset = dataset['validation']  

    sub_dataset = TranslationDataset(full_dataset, tokenizer, max_length)

    if num_samples is not None:
        sampler = RandomSampler(sub_dataset, replacement=False, num_samples=num_samples)
    else:
        sampler = RandomSampler(sub_dataset) if is_train else None

    loader = DataLoader(
        dataset=sub_dataset,
        batch_size=batch_size,
        shuffle=False,  # 不需要shuffle，因为sampler已经处理了
        pin_memory=True,
        num_workers=2,
        sampler=sampler
    )
    return loader

# test DataLoader
testdata_loader = get_dataloader(ds, tokenizer, batch_size=2, is_train=True, num_samples=3)
# testdata_loader = get_dataloader(ds, tokenizer, batch_size=2, is_train=False, num_samples=3)
for batch in testdata_loader:
    source_texts = batch['source_text']  # original source text
    target_texts = batch['target_text']  # original target text
    src = batch['input_ids']
    trg = batch['labels']
    trg_input = trg[:, :-1]
    target = trg[:, 1:]
    # Print original texts
    print(source_texts[0])
    print(target_texts[0])
    # Print the shapes of the tensors
    print(batch.keys())
    print(f"src shape: {src.shape}")
    print(f"trg shape: {trg.shape}")
    print(f"trg_input shape: {trg_input.shape}")
    print(f"target shape: {target.shape}")
    break

友谊常以爱情而告终，而爱情常不能以友谊而结束。
Friendship often ends in love, but love in friendship, never.
dict_keys(['source_text', 'target_text', 'input_ids', 'attention_mask', 'labels'])
src shape: torch.Size([2, 128])
trg shape: torch.Size([2, 128])
trg_input shape: torch.Size([2, 127])
target shape: torch.Size([2, 127])


In [6]:
# 查看训练集的前几条数据
print(ds['train'][:2])

{'translation': [{'en': '1929 or 1989?', 'zh': '1929年还是1989年?'}, {'en': 'PARIS – As the economic crisis deepens and widens, the world has been searching for historical analogies to help us understand what has been happening.', 'zh': '巴黎-随着经济危机不断加深和蔓延，整个世界一直在寻找历史上的类似事件希望有助于我们了解目前正在发生的情况。'}]}


## Build the model

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm import tqdm
import math

# PE
class PositionalEncoder(nn.Module):
    def __init__(self, d_model,max_seq_len=128):
        super().__init__()
        self.d_model = d_model # demension of model
        pe = torch.zeros(max_seq_len,d_model)
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000**((2 * i) / d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000**((2 * (i + 1)) / d_model)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self,x):
        x = x*math.sqrt(self.d_model)
        seq_len = x.size(1)
        x = x+self.pe[:,:seq_len]
        return x

# MHA
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        
        # Linear projections for Q, K, V
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        # Dropout and final linear projection
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:            
            # Ensure mask has the correct shape (batch_size, 1, seq_len) or (batch_size, 1, 1, seq_len)
            if len(mask.shape) == 3:  # (batch_size, 1, seq_len)
                mask = mask.unsqueeze(1)  # Add a dimension for heads: (batch_size, 1, 1, seq_len)
            elif len(mask.shape) == 4:  # (batch_size, 1, 1, seq_len)
                pass  # Already in the correct shape
            else:
                raise ValueError(f"Unexpected mask shape {mask.shape}")
            
            mask = mask.repeat(1, self.h, 1, 1)  # Repeat for each head: (batch_size, heads, 1, seq_len)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        # Linear projections and split into heads
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k).transpose(1, 2)  # (bs, h, seq_len, d_k)
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k).transpose(1, 2)  # (bs, h, seq_len, d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k).transpose(1, 2)  # (bs, h, seq_len, d_k)
        
        # Apply attention mechanism
        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)
        
        # Concatenate heads and project back to original dimension
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        
        output = self.out(concat)
        return output


# FF
class FeedForward(nn.Module):

    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        # d_ff default is 2048
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        return x


# Norm
class NormLayer(nn.Module):

    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.size = d_model
        # two learnable parameters
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps

    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
        / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

### Local Attention

In [8]:
class LocalAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1,window_size=64):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.window_size = window_size
        
        # Linear projections for Q, K, V
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        # Dropout and final linear projection
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        # Apply local attention by restricting the attention scores to a local window
        batch_size, heads, seq_len, _ = q.size()
        # print(q.size())
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # print(f"scores shape ：{scores.shape}")
        
        # Limit attention to local window
        window_mask = self.create_local_mask(scores, seq_len).to(q.device)
        # print(f"window mask shape：{window_mask.shape}")
        scores = scores * window_mask
        
        
        if mask is not None:            
            # Ensure mask has the correct shape (batch_size, 1, seq_len) or (batch_size, 1, 1, seq_len)
            if len(mask.shape) == 3:  # (batch_size, 1, seq_len)
                mask = mask.unsqueeze(1)  # Add a dimension for heads: (batch_size, 1, 1, seq_len)
            elif len(mask.shape) == 4:  # (batch_size, 1, 1, seq_len)
                pass  # Already in the correct shape
            else:
                raise ValueError(f"Unexpected mask shape {mask.shape}")
            
            mask = mask.repeat(1, self.h, 1, 1)  # Repeat for each head: (batch_size, heads, 1, seq_len)
            scores = scores.masked_fill(mask == 0, -1e9)
            
        scores = F.softmax(scores, dim=-1)
        
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output


    def create_local_mask(self, scores, seq_len):
        batch_size, num_heads, seq_len_x, seq_len_y = scores.shape  # 从scores中获取实际的seq_len_x和seq_len_y
        mask = torch.ones(seq_len_x, seq_len_y).tril(diagonal=self.window_size)  # 创建局部窗口掩码
        mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len_x, seq_len_y)
        mask = mask.repeat(batch_size, num_heads, 1, 1)  # (batch_size, heads, seq_len_x, seq_len_y)
    
        return mask

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k).transpose(1, 2)
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k).transpose(1, 2)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k).transpose(1, 2)
        
        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)
        
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        
        output = self.out(concat)
        return output

### Spare Attention

In [9]:
class SparseAttention(nn.Module):
    def __init__(self, heads, d_model, sparsity=0.1, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.sparsity = sparsity
        
        # Linear projections for Q, K, V
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        # Dropout and final linear projection
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Sparse attention: Mask out some parts of the attention matrix
        sparse_mask = torch.rand_like(scores) > self.sparsity
        scores = scores.masked_fill(sparse_mask == 0, -1e9)
        
        if mask is not None:            
            # Ensure mask has the correct shape (batch_size, 1, seq_len) or (batch_size, 1, 1, seq_len)
            if len(mask.shape) == 3:  # (batch_size, 1, seq_len)
                mask = mask.unsqueeze(1)  # Add a dimension for heads: (batch_size, 1, 1, seq_len)
            elif len(mask.shape) == 4:  # (batch_size, 1, 1, seq_len)
                pass  # Already in the correct shape
            else:
                raise ValueError(f"Unexpected mask shape {mask.shape}")
            
            mask = mask.repeat(1, self.h, 1, 1)  # Repeat for each head: (batch_size, heads, 1, seq_len)
            scores = scores.masked_fill(mask == 0, -1e9)
                    
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        
        output = torch.matmul(scores, v)
        return output

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k).transpose(1, 2)
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k).transpose(1, 2)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k).transpose(1, 2)
        
        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)
        
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        
        output = self.out(concat)
        return output

### Convolutional Attention

In [31]:
class ConvolutionalAttention(nn.Module):
    def __init__(self, heads, d_model, kernel_size=3, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads

        # Convolutional projections for Q, K, V
        self.q_conv = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size//2)
        self.k_conv = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size//2)
        self.v_conv = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size//2)

        # Dropout and final linear projection
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            if len(mask.shape) == 3:
                mask = mask.unsqueeze(1)
            elif len(mask.shape) == 4:
                pass
            else:
                raise ValueError(f"Unexpected mask shape {mask.shape}")
            mask = mask.repeat(1, self.h, 1, 1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output

    def forward(self, q, k, v, mask=None):
        bs, seq_len, d_model = q.size()

        # Apply convolutional projections
        q = self.q_conv(q.transpose(1, 2)).transpose(1, 2)
        k = self.k_conv(k.transpose(1, 2)).transpose(1, 2)
        v = self.v_conv(v.transpose(1, 2)).transpose(1, 2)

        # Split into heads and adjust shape
        q = q.view(bs, seq_len, self.h, self.d_k).transpose(1, 2)  # (bs, h, seq_len, d_k)
        k = k.view(bs, seq_len, self.h, self.d_k).transpose(1, 2)  # (bs, h, seq_len, d_k)
        v = v.view(bs, seq_len, self.h, self.d_k).transpose(1, 2)  # (bs, h, seq_len, d_k)

        # Apply attention mechanism
        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)

        # Concatenate heads and project back to original dimension
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output

In [32]:
# Encoder text -> vector
class EncoderLayer(nn.Module):

    def __init__(self, d_model, heads, dropout=0.1):
        super().__init__()
        self.norm_1 = NormLayer(d_model)
        self.norm_2 = NormLayer(d_model)
        # self.attn = MultiHeadAttention(heads, d_model, dropout=dropout)
        self.attn = ConvolutionalAttention(heads, d_model, dropout=dropout)
        self.ff = FeedForward(d_model, dropout=dropout)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.attn(x, x, x, mask)
        x = x + self.dropout_1(attn_output) 
        
        x = self.norm_1(x)  

        ff_output = self.ff(x) 
        x = x + self.dropout_2(ff_output)  
        x = self.norm_2(x)  

        return x


class Encoder(nn.Module):

    def __init__(self, vocab_size, d_model, N, heads, dropout):
        super().__init__()
        self.N = N
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model)
        self.layers = nn.ModuleList([EncoderLayer(d_model, heads, dropout) for _ in range(N)])
        self.norm = NormLayer(d_model)

    def forward(self, src, mask):
        x = self.embed(src)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)
    
# Decoder
class DecoderLayer(nn.Module):

    def __init__(self, d_model, heads, dropout=0.1):
        super().__init__()
        self.norm_1 = NormLayer(d_model)
        self.norm_2 = NormLayer(d_model)
        self.norm_3 = NormLayer(d_model)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        self.dropout_3 = nn.Dropout(dropout)
        self.attn_1 = ConvolutionalAttention(heads, d_model, dropout=dropout)
        self.attn_2 = ConvolutionalAttention(heads, d_model, dropout=dropout)
        # self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout)
        # self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout)
        self.ff = FeedForward(d_model, dropout=dropout)

    def forward(self, x, e_outputs, src_mask, trg_mask):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, src_mask))
        x2 = self.norm_3(x)
        x = x + self.dropout_3(self.ff(x2))
        return x

class Decoder(nn.Module):

    def __init__(self, vocab_size, d_model, N, heads, dropout):
        super().__init__()
        self.N = N
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, heads, dropout) for _ in range(N)])
        self.norm = NormLayer(d_model)

    def forward(self, trg, e_outputs, src_mask, trg_mask):
        x = self.embed(trg)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, e_outputs, src_mask, trg_mask)
        return self.norm(x)

class Transformer(nn.Module):

    def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):
        super().__init__()
        self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
        self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)
        self.out = nn.Linear(d_model, trg_vocab)

    def forward(self, src, trg, src_mask, trg_mask):
        e_outputs = self.encoder(src, src_mask)
        d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
        output = self.out(d_output)
        return output

### Pre-process data
Some data processing functions

In [33]:
def prepare_batch_for_transformer(batch, tokenizer, device):
    """
    Prepare a batch for Transformer model during training.
    This function will:
    - Generate source and target masks.
    - Shift target sequence for input (trg_input) and labels (target).
    
    Parameters:
    - batch (dict): Batch of data containing 'input_ids' (source) and 'labels' (target).
    - tokenizer: Tokenizer used to decode/encode tokens (e.g., for padding token ID).
    - device: The device (CPU/GPU) on which to move the tensors.
    
    Returns:
    - src (Tensor): Source input tensor.
    - trg_input (Tensor): Target input tensor for the decoder.
    - target (Tensor): Ground truth labels for loss calculation.
    - src_mask (Tensor): Source mask to avoid attention to padding tokens.
    - trg_mask (Tensor): Target mask to avoid looking ahead at future tokens.
    """
    # Extract input and labels from the batch

    src = batch['input_ids'].to(device)
    trg = batch['labels'].to(device)
    att_mask = batch['attention_mask'].to(device)

    src = src & att_mask

    # print(f"att_mask:{att_mask.shape}")

    # Prepare the target inputs for the decoder (shifted left by 1)
    trg_input = trg[:, :-1].to(device)  # Use all tokens except the last token of the target sequence
    target = trg[:, 1:].to(device)      # Ground truth labels (shifted right by 1 for loss calculation)

    # Generate masks
    # src_mask = create_padding_mask(src).to(device) & att_mask.to(device)
    src_mask = create_padding_mask(src).to(device)
    trg_mask = create_padding_mask(trg_input) & create_subsequent_mask(trg_input.size(-1)).type_as(src_mask.data).to(device)

    
    return src, trg_input, target, src_mask, trg_mask

# Function to create padding mask
def create_padding_mask(seq):
    """
    Generates a mask to hide padding tokens in the input sequences.
    
    Parameters:
    seq (torch.Tensor): Input sequence tensor of shape (batch_size, seq_length).
    
    Returns:
    torch.Tensor: A mask tensor of shape (batch_size, 1, 1, seq_length) where '0' indicates positions that should be masked.
    """
    # The condition (seq != 0) creates a boolean mask where padding tokens are False
    mask = (seq != 0).unsqueeze(1).unsqueeze(2)
    return mask

# Function to create subsequent mask for decoder
def create_subsequent_mask(size):
    """
    Generates a mask to prevent positions from attending to subsequent positions,
    ensuring that each position only attends to previous and current positions.
    
    Parameters:
    size (int): The length of the target sequence.
    
    Returns:
    torch.Tensor: A lower triangular mask tensor of shape (1, size, size), where '1' indicates visible positions.
    """
    # Create a tensor of ones with shape (1, size, size)
    attn_shape = (1, size, size)
    # Use torch.triu to get the upper triangular part of the matrix starting from diagonal=1, then invert it
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return subsequent_mask == 0

In [34]:
def log_translation_table(source_texts, groundtruths, predicted, tokenizer):
    "Log a wandb.Table with (source, groundtruth, predicted)"
    table = wandb.Table(columns=["source", "groundtruth", "predicted"])
    
    # Convert the source, groundtruth, and predicted tensors to text
    predicted = predicted.to("cpu")
    
    for src, gro, pred in zip(source_texts, groundtruths, predicted):
        pre = tokenizer.decode(pred, skip_special_tokens=True)
        table.add_data(src, gro, pre)
    
    wandb.log({"translation_table": table}, commit=False)


from nltk.translate.bleu_score import sentence_bleu

def validate_model(model, valid_dl, loss_func, tokenizer, device, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    total_bleu = 0.  # 用于计算 BLEU 分数
    num_batches = 0  # 用于统计批次数
    with torch.inference_mode():
        for i, batch in enumerate(valid_dl):
            # Get original source text
            source_texts = batch['source_text']  # original source text
            target_texts = batch['target_text']  # original target text
            src, trg_input, target, src_mask, trg_mask = prepare_batch_for_transformer(batch, tokenizer, device)
    
            # Generate masks
            src_mask = create_padding_mask(src).to(device)
            trg_mask = create_padding_mask(trg_input) & create_subsequent_mask(trg_input.size(-1)).type_as(src_mask.data).to(device)


            # Forward pass ➡
            output = model(src, trg_input, src_mask, trg_mask)
            # output = model(src, target_ids[:, :-1], src_mask=src_mask, trg_mask=trg_mask)
            
            # Calculate loss
            val_loss += loss_func(output.reshape(-1, vocab_size), target.reshape(-1)).item() * target.size(0)

            # Get the predicted output sequence (indices of max logit values)
            _, predicted = torch.max(output, dim=-1)

            # Calculate BLEU score for this batch
            for idx in range(len(predicted)):
                predicted_text = tokenizer.decode(predicted[idx], skip_special_tokens=True)
                target_text = target_texts[idx]
                
                # BLEU score calculation (single reference, single hypothesis)
                bleu_score = sentence_bleu([target_text.split()], predicted_text.split())
                total_bleu += bleu_score
            
            # Log one batch of predictions to the dashboard, always same batch_idx.
            if i == batch_idx:
                log_translation_table(source_texts, target_texts, predicted, tokenizer)

            num_batches += 1

    avg_loss = val_loss / len(valid_dl.dataset)
    avg_bleu = total_bleu / (num_batches * valid_dl.batch_size)  # Normalizing BLEU score over the dataset

    return avg_loss, avg_bleu  # Return loss and BLEU score

## Start training

In [None]:
# initialise a wandb run
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

RUN_NAME = "transformer-attention-CA"
wandb.init(
    project="attention-mechanism-code",
    job_type='train',
    name=RUN_NAME,
    config={
        "epochs": 20,
        "batch_size": 32,
        "lr": 1e-3,
        "d_model":512,
        "N":6,
        "heads":8, 
        "dropout":0.1
        })

# Copy your config
config = wandb.config
print(config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


# Get the data
# 10k ~ 50k 
train_loader = get_dataloader(ds, tokenizer, batch_size=config.batch_size, is_train=True, num_samples=20000)
valid_loader = get_dataloader(ds, tokenizer, batch_size=2*config.batch_size, is_train=False, num_samples=2000)

n_steps_per_epoch = math.ceil(len(train_loader.dataset) / config.batch_size)

model = Transformer(vocab_size , 
                    vocab_size , 
                    config.d_model, 
                    config.N, 
                    config.heads, 
                    config.dropout
                   ).to(device)

# Make the loss and optimizer
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

# Training
example_ct = 0
step_ct = 0
torch.cuda.empty_cache()
for epoch in range(config.epochs):
    model.train()
    print(f"Epoch {epoch+1}/{config.epochs} - Start training")
    for step, batch in enumerate(train_loader):
        # Prepare data for Transformer model
        src, trg_input, target, src_mask, trg_mask = prepare_batch_for_transformer(batch, tokenizer, device)
        # print(f"src shape:{src.shape}")
        # print(f"trg_input shape:{trg_input.shape}")
        # print(f"src_mask shape:{src_mask.shape}")
        # print(f"trg_mask shape:{trg_mask.shape}")

        # Check shape
        # if src_input_ids.shape[1] != max_length or trg_input.shape[1] != max_length - 1:
        #     print(f"Batch {step} contains invalid sample shapes: skipping this batch.")
        #     continue  # 跳过这个批次

        # Forward pass
        output = model(src, trg_input, src_mask, trg_mask)
        
        # Ensure the tensor is contiguous before reshaping
        output = output.contiguous()        
        # Use reshape instead of view to avoid memory layout issues
        reshaped_output = output.reshape(-1, vocab_size)
        reshaped_target = target.reshape(-1)
        
        # Ensure dimensions match for loss calculation
        assert reshaped_output.shape[0] == reshaped_target.shape[0], "Batch sizes do not match"
        
        # Calculate loss
        train_loss = loss_func(reshaped_output, reshaped_target)
        
        # print(f"Train loss: {train_loss}")
        
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        example_ct += src.size(0) 
        metrics = {"train/train_loss": train_loss.item(),
                   "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
                   "train/example_ct": example_ct}

        if step + 1 < n_steps_per_epoch:
            # Log train metrics to wandb
            wandb.log(metrics)

        step_ct += 1
    print(f"          - Start validating")
    val_loss, avg_bleu = validate_model(model, valid_loader, loss_func, tokenizer, device)

    # Log train and validation metrics to wandb
    val_metrics = {"val/val_loss": val_loss,
                   "val/val_avg_bleu": avg_bleu}
    wandb.log({**metrics, **val_metrics})

    # Save the model checkpoint to wandb
    torch.save(model, "transformer_model.pt")
    wandb.log_model("./transformer_model.pt", "transformer_model", aliases=[f"epoch-{epoch+1}_dropout-{round(wandb.config.dropout, 4)}"])

    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, BLEU: {avg_bleu:.2f}")

# If you had a test set, this is how you could log it as a Summary metric
wandb.summary['test_avg_bleu'] = 0.8

# Close your wandb run
wandb.finish()

## Find bugs
for test cells

In [None]:
# test code here

In [None]:
from transformers import BertTokenizer, AutoTokenizer

# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 'bert-base-uncased' vocab_size:30522
model_checkpoint = "Helsinki-NLP/opus-mt-zh-en"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# vocab_size "vocab_size": 65001
print(tokenizer.vocab_size)

In [None]:
from transformers import AutoConfig

tokenizer_config = AutoConfig.from_pretrained(model_checkpoint)
tokenizer_config