-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
105 lines (83 loc) · 3.09 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from config import Config
from dataset.dataloader import CtaDataLoader
from logs.logger import Logger
from model.metric import multiple_f1_score
from model.model import BertForClassification
from transformers import BertTokenizer, BertConfig
from utils.functions import collate, prepare_device, get_token_logits, set_rs, get_map_location, \
filter_model_state_dict, get_dataset_type
def test(
config,
model,
dataloader,
device,
tokenizer,
loss_fn,
metric_fn,
batch_size,
num_labels
):
set_rs(config["random_seed"])
_logits, _targets = [], []
model.eval()
running_loss = 0.0
with torch.no_grad():
for batch in dataloader:
data = batch["data"].to(device)
labels = batch["labels"].to(device)
attention_mask = torch.clone(data != 0)
probs = model(data, attention_mask=attention_mask)
# TODO: why it can return tuple(tensor), except for just tensor?
if isinstance(probs, tuple):
probs = probs[0]
cls_probs = get_token_logits(device, data, probs, tokenizer.cls_token_id)
loss = loss_fn(cls_probs, labels)
running_loss += loss.item()
_logits.append(cls_probs.argmax(1).cpu().detach().numpy().tolist())
_targets.append(labels.cpu().detach().numpy().tolist())
return {
"loss": running_loss / batch_size,
"metrics": metric_fn(_logits, _targets, num_labels)
}
if __name__ == "__main__":
conf = Config(config_path="config.json")
tokenizer = BertTokenizer.from_pretrained(conf["pretrained_model_name"])
dataset_type = get_dataset_type(conf["table_serialization_type"])
dataset = dataset_type(
tokenizer=tokenizer,
num_rows=conf["dataset"]["num_rows"],
data_dir=conf["dataset"]["data_dir"] + conf["dataset"]["test_path"]
)
dataloader = CtaDataLoader(
dataset,
batch_size=conf["batch_size"],
num_workers=conf["dataloader"]["num_workers"],
collate_fn=collate
)
model = BertForClassification(
BertConfig.from_pretrained(conf["pretrained_model_name"], num_labels=conf["num_labels"])
)
checkpoint = torch.load(conf["checkpoint_dir"] + conf["checkpoint_name"], map_location=get_map_location())
model.load_state_dict(filter_model_state_dict(checkpoint["model_state_dict"]))
device, device_ids = prepare_device(conf["num_gpu"])
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)
loss_metrics = test(
conf,
model,
dataloader,
device,
tokenizer,
torch.nn.CrossEntropyLoss(),
multiple_f1_score,
conf["batch_size"],
conf["num_labels"]
)
# Logging results
logger = Logger(conf["test_log_filename"])
logger.info(f"--- --- ---", "TEST")
logger.info(f"Loss: {loss_metrics['loss']};", "LOSS")
for metric in conf["metrics"]:
logger.info(f"{metric} = {loss_metrics['metrics'][metric]}", "METRIC")