-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
48 lines (29 loc) · 1.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#=========================================================================================================
#================================ 0. MODULE
# Base
import pandas as pd
import numpy as np
import sys
from math import floor
# Object detector
from object_detector_Unet import *
# from object_detector_SSD import *
# Ramp
sys.path.append('../')
import problem
#=========================================================================================================
#================================ 1. DATA
print("\nLoading data", end='...')
data_path = '../'
Xtrain, Ytrain = problem.get_train_data(data_path)
SIZE = floor(Xtrain.shape[0] / 2.)
Xtrain = Xtrain[:SIZE]
Ytrain = Ytrain[:SIZE]
print('done')
print('>>', Xtrain.shape, Ytrain.shape,'\n')
#=========================================================================================================
#================================ 2. TRAINING
object_detector = ObjectDetector()
object_detector.fit(Xtrain, Ytrain)