# 1、词元划分示例：

In [1]:
tex_file = "../data/diagramCode/0_adequate.tex"
from dataprocess import texProcess

# 读取文件
with open(tex_file, 'r') as f:
    content = f.readlines()
    print(content)
        
    split = texProcess.tex_split(content)
    print(split)


['\\documentclass{article} \\usepackage[all]{xy} \\begin{document}\n', '$$\n', '\\xymatrix{\n', "A^{\\oplus n} \\ar[r] \\ar[d]_{m_1, \\ldots, m_n} & N' \\ar[d] \\\\\n", 'M \\ar[r] & N\n', '}\n', '$$\n', '\\end{document}\n']
['\\xymatrix', '{', 'enter', 'A', '^', '{', '\\oplus', 'n', '}', '\\ar', '[', 'r', ']', '\\ar', '[', 'd', ']', '_', '{', 'm', '_', '1', ',', '\\ldots', ',', 'm', '_', 'n', '}', '&', 'N', "'", '\\ar', '[', 'd', ']', '\\\\', 'enter', 'M', '\\ar', '[', 'r', ']', '&', 'N', 'enter', '}']


# 2、索引分配示例：

In [2]:
from dataprocess import vocab
import pickle

# tex经过词元划分后存放在dataprocess/target_all.pkl
with open('dataprocess/target_all.pkl', 'rb') as f:
    target = pickle.load(f)

# 统计词元个数
counter = vocab.count_corpus(target)  
print(f"词表里共{len(counter)}个词元")  # 1492,共1492个词被收入词典并编码

src_vocab = vocab.Vocab(target, min_freq=5, reserved_tokens=['<pad>', '<bos>', '<eos>'])
# 返回前十个词元及其对应的索引
print("前十个词元及其对应的索引:")
print(list(src_vocab.token_to_idx.items())[:10])

词表里共1492个词元
前十个词元及其对应的索引:
[('<unk>', 0), ('<pad>', 1), ('<bos>', 2), ('<eos>', 3), ('{', 4), ('}', 5), ('[', 6), (']', 7), ('enter', 8), ('\\ar', 9)]


# 3、数据集加载示例

In [3]:
import torch
from torch.utils.data import random_split
from dataload import ImageTextDataset

image_folder = "../data/diagram"
dataset = ImageTextDataset(image_folder)

# 数据划分，得到train_set, val_set, test_set
# 大约3500个训练样本，150验证样本，150测试样本
lengths = [int(len(dataset) * 0.9), int(len(dataset) * 0.05),
           len(dataset) - int(len(dataset) * 0.9) - int(len(dataset) * 0.05)]
train_set, val_set, test_set = random_split(dataset, lengths)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True,drop_last=True)
# 打印返回的图像和字幕数据
for image,tex in  train_loader:
    print(image[0])
    print(tex[0])
    print(image[0].shape)
    print(tex[0].shape)
    break

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])
tensor([ 21,   4,   8,  18,   4,  20,   5, 102,   6,  12,   7,  13,  31,  11,
          8,  18,   4,  20,   5,  17,   9,   6,  12,   7,  13, 126, 102,   6,
         35,   7,  13, 144,  11,  18,   4,  24,   5,   8,   5,   3,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1])
torch.Size([1, 200, 270])
torch.Size([12

# 4、编码器提取特征示例

In [4]:
from cnnEncoder import Encoder
encoder = Encoder()
# 切换为测试模式
encoder.eval()

for image, text in train_loader:
    print(image.shape)
    x = encoder(image)
    print(x.shape)
    print(x)
    break


torch.Size([16, 1, 200, 270])
torch.Size([16, 1000])
tensor([[0.0222, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0178],
        [0.0219, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0157],
        [0.0200, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0193],
        ...,
        [0.0246, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0218],
        [0.0216, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0204],
        [0.0205, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0164]],
       grad_fn=<ReluBackward0>)


# 5、解码器预测示例

In [5]:
from decoder import Decoder
decoder = Decoder()
decoder.eval()
for image, text in train_loader:
    x = encoder(image)
    y = decoder(x, text)
    print(y.shape)  # [1,16,1491]
    break
    
# 解码器的输出为（16，125，1491），它代表了125个时间步，每个时间步预测词元在词表上（1491）的概率分布

torch.Size([16, 125, 1491])


# 6、模型训练

In [18]:
import model_train
# 导入python文件会自动执行其中的可执行部分，故直接导入model_train便可开始训练
# 有训练过程截图，不再次训练

Epoch [1/20], Loss: 7.3021368980407715
Epoch [1/20], Loss: 0.03369196876883507
Epoch [1/20], Loss: 0.02008460834622383
Epoch [1/20], Loss: 0.007105511613190174
Epoch [1/20], Loss: 0.00177188019733876
Epoch [2/20], Loss: 0.00034841158776544034


KeyboardInterrupt: 

# 7、使用模型进行预测示例