In [None]:
# Notebook Init
# Ensure that you have cloned the road-network-inference repo into /content/drive/MyDrive/
# Make sure the notebook is running using a GPU
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor/DBNet
!pip install PyMaxflow
import os
from tqdm import tqdm 
import os
import pickle
import numpy as np
from time import time
import random
import torch
import torch.utils.data as data
from networks.dinknet import ResNet34_EdgeNet
from framework import MyFrame
from loss import Regularized_Loss
from data import ImageFolder
import gc

%cd /content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor/DBNet/wrapper/bilateralfilter
%run /content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor/DBNet/wrapper/bilateralfilter/setup.py build
%run /content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor/DBNet/wrapper/bilateralfilter/setup.py install
%cd /content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor/DBNet
/content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor/DBNet/wrapper/bilateralfilter
running build
running build_py
running build_ext
running install
running build
running build_py
running build_ext
running install_lib
running install_egg_info
Removing /usr/local/lib/python3.7/dist-packages/bilateralfilter-0.1.egg-info
Writing /usr/local/lib/python3.7/dist-packages/bilateralfilter-0.1.egg-info
/content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor


In [None]:
# Set sat_dir to the location of the sattelite images generated by the dataset creation notebook
sat_dir = "/content/drive/MyDrive/Road-Network-Inference/satelliteImage3RGB/"
# Set centerline_dir to the directory of the centerline images outputted bu the dataset creation
centerline_dir = "/content/drive/MyDrive/Road-Network-Inference/centerLineImage3/"


In [None]:
# Run to get the names of the cities
import math
cities = (([ name for name in os.listdir(sat_dir) if os.path.isdir(os.path.join(sat_dir, name))]))
list.sort(cities)
print(cities)


['austin', 'baltimore', 'denver', 'new_york_city', 'philadelphia', 'portland', 'san_francisco', 'san_jose', 'seattle', 'washington']
['austin', 'baltimore', 'denver', 'new_york_city', 'philadelphia', 'portland', 'san_francisco', 'san_jose', 'seattle', 'washington']


In [None]:
# Generate training data from sattelite image and centerline image for each example
for city in cities:
  img_path = sat_dir + city + "/"
  osm_path = centerline_dir + city + "/"
  %run /content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor/boundary_detect/run.py --img_root $img_path
  %run /content/drive/MyDrive/Road-Network-Inference/ScRoadExtractor/road_label_propagation.py --img_root $img_path --osm_root $osm_path

In [None]:

# Model output file prefix, if it already exists it will load it and resume training
NAME = 'DBNet_10Cities_zoomed_2'
# Batchsize
BATCHSIZE_PER_CARD = 12
# Number of epochs to train for
total_epoch = 300

In [None]:
# Load the data and save as pickle file
# After this cell is run for the first time the notebook will need to be restarted to clear the memory. - 
# Then this cell is to be run again upon which it will load the data from the memory mapped pickle file instead
SHAPE = (512, 512)
sat_dir = './data/train/sat/'
lab_dir = './data/train/proposal_mask/'
hed_dir = './data/train/rough_edge/'

print(torch.cuda.device_count())

batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
solver = MyFrame(ResNet34_EdgeNet, Regularized_Loss, 2e-4)
imagelist = (os.listdir(lab_dir))
trainlist = map(lambda x: x[:-9], imagelist)

print("Pre Loading Data")
dataset = ImageFolder(trainlist, sat_dir, lab_dir, hed_dir)
print("Loading Complete")

In [None]:
# Train the model and save the outputs




data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batchsize,
    shuffle=False,
    num_workers=0,
    pin_memory=True)



mylog = open('./logs/' + NAME + '.log', 'w')
print('==============================================================================', file=mylog)
print('==============================================================================', file=mylog)
print('====================================NEW SESSION===============================', file=mylog)
print('==============================================================================', file=mylog)
print('==============================================================================', file=mylog)

tic = time()
no_optim = 0
train_epoch_best_loss = 100.

solver.load('./weights/' + NAME + '.th')
for epoch in range(1, total_epoch + 1):
    data_loader_iter = iter(data_loader)
    train_epoch_loss = 0
    for img, mask, hed in tqdm(data_loader_iter):
        solver.set_input(img, mask, hed)
        train_loss = solver.optimize()
        train_epoch_loss += train_loss

    train_epoch_loss /= len(data_loader_iter)
    print('********', file=mylog)
    print('epoch:', epoch, '    time:', int(time() - tic), file=mylog)
    print('train_loss:', train_epoch_loss, file=mylog)
    print('SHAPE:', SHAPE, file=mylog)
    print('********')
    print('epoch:', epoch, '    time:', int(time() - tic))
    print('train_loss:', train_epoch_loss)
    print('SHAPE:', SHAPE)

    if train_epoch_loss >= train_epoch_best_loss:
        no_optim += 1
    else:
        no_optim = 0
        train_epoch_best_loss = train_epoch_loss
        solver.save('weights/' + NAME + '.th')
    if no_optim > 6:
        print('early stop at %d epoch' % epoch, file=mylog)
        print('early stop at %d epoch' % epoch)
        break
    if no_optim > 3:
        if solver.old_lr < 5e-7:
            break
        solver.load('./weights/' + NAME + '.th')
        solver.update_lr(5.0, factor=True, mylog=mylog)
    mylog.flush()
print('Finish!', file=mylog)
print('Finish!')
mylog.close()

