-
Notifications
You must be signed in to change notification settings - Fork 6
/
inference.py
92 lines (74 loc) · 3.35 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import sys
from torch.utils.data import DataLoader
from utils.data_loading import BasicDataset
import logging
from utils.path_hyperparameter import ph
import torch
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1Score
from models.Models import DPCD
from utils.dataset_process import compute_mean_std
from tqdm import tqdm
def train_net(dataset_name, load_checkpoint=True):
# 1. Create dataset
# compute mean and std of train dataset to normalize train/val/test dataset
t1_mean, t1_std = compute_mean_std(images_dir=f'./{dataset_name}/train/t1/')
t2_mean, t2_std = compute_mean_std(images_dir=f'./{dataset_name}/train/t2/')
dataset_args = dict(t1_mean=t1_mean, t1_std=t1_std, t2_mean=t2_mean, t2_std=t2_std)
test_dataset = BasicDataset(t1_images_dir=f'./{dataset_name}/test/t1/',
t2_images_dir=f'./{dataset_name}/test/t2/',
labels_dir=f'./{dataset_name}/test/label/',
train=False, **dataset_args)
# 2. Create data loaders
loader_args = dict(num_workers=8,
prefetch_factor=5,
persistent_workers=True
)
test_loader = DataLoader(test_dataset, shuffle=False, drop_last=False,
batch_size=ph.batch_size * ph.inference_ratio, **loader_args)
# 3. Initialize logging
logging.basicConfig(level=logging.INFO)
# 4. Set up device, model, metric calculator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.basicConfig(level=logging.INFO)
logging.info(f'Using device {device}')
net = DPCD()
net.to(device=device)
assert ph.load, 'Loading model error, checkpoint ph.load'
load_model = torch.load(ph.load, map_location=device)
if load_checkpoint:
net.load_state_dict(load_model['net'])
else:
net.load_state_dict(load_model)
logging.info(f'Model loaded from {ph.load}')
torch.save(net.state_dict(), f'{dataset_name}_best_model.pth')
metric_collection = MetricCollection({
'accuracy': Accuracy().to(device=device),
'precision': Precision().to(device=device),
'recall': Recall().to(device=device),
'f1score': F1Score().to(device=device)
}) # metrics calculator
net.eval()
logging.info('SET model mode to test!')
with torch.no_grad():
for batch_img1, batch_img2, labels, name in tqdm(test_loader):
batch_img1 = batch_img1.float().to(device)
batch_img2 = batch_img2.float().to(device)
labels = labels.float().to(device)
cd_preds = net(batch_img1, batch_img2, log=True, img_name=name)
cd_preds = torch.sigmoid(cd_preds[0])
# Calculate and log other batch metrics
cd_preds = cd_preds.float()
labels = labels.int().unsqueeze(1)
metric_collection.update(cd_preds, labels)
# clear batch variables from memory
del batch_img1, batch_img2, labels
test_metrics = metric_collection.compute()
print(f"Metrics on all data: {test_metrics}")
metric_collection.reset()
print('over')
if __name__ == '__main__':
try:
train_net(dataset_name='data_njds_crop_8', load_checkpoint=False)
except KeyboardInterrupt:
logging.info('Error')
sys.exit(0)