-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathdownload_checkpoints.py
83 lines (68 loc) · 2.76 KB
/
download_checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import math
import os
import os.path as osp
from multiprocessing import Pool
import torch
from mmengine.config import Config
from mmengine.utils import mkdir_or_exist
def download(url, out_file, min_bytes=math.pow(1024, 2), progress=True):
# math.pow(1024, 2) is mean 1 MB
assert_msg = f"Downloaded url '{url}' does not exist " \
f'or size is < min_bytes={min_bytes}'
try:
print(f'Downloading {url} to {out_file}...')
torch.hub.download_url_to_file(url, str(out_file), progress=progress)
assert osp.exists(
out_file) and osp.getsize(out_file) > min_bytes, assert_msg
except Exception as e:
if osp.exists(out_file):
os.remove(out_file)
print(f'ERROR: {e}\nRe-attempting {url} to {out_file} ...')
os.system(f"curl -L '{url}' -o '{out_file}' --retry 3 -C -"
) # curl download, retry and resume on fail
finally:
if osp.exists(out_file) and osp.getsize(out_file) < min_bytes:
os.remove(out_file) # remove partial downloads
if not osp.exists(out_file):
print(f'ERROR: {assert_msg}\n')
print('=========================================\n')
def parse_args():
parser = argparse.ArgumentParser(description='Download checkpoints')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'out', type=str, help='output dir of checkpoints to be stored')
parser.add_argument(
'--nproc', type=int, default=16, help='num of Processes')
parser.add_argument(
'--intranet',
action='store_true',
help='switch to internal network url')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
mkdir_or_exist(args.out)
cfg = Config.fromfile(args.config)
checkpoint_url_list = []
checkpoint_out_list = []
for model in cfg:
model_infos = cfg[model]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
checkpoint = model_info['checkpoint']
out_file = osp.join(args.out, checkpoint)
if not osp.exists(out_file):
url = model_info['url']
if args.intranet is True:
url = url.replace('.com', '.sensetime.com')
url = url.replace('https', 'http')
checkpoint_url_list.append(url)
checkpoint_out_list.append(out_file)
if len(checkpoint_url_list) > 0:
pool = Pool(min(os.cpu_count(), args.nproc))
pool.starmap(download, zip(checkpoint_url_list, checkpoint_out_list))
else:
print('No files to download!')