In [1]:
import torch
import numpy as np
import json
import os
import shutil
import torch
import torch.nn as nn
import torch.utils.data
from tqdm import tqdm

from logger import logger
from params import params
from vocab import Vocab
from dataset import Dataset, collate_fn
from beam import Generator
from transformer import Model

In [2]:
torch.__version__, np.__version__

('1.3.1', '1.15.4')

In [3]:
data_filename = "../data/squad/data_9.pt"
model_filename = "../checkpoint/squad/checkpoint_9.pt"
model_statistics_filename = "../data/squad/model_statistics_9.pt"

In [4]:
logger = logger()

In [5]:
data = torch.load(data_filename)
vocab = data['vocab']
params = data['params']
model_statistics = torch.load(model_statistics_filename)

In [15]:
def prepare_dataloaders(params, data):
    '''
    作用:
    将模型训练集/验证集的输入输出数据构造为batch

    输入参数:
    params: 参数集合
    data: 输入的数据

    输出参数:
    train_loader: 训练集的dataloader
    dev_loader: 验证集的dataloader
    '''

    logger.info('正在从{}中读取数据'.format(params.dataset_dir))

    # 构造train_loader
    train_dataset = Dataset(params, data, mode='train')
    train_loader = torch.utils.data.DataLoader(
        dataset = train_dataset,
        num_workers = params.num_workers,
        batch_size = params.batch_size,
        collate_fn = collate_fn,
        shuffle = True
    )
    logger.info('正在构造train_loader,共有{}个batch'.format(len(train_dataset)))

    # 构造dev_loader
    dev_dataset = Dataset(params, data, mode='dev')
    dev_loader = torch.utils.data.DataLoader(
        dataset = dev_dataset,
        num_workers = params.num_workers,
        batch_size = params.batch_size,
        collate_fn = collate_fn,
        shuffle = False
    )
    logger.info('正在构造dev_loader,共有{}个batch'.format(len(dev_dataset)))

    return train_loader, dev_loader


In [16]:
train_loader, dev_loader = prepare_dataloaders(params, data)

2020/03/04 06:16:00 PM:  正在从squad中读取数据 
2020/03/04 06:16:00 PM:  正在构造train_loader,共有75722个batch 
2020/03/04 06:16:00 PM:  正在构造dev_loader,共有18015个batch 


In [10]:
model = Model(params, vocab).to(params.device)

In [12]:
model_params = torch.load(model_filename, map_location=params.device)
model.load_state_dict(model_params)

<All keys matched successfully>

In [None]:
model.eval()

In [21]:
# predict dev using training method
for batch_index, batch in enumerate(dev_loader):
    
    print(batch_index, batch[0].shape, batch[1].shape)
    
    break

0 torch.Size([32, 50]) torch.Size([32, 18])


In [None]:
# predict dev using inference method
