-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
56 lines (52 loc) · 1.89 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
###libs
import os
import argparse
import torch
from torch.utils.data import DataLoader
import json
###files
from config import Config as cfg
from dataProcessing import VOC_dataset as dataset
from dataProcessing import Testset
from models.network import NetAPI
from trainer import Trainer
import warnings
from loss_funcs import LossAPI
warnings.filterwarnings('ignore')
def get_imgs(path):
imgs =[]
for dirname, dirs, filenames in os.walk(path,followlinks=True):
#print(dirs)
for filename in filenames:
#print(filenames)
if ".jpg" in filename:
imgs.append(os.path.join(dirname, filename))
json.dump(imgs,open("data/test.json","w"))
def main(args,cfgs):
#get data config
config = cfgs['test']
#get_imgs('../dataset/global-wheat')
config.file="data/test.json"
test_set = Testset(config)
test_loader = DataLoader(test_set,batch_size=config.bs,shuffle=False,pin_memory=False)
datasets = {'test':test_loader}
config.exp_name = args.exp
config.device = torch.device("cuda")
torch.cuda.empty_cache()
#network
network = NetAPI(config,args.net,init=not args.resume)
loss = LossAPI(config,args.loss)
torch.cuda.empty_cache()
det = Trainer(config,datasets,network,loss,(args.resume,1))
det.test()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--resume", type=str, default=None, help="start from epoch?")
parser.add_argument("--exp",type=str,default='exp',help="name of exp")
parser.add_argument("--net",type=str,default='yolo',help="network type:yolo")
parser.add_argument("--bs",type=int,default=16,help="batchsize")
parser.add_argument("--loss",type=str,default='yolo',help="loss type")
args = parser.parse_args()
cfgs={}
cfgs['test'] = cfg('test')
main(args,cfgs)