forked from azavea/raster-vision
-
Notifications
You must be signed in to change notification settings - Fork 0
/
xview.py
135 lines (116 loc) · 5.09 KB
/
xview.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
from os.path import join
from rastervision.core.rv_pipeline import (ObjectDetectionConfig,
ObjectDetectionChipOptions,
ObjectDetectionPredictOptions)
from rastervision.core.data import (
ClassConfig, ClassInferenceTransformerConfig, DatasetConfig,
GeoJSONVectorSourceConfig, ObjectDetectionLabelSourceConfig,
RasterioSourceConfig, SceneConfig)
from rastervision.pytorch_backend import PyTorchObjectDetectionConfig
from rastervision.pytorch_learner import (
Backbone, GeoDataWindowMethod, ObjectDetectionGeoDataConfig,
ObjectDetectionGeoDataWindowConfig, ObjectDetectionImageDataConfig,
ObjectDetectionModelConfig, SolverConfig)
from rastervision.pytorch_backend.examples.utils import (get_scene_info,
save_image_crop)
def get_config(runner,
raw_uri: str,
processed_uri: str,
root_uri: str,
nochip: bool = True,
test: bool = False) -> ObjectDetectionConfig:
"""Generate the pipeline config for this task. This function will be called
by RV, with arguments from the command line, when this example is run.
Args:
runner (Runner): Runner for the pipeline. Will be provided by RV.
raw_uri (str): Directory where the raw data resides
processed_uri (str): Directory for storing processed data.
E.g. crops for testing.
root_uri (str): Directory where all the output will be written.
nochip (bool, optional): If True, read directly from the TIFF during
training instead of from pre-generated chips. The analyze and chip
commands should not be run, if this is set to True. Defaults to
False.
test (bool, optional): If True, does the following simplifications:
(1) Uses only the first 2 scenes.
(2) Uses only a 2000x2000 crop of the scenes.
(3) Trains for only 2 epochs.
Defaults to False.
Returns:
ObjectDetectionConfig: A pipeline config.
"""
train_scene_info = get_scene_info(join(processed_uri, 'train-scenes.csv'))
val_scene_info = get_scene_info(join(processed_uri, 'val-scenes.csv'))
if test:
train_scene_info = train_scene_info[0:1]
val_scene_info = val_scene_info[0:1]
def make_scene(scene_info):
(raster_uri, label_uri) = scene_info
raster_uri = join(raw_uri, raster_uri)
label_uri = join(processed_uri, label_uri)
if test:
crop_uri = join(processed_uri, 'crops',
os.path.basename(raster_uri))
save_image_crop(raster_uri, crop_uri, size=2000, min_features=5)
raster_uri = crop_uri
id = os.path.splitext(os.path.basename(raster_uri))[0]
raster_source = RasterioSourceConfig(
uris=[raster_uri], channel_order=[0, 1, 2])
label_source = ObjectDetectionLabelSourceConfig(
vector_source=GeoJSONVectorSourceConfig(
uris=label_uri,
ignore_crs_field=True,
transformers=[
ClassInferenceTransformerConfig(default_class_id=0)
]))
return SceneConfig(
id=id, raster_source=raster_source, label_source=label_source)
train_scenes = [make_scene(info) for info in train_scene_info]
val_scenes = [make_scene(info) for info in val_scene_info]
class_config = ClassConfig(names=['vehicle'], colors=['red'])
scene_dataset = DatasetConfig(
class_config=class_config,
train_scenes=train_scenes,
validation_scenes=val_scenes)
chip_sz = 300
img_sz = chip_sz
chip_options = ObjectDetectionChipOptions(neg_ratio=1.0, ioa_thresh=0.8)
if nochip:
window_opts = ObjectDetectionGeoDataWindowConfig(
method=GeoDataWindowMethod.random,
size=chip_sz,
size_lims=(chip_sz, chip_sz + 1),
max_windows=200,
clip=True,
neg_ratio=chip_options.neg_ratio,
ioa_thresh=chip_options.ioa_thresh)
data = ObjectDetectionGeoDataConfig(
scene_dataset=scene_dataset,
window_opts=window_opts,
img_sz=img_sz,
augmentors=[])
else:
data = ObjectDetectionImageDataConfig(img_sz=img_sz, num_workers=4)
predict_options = ObjectDetectionPredictOptions(
merge_thresh=0.1, score_thresh=0.5)
backend = PyTorchObjectDetectionConfig(
data=data,
model=ObjectDetectionModelConfig(backbone=Backbone.resnet50),
solver=SolverConfig(
lr=1e-4,
num_epochs=10 if not test else 2,
batch_sz=16,
one_cycle=True,
),
log_tensorboard=True,
run_tensorboard=False,
)
return ObjectDetectionConfig(
root_uri=root_uri,
dataset=scene_dataset,
backend=backend,
train_chip_sz=chip_sz,
predict_chip_sz=chip_sz,
chip_options=chip_options,
predict_options=predict_options)