In [None]:
%cd ..
%env TOKENIZERS_PARALLELISM=false

In [None]:
import torch
from tqdm.auto import tqdm

from src import DataLoader, PredictionDataset, SpanPredictionModel
from src.scoring import compute_score

#### 設定各項參數

In [None]:
device = torch.device('cuda')
model_path = 'weights/sp.pt'
dataset_path = 'data/splitted/test.csv'
batch_size = 1
top_k = 3
max_tokens = 10000

#### 載入模型

In [None]:
model = SpanPredictionModel()
model.load_state_dict(torch.load(model_path))
model = model.requires_grad_(False).to(device)

#### 載入資料集

In [None]:
dataset = PredictionDataset(dataset_path, model.tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, pin_memory=True)

#### 進行預測

In [None]:
answers = []
for xs, encodings, state in tqdm(dataloader):
    encodings.__setstate__(state)
    encodings = encodings.to(device)

    special_tokens_mask = encodings.pop('special_tokens_mask')
    valid_mask: torch.Tensor = encodings.attention_mask.bool()
    valid_mask &= ~special_tokens_mask.bool()
    valid_mask[:, 0] = True

    preds = model(encodings)
    ans = model.decode_answers(xs, encodings, valid_mask, preds, top_k, max_tokens)
    answers.extend(ans)

#### 計算分數

In [None]:
compute_score(answers)