forked from dgriffiths3/pointnet2-tensorflow2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_scannet.py
105 lines (74 loc) · 2.34 KB
/
train_scannet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import sys
import datetime
sys.path.insert(0, './')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow import keras
from models.sem_seg_model import SEM_SEG_Model
tf.random.set_seed(42)
def load_dataset(in_file, batch_size):
assert os.path.isfile(in_file), '[error] dataset path not found'
n_points = 8192
shuffle_buffer = 1000
def _extract_fn(data_record):
in_features = {
'points': tf.io.FixedLenFeature([n_points * 3], tf.float32),
'labels': tf.io.FixedLenFeature([n_points], tf.int64)
}
return tf.io.parse_single_example(data_record, in_features)
def _preprocess_fn(sample):
points = sample['points']
labels = sample['labels']
points = tf.reshape(points, (n_points, 3))
labels = tf.reshape(labels, (n_points, 1))
shuffle_idx = tf.range(points.shape[0])
shuffle_idx = tf.random.shuffle(shuffle_idx)
points = tf.gather(points, shuffle_idx)
labels = tf.gather(labels, shuffle_idx)
return points, labels
dataset = tf.data.TFRecordDataset(in_file)
dataset = dataset.shuffle(shuffle_buffer)
dataset = dataset.map(_extract_fn)
dataset = dataset.map(_preprocess_fn)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
def train():
model = SEM_SEG_Model(config['batch_size'], config['num_classes'], config['bn'])
train_ds = load_dataset(config['train_ds'], config['batch_size'])
val_ds = load_dataset(config['val_ds'], config['batch_size'])
callbacks = [
keras.callbacks.TensorBoard(
'./logs/{}'.format(config['log_dir']), update_freq=50),
keras.callbacks.ModelCheckpoint(
'./logs/{}/model/weights'.format(config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True)
]
model.build((config['batch_size'], 8192, 3))
print(model.summary())
model.compile(
optimizer=keras.optimizers.Adam(config['lr']),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[keras.metrics.SparseCategoricalAccuracy()]
)
model.fit(
train_ds,
validation_data=val_ds,
validation_steps=10,
validation_freq=1,
callbacks=callbacks,
epochs=100,
verbose=1
)
if __name__ == '__main__':
config = {
'train_ds' : 'data/scannet_train.tfrecord',
'val_ds' : 'data/scannet_val.tfrecord',
'log_dir' : 'scannet_1',
'log_freq' : 10,
'test_freq' : 100,
'batch_size' : 4,
'num_classes' : 21,
'lr' : 0.001,
'bn' : False,
}
train()