In [None]:
# Run all cells in order to generate a csv which contains: all osm data + predicted lane widths (if it was already measured) + predicted lane width data,  for each sample in the generated dataset

In [None]:
# notebook init
# Ensure that you have cloned the road-network-inference repo into /content/drive/MyDrive/
# Make sure the notebook is running using a GPU
from google.colab import drive
drive.mount('/content/drive')

!pip install PyMaxflow
!pip install --upgrade torch torchvision

import requests
from PIL import Image
import math
from math import pi
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import csv
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import WeightedRandomSampler
import torchvision
from torchvision import datasets, models, transforms
from skimage import io, transform
from torch.autograd import Variable as V
from tqdm import tqdm 
from skimage.draw import line
import cv2
from time import time


%matplotlib inline

%cd /content/drive/MyDrive/Road-Network-Inference

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Road-Network-Inference


In [None]:
# path to the output csv. We highly recommend keeping this the same for road width and lane detection so that the outputs end up in the same file
output_csv = '/content/drive/MyDrive/Road-Network-Inference/samples_out.csv'

# seed for torch
torch_seed = 42

# location of csv dataset
data_path = csvPath = 'data_backup.csv'


# path to the folder containing satellite image
base_satellite_image_path = '/content/drive/MyDrive/Road-Network-Inference/'

# path to the folder to save the checkpoint and model
best_model_path = '/content/drive/MyDrive/Road-Network-Inference/bestModel_weighted.pt'
checkpoint_path = '/content/drive/MyDrive/Road-Network-Inference/chkPoint_weighted.pt'

In [None]:
torch.manual_seed(torch_seed)

<torch._C.Generator at 0x7f90dbeb5870>

In [None]:
df = pd.read_csv(data_path)
classes = np.sort(df.lanes.unique())

# Define Dataset Structure

In [None]:
# dataset for the model
# indices is a list of index to index in to the csv dataset
# augmentSize is how many times to increase the dataset through transform augment
# ex 2 will produce 2 data images from 1 satelite image
# the randomness in the transform flip and rotation will make the 2 data images differ but have same data
# the idea is that the road at the center of the image is the same regardless of rotation and flip
# transform is the transform operation applied to the image
# returns the image, lanes, and extra data such as one hot encoded highway type and oneway
class RoadsDataset(Dataset):
    def __init__(self, 
                 indices,
                 augmentSize = 1, 
                 transform=None):
        self.indices = indices
        self.augmentSize = augmentSize
        self.labelDf = pd.read_csv(data_path)
        self.transform = transform

    def __len__(self):
        return self.augmentSize*len(self.indices)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # convert dataset index to csv dataset index
        idx = self.indices[idx // self.augmentSize]

        imgName = os.path.join(base_satellite_image_path, df.iloc[idx]['sat_name'])

        image = Image.open(imgName).convert('RGB') # was rgba png with 4 dim

        if self.transform:
            image = self.transform(image)
        # print(self.labelDf.iloc[idx]['sat_name'])
        lanes = self.labelDf.iloc[idx]['lanes'] - 1 # -1 to make label value 1-9 -> 0-8 https://stackoverflow.com/questions/51691563/cuda-runtime-error-59-device-side-assert-triggered
        try:
            numerical_data = self.labelDf.iloc[idx, 3:16].astype(int)
        except:
            print('numerical data exception idx:', idx)
            numerical_data = []
        numerical_data = torch.tensor(numerical_data)
        return image, lanes, numerical_data

# Load Dataset

In [None]:
totalCount = len(df)
indices = list(range(totalCount))

In [None]:
transformMean = np.array([0.5, 0.5, 0.5])
transformStd = np.array([0.5, 0.5, 0.5])
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(transformMean, transformStd)])

In [None]:
testDataset = RoadsDataset(indices = indices, transform=transform)

In [None]:
batchSize = 2
numWorkers = 2 # complains if its greater than 2

In [None]:
dataloaders = { 'test': DataLoader(testDataset, batch_size=batchSize, shuffle=False) }

In [None]:
datasetSizes = { 'test': len(testDataset) }

# Prepare to test

In [None]:
# modified VGG16 to take in extra data from the dataset in addition totthe image
NUM_NUMERICAL_FEATURES = 13
class VGG16(nn.Module):
  
    def __init__(self):
        super(VGG16, self).__init__()
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)

        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(131072 , 4096)
        self.fc2 = nn.Linear(4096 + NUM_NUMERICAL_FEATURES, 4096)
        self.fc3 = nn.Linear(4096, 10)

    def stackcat(self, x, y):
        return torch.cat((x, y), 1)

    def forward(self, x, num_data):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.conv3_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.conv4_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))
        x = F.relu(self.conv5_3(x))
        x = self.maxpool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5) # dropout was included to combat overfitting
        x = self.stackcat(x, num_data)  # add dataset values to the linear layer
        x = F.relu(self.fc2(x))
        x = F.dropout(x, 0.5)
        x = self.fc3(x)
        return x

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = VGG16()
model_ft = model.to(device)

In [None]:
class_correct = list(0. for i in range(len(classes)))
class_total = list(0. for i in range(len(classes)))
all_preds = []
all_labels = []
with torch.no_grad():
    for inputs, labels, num_data in tqdm(dataloaders['test']):
        num_data = num_data.to(device)
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model_ft(inputs, num_data)
        _, preds = torch.max(outputs.data, 1)
        c = (preds == labels)
        all_preds.extend(preds)
        all_labels.extend(labels)
        if c.size(dim=0) != 1:
            c = c.squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(len(classes)):
    if class_total[i] == 0:
        print('0 total of %6s : %2d %% | %d' % (classes[i], class_correct[i], class_total[i]))
    else :
        print('Accuracy of %5s : %2d %% | %d' % (
        classes[i], 100 * class_correct[i] / class_total[i], class_total[i]))
print('Overall test accuracy : %2d %% | %d' % (
        100 * sum(class_correct) / sum(class_total), sum(class_total)))

  0%|          | 1/5500 [00:00<09:38,  9.51it/s]

satelliteImage3/washington/washington_0_sat.png
satelliteImage3/washington/washington_1_sat.png
satelliteImage3/washington/washington_2_sat.png
satelliteImage3/washington/washington_3_sat.png
satelliteImage3/washington/washington_4_sat.png
satelliteImage3/washington/washington_5_sat.png


  0%|          | 5/5500 [00:00<07:18, 12.54it/s]

satelliteImage3/washington/washington_6_sat.png
satelliteImage3/washington/washington_7_sat.png
satelliteImage3/washington/washington_8_sat.png
satelliteImage3/washington/washington_9_sat.png
satelliteImage3/washington/washington_10_sat.png
satelliteImage3/washington/washington_11_sat.png


  0%|          | 7/5500 [00:00<07:05, 12.90it/s]

satelliteImage3/washington/washington_12_sat.png
satelliteImage3/washington/washington_13_sat.png
satelliteImage3/washington/washington_14_sat.png
satelliteImage3/washington/washington_15_sat.png
satelliteImage3/washington/washington_16_sat.png
satelliteImage3/washington/washington_17_sat.png


  0%|          | 11/5500 [00:00<06:47, 13.47it/s]

satelliteImage3/washington/washington_18_sat.png
satelliteImage3/washington/washington_19_sat.png
satelliteImage3/washington/washington_20_sat.png
satelliteImage3/washington/washington_21_sat.png
satelliteImage3/washington/washington_22_sat.png
satelliteImage3/washington/washington_23_sat.png


  0%|          | 13/5500 [00:00<06:42, 13.62it/s]

satelliteImage3/washington/washington_24_sat.png
satelliteImage3/washington/washington_25_sat.png
satelliteImage3/washington/washington_26_sat.png
satelliteImage3/washington/washington_27_sat.png
satelliteImage3/washington/washington_28_sat.png
satelliteImage3/washington/washington_29_sat.png


  0%|          | 17/5500 [00:01<06:41, 13.67it/s]

satelliteImage3/washington/washington_30_sat.png
satelliteImage3/washington/washington_31_sat.png
satelliteImage3/washington/washington_32_sat.png
satelliteImage3/washington/washington_33_sat.png
satelliteImage3/washington/washington_34_sat.png
satelliteImage3/washington/washington_35_sat.png


  0%|          | 19/5500 [00:01<06:46, 13.49it/s]

satelliteImage3/washington/washington_36_sat.png
satelliteImage3/washington/washington_37_sat.png
satelliteImage3/washington/washington_38_sat.png
satelliteImage3/washington/washington_39_sat.png
satelliteImage3/washington/washington_40_sat.png
satelliteImage3/washington/washington_41_sat.png


  0%|          | 21/5500 [00:01<06:43, 13.57it/s]

satelliteImage3/washington/washington_42_sat.png
satelliteImage3/washington/washington_43_sat.png


  0%|          | 23/5500 [00:02<11:09,  8.18it/s]

satelliteImage3/washington/washington_44_sat.png
satelliteImage3/washington/washington_45_sat.png
satelliteImage3/washington/washington_46_sat.png
satelliteImage3/washington/washington_47_sat.png
satelliteImage3/washington/washington_48_sat.png
satelliteImage3/washington/washington_49_sat.png


  0%|          | 27/5500 [00:02<08:54, 10.24it/s]

satelliteImage3/washington/washington_50_sat.png
satelliteImage3/washington/washington_51_sat.png
satelliteImage3/washington/washington_52_sat.png
satelliteImage3/washington/washington_53_sat.png
satelliteImage3/washington/washington_54_sat.png
satelliteImage3/washington/washington_55_sat.png


  1%|          | 28/5500 [00:02<08:05, 11.27it/s]

satelliteImage3/washington/washington_56_sat.png
satelliteImage3/washington/washington_57_sat.png





KeyboardInterrupt: ignored

In [None]:
print(int(all_preds[0].cpu().numpy()))

7


# Save to CSV



In [None]:
if os.path.exists(output_csv):
  csv_dir = output_csv
df = pd.read_csv(csv_dir)

output = []
output.append(df.columns.values.tolist() + ["predicted_lane_nums"])
for i, row in df.iterrows():
  output.append(row.values.flatten().tolist() + [int(all_preds[i].cpu().numpy())])
np.savetxt(output_csv, output, delimiter=',', fmt='%s')


[['idx', 'sat_name', 'cen_name', 'living_street', 'motorway', 'motorway_link', 'primary', 'primary_link', 'residential', 'secondary', 'secondary_link', 'tertiary', 'trunk', 'trunk_link', 'unclassified', 'oneway', 'grade', 'n1', 'n1_ele', 'n1_lat', 'n1_lon', 'n1_x', 'n1_y', 'n2', 'n2_ele', 'n2_lat', 'n2_lon', 'n2_x', 'n2_y', 'perp1_x', 'perp1_y', 'perp2_x', 'perp2_y', 'name', 'city_name', 'city_idx', 'osmid', 'lanes', 'road_width', 'road_width_meters', 'auto_lane_nums'], [10997, 'satelliteImage3/salt_lake_city/salt_lake_city_997_sat.png', 'centerLineImage3/salt_lake_city/salt_lake_city_997_osm.png', 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, False, -0.001, 3927749979, 1287.932, 40.772147, -111.9196967, 255.92543460428715, 81.28201495110989, 7011437388, 1287.908, 40.7717921, -111.9196965, 256.074565410614, 430.7175184339285, 128.0000116568334, 256.0546273669424, 383.9999883431666, 255.94537263305764, '1000 West', 'salt_lake_city', 997, 486579598, 4, 26.004048246685613, 26.004048246685613, 7]]
[