-
Notifications
You must be signed in to change notification settings - Fork 419
/
massachusetts_road_dataset_utils.py
113 lines (95 loc) · 4.44 KB
/
massachusetts_road_dataset_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
__author__ = 'Fabian Isensee'
import os
import sys
import fnmatch
import matplotlib.pyplot as plt
sys.path.append("../../modelzoo/")
from generators import *
from multiprocessing.dummy import Pool
from urllib import urlretrieve
def prep_folders():
if not os.path.isdir("data"):
os.mkdir("data")
if not os.path.isdir("data/validation"):
os.mkdir("data/validation")
if not os.path.isdir("data/training"):
os.mkdir("data/training")
if not os.path.isdir("data/test"):
os.mkdir("data/test")
if not os.path.isdir("data/validation/sat_img"):
os.mkdir("data/validation/sat_img")
if not os.path.isdir("data/validation/map"):
os.mkdir("data/validation/map")
if not os.path.isdir("data/training/sat_img"):
os.mkdir("data/training/sat_img")
if not os.path.isdir("data/training/map"):
os.mkdir("data/training/map")
if not os.path.isdir("data/test/sat_img"):
os.mkdir("data/test/sat_img")
if not os.path.isdir("data/test/map"):
os.mkdir("data/test/map")
def prep_urls():
valid_data_url = valid_target_url = np.loadtxt("mass_roads_validation.txt", dtype=str)
valid_data_str = "https://www.cs.toronto.edu/~vmnih/data/mass_roads/valid/sat/"
valid_target_str = "https://www.cs.toronto.edu/~vmnih/data/mass_roads/valid/map/"
train_data_url = train_target_url = np.loadtxt("mass_roads_train.txt", dtype=str)
train_data_str = "https://www.cs.toronto.edu/~vmnih/data/mass_roads/train/sat/"
train_target_str = "https://www.cs.toronto.edu/~vmnih/data/mass_roads/train/map/"
test_data_url = test_target_url = np.loadtxt("mass_roads_test.txt", dtype=str)
test_data_str = "https://www.cs.toronto.edu/~vmnih/data/mass_roads/test/sat/"
test_target_str = "https://www.cs.toronto.edu/~vmnih/data/mass_roads/test/map/"
all_tasks = []
# save url along with the filename for each file
for img_name in train_data_url:
all_tasks.append(tuple([train_data_str + img_name + "f", "data/training/sat_img/%sf"%img_name]))
all_tasks.append(tuple([train_target_str + img_name, "data/training/map/%s"%img_name]))
for img_name in valid_data_url:
all_tasks.append(tuple([valid_data_str + img_name, "data/validation/sat_img/%s"%img_name]))
all_tasks.append(tuple([valid_target_str + img_name[:-1], "data/validation/map/%s"%img_name[:-1]]))
for img_name in test_data_url:
all_tasks.append(tuple([test_data_str + img_name, "data/test/sat_img/%s"%img_name]))
all_tasks.append(tuple([test_target_str + img_name[:-1], "data/test/map/%s"%img_name[:-1]]))
return all_tasks
def download_dataset(all_tasks, num_workers=4):
def urlretrieve_star(args):
return urlretrieve(*args)
pool = Pool(num_workers)
pool.map(urlretrieve_star, all_tasks)
pool.close()
pool.join()
def load_data(folder):
images_sat = [img for img in os.listdir(os.path.join(folder, "sat_img")) if fnmatch.fnmatch(img, "*.tif*")]
images_map = [img for img in os.listdir(os.path.join(folder, "map")) if fnmatch.fnmatch(img, "*.tif*")]
assert(len(images_sat) == len(images_map))
images_sat.sort()
images_map.sort()
# images are 1500 by 1500 pixels each
data = np.zeros((len(images_sat), 3, 1500, 1500), dtype=np.uint8)
target = np.zeros((len(images_sat), 1, 1500, 1500), dtype=np.uint8)
ctr = 0
for sat_im, map_im in zip(images_sat, images_map):
data[ctr] = plt.imread(os.path.join(folder, "sat_img", sat_im)).transpose((2, 0, 1))
# target has values 0 and 255. make that 0 and 1
target[ctr, 0] = plt.imread(os.path.join(folder, "map", map_im))/255
ctr += 1
return data, target
def prepare_dataset():
prep_folders()
all_tasks = prep_urls()
download_dataset(all_tasks)
print "download done..."
try:
data_train, target_train = load_data("data/training")
data_valid, target_valid = load_data("data/validation")
data_test, target_test = load_data("data/test")
# loading np arrays is much faster than loading the images one by one every time
np.save("data_train.npy", data_train)
np.save("target_train.npy", target_train)
np.save("data_valid.npy", data_valid)
np.save("target_valid.npy", target_valid)
np.save("data_test.npy", data_test)
np.save("target_test.npy", target_test)
except:
print "something went wrong, maybe the download?"
if __name__ == "__main__":
prepare_dataset()