# Transformer批量训练的代码及注释

将需要的库进行加载

In [8]:
#本地文件引入
from Utilities import Mytokenizer
import Transformer
#库引入
from torch.utils.data import Dataset,DataLoader
import torch

#设置device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Dataset类实现

In [2]:
class Mydataset(Dataset):
    """
    @file_path 数据存储位置
    @tokenizer 将文字id化的实例化后的Tokenizer
    @文件中 trg数据的位置
    
    方便使用的dataset封装 详细注释及教学版请看 数据的全流程解释.ipynb
    """
    def __init__(self,file_path,tokenizer,trg_index=0):
        self.tokenizer = tokenizer
        #读取所有文本
        with open(file_path,'r',encoding='utf8') as f:
            self.lines = f.readlines()
        #self.trg_count_words_line = []
        self.length = len(self.lines)
        self.trg_index = trg_index
        if trg_index==0:
            self.src_index = 1
        else:
            self.src_index = 0
        
    def __getitem__(self,index):
        src = self.lines[index].split('\t')[self.src_index]
        src = src.split('\n')[0]
        #print("src:",src)
        trg = self.lines[index].split('\t')[self.trg_index]
        #print("trg:",trg)
        #如上面的例子 三个词的句子应有四个样本,所以应该拷贝三次
        #copy_time = len(trg.split(' '))
        #print(copy_time)
        # src_id化 这里简单定义了src使用中文
        src_id = self.tokenizer.ch_token_id([src],len(src.split(' ')))
        #print("src_id:",src_id)
        trg_id = self.tokenizer.en_token_id([trg],len(trg.split(' '))+2)
        #print("trg_id:",trg_id)
        src_tensor = torch.LongTensor(src_id)
        trg_tensor = torch.LongTensor(trg_id)
        
        b = Transformer.Batch(src_tensor,trg_tensor)
        
        trg = b.trg
        copy_times = trg.shape[1]
        #print(copy_times)
        trg = trg.repeat(copy_times,1)
        src = b.src
        #print(src.shape)
        src = src.repeat(copy_times,1)
        #print(src.shape)
        src_mask = b.src_mask
        #print(src_mask.shape)
        src_mask = src_mask.repeat(copy_times,1,1)
        trg_mask = b.trg_mask
        trg_y = b.trg_y.repeat(copy_times,1)
        trg_y = trg_y.masked_fill(Transformer.subsequent_mask(copy_times)==0,0)
        return src.to(device),src_mask.to(device),trg.to(device),trg_mask.to(device),trg_y.to(device),b.ntokens.to(device)
    def __len__(self):
        return self.length
    
        

## DataSet以及Tokenizer初始化

In [3]:
#路径设定 请设定为自己的路径
train_path = '/home/jovyan/input/anki2023_en_ch/train.txt'
data_path = '/home/jovyan/input/anki2023_en_ch/cmn.txt'
#初始化Tokenizer en为目标语言
tokenizer = Mytokenizer(data_path,'en')
dataset = Mydataset(train_path,tokenizer)

中文字典字数 3643
英文字典字数 8349


## 模型、优化函数、损失函数初始化

In [4]:
#获取中文词典、和英文词典的大小
src_vocab,trg_vocab = tokenizer.get_vocab()
d_model = 512
#获得Transformer对象
model = Transformer.make_model(src_vocab,trg_vocab,N=3,d_model=d_model,d_ff=2048,h=8).to(device)
#Transformer生成器
generater = Transformer.Generator(d_model,trg_vocab).to(device)
#自定义优化函数对象
opt = Transformer.get_std_opt(model)
label_smoothing = Transformer.LabelSmoothing(trg_vocab,0,0.1)
loss_compute = Transformer.SimpleLossCompute(generater,label_smoothing,opt)



## 开始训练

In [5]:
#训练次数
epoch = 3
for i in range(epoch):
    e_loss = 0
    for src,src_mask,trg,trg_mask,trg_y,tokens in dataset:
        #数据展示
        #print(src.shape,'-',src_mask.shape,'    ',trg.shape,'-',trg_mask.shape,'**',trg_y.shape)
        output = model.forward(src,trg,src_mask,trg_mask)
        loss = loss_compute(output,trg_y,tokens).cpu()
        e_loss+=loss
        print('loss:',loss)
    print(f"代数：{i+1} , 损失:{e_loss}")

loss: tensor(22.4320)
loss: tensor(23.9844)
loss: tensor(24.3766)
loss: tensor(22.7481)
loss: tensor(26.2118)
loss: tensor(25.8319)
loss: tensor(24.6913)
loss: tensor(24.7470)
loss: tensor(48.8272)
loss: tensor(50.4420)
loss: tensor(45.6334)
loss: tensor(21.9424)
loss: tensor(48.7035)
loss: tensor(49.0590)
loss: tensor(48.9000)
loss: tensor(45.7834)
loss: tensor(45.3217)
loss: tensor(46.4886)
loss: tensor(45.5641)
loss: tensor(47.5375)
loss: tensor(48.2569)
loss: tensor(46.6815)
loss: tensor(25.3270)
loss: tensor(46.6383)
loss: tensor(48.0010)
loss: tensor(22.4841)
loss: tensor(22.2173)
loss: tensor(48.3282)
loss: tensor(44.2221)
loss: tensor(48.9370)
loss: tensor(44.5197)
loss: tensor(22.1568)
loss: tensor(45.9320)
loss: tensor(46.6134)


KeyboardInterrupt: 