In [None]:
%cd ..
%env TOKENIZERS_PARALLELISM=false

In [None]:
import torch
from tqdm.auto import tqdm

from src import DataLoader, PredictionDatasetForSiamese, SiameseSpanPredictionModel
from src.scoring import compute_score

#### 設定各項參數

In [None]:
device = torch.device('cuda')
model_path = 'weights/s-sp.pt'
dataset_path = 'data/splitted/test.csv'
batch_size = 1
top_k = 1
max_tokens = 48

#### 載入模型

In [None]:
model = SiameseSpanPredictionModel()
model.load_state_dict(torch.load(model_path))
model = model.requires_grad_(False).to(device)

#### 載入資料集

In [None]:
dataset = PredictionDatasetForSiamese(dataset_path, model.tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, pin_memory=True)

#### 進行預測

In [None]:
answers = []
for xs, q, r, s, q_state, r_state in tqdm(dataloader):       
    q.__setstate__(q_state)
    r.__setstate__(r_state)
    
    q = q.to(device)
    r = r.to(device)
    s = s.to(device)

    q_special_tokens_mask = q.pop('special_tokens_mask')
    r_special_tokens_mask = r.pop('special_tokens_mask')

    q_valid_mask: torch.Tensor = q.attention_mask.bool()
    q_valid_mask &= ~q_special_tokens_mask.bool()
    q_valid_mask[:, 0] = True

    r_valid_mask: torch.Tensor = r.attention_mask.bool()
    r_valid_mask &= ~r_special_tokens_mask.bool()
    r_valid_mask[:, 0] = True

    preds = model(q, r, s)
    ans = model.decode_answers(xs, q, r, q_valid_mask, r_valid_mask, preds, top_k, max_tokens)
    answers.extend(ans)

#### 計算分數

In [None]:
compute_score(answers)