In [2]:
import traintools
import matplotlib.pyplot as plt
import torch
import collections
import math
from torch import nn

2025-08-28 00:32:08,236 - INFO - NumExpr defaulting to 16 threads.


In [3]:
traintools.DATA_HUB['fra-eng'] = (traintools.DATA_URL + 'fra-eng.zip',
								  '94646ad1522d915e7b0f9296181140edcf86a4f5')

def read_data_nmt():
	data_dir = traintools.download_extract('fra-eng')
	with open(data_dir + '/fra.txt', 'r', encoding='utf-8') as f:
		return f.read()

raw_text = read_data_nmt()

In [4]:
def preprocess_nmt(text):
	def no_space(char,prev_char):
		return char in set(',.!?') and prev_char != ' '
	text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
	out = []
	for i, char in enumerate(text):
		if i > 0 and no_space(char, text[i-1]):
			out.append(' ' + char)
		else:
			out.append(char)
	return ''.join(out)

text = preprocess_nmt(raw_text)

In [5]:
def tokenize_nmt(text,num_examples=None):
	source = []
	target = []

	for i,line in enumerate(text.split('\n')):
		if num_examples and i > num_examples:
			break
		parts = line.split('\t')
		if len(parts) == 2:
			source.append(parts[0].split(' '))
			target.append(parts[1].split(' '))
	return source, target

source ,target = tokenize_nmt(text)

print(source[:6])
print(target[:6])

[['go', '.'], ['hi', '.'], ['run', '!'], ['run', '!'], ['who', '?'], ['wow', '!']]
[['va', '!'], ['salut', '!'], ['cours', '!'], ['courez', '!'], ['qui', '?'], ['ça', 'alors', '!']]


In [6]:
src_vocab = traintools.Vocab(source, min_freq=2,
							 reserved_tokens=['<pad>', '<bos>', '<eos>'])

print(src_vocab.token_freqs)



In [7]:
def truncate_pad(line,num_steps,padding_token):
	if len(line) > num_steps:
		return line[:num_steps]  # 截断
	return line + [padding_token] * (num_steps - len(line))  # 填充

print(source[0])
print(src_vocab[source[0]])
truncate_pad(src_vocab[source[0]],10,src_vocab['<pad>'])

['go', '.']
[47, 4]


[47, 4, 1, 1, 1, 1, 1, 1, 1, 1]

In [8]:
def build_array_nmt(lines,vocab,num_steps):
	lines = [vocab[l] for l in lines]
	lines = [l + [vocab['<eos>']] for l in lines]
	array = torch.tensor([truncate_pad(l,num_steps,vocab['<pad>']) for l in lines])

	valid_len = (array != vocab['<pad>']).sum(dim=1)
	return array, valid_len

In [9]:
def load_data_nmt(batch_size,num_steps,num_examples = 600):
	text = preprocess_nmt(read_data_nmt())
	source,target = tokenize_nmt(text=text,num_examples=num_examples)
	reserved_tokens = ['<pad>', '<bos>', '<eos>']
	src_vocab = traintools.Vocab(source, min_freq=2,reserved_tokens=reserved_tokens)
	tgt_vocab = traintools.Vocab(target, min_freq=2,reserved_tokens=reserved_tokens)
	src_array,src_valid_len = build_array_nmt(source, src_vocab, num_steps)
	tgt_array,tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)

	data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
	data_iter = traintools.load_array(data_arrays, batch_size)
	return data_iter,src_vocab,tgt_vocab

train_iter,src_vocab,tgt_vocab = load_data_nmt(batch_size = 2,num_steps=8)
for X,X_valid_len,Y,Y_valid_len in train_iter:
	print('X:',X)
	print('X的有效长度:',X_valid_len)
	print('Y:',Y)
	print('Y的有效长度:',Y_valid_len)
	break

X: tensor([[ 9, 56,  4,  3,  1,  1,  1,  1],
        [87, 22,  4,  3,  1,  1,  1,  1]])
X的有效长度: tensor([4, 4])
Y: tensor([[127,  37,   0,   4,   3,   1,   1,   1],
        [177, 178,  25,   4,   3,   1,   1,   1]])
Y的有效长度: tensor([5, 5])


In [10]:
class Encoder(nn.Module):
	def __init__(self,**kwargs):
		super().__init__( **kwargs)

	def forward(self,X,*args):
		raise NotImplementedError
	
class Decoder(nn.Module):
	def __init__(self,**kwargs):
		super().__init__(**kwargs)
	def forward(self,X,state):
		raise NotImplementedError
	def init_state(self,enc_out,*args):
		# 将编码器的输出转换为编码后的状态
		# 可能需要输入序列的有效长度
		raise NotImplementedError
	
class EncoderDecoder(nn.Module):
	def __init__(self,encoder,decoder,*args, **kwargs):
		super().__init__(*args, **kwargs)
		self.encoder = encoder
		self.decoder = decoder

	def forward(self,enc_X,dec_X,*args):
		enc_out = self.encoder(enc_X,*args)
		dec_state = self.decoder.init_state(enc_out,*args)
		dec_out = self.decoder(dec_X,dec_state,*args)
		return dec_out


In [11]:
class Seq2SeqEncoder(Encoder):
	def __init__(self, vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
		super().__init__(**kwargs)
		# embedding层是一个查找表，对于任意输入词元，它都能返回一个固定大小的向量
		self.embedding = nn.Embedding(vocab_size, embed_size)
		# 编码器通过使用一个多层GRU实现
		self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)

	def forward(self, X, *args):
		# 输入X的形状为(batch_size, num_steps, embed_size)
		X = self.embedding(X)
		# 在循环神经网络中，第一个轴对应于时间步
		# 使用permute函数调整维度顺序
		X = X.permute(1, 0, 2) # (num_steps, batch_size, embed_size)
		output,state = self.rnn(X)
		return output,state
	
encoder= Seq2SeqEncoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
encoder.eval()
X = torch.zeros((4,7),dtype=torch.long)
output,state = encoder(X)
print(output.shape)

torch.Size([7, 4, 16])


In [None]:
class Seq2SeqDecoder(Decoder):
	def __init__(self, vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
		super().__init__(**kwargs)
		# embedding层是一个查找表，对于任意输入词元，它都能返回一个固定大小的向量
		self.embedding = nn.Embedding(vocab_size, embed_size)
		# 解码器通过使用一个多层GRU实现
		self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)
		# 全连接层将RNN的输出映射到词汇表大小
		self.dense = nn.Linear(num_hiddens, vocab_size)

	def forward(self, X, state):
		# 输入X的形状为(batch_size, num_steps, embed_size)
		X = self.embedding(X)
		# 在循环神经网络中，第一个轴对应于时间步
		# 使用permute函数调整维度顺序
		X = X.permute(1, 0, 2) # (num_steps, batch_size, embed_size)
		
		# 广播context_state，使其与X具有相同的num_steps
		# 其中state[-1]是最后一层的隐状态，形状为(batch_size, num_hiddens)
		# X的形状为(num_steps, batch_size, embed_size),X.size(0)为num_steps
		# repeat函数用于沿着指定的维度重复tensor
		# repeat(X.size(0), 1, 1)的作用是将state[-1]的形状从(batch_size, num_hiddens)变为(num_steps, batch_size, num_hiddens)
		context = state[-1].repeat(X.size(0), 1, 1)

		# dim = 2代表在特征维度上进行拼接
		# X的形状是(num_steps, batch_size, embed_size)
		# context的形状是(num_steps, batch_size, num_hiddens)
		X_and_context = torch.cat((X, context), dim=2)
		output,state = self.rnn(X_and_context,state)
		output = self.dense(output)
		return output,state