<a href="https://colab.research.google.com/github/LimSeunghyeon1/Lanenet_with_Colab/blob/master/Lanenet_for_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! #unzip -d /content/drive/Shareddrives/colab/ /content/drive/MyDrive/Apollo_Sim_3D_Lane_Release.zip

In [2]:
"""
Utility functions and default settings
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import argparse
import errno
import os
import sys
import easydict

import cv2
import matplotlib
import numpy as np
import torch
import torch.nn.init as init
import torch.optim
from torch.optim import lr_scheduler
import os.path as ops
from mpl_toolkits.mplot3d import Axes3D
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
plt.rcParams['figure.figsize'] = (35, 30)


def define_args():
    parser=easydict.EasyDict({
       "dataset_name":str(),
       "data_dir":str(),
       "dataset_dir":str(),
       "save_path":'/content/drive/Shareddrives/colab/data_splits',
       "org_h":1080,
       "org_w":1920,
       "crop_y":0,
       "cam_height":1.55,
       "pitch":float(3),
       "fix_cam":False,
       "no_3d": False,
       "no_centerline":False,
       "mod":'3DLaneNet',
       "pretrained":True,
       "batch_norm":True,
       "pred_cam":False,
       "ipm_h":208,
       "ipm_w":128,
       "resize_h":360,
       "resize_w":480,
       "y_ref":20.0,
       "prob_th":0.5,
       "batch_size":8,
       "nepochs":30,
       "learning_rate":5*1e-4,
       "no_cuda":False,
       "nworkers":0,
       "no_dropout":False,
       "pretrain_epochs":20,
       "channels_in":3,
       "flip_on":False,
       "test_mode":False,
       "start_epoch":0,
       "evaluate":True,  ##only show evaluation?
       "resume":str(),
       "vgg_mean":[0.485,0.456,0.406],
       "vgg_std":[0.229,0.224,0.225],
       "optimizer":'adam',
       "weight_init":"normal",
       "weight_decay":float(0),
       "lr_decay":False,
       "niter":50,
       "niter_decay":400,
       "lr_policy":None,
       "lr_decay_iters":30,
       "clip_grad_norm":0,
       "cudnn":True,
       "no_tb":False,
       "print_freq":500,
       "save_freq":500,
       "list":[954,2789]
    })
    '''
    parser = argparse.ArgumentParser(description='Lane_detection_all_objectives')
    # Paths settings
    parser.add_argument('--dataset_name', type=str, help='the dataset name to be used in saving model names')
    parser.add_argument('--data_dir', type=str, help='The path saving train.json and val.json files')
    parser.add_argument('--dataset_dir', type=str, help='The path saving actual data')
    parser.add_argument('--save_path', type=str, default='data_splits/', help='directory to save output')
    # Dataset settings
    parser.add_argument('--org_h', type=int, default=1080, help='height of the original image')
    parser.add_argument('--org_w', type=int, default=1920, help='width of the original image')
    parser.add_argument('--crop_y', type=int, default=0, help='crop from image')
    parser.add_argument('--cam_height', type=float, default=1.55, help='height of camera in meters')
    parser.add_argument('--pitch', type=float, default=3, help='pitch angle of camera to ground in centi degree')
    parser.add_argument('--fix_cam', type=str2bool, nargs='?', const=True, default=False, help='if to use fix camera')
    parser.add_argument('--no_3d', action='store_true', help='if a dataset include laneline 3D attributes')
    parser.add_argument('--no_centerline', action='store_true', help='if a dataset include centerline')
    # 3DLaneNet settings
    parser.add_argument('--mod', type=str, default='3DLaneNet', help='model to train')
    parser.add_argument("--pretrained", type=str2bool, nargs='?', const=True, default=True, help="use pretrained vgg model")
    parser.add_argument("--batch_norm", type=str2bool, nargs='?', const=True, default=True, help="apply batch norm")
    parser.add_argument("--pred_cam", type=str2bool, nargs='?', const=True, default=False, help="use network to predict camera online?")
    parser.add_argument('--ipm_h', type=int, default=208, help='height of inverse projective map (IPM)')
    parser.add_argument('--ipm_w', type=int, default=128, help='width of inverse projective map (IPM)')
    parser.add_argument('--resize_h', type=int, default=360, help='height of the original image')
    parser.add_argument('--resize_w', type=int, default=480, help='width of the original image')
    parser.add_argument('--y_ref', type=float, default=20.0, help='the reference Y distance in meters from where lane association is determined')
    parser.add_argument('--prob_th', type=float, default=0.5, help='probability threshold for selecting output lanes')
    # General model settings
    parser.add_argument('--batch_size', type=int, default=8, help='batch size')
    parser.add_argument('--nepochs', type=int, default=30, help='total numbers of epochs')
    parser.add_argument('--learning_rate', type=float, default=5*1e-4, help='learning rate')
    parser.add_argument('--no_cuda', action='store_true', help='if gpu available')
    parser.add_argument('--nworkers', type=int, default=0, help='num of threads')
    parser.add_argument('--no_dropout', action='store_true', help='no dropout in network')
    parser.add_argument('--pretrain_epochs', type=int, default=20, help='Number of epochs to perform segmentation pretraining')
    parser.add_argument('--channels_in', type=int, default=3, help='num channels of input image')
    parser.add_argument('--flip_on', action='store_true', help='Random flip input images on?')
    parser.add_argument('--test_mode', action='store_true', help='prevents loading latest saved model')
    parser.add_argument('--start_epoch', type=int, default=0, help='prevents loading latest saved model')
    parser.add_argument('--evaluate', action='store_true', help='only perform evaluation')
    parser.add_argument('--resume', type=str, default='', help='resume latest saved run')
    parser.add_argument('--vgg_mean', type=float, default=[0.485, 0.456, 0.406], help='Mean of rgb used in pretrained model on ImageNet')
    parser.add_argument('--vgg_std', type=float, default=[0.229, 0.224, 0.225], help='Std of rgb used in pretrained model on ImageNet')
    # Optimizer settings
    parser.add_argument('--optimizer', type=str, default='adam', help='adam or sgd')
    parser.add_argument('--weight_init', type=str, default='normal', help='normal, xavier, kaiming, orhtogonal weights initialisation')
    parser.add_argument('--weight_decay', type=float, default=0, help='L2 weight decay/regularisation on?')
    parser.add_argument('--lr_decay', action='store_true', help='decay learning rate with rule')
    parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate')
    parser.add_argument('--niter_decay', type=int, default=400, help='# of iter to linearly decay learning rate to zero')
    parser.add_argument('--lr_policy', default=None, help='learning rate policy: lambda|step|plateau')
    parser.add_argument('--lr_decay_iters', type=int, default=30, help='multiply by a gamma every lr_decay_iters iterations')
    parser.add_argument('--clip_grad_norm', type=int, default=0, help='performs gradient clipping')
    # CUDNN usage
    parser.add_argument("--cudnn", type=str2bool, nargs='?', const=True, default=True, help="cudnn optimization active")
    # Tensorboard settings
    parser.add_argument("--no_tb", type=str2bool, nargs='?', const=True, default=False, help="Use tensorboard logging by tensorflow")
    # Print settings
    parser.add_argument('--print_freq', type=int, default=500, help='padding')
    parser.add_argument('--save_freq', type=int, default=500, help='padding')
    # Skip batch
    parser.add_argument('--list', type=int, nargs='+', default=[954, 2789], help='Images you want to skip')
'''
    return parser


def tusimple_config(args):

    # set dataset parameters
    args.org_h = 720
    args.org_w = 1280
    args.crop_y = 80
    args.no_centerline = True
    args.no_3d = True
    args.fix_cam = True
    args.pred_cam = False

    # set camera parameters for the test dataset
    args.K = np.array([[1000, 0, 640],
                       [0, 1000, 400],
                       [0, 0, 1]])
    args.cam_height = 1.6
    args.pitch = 9

    # specify model settings
    """
    paper presented params:
        args.top_view_region = np.array([[-10, 85], [10, 85], [-10, 5], [10, 5]])
        args.anchor_y_steps = np.array([5, 20, 40, 60, 80, 100])
    """
    # args.top_view_region = np.array([[-10, 82], [10, 82], [-10, 2], [10, 2]])
    # args.anchor_y_steps = np.array([2, 3, 5, 10, 15, 20, 30, 40, 60, 80])
    args.top_view_region = np.array([[-10, 103], [10, 103], [-10, 3], [10, 3]])
    args.anchor_y_steps = np.array([5, 10, 15, 20, 30, 40, 50, 60, 80, 100])
    args.num_y_steps = len(args.anchor_y_steps)

    # initialize with pre-trained vgg weights
    args.pretrained = False
    # apply batch norm in network
    args.batch_norm = True


def sim3d_config(args):

    # set dataset parameters
    args.org_h = 1080
    args.org_w = 1920
    args.crop_y = 0
    args.no_centerline = False
    args.no_3d = False
    args.fix_cam = False
    args.pred_cam = False

    # set camera parameters for the test datasets
    args.K = np.array([[2015., 0., 960.],
                       [0., 2015., 540.],
                       [0., 0., 1.]])

    # specify model settings
    """
    paper presented params:
        args.top_view_region = np.array([[-10, 85], [10, 85], [-10, 5], [10, 5]])
        args.anchor_y_steps = np.array([5, 20, 40, 60, 80, 100])
    """
    # args.top_view_region = np.array([[-10, 83], [10, 83], [-10, 3], [10, 3]])
    # args.anchor_y_steps = np.array([3, 5, 10, 20, 40, 60, 80, 100])
    args.top_view_region = np.array([[-10, 103], [10, 103], [-10, 3], [10, 3]])
    args.anchor_y_steps = np.array([5, 10, 15, 20, 30, 40, 50, 60, 80, 100])
    args.num_y_steps = len(args.anchor_y_steps)

    # initialize with pre-trained vgg weights
    args.pretrained = False
    # apply batch norm in network
    args.batch_norm = True


class Visualizer:
    def __init__(self, args, vis_folder='val_vis'):
        self.save_path = args.save_path
        self.vis_folder = vis_folder
        self.no_3d = args.no_3d
        self.no_centerline = args.no_centerline
        self.vgg_mean = args.vgg_mean
        self.vgg_std = args.vgg_std
        self.ipm_w = args.ipm_w
        self.ipm_h = args.ipm_h
        self.num_y_steps = args.num_y_steps

        if args.no_3d:
            self.anchor_dim = args.num_y_steps + 1
        else:
            if 'ext' in args.mod:
                self.anchor_dim = 3 * args.num_y_steps + 1
            else:
                self.anchor_dim = 2 * args.num_y_steps + 1

        x_min = args.top_view_region[0, 0]
        x_max = args.top_view_region[1, 0]
        self.anchor_x_steps = np.linspace(x_min, x_max, np.int(args.ipm_w / 8), endpoint=True)
        self.anchor_y_steps = args.anchor_y_steps

        # transformation from ipm to ground region
        H_ipm2g = cv2.getPerspectiveTransform(np.float32([[0, 0],
                                                          [self.ipm_w-1, 0],
                                                          [0, self.ipm_h-1],
                                                          [self.ipm_w-1, self.ipm_h-1]]),
                                              np.float32(args.top_view_region))
        self.H_g2ipm = np.linalg.inv(H_ipm2g)

        # probability threshold for choosing visualize lanes
        self.prob_th = args.prob_th

    def draw_on_img(self, img, lane_anchor, P_g2im, draw_type='laneline', color=[0, 0, 1]):
        """
        :param img: image in numpy array, each pixel in [0, 1] range
        :param lane_anchor: lane anchor in N X C numpy ndarray, dimension in agree with dataloader
        :param P_g2im: projection from ground 3D coordinates to image 2D coordinates
        :param draw_type: 'laneline' or 'centerline' deciding which to draw
        :param color: [r, g, b] color for line,  each range in [0, 1]
        :return:
        """

        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                else:
                    z_3d = lane_anchor[j, self.num_y_steps:self.anchor_dim - 1]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                else:
                    z_3d = lane_anchor[j, self.anchor_dim + self.num_y_steps:2 * self.anchor_dim - 1]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2 * self.anchor_dim:2 * self.anchor_dim + self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                else:
                    z_3d = lane_anchor[j, 2 * self.anchor_dim + self.num_y_steps:3 * self.anchor_dim - 1]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)
        return img

    def draw_on_img_new(self, img, lane_anchor, P_g2im, draw_type='laneline', color=[0, 0, 1]):
        """
        :param img: image in numpy array, each pixel in [0, 1] range
        :param lane_anchor: lane anchor in N X C numpy ndarray, dimension in agree with dataloader
        :param P_g2im: projection from ground 3D coordinates to image 2D coordinates
        :param draw_type: 'laneline' or 'centerline' deciding which to draw
        :param color: [r, g, b] color for line,  each range in [0, 1]
        :return:
        """
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                    visibility = np.ones_like(x_2d)
                else:
                    z_3d = lane_anchor[j, self.num_y_steps:2*self.num_y_steps]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                    visibility = lane_anchor[j, 2 * self.num_y_steps:3 * self.num_y_steps]
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    if visibility[k] > self.prob_th:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)
                    else:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), [0, 0, 0], 2)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                    visibility = np.ones_like(x_2d)
                else:
                    z_3d = lane_anchor[j, self.anchor_dim + self.num_y_steps:self.anchor_dim + 2*self.num_y_steps]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                    visibility = lane_anchor[j, self.anchor_dim + 2*self.num_y_steps:self.anchor_dim + 3*self.num_y_steps]
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    if visibility[k] > self.prob_th:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)
                    else:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), [0, 0, 0], 2)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2*self.anchor_dim:2*self.anchor_dim + self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                    visibility = np.ones_like(x_2d)
                else:
                    z_3d = lane_anchor[j, 2*self.anchor_dim + self.num_y_steps:2*self.anchor_dim + 2*self.num_y_steps]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                    visibility = lane_anchor[j,
                                 2 * self.anchor_dim + 2 * self.num_y_steps:2 * self.anchor_dim + 3 * self.num_y_steps]
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    if visibility[k] > self.prob_th:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)
                    else:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), [0, 0, 0], 2)
        return img

    def draw_on_ipm(self, im_ipm, lane_anchor, draw_type='laneline', color=[0, 0, 1]):
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                      (x_ipm[k], y_ipm[k]), color, 1)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                      (x_ipm[k], y_ipm[k]), color, 1)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2 * self.anchor_dim:2 * self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                      (x_ipm[k], y_ipm[k]), color, 1)
        return im_ipm

    def draw_on_ipm_new(self, im_ipm, lane_anchor, draw_type='laneline', color=[0, 0, 1], width=1):
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    visibility = np.ones_like(x_g)
                else:
                    visibility = lane_anchor[j, 2*self.num_y_steps:3*self.num_y_steps]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    if visibility[k] > self.prob_th:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), color, width)
                    else:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), [0, 0, 0], width)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    visibility = np.ones_like(x_g)
                else:
                    visibility = lane_anchor[j, self.anchor_dim + 2*self.num_y_steps:self.anchor_dim + 3*self.num_y_steps]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    if visibility[k] > self.prob_th:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), color, width)
                    else:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), [0, 0, 0], width)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2*self.anchor_dim:2*self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    visibility = np.ones_like(x_g)
                else:
                    visibility = lane_anchor[j, 2*self.anchor_dim + 2*self.num_y_steps:2*self.anchor_dim + 3*self.num_y_steps]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    if visibility[k] > self.prob_th:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), color, width)
                    else:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), [0, 0, 0], width)
        return im_ipm

    def draw_3d_curves(self, ax, lane_anchor, draw_type='laneline', color=[0, 0, 1]):
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_g)
                else:
                    z_g = lane_anchor[j, self.num_y_steps:2*self.num_y_steps]
                ax.plot(x_g, self.anchor_y_steps, z_g, color=color)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_g)
                else:
                    z_g = lane_anchor[j, self.anchor_dim + self.num_y_steps:self.anchor_dim + 2*self.num_y_steps]
                ax.plot(x_g, self.anchor_y_steps, z_g, color=color)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2*self.anchor_dim:2*self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_g)
                else:
                    z_g = lane_anchor[j, 2*self.anchor_dim + self.num_y_steps:2*self.anchor_dim + 2*self.num_y_steps]
                ax.plot(x_g, self.anchor_y_steps, z_g, color=color)

    def draw_3d_curves_new(self, ax, lane_anchor, h_cam, draw_type='laneline', color=[0, 0, 1]):
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_gflat = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_gflat)
                    visibility = np.ones_like(x_gflat)
                else:
                    z_g = lane_anchor[j, self.num_y_steps:2*self.num_y_steps]
                    visibility = lane_anchor[j, 2*self.num_y_steps:3*self.num_y_steps]
                x_gflat = x_gflat[np.where(visibility > self.prob_th)]
                z_g = z_g[np.where(visibility > self.prob_th)]
                if len(x_gflat) > 0:
                    # transform lane detected in flat ground space to 3d ground space
                    x_g, y_g = transform_lane_gflat2g(h_cam,
                                                      x_gflat,
                                                      self.anchor_y_steps[np.where(visibility > self.prob_th)],
                                                      z_g)
                    ax.plot(x_g, y_g, z_g, color=color)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_gflat = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_gflat)
                    visibility = np.ones_like(x_gflat)
                else:
                    z_g = lane_anchor[j, self.anchor_dim + self.num_y_steps:self.anchor_dim + 2*self.num_y_steps]
                    visibility = lane_anchor[j, self.anchor_dim + 2*self.num_y_steps:self.anchor_dim + 3*self.num_y_steps]
                x_gflat = x_gflat[np.where(visibility > self.prob_th)]
                z_g = z_g[np.where(visibility > self.prob_th)]
                if len(x_gflat) > 0:
                    # transform lane detected in flat ground space to 3d ground space
                    x_g, y_g = transform_lane_gflat2g(h_cam,
                                                      x_gflat,
                                                      self.anchor_y_steps[np.where(visibility > self.prob_th)],
                                                      z_g)
                    ax.plot(x_g, y_g, z_g, color=color)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2*self.anchor_dim:2*self.anchor_dim + self.num_y_steps]
                x_gflat = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_gflat)
                    visibility = np.ones_like(x_gflat)
                else:
                    z_g = lane_anchor[j, 2*self.anchor_dim + self.num_y_steps:2*self.anchor_dim + 2*self.num_y_steps]
                    visibility = lane_anchor[j, 2*self.anchor_dim + 2*self.num_y_steps:2*self.anchor_dim + 3*self.num_y_steps]
                x_gflat = x_gflat[np.where(visibility > self.prob_th)]
                z_g = z_g[np.where(visibility > self.prob_th)]
                if len(x_gflat) > 0:
                    # transform lane detected in flat ground space to 3d ground space
                    x_g, y_g = transform_lane_gflat2g(h_cam,
                                                      x_gflat,
                                                      self.anchor_y_steps[np.where(visibility > self.prob_th)],
                                                      z_g)
                    ax.plot(x_g, y_g, z_g, color=color)

    def save_result(self, dataset, train_or_val, epoch, batch_i, idx, images, gt, pred, pred_cam_pitch, pred_cam_height, aug_mat=np.identity(3, dtype=np.float), evaluate=False):
        if not dataset.data_aug:
            aug_mat = np.repeat(np.expand_dims(aug_mat, axis=0), idx.shape[0], axis=0)

        for i in range(idx.shape[0]):
            # during training, only visualize the first sample of this batch
            if i > 0 and not evaluate:
                break
            im = images.permute(0, 2, 3, 1).data.cpu().numpy()[i]
            # the vgg_std and vgg_mean are for images in [0, 1] range
            im = im * np.array(self.vgg_std)
            im = im + np.array(self.vgg_mean)
            im = np.clip(im, 0, 1)

            gt_anchors = gt[i]
            pred_anchors = pred[i]

            # apply nms to avoid output directly neighbored lanes
            # consider w/o centerline cases
            if self.no_centerline:
                pred_anchors[:, -1] = nms_1d(pred_anchors[:, -1])
            else:
                pred_anchors[:, self.anchor_dim - 1] = nms_1d(pred_anchors[:, self.anchor_dim - 1])
                pred_anchors[:, 2 * self.anchor_dim - 1] = nms_1d(pred_anchors[:, 2 * self.anchor_dim - 1])
                pred_anchors[:, 3 * self.anchor_dim - 1] = nms_1d(pred_anchors[:, 3 * self.anchor_dim - 1])

            H_g2im, P_g2im, H_crop, H_im2ipm = dataset.transform_mats(idx[i])
            if self.no_3d:
                P_gt = np.matmul(H_crop, H_g2im)
                H_g2im_pred = homograpthy_g2im(pred_cam_pitch[i],
                                               pred_cam_height[i], dataset.K)
                P_pred = np.matmul(H_crop, H_g2im_pred)

                # consider data augmentation
                P_gt = np.matmul(aug_mat[i, :, :], P_gt)
                P_pred = np.matmul(aug_mat[i, :, :], P_pred)
            else:
                P_gt = np.matmul(H_crop, P_g2im)
                P_g2im_pred = projection_g2im(pred_cam_pitch[i],
                                              pred_cam_height[i], dataset.K)
                P_pred = np.matmul(H_crop, P_g2im_pred)

                # consider data augmentation
                P_gt = np.matmul(aug_mat[i, :, :], P_gt)
                P_pred = np.matmul(aug_mat[i, :, :], P_pred)

            # update transformation with image augmentation
            H_im2ipm = np.matmul(H_im2ipm, np.linalg.inv(aug_mat[i, :, :]))
            im_ipm = cv2.warpPerspective(im, H_im2ipm, (self.ipm_w, self.ipm_h))
            im_ipm = np.clip(im_ipm, 0, 1)

            # draw lanes on image
            im_laneline = im.copy()
            im_laneline = self.draw_on_img(im_laneline, gt_anchors, P_gt, 'laneline', [0, 0, 1])
            im_laneline = self.draw_on_img(im_laneline, pred_anchors, P_pred, 'laneline', [1, 0, 0])
            if not self.no_centerline:
                im_centerline = im.copy()
                im_centerline = self.draw_on_img(im_centerline, gt_anchors, P_gt, 'centerline', [0, 0, 1])
                im_centerline = self.draw_on_img(im_centerline, pred_anchors, P_pred, 'centerline', [1, 0, 0])

            # draw lanes on ipm
            ipm_laneline = im_ipm.copy()
            ipm_laneline = self.draw_on_ipm(ipm_laneline, gt_anchors, 'laneline', [0, 0, 1])
            ipm_laneline = self.draw_on_ipm(ipm_laneline, pred_anchors, 'laneline', [1, 0, 0])
            if not self.no_centerline:
                ipm_centerline = im_ipm.copy()
                ipm_centerline = self.draw_on_ipm(ipm_centerline, gt_anchors, 'centerline', [0, 0, 1])
                ipm_centerline = self.draw_on_ipm(ipm_centerline, pred_anchors, 'centerline', [1, 0, 0])

            # plot on a single figure
            if self.no_centerline and self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(121)
                ax2 = fig.add_subplot(122)
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
            elif not self.no_centerline and self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(221)
                ax2 = fig.add_subplot(222)
                ax3 = fig.add_subplot(223)
                ax4 = fig.add_subplot(224)
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
                ax3.imshow(im_centerline)
                ax4.imshow(ipm_centerline)
            elif not self.no_centerline and not self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(231)
                ax2 = fig.add_subplot(232)
                ax3 = fig.add_subplot(233, projection='3d')
                ax4 = fig.add_subplot(234)
                ax5 = fig.add_subplot(235)
                ax6 = fig.add_subplot(236, projection='3d')
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
                self.draw_3d_curves(ax3, gt_anchors, 'laneline', [0, 0, 1])
                self.draw_3d_curves(ax3, pred_anchors, 'laneline', [1, 0, 0])
                ax3.set_xlabel('x axis')
                ax3.set_ylabel('y axis')
                ax3.set_zlabel('z axis')
                bottom, top = ax3.get_zlim()
                ax3.set_zlim(min(bottom, -1), max(top, 1))
                ax3.set_xlim(-20, 20)
                ax3.set_ylim(0, 100)
                ax4.imshow(im_centerline)
                ax5.imshow(ipm_centerline)
                self.draw_3d_curves(ax6, gt_anchors, 'centerline', [0, 0, 1])
                self.draw_3d_curves(ax6, pred_anchors, 'centerline', [1, 0, 0])
                ax6.set_xlabel('x axis')
                ax6.set_ylabel('y axis')
                ax6.set_zlabel('z axis')
                bottom, top = ax6.get_zlim()
                ax6.set_zlim(min(bottom, -1), max(top, 1))
                ax6.set_xlim(-20, 20)
                ax6.set_ylim(0, 100)

            if evaluate:
                fig.savefig(self.save_path + '/example/' + self.vis_folder + '/infer_{}'.format(idx[i]))
            else:
                fig.savefig(self.save_path + '/example/{}/epoch-{}_batch-{}_idx-{}'.format(train_or_val,
                                                                                           epoch, batch_i, idx[i]))
            plt.clf()
            plt.close(fig)

    def save_result_new(self, dataset, train_or_val, epoch, batch_i, idx, images, gt, pred, pred_cam_pitch, pred_cam_height, aug_mat=np.identity(3, dtype=np.float), evaluate=False):
        if not dataset.data_aug:
            aug_mat = np.repeat(np.expand_dims(aug_mat, axis=0), idx.shape[0], axis=0)

        for i in range(idx.shape[0]):
            # during training, only visualize the first sample of this batch
            if i > 0 and not evaluate:
                break
            im = images.permute(0, 2, 3, 1).data.cpu().numpy()[i]
            # the vgg_std and vgg_mean are for images in [0, 1] range
            im = im * np.array(self.vgg_std)
            im = im + np.array(self.vgg_mean)
            im = np.clip(im, 0, 1)

            gt_anchors = gt[i]
            pred_anchors = pred[i]

            # apply nms to avoid output directly neighbored lanes
            # consider w/o centerline cases
            if self.no_centerline:
                pred_anchors[:, -1] = nms_1d(pred_anchors[:, -1])
            else:
                pred_anchors[:, self.anchor_dim - 1] = nms_1d(pred_anchors[:, self.anchor_dim - 1])
                pred_anchors[:, 2 * self.anchor_dim - 1] = nms_1d(pred_anchors[:, 2 * self.anchor_dim - 1])
                pred_anchors[:, 3 * self.anchor_dim - 1] = nms_1d(pred_anchors[:, 3 * self.anchor_dim - 1])

            H_g2im, P_g2im, H_crop, H_im2ipm = dataset.transform_mats(idx[i])
            P_gt = np.matmul(H_crop, H_g2im)
            H_g2im_pred = homograpthy_g2im(pred_cam_pitch[i],
                                           pred_cam_height[i], dataset.K)
            P_pred = np.matmul(H_crop, H_g2im_pred)

            # consider data augmentation
            P_gt = np.matmul(aug_mat[i, :, :], P_gt)
            P_pred = np.matmul(aug_mat[i, :, :], P_pred)

            # update transformation with image augmentation
            H_im2ipm = np.matmul(H_im2ipm, np.linalg.inv(aug_mat[i, :, :]))
            im_ipm = cv2.warpPerspective(im, H_im2ipm, (self.ipm_w, self.ipm_h))
            im_ipm = np.clip(im_ipm, 0, 1)

            # draw lanes on image
            im_laneline = im.copy()
            im_laneline = self.draw_on_img_new(im_laneline, gt_anchors, P_gt, 'laneline', [0, 0, 1])
            im_laneline = self.draw_on_img_new(im_laneline, pred_anchors, P_pred, 'laneline', [1, 0, 0])
            if not self.no_centerline:
                im_centerline = im.copy()
                im_centerline = self.draw_on_img_new(im_centerline, gt_anchors, P_gt, 'centerline', [0, 0, 1])
                im_centerline = self.draw_on_img_new(im_centerline, pred_anchors, P_pred, 'centerline', [1, 0, 0])

            # draw lanes on ipm
            ipm_laneline = im_ipm.copy()
            ipm_laneline = self.draw_on_ipm_new(ipm_laneline, gt_anchors, 'laneline', [0, 0, 1])
            ipm_laneline = self.draw_on_ipm_new(ipm_laneline, pred_anchors, 'laneline', [1, 0, 0])
            if not self.no_centerline:
                ipm_centerline = im_ipm.copy()
                ipm_centerline = self.draw_on_ipm_new(ipm_centerline, gt_anchors, 'centerline', [0, 0, 1])
                ipm_centerline = self.draw_on_ipm_new(ipm_centerline, pred_anchors, 'centerline', [1, 0, 0])

            # plot on a single figure
            if self.no_centerline and self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(121)
                ax2 = fig.add_subplot(122)
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
            elif not self.no_centerline and self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(221)
                ax2 = fig.add_subplot(222)
                ax3 = fig.add_subplot(223)
                ax4 = fig.add_subplot(224)
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
                ax3.imshow(im_centerline)
                ax4.imshow(ipm_centerline)
            elif not self.no_centerline and not self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(231)
                ax2 = fig.add_subplot(232)
                ax3 = fig.add_subplot(233, projection='3d')
                ax4 = fig.add_subplot(234)
                ax5 = fig.add_subplot(235)
                ax6 = fig.add_subplot(236, projection='3d')
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
                # TODO:use separate gt_cam_height when ready
                self.draw_3d_curves_new(ax3, gt_anchors, pred_cam_height[i], 'laneline', [0, 0, 1])
                self.draw_3d_curves_new(ax3, pred_anchors, pred_cam_height[i], 'laneline', [1, 0, 0])
                ax3.set_xlabel('x axis')
                ax3.set_ylabel('y axis')
                ax3.set_zlabel('z axis')
                bottom, top = ax3.get_zlim()
                ax3.set_xlim(-20, 20)
                ax3.set_ylim(0, 100)
                ax3.set_zlim(min(bottom, -1), max(top, 1))
                ax4.imshow(im_centerline)
                ax5.imshow(ipm_centerline)
                # TODO:use separate gt_cam_height when ready
                self.draw_3d_curves_new(ax6, gt_anchors, pred_cam_height[i], 'centerline', [0, 0, 1])
                self.draw_3d_curves_new(ax6, pred_anchors, pred_cam_height[i], 'centerline', [1, 0, 0])
                ax6.set_xlabel('x axis')
                ax6.set_ylabel('y axis')
                ax6.set_zlabel('z axis')
                bottom, top = ax6.get_zlim()
                ax6.set_xlim(-20, 20)
                ax6.set_ylim(0, 100)
                ax6.set_zlim(min(bottom, -1), max(top, 1))

            if evaluate:
                fig.savefig(self.save_path + '/example/' + self.vis_folder + '/infer_{}'.format(idx[i]))
            else:
                fig.savefig(self.save_path + '/example/{}/epoch-{}_batch-{}_idx-{}'.format(train_or_val,
                                                                                           epoch, batch_i, idx[i]))
            plt.clf()
            plt.close(fig)


def prune_3d_lane_by_visibility(lane_3d, visibility):
    lane_3d = lane_3d[visibility > 0, ...]
    return lane_3d


def prune_3d_lane_by_range(lane_3d, x_min, x_max):
    # TODO: solve hard coded range later
    # remove points with y out of range
    # 3D label may miss super long straight-line with only two points: Not have to be 200, gt need a min-step
    # 2D dataset requires this to rule out those points projected to ground, but out of meaningful range
    lane_3d = lane_3d[np.logical_and(lane_3d[:, 1] > 0, lane_3d[:, 1] < 200), ...]

    # remove lane points out of x range
    lane_3d = lane_3d[np.logical_and(lane_3d[:, 0] > x_min,
                                     lane_3d[:, 0] < x_max), ...]
    return lane_3d


def resample_laneline_in_y(input_lane, y_steps, out_vis=False):
    """
        Interpolate x, z values at each anchor grid, including those beyond the range of input lnae y range
    :param input_lane: N x 2 or N x 3 ndarray, one row for a point (x, y, z-optional).
                       It requires y values of input lane in ascending order
    :param y_steps: a vector of steps in y
    :param out_vis: whether to output visibility indicator which only depends on input y range
    :return:
    """

    # at least two points are included
    assert(input_lane.shape[0] >= 2)

    y_min = np.min(input_lane[:, 1])-5
    y_max = np.max(input_lane[:, 1])+5

    if input_lane.shape[1] < 3:
        input_lane = np.concatenate([input_lane, np.zeros([input_lane.shape[0], 1], dtype=np.float32)], axis=1)

    f_x = interp1d(input_lane[:, 1], input_lane[:, 0], fill_value="extrapolate")
    f_z = interp1d(input_lane[:, 1], input_lane[:, 2], fill_value="extrapolate")

    x_values = f_x(y_steps)
    z_values = f_z(y_steps)

    if out_vis:
        output_visibility = np.logical_and(y_steps >= y_min, y_steps <= y_max)
        return x_values, z_values, output_visibility.astype(np.float32) + 1e-9
    return x_values, z_values


def resample_laneline_in_y_with_vis(input_lane, y_steps, vis_vec):
    """
        Interpolate x, z values at each anchor grid, including those beyond the range of input lnae y range
    :param input_lane: N x 2 or N x 3 ndarray, one row for a point (x, y, z-optional).
                       It requires y values of input lane in ascending order
    :param y_steps: a vector of steps in y
    :param out_vis: whether to output visibility indicator which only depends on input y range
    :return:
    """

    # at least two points are included
    assert(input_lane.shape[0] >= 2)

    if input_lane.shape[1] < 3:
        input_lane = np.concatenate([input_lane, np.zeros([input_lane.shape[0], 1], dtype=np.float32)], axis=1)

    f_x = interp1d(input_lane[:, 1], input_lane[:, 0], fill_value="extrapolate")
    f_z = interp1d(input_lane[:, 1], input_lane[:, 2], fill_value="extrapolate")
    f_vis = interp1d(input_lane[:, 1], vis_vec, fill_value="extrapolate")

    x_values = f_x(y_steps)
    z_values = f_z(y_steps)
    vis_values = f_vis(y_steps)

    x_values = x_values[vis_values > 0.5]
    y_values = y_steps[vis_values > 0.5]
    z_values = z_values[vis_values > 0.5]
    return np.array([x_values, y_values, z_values]).T


def homography_im2ipm_norm(top_view_region, org_img_size, crop_y, resize_img_size, cam_pitch, cam_height, K):
    """
        Compute the normalized transformation such that image region are mapped to top_view region maps to
        the top view image's 4 corners
        Ground coordinates: x-right, y-forward, z-up
        The purpose of applying normalized transformation: 1. invariance in scale change
                                                           2.Torch grid sample is based on normalized grids
    :param top_view_region: a 4 X 2 list of (X, Y) indicating the top-view region corners in order:
                            top-left, top-right, bottom-left, bottom-right
    :param org_img_size: the size of original image size: [h, w]
    :param crop_y: pixels croped from original img
    :param resize_img_size: the size of image as network input: [h, w]
    :param cam_pitch: camera pitch angle wrt ground plane
    :param cam_height: camera height wrt ground plane in meters
    :param K: camera intrinsic parameters
    :return: H_im2ipm_norm: the normalized transformation from image to IPM image
    """

    # compute homography transformation from ground to image (only this depends on cam_pitch and cam height)
    H_g2im = homograpthy_g2im(cam_pitch, cam_height, K)
    # transform original image region to network input region
    H_c = homography_crop_resize(org_img_size, crop_y, resize_img_size)
    H_g2im = np.matmul(H_c, H_g2im)

    # compute top-view corners' coordinates in image
    x_2d, y_2d = homographic_transformation(H_g2im, top_view_region[:, 0], top_view_region[:, 1])
    border_im = np.concatenate([x_2d.reshape(-1, 1), y_2d.reshape(-1, 1)], axis=1)

    # compute the normalized transformation
    border_im[:, 0] = border_im[:, 0] / resize_img_size[1]
    border_im[:, 1] = border_im[:, 1] / resize_img_size[0]
    border_im = np.float32(border_im)
    dst = np.float32([[0, 0], [1, 0], [0, 1], [1, 1]])
    # img to ipm
    H_im2ipm_norm = cv2.getPerspectiveTransform(border_im, dst)
    # ipm to im
    H_ipm2im_norm = cv2.getPerspectiveTransform(dst, border_im)
    return H_im2ipm_norm, H_ipm2im_norm


def homography_ipmnorm2g(top_view_region):
    src = np.float32([[0, 0], [1, 0], [0, 1], [1, 1]])
    H_ipmnorm2g = cv2.getPerspectiveTransform(src, np.float32(top_view_region))
    return H_ipmnorm2g


def homograpthy_g2im(cam_pitch, cam_height, K):
    # transform top-view region to original image region
    R_g2c = np.array([[1, 0, 0],
                      [0, np.cos(np.pi / 2 + cam_pitch), -np.sin(np.pi / 2 + cam_pitch)],
                      [0, np.sin(np.pi / 2 + cam_pitch), np.cos(np.pi / 2 + cam_pitch)]])
    H_g2im = np.matmul(K, np.concatenate([R_g2c[:, 0:2], [[0], [cam_height], [0]]], 1))
    return H_g2im


def projection_g2im(cam_pitch, cam_height, K):
    P_g2c = np.array([[1,                             0,                              0,          0],
                      [0, np.cos(np.pi / 2 + cam_pitch), -np.sin(np.pi / 2 + cam_pitch), cam_height],
                      [0, np.sin(np.pi / 2 + cam_pitch),  np.cos(np.pi / 2 + cam_pitch),          0]])
    P_g2im = np.matmul(K, P_g2c)
    return P_g2im


def homography_crop_resize(org_img_size, crop_y, resize_img_size):
    """
        compute the homography matrix transform original image to cropped and resized image
    :param org_img_size: [org_h, org_w]
    :param crop_y:
    :param resize_img_size: [resize_h, resize_w]
    :return:
    """
    # transform original image region to network input region
    ratio_x = resize_img_size[1] / org_img_size[1]
    ratio_y = resize_img_size[0] / (org_img_size[0] - crop_y)
    H_c = np.array([[ratio_x, 0, 0],
                    [0, ratio_y, -ratio_y*crop_y],
                    [0, 0, 1]])
    return H_c


def homographic_transformation(Matrix, x, y):
    """
    Helper function to transform coordinates defined by transformation matrix
    Args:
            Matrix (multi dim - array): 3x3 homography matrix
            x (array): original x coordinates
            y (array): original y coordinates
    """
    ones = np.ones((1, len(y)))
    coordinates = np.vstack((x, y, ones))
    trans = np.matmul(Matrix, coordinates)

    x_vals = trans[0, :]/trans[2, :]
    y_vals = trans[1, :]/trans[2, :]
    return x_vals, y_vals


def projective_transformation(Matrix, x, y, z):
    """
    Helper function to transform coordinates defined by transformation matrix
    Args:
            Matrix (multi dim - array): 3x4 projection matrix
            x (array): original x coordinates
            y (array): original y coordinates
            z (array): original z coordinates
    """
    ones = np.ones((1, len(z)))
    coordinates = np.vstack((x, y, z, ones))
    trans = np.matmul(Matrix, coordinates)

    x_vals = trans[0, :]/trans[2, :]
    y_vals = trans[1, :]/trans[2, :]
    return x_vals, y_vals


def transform_lane_gflat2g(h_cam, X_gflat, Y_gflat, Z_g):
    """
        Given X coordinates in flat ground space, Y coordinates in flat ground space, and Z coordinates in real 3D ground space
        with projection matrix from 3D ground to flat ground, compute real 3D coordinates X, Y in 3D ground space.
    :param P_g2gflat: a 3 X 4 matrix transforms lane form 3d ground x,y,z to flat ground x, y
    :param X_gflat: X coordinates in flat ground space
    :param Y_gflat: Y coordinates in flat ground space
    :param Z_g: Z coordinates in real 3D ground space
    :return:
    """

    X_g = X_gflat - X_gflat * Z_g / h_cam
    Y_g = Y_gflat - Y_gflat * Z_g / h_cam

    return X_g, Y_g


def transform_lane_g2gflat(h_cam, X_g, Y_g, Z_g):
    """
        Given X coordinates in flat ground space, Y coordinates in flat ground space, and Z coordinates in real 3D ground space
        with projection matrix from 3D ground to flat ground, compute real 3D coordinates X, Y in 3D ground space.
    :param P_g2gflat: a 3 X 4 matrix transforms lane form 3d ground x,y,z to flat ground x, y
    :param X_gflat: X coordinates in flat ground space
    :param Y_gflat: Y coordinates in flat ground space
    :param Z_g: Z coordinates in real 3D ground space
    :return:
    """

    X_gflat = X_g * h_cam / (h_cam - Z_g)
    Y_gflat = Y_g * h_cam / (h_cam - Z_g)

    return X_gflat, Y_gflat


def nms_1d(v):
    """
    :param v: a 1D numpy array
    :return:
    """
    v_out = v.copy()
    len = v.shape[0]
    if len < 2:
        return v
    for i in range(len):
        if i is not 0 and v[i - 1] > v[i]:
            v_out[i] = 0.
        elif i is not len-1 and v[i+1] > v[i]:
            v_out[i] = 0.
    return v_out


def first_run(save_path):
    txt_file = os.path.join(save_path,'first_run.txt')
    if not os.path.exists(txt_file):
        open(txt_file, 'w').close()
    else:
        saved_epoch = open(txt_file).read()
        if saved_epoch is None:
            print('You forgot to delete [first run file]')
            return '' 
        return saved_epoch
    return ''


def mkdir_if_missing(directory):
    if not os.path.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise


# trick from stackoverflow
def str2bool(argument):
    if argument.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif argument.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Wrong argument in argparse, should be a boolean')


class Logger(object):
    """
    Source https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
    """
    def __init__(self, fpath=None):
        self.console = sys.stdout
        self.file = None
        self.fpath = fpath
        if fpath is not None:
            mkdir_if_missing(os.path.dirname(fpath))
            self.file = open(fpath, 'w')

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, *args):
        self.close()

    def write(self, msg):
        self.console.write(msg)
        if self.file is not None:
            self.file.write(msg)

    def flush(self):
        self.console.flush()
        if self.file is not None:
            self.file.flush()
            os.fsync(self.file.fileno())

    def close(self):
        self.console.close()
        if self.file is not None:
            self.file.close()


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def define_optim(optim, params, lr, weight_decay):
    if optim == 'adam':
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    elif optim == 'sgd':
        optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
    elif optim == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
    else:
        raise KeyError("The requested optimizer: {} is not implemented".format(optim))
    return optimizer


def define_scheduler(optimizer, args):
    if args.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 - args.niter) / float(args.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif args.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=args.lr_decay_iters, gamma=args.gamma)
    elif args.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                   factor=args.gamma,
                                                   threshold=0.0001,
                                                   patience=args.lr_decay_iters)
    elif args.lr_policy == 'none':
        scheduler = None
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy)
    return scheduler


def define_init_weights(model, init_w='normal', activation='relu'):
    print('Init weights in network with [{}]'.format(init_w))
    if init_w == 'normal':
        model.apply(weights_init_normal)
    elif init_w == 'xavier':
        model.apply(weights_init_xavier)
    elif init_w == 'kaiming':
        model.apply(weights_init_kaiming)
    elif init_w == 'orthogonal':
        model.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError('initialization method [{}] is not implemented'.format(init_w))


def weights_init_normal(m):
    classname = m.__class__.__name__
#    print(classname)
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_xavier(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        init.xavier_normal_(m.weight.data, gain=0.02)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.xavier_normal_(m.weight.data, gain=0.02)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
#    print(classname)
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        init.orthogonal(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.orthogonal(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)

In [None]:
"""
MinCostFow solver adapted for matching two set of contours. The implementation is based on google-ortools.
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

from __future__ import print_function
import numpy as np
from ortools.graph import pywrapgraph
import time


def SolveMinCostFlow(adj_mat, cost_mat):
    """
        Solving an Assignment Problem with MinCostFlow"
    :param adj_mat: adjacency matrix with binary values indicating possible matchings between two sets
    :param cost_mat: cost matrix recording the matching cost of every possible pair of items from two sets
    :return:
    """

    # Instantiate a SimpleMinCostFlow solver.
    min_cost_flow = pywrapgraph.SimpleMinCostFlow()
    # Define the directed graph for the flow.

    cnt_1, cnt_2 = adj_mat.shape
    cnt_nonzero_row = int(np.sum(np.sum(adj_mat, axis=1) > 0))
    cnt_nonzero_col = int(np.sum(np.sum(adj_mat, axis=0) > 0))

    # prepare directed graph for the flow
    start_nodes = np.zeros(cnt_1, dtype=np.int).tolist() +\
                  np.repeat(np.array(range(1, cnt_1+1)), cnt_2).tolist() + \
                  [i for i in range(cnt_1+1, cnt_1 + cnt_2 + 1)]
    end_nodes = [i for i in range(1, cnt_1+1)] + \
                np.repeat(np.array([i for i in range(cnt_1+1, cnt_1 + cnt_2 + 1)]).reshape([1, -1]), cnt_1, axis=0).flatten().tolist() + \
                [cnt_1 + cnt_2 + 1 for i in range(cnt_2)]
    capacities = np.ones(cnt_1, dtype=np.int).tolist() + adj_mat.flatten().astype(np.int).tolist() + np.ones(cnt_2, dtype=np.int).tolist()
    costs = (np.zeros(cnt_1, dtype=np.int).tolist() + cost_mat.flatten().astype(np.int).tolist() + np.zeros(cnt_2, dtype=np.int).tolist())
    # Define an array of supplies at each node.
    supplies = [min(cnt_nonzero_row, cnt_nonzero_col)] + np.zeros(cnt_1 + cnt_2, dtype=np.int).tolist() + [-min(cnt_nonzero_row, cnt_nonzero_col)]
    # supplies = [min(cnt_1, cnt_2)] + np.zeros(cnt_1 + cnt_2, dtype=np.int).tolist() + [-min(cnt_1, cnt_2)]
    source = 0
    sink = cnt_1 + cnt_2 + 1

    # Add each arc.
    for i in range(len(start_nodes)):
        min_cost_flow.AddArcWithCapacityAndUnitCost(start_nodes[i], end_nodes[i],
                                                    capacities[i], costs[i])

    # Add node supplies.
    for i in range(len(supplies)):
        min_cost_flow.SetNodeSupply(i, supplies[i])

    match_results = []
    # Find the minimum cost flow between node 0 and node 10.
    if min_cost_flow.Solve() == min_cost_flow.OPTIMAL:
        # print('Total cost = ', min_cost_flow.OptimalCost())
        # print()
        for arc in range(min_cost_flow.NumArcs()):

            # Can ignore arcs leading out of source or into sink.
            if min_cost_flow.Tail(arc)!=source and min_cost_flow.Head(arc)!=sink:

                # Arcs in the solution have a flow value of 1. Their start and end nodes
                # give an assignment of worker to task.

                if min_cost_flow.Flow(arc) > 0:
                    # print('set A item %d assigned to set B item %d.  Cost = %d' % (
                    #     min_cost_flow.Tail(arc)-1,
                    #     min_cost_flow.Head(arc)-cnt_1-1,
                    #     min_cost_flow.UnitCost(arc)))
                    match_results.append([min_cost_flow.Tail(arc)-1,
                                          min_cost_flow.Head(arc)-cnt_1-1,
                                          min_cost_flow.UnitCost(arc)])
    else:
        print('There was an issue with the min cost flow input.')

    return match_results


def main():
    """Solving an Assignment Problem with MinCostFlow"""

    # Instantiate a SimpleMinCostFlow solver.
    min_cost_flow = pywrapgraph.SimpleMinCostFlow()
    # Define the directed graph for the flow.

    start_nodes = [0, 0, 0, 0] + [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4] + [5, 6, 7, 8]
    end_nodes = [1, 2, 3, 4] + [5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8] + [9, 9, 9, 9]
    capacities = [1, 1, 1, 1] + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + [1, 1, 1, 1]
    costs = ([0, 0, 0, 0] + [90, 76, 75, 70, 35, 85, 55, 65, 125, 95, 90, 105, 45, 110, 95, 115] + [0, 0, 0, 0])
    # Define an array of supplies at each node.
    supplies = [4, 0, 0, 0, 0, 0, 0, 0, 0, -4]
    source = 0
    sink = 9
    tasks = 4

    # Add each arc.
    for i in range(len(start_nodes)):
        min_cost_flow.AddArcWithCapacityAndUnitCost(start_nodes[i], end_nodes[i],
                                                    capacities[i], costs[i])

    # Add node supplies.

    for i in range(len(supplies)):
        min_cost_flow.SetNodeSupply(i, supplies[i])
    # Find the minimum cost flow between node 0 and node 10.
    if min_cost_flow.Solve() == min_cost_flow.OPTIMAL:
        print('Total cost = ', min_cost_flow.OptimalCost())
        print()
        for arc in range(min_cost_flow.NumArcs()):

            # Can ignore arcs leading out of source or into sink.
            if min_cost_flow.Tail(arc)!=source and min_cost_flow.Head(arc)!=sink:

                # Arcs in the solution have a flow value of 1. Their start and end nodes
                # give an assignment of worker to task.

                if min_cost_flow.Flow(arc) > 0:
                    print('Worker %d assigned to task %d.  Cost = %d' % (
                        min_cost_flow.Tail(arc),
                        min_cost_flow.Head(arc),
                        min_cost_flow.UnitCost(arc)))
    else:
        print('There was an issue with the min cost flow input.')

'''
if __name__ == '__main__':
    start_time = time.clock()
    main()
    print()
    print("Time =", time.clock() - start_time, "seconds")
    '''

In [None]:
"""
Description: This code is to evaluate 3D lane detection. The optimal matching between ground-truth set and predicted
set of lanes are sought via solving a min cost flow.
Evaluation metrics includes:
    Average Precision (AP)
    Max F-scores
    x error close (0 - 40 m)
    x error far (0 - 100 m)
    z error close (0 - 40 m)
    z error far (0 - 100 m)
Reference: "Gen-LaneNet: Generalized and Scalable Approach for 3D Lane Detection". Y. Guo. etal. 2020
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import numpy as np
import cv2
import os
import os.path as ops
import copy
import math
import ujson as json
from scipy.interpolate import interp1d
import matplotlib
#from tools.utils import *
#from tools.MinCostFlow import SolveMinCostFlow
from mpl_toolkits.mplot3d import Axes3D

matplotlib.use('Agg')
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (35, 30)
plt.rcParams.update({'font.size': 25})
plt.rcParams.update({'font.weight': 'semibold'})

color = [[0, 0, 255],  # red
         [0, 255, 0],  # green
         [255, 0, 255],  # purple
         [255, 255, 0]]  # cyan

vis_min_y = 5
vis_max_y = 80


class LaneEval(object):
    def __init__(self, args):
        self.dataset_dir = args.dataset_dir
        self.K = args.K
        self.no_centerline = args.no_centerline
        self.resize_h = args.resize_h
        self.resize_w = args.resize_w
        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [args.resize_h, args.resize_w])

        self.x_min = args.top_view_region[0, 0]
        self.x_max = args.top_view_region[1, 0]
        self.y_min = args.top_view_region[2, 1]
        self.y_max = args.top_view_region[0, 1]
        self.y_samples = np.linspace(self.y_min, self.y_max, num=100, endpoint=False)
        # self.y_samples = np.linspace(min_y, max_y, num=100, endpoint=False)
        self.dist_th = 1.5
        self.ratio_th = 0.75
        self.close_range = 40

    def bench(self, pred_lanes, gt_lanes, gt_visibility, raw_file, gt_cam_height, gt_cam_pitch, vis, ax1, ax2):
        """
            Matching predicted lanes and ground-truth lanes in their IPM projection, ignoring z attributes.
            x error, y_error, and z error are all considered, although the matching does not rely on z
            The input of prediction and ground-truth lanes are in ground coordinate, x-right, y-forward, z-up
            The fundamental assumption is: 1. there are no two points from different lanes with identical x, y
                                              but different z's
                                           2. there are no two points from a single lane having identical x, y
                                              but different z's
            If the interest area is within the current drivable road, the above assumptions are almost always valid.
        :param pred_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :param gt_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :param raw_file: file path rooted in dataset folder
        :param gt_cam_height: camera height given in ground-truth data
        :param gt_cam_pitch: camera pitch given in ground-truth data
        :return:
        """

        # change this properly
        close_range_idx = np.where(self.y_samples > self.close_range)[0][0]

        r_lane, p_lane = 0., 0.
        x_error_close = []
        x_error_far = []
        z_error_close = []
        z_error_far = []

        # only keep the visible portion
        gt_lanes = [prune_3d_lane_by_visibility(np.array(gt_lane), np.array(gt_visibility[k])) for k, gt_lane in
                    enumerate(gt_lanes)]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        # only consider those gt lanes overlapping with sampling range
        gt_lanes = [lane for lane in gt_lanes if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]
        gt_lanes = [prune_3d_lane_by_range(np.array(gt_lane), 3 * self.x_min, 3 * self.x_max) for gt_lane in gt_lanes]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        cnt_gt = len(gt_lanes)
        cnt_pred = len(pred_lanes)

        gt_visibility_mat = np.zeros((cnt_gt, 100))
        pred_visibility_mat = np.zeros((cnt_pred, 100))
        # resample gt and pred at y_samples
        for i in range(cnt_gt):
            min_y = np.min(np.array(gt_lanes[i])[:, 1])
            max_y = np.max(np.array(gt_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(gt_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            gt_lanes[i] = np.vstack([x_values, z_values]).T
            gt_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                     np.logical_and(x_values <= self.x_max,
                                                                    np.logical_and(self.y_samples >= min_y,
                                                                                   self.y_samples <= max_y)))
            gt_visibility_mat[i, :] = np.logical_and(gt_visibility_mat[i, :], visibility_vec)

        for i in range(cnt_pred):
            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size
            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))
            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)
            min_y = np.min(np.array(pred_lanes[i])[:, 1])
            max_y = np.max(np.array(pred_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            pred_lanes[i] = np.vstack([x_values, z_values]).T
            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                       np.logical_and(x_values <= self.x_max,
                                                                      np.logical_and(self.y_samples >= min_y,
                                                                                     self.y_samples <= max_y)))
            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)
            # pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, x_values <= self.x_max)

        adj_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat.fill(1000)
        num_match_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        x_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        x_dist_mat_close.fill(1000.)
        x_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        x_dist_mat_far.fill(1000.)
        z_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        z_dist_mat_close.fill(1000.)
        z_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        z_dist_mat_far.fill(1000.)
        # compute curve to curve distance
        for i in range(cnt_gt):
            for j in range(cnt_pred):
                x_dist = np.abs(gt_lanes[i][:, 0] - pred_lanes[j][:, 0])
                z_dist = np.abs(gt_lanes[i][:, 1] - pred_lanes[j][:, 1])
                euclidean_dist = np.sqrt(x_dist ** 2 + z_dist ** 2)

                # apply visibility to penalize different partial matching accordingly
                euclidean_dist[
                    np.logical_or(gt_visibility_mat[i, :] < 0.5, pred_visibility_mat[j, :] < 0.5)] = self.dist_th

                # if np.average(euclidean_dist) < 2*self.dist_th: # don't prune here to encourage finding perfect match
                num_match_mat[i, j] = np.sum(euclidean_dist < self.dist_th)
                adj_mat[i, j] = 1
                # ATTENTION: use the sum as int type to meet the requirements of min cost flow optimization (int type)
                # using num_match_mat as cost does not work?
                cost_mat[i, j] = np.sum(euclidean_dist).astype(np.int)
                # cost_mat[i, j] = num_match_mat[i, j]

                # use the both visible portion to calculate distance error
                both_visible_indices = np.logical_and(gt_visibility_mat[i, :] > 0.5, pred_visibility_mat[j, :] > 0.5)
                if np.sum(both_visible_indices[:close_range_idx]) > 0:
                    x_dist_mat_close[i, j] = np.sum(
                        x_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(
                        both_visible_indices[:close_range_idx])
                    z_dist_mat_close[i, j] = np.sum(
                        z_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(
                        both_visible_indices[:close_range_idx])
                else:
                    x_dist_mat_close[i, j] = self.dist_th
                    z_dist_mat_close[i, j] = self.dist_th

                if np.sum(both_visible_indices[close_range_idx:]) > 0:
                    x_dist_mat_far[i, j] = np.sum(
                        x_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(
                        both_visible_indices[close_range_idx:])
                    z_dist_mat_far[i, j] = np.sum(
                        z_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(
                        both_visible_indices[close_range_idx:])
                else:
                    x_dist_mat_far[i, j] = self.dist_th
                    z_dist_mat_far[i, j] = self.dist_th

        # solve bipartite matching vis min cost flow solver
        match_results = SolveMinCostFlow(adj_mat, cost_mat)
        match_results = np.array(match_results)

        # only a match with avg cost < self.dist_th is consider valid one
        match_gt_ids = []
        match_pred_ids = []
        if match_results.shape[0] > 0:
            for i in range(len(match_results)):
                if match_results[i, 2] < self.dist_th * self.y_samples.shape[0]:
                    gt_i = match_results[i, 0]
                    pred_i = match_results[i, 1]
                    # consider match when the matched points is above a ratio
                    if num_match_mat[gt_i, pred_i] / np.sum(gt_visibility_mat[gt_i, :]) >= self.ratio_th:
                        r_lane += 1
                        match_gt_ids.append(gt_i)
                    if num_match_mat[gt_i, pred_i] / np.sum(pred_visibility_mat[pred_i, :]) >= self.ratio_th:
                        p_lane += 1
                        match_pred_ids.append(pred_i)
                    x_error_close.append(x_dist_mat_close[gt_i, pred_i])
                    x_error_far.append(x_dist_mat_far[gt_i, pred_i])
                    z_error_close.append(z_dist_mat_close[gt_i, pred_i])
                    z_error_far.append(z_dist_mat_far[gt_i, pred_i])

        # visualize lanelines and matching results both in image and 3D
        if vis:
            P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, self.K)
            P_gt = np.matmul(self.H_crop, P_g2im)
            img = cv2.imread(ops.join(self.dataset_dir, raw_file))
            img = cv2.warpPerspective(img, self.H_crop, (self.resize_w, self.resize_h))
            img = img.astype(np.float) / 255

            for i in range(cnt_gt):
                x_values = gt_lanes[i][:, 0]
                z_values = gt_lanes[i][:, 1]
                x_2d, y_2d = projective_transformation(P_gt, x_values, self.y_samples, z_values)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)

                if i in match_gt_ids:
                    color = [0, 0, 1]
                else:
                    color = [0, 1, 1]
                for k in range(1, x_2d.shape[0]):
                    # only draw the visible portion
                    if gt_visibility_mat[i, k - 1] and gt_visibility_mat[i, k]:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color[-1::-1], 3)
                ax2.plot(x_values[np.where(gt_visibility_mat[i, :])],
                         self.y_samples[np.where(gt_visibility_mat[i, :])],
                         z_values[np.where(gt_visibility_mat[i, :])], color=color, linewidth=5)

            for i in range(cnt_pred):
                x_values = pred_lanes[i][:, 0]
                z_values = pred_lanes[i][:, 1]
                x_2d, y_2d = projective_transformation(P_gt, x_values, self.y_samples, z_values)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)

                if i in match_pred_ids:
                    color = [1, 0, 0]
                else:
                    color = [1, 0, 1]
                for k in range(1, x_2d.shape[0]):
                    # only draw the visible portion
                    if pred_visibility_mat[i, k - 1] and pred_visibility_mat[i, k]:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color[-1::-1], 2)
                ax2.plot(x_values[np.where(pred_visibility_mat[i, :])],
                         self.y_samples[np.where(pred_visibility_mat[i, :])],
                         z_values[np.where(pred_visibility_mat[i, :])], color=color, linewidth=5)

            cv2.putText(img, 'Recall: {:.3f}'.format(r_lane / (cnt_gt + 1e-6)),
                        (5, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)
            cv2.putText(img, 'Precision: {:.3f}'.format(p_lane / (cnt_pred + 1e-6)),
                        (5, 60), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)
            ax1.imshow(img[:, :, [2, 1, 0]])

        return r_lane, p_lane, cnt_gt, cnt_pred, x_error_close, x_error_far, z_error_close, z_error_far

    # compare predicted set and ground-truth set using a fixed lane probability threshold
    def bench_one_submit(self, pred_file, gt_file, prob_th=0.5, vis=False):
        if vis:
            save_path = pred_file[:pred_file.rfind('/')]
            save_path += '/vis'
            if vis and not os.path.exists(save_path):
                try:
                    os.makedirs(save_path)
                except OSError as e:
                    print(e.message)
        # try:
        pred_lines = open(pred_file).readlines()
        json_pred = [json.loads(line) for line in pred_lines]
        # except BaseException as e:
        #     raise Exception('Fail to load json file of the prediction.')
        json_gt = [json.loads(line) for line in open(gt_file).readlines()]
        if len(json_gt) != len(json_pred):
            raise Exception('We do not get the predictions of all the test tasks')
        gts = {l['raw_file']: l for l in json_gt}

        laneline_stats = []
        laneline_x_error_close = []
        laneline_x_error_far = []
        laneline_z_error_close = []
        laneline_z_error_far = []
        centerline_stats = []
        centerline_x_error_close = []
        centerline_x_error_far = []
        centerline_z_error_close = []
        centerline_z_error_far = []
        for i, pred in enumerate(json_pred):
            if 'raw_file' not in pred or 'laneLines' not in pred:
                raise Exception('raw_file or lanelines not in some predictions.')
            raw_file = pred['raw_file']

            # if raw_file != 'images/05/0000347.jpg':
            #     continue
            pred_lanelines = pred['laneLines']
            pred_laneLines_prob = pred['laneLines_prob']
            pred_lanelines = [pred_lanelines[ii] for ii in range(len(pred_laneLines_prob)) if
                              pred_laneLines_prob[ii] > prob_th]

            if raw_file not in gts:
                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')
            gt = gts[raw_file]
            gt_cam_height = gt['cam_height']
            gt_cam_pitch = gt['cam_pitch']

            if vis:
                fig = plt.figure()
                ax1 = fig.add_subplot(221)
                ax2 = fig.add_subplot(222, projection='3d')
                ax3 = fig.add_subplot(223)
                ax4 = fig.add_subplot(224, projection='3d')
            else:
                ax1 = 0
                ax2 = 0
                ax3 = 0
                ax4 = 0

            # evaluate lanelines
            gt_lanelines = gt['laneLines']
            gt_visibility = gt['laneLines_visibility']
            # N to N matching of lanelines
            r_lane, p_lane, cnt_gt, cnt_pred, \
            x_error_close, x_error_far, \
            z_error_close, z_error_far = self.bench(pred_lanelines,
                                                    gt_lanelines,
                                                    gt_visibility,
                                                    raw_file,
                                                    gt_cam_height,
                                                    gt_cam_pitch,
                                                    vis, ax1, ax2)
            laneline_stats.append(np.array([r_lane, p_lane, cnt_gt, cnt_pred]))
            # consider x_error z_error only for the matched lanes
            # if r_lane > 0 and p_lane > 0:
            laneline_x_error_close.extend(x_error_close)
            laneline_x_error_far.extend(x_error_far)
            laneline_z_error_close.extend(z_error_close)
            laneline_z_error_far.extend(z_error_far)

            # evaluate centerlines
            if not self.no_centerline:
                pred_centerlines = pred['centerLines']
                pred_centerlines_prob = pred['centerLines_prob']
                pred_centerlines = [pred_centerlines[ii] for ii in range(len(pred_centerlines_prob)) if
                                    pred_centerlines_prob[ii] > prob_th]

                gt_centerlines = gt['centerLines']
                gt_visibility = gt['centerLines_visibility']

                # N to N matching of lanelines
                r_lane, p_lane, cnt_gt, cnt_pred, \
                x_error_close, x_error_far, \
                z_error_close, z_error_far = self.bench(pred_centerlines,
                                                        gt_centerlines,
                                                        gt_visibility,
                                                        raw_file,
                                                        gt_cam_height,
                                                        gt_cam_pitch,
                                                        vis, ax3, ax4)
                centerline_stats.append(np.array([r_lane, p_lane, cnt_gt, cnt_pred]))
                # consider x_error z_error only for the matched lanes
                # if r_lane > 0 and p_lane > 0:
                centerline_x_error_close.extend(x_error_close)
                centerline_x_error_far.extend(x_error_far)
                centerline_z_error_close.extend(z_error_close)
                centerline_z_error_far.extend(z_error_far)

            if vis:
                ax1.set_xticks([])
                ax1.set_yticks([])
                # ax2.set_xlabel('x axis')
                # ax2.set_ylabel('y axis')
                # ax2.set_zlabel('z axis')
                bottom, top = ax2.get_zlim()
                left, right = ax2.get_xlim()
                ax2.set_zlim(min(bottom, -0.1), max(top, 0.1))
                ax2.set_xlim(left, right)
                ax2.set_ylim(vis_min_y, vis_max_y)
                ax2.locator_params(nbins=5, axis='x')
                ax2.locator_params(nbins=5, axis='z')
                ax2.tick_params(pad=18)

                ax3.set_xticks([])
                ax3.set_yticks([])
                # ax4.set_xlabel('x axis')
                # ax4.set_ylabel('y axis')
                # ax4.set_zlabel('z axis')
                bottom, top = ax4.get_zlim()
                left, right = ax4.get_xlim()
                ax4.set_zlim(min(bottom, -0.1), max(top, 0.1))
                ax4.set_xlim(left, right)
                ax4.set_ylim(vis_min_y, vis_max_y)
                ax4.locator_params(nbins=5, axis='x')
                ax4.locator_params(nbins=5, axis='z')
                ax4.tick_params(pad=18)

                fig.subplots_adjust(wspace=0, hspace=0.01)
                fig.savefig(ops.join(save_path, raw_file.replace("/", "_")))
                plt.close(fig)
                print('processed sample: {}  {}'.format(i, raw_file))

        output_stats = []
        laneline_stats = np.array(laneline_stats)
        laneline_x_error_close = np.array(laneline_x_error_close)
        laneline_x_error_far = np.array(laneline_x_error_far)
        laneline_z_error_close = np.array(laneline_z_error_close)
        laneline_z_error_far = np.array(laneline_z_error_far)

        R_lane = np.sum(laneline_stats[:, 0]) / (np.sum(laneline_stats[:, 2]) + 1e-6)
        P_lane = np.sum(laneline_stats[:, 1]) / (np.sum(laneline_stats[:, 3]) + 1e-6)
        F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)
        x_error_close_avg = np.average(laneline_x_error_close)
        x_error_far_avg = np.average(laneline_x_error_far)
        z_error_close_avg = np.average(laneline_z_error_close)
        z_error_far_avg = np.average(laneline_z_error_far)

        output_stats.append(F_lane)
        output_stats.append(R_lane)
        output_stats.append(P_lane)
        output_stats.append(x_error_close_avg)
        output_stats.append(x_error_far_avg)
        output_stats.append(z_error_close_avg)
        output_stats.append(z_error_far_avg)

        if not self.no_centerline:
            centerline_stats = np.array(centerline_stats)
            centerline_x_error_close = np.array(centerline_x_error_close)
            centerline_x_error_far = np.array(centerline_x_error_far)
            centerline_z_error_close = np.array(centerline_z_error_close)
            centerline_z_error_far = np.array(centerline_z_error_far)

            R_lane = np.sum(centerline_stats[:, 0]) / (np.sum(centerline_stats[:, 2]) + 1e-6)
            P_lane = np.sum(centerline_stats[:, 1]) / (np.sum(centerline_stats[:, 3]) + 1e-6)
            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)
            x_error_close_avg = np.average(centerline_x_error_close)
            x_error_far_avg = np.average(centerline_x_error_far)
            z_error_close_avg = np.average(centerline_z_error_close)
            z_error_far_avg = np.average(centerline_z_error_far)

            output_stats.append(F_lane)
            output_stats.append(R_lane)
            output_stats.append(P_lane)
            output_stats.append(x_error_close_avg)
            output_stats.append(x_error_far_avg)
            output_stats.append(z_error_close_avg)
            output_stats.append(z_error_far_avg)

        return output_stats

    def bench_PR(self, pred_lanes, gt_lanes, gt_visibility):
        """
            Matching predicted lanes and ground-truth lanes in their IPM projection, ignoring z attributes.
            x error, y_error, and z error are all considered, although the matching does not rely on z
            The input of prediction and ground-truth lanes are in ground coordinate, x-right, y-forward, z-up
            The fundamental assumption is: 1. there are no two points from different lanes with identical x, y
                                              but different z's
                                           2. there are no two points from a single lane having identical x, y
                                              but different z's
            If the interest area is within the current drivable road, the above assumptions are almost always valid.
        :param pred_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :param gt_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :return:
        """

        r_lane, p_lane = 0., 0.

        # only keep the visible portion
        gt_lanes = [prune_3d_lane_by_visibility(np.array(gt_lane), np.array(gt_visibility[k])) for k, gt_lane in
                    enumerate(gt_lanes)]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        # only consider those gt lanes overlapping with sampling range
        gt_lanes = [lane for lane in gt_lanes if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]
        gt_lanes = [prune_3d_lane_by_range(np.array(gt_lane), 3 * self.x_min, 3 * self.x_max) for gt_lane in gt_lanes]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        cnt_gt = len(gt_lanes)
        cnt_pred = len(pred_lanes)

        gt_visibility_mat = np.zeros((cnt_gt, 100))
        pred_visibility_mat = np.zeros((cnt_pred, 100))
        # resample gt and pred at y_samples
        for i in range(cnt_gt):
            min_y = np.min(np.array(gt_lanes[i])[:, 1])
            max_y = np.max(np.array(gt_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(gt_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            gt_lanes[i] = np.vstack([x_values, z_values]).T
            gt_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                     np.logical_and(x_values <= self.x_max,
                                                                    np.logical_and(self.y_samples >= min_y,
                                                                                   self.y_samples <= max_y)))
            gt_visibility_mat[i, :] = np.logical_and(gt_visibility_mat[i, :], visibility_vec)

        for i in range(cnt_pred):
            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size
            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))
            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)
            min_y = np.min(np.array(pred_lanes[i])[:, 1])
            max_y = np.max(np.array(pred_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            pred_lanes[i] = np.vstack([x_values, z_values]).T
            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                       np.logical_and(x_values <= self.x_max,
                                                                      np.logical_and(self.y_samples >= min_y,
                                                                                     self.y_samples <= max_y)))
            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)
            # pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, x_values <= self.x_max)

        adj_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat.fill(1000)
        num_match_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        # compute curve to curve distance
        for i in range(cnt_gt):
            for j in range(cnt_pred):
                x_dist = np.abs(gt_lanes[i][:, 0] - pred_lanes[j][:, 0])
                z_dist = np.abs(gt_lanes[i][:, 1] - pred_lanes[j][:, 1])
                euclidean_dist = np.sqrt(x_dist ** 2 + z_dist ** 2)

                # apply visibility to penalize different partial matching accordingly
                euclidean_dist[
                    np.logical_or(gt_visibility_mat[i, :] < 0.5, pred_visibility_mat[j, :] < 0.5)] = self.dist_th

                # if np.average(euclidean_dist) < 2*self.dist_th: # don't prune here to encourage finding perfect match
                num_match_mat[i, j] = np.sum(euclidean_dist < self.dist_th)
                adj_mat[i, j] = 1
                # ATTENTION: use the sum as int type to meet the requirements of min cost flow optimization (int type)
                # why using num_match_mat as cost does not work?
                cost_mat[i, j] = np.sum(euclidean_dist).astype(np.int)
                # cost_mat[i, j] = num_match_mat[i, j]

        # solve bipartite matching vis min cost flow solver
        match_results = SolveMinCostFlow(adj_mat, cost_mat)
        match_results = np.array(match_results)

        # only a match with avg cost < self.dist_th is consider valid one
        match_gt_ids = []
        match_pred_ids = []
        if match_results.shape[0] > 0:
            for i in range(len(match_results)):
                if match_results[i, 2] < self.dist_th * self.y_samples.shape[0]:
                    gt_i = match_results[i, 0]
                    pred_i = match_results[i, 1]
                    # consider match when the matched points is above a ratio
                    if num_match_mat[gt_i, pred_i] / np.sum(gt_visibility_mat[gt_i, :]) >= self.ratio_th:
                        r_lane += 1
                        match_gt_ids.append(gt_i)
                    if num_match_mat[gt_i, pred_i] / np.sum(pred_visibility_mat[pred_i, :]) >= self.ratio_th:
                        p_lane += 1
                        match_pred_ids.append(pred_i)

        return r_lane, p_lane, cnt_gt, cnt_pred

    # evaluate two dataset at varying lane probability threshold to calculate AP
    def bench_one_submit_varying_probs(self, pred_file, gt_file, eval_out_file=None, eval_fig_file=None):
        varying_th = np.linspace(0.05, 0.95, 19)
        # try:
        pred_lines = open(pred_file).readlines()
        json_pred = [json.loads(line) for line in pred_lines]
        # except BaseException as e:
        #     raise Exception('Fail to load json file of the prediction.')
        json_gt = [json.loads(line) for line in open(gt_file).readlines()]
        if len(json_gt) != len(json_pred):
            raise Exception('We do not get the predictions of all the test tasks')
        gts = {l['raw_file']: l for l in json_gt}

        laneline_r_all = []
        laneline_p_all = []
        laneline_gt_cnt_all = []
        laneline_pred_cnt_all = []
        centerline_r_all = []
        centerline_p_all = []
        centerline_gt_cnt_all = []
        centerline_pred_cnt_all = []
        for i, pred in enumerate(json_pred):
            print('Evaluating sample {} / {}'.format(i, len(json_pred)))
            if 'raw_file' not in pred or 'laneLines' not in pred:
                raise Exception('raw_file or lanelines not in some predictions.')
            raw_file = pred['raw_file']

            pred_lanelines = pred['laneLines']
            pred_laneLines_prob = pred['laneLines_prob']
            if raw_file not in gts:
                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')
            gt = gts[raw_file]
            gt_cam_height = gt['cam_height']
            gt_cam_pitch = gt['cam_pitch']

            # evaluate lanelines
            gt_lanelines = gt['laneLines']
            gt_visibility = gt['laneLines_visibility']
            r_lane_vec = []
            p_lane_vec = []
            cnt_gt_vec = []
            cnt_pred_vec = []

            for prob_th in varying_th:
                pred_lanelines = [pred_lanelines[ii] for ii in range(len(pred_laneLines_prob)) if
                                  pred_laneLines_prob[ii] > prob_th]
                pred_laneLines_prob = [prob for prob in pred_laneLines_prob if prob > prob_th]
                pred_lanelines_copy = copy.deepcopy(pred_lanelines)
                # N to N matching of lanelines
                r_lane, p_lane, cnt_gt, cnt_pred = self.bench_PR(pred_lanelines_copy,
                                                                 gt_lanelines,
                                                                 gt_visibility)
                r_lane_vec.append(r_lane)
                p_lane_vec.append(p_lane)
                cnt_gt_vec.append(cnt_gt)
                cnt_pred_vec.append(cnt_pred)

            laneline_r_all.append(r_lane_vec)
            laneline_p_all.append(p_lane_vec)
            laneline_gt_cnt_all.append(cnt_gt_vec)
            laneline_pred_cnt_all.append(cnt_pred_vec)

            # evaluate centerlines
            if not self.no_centerline:
                pred_centerlines = pred['centerLines']
                pred_centerLines_prob = pred['centerLines_prob']
                gt_centerlines = gt['centerLines']
                gt_visibility = gt['centerLines_visibility']
                r_lane_vec = []
                p_lane_vec = []
                cnt_gt_vec = []
                cnt_pred_vec = []

                for prob_th in varying_th:
                    pred_centerlines = [pred_centerlines[ii] for ii in range(len(pred_centerLines_prob)) if
                                        pred_centerLines_prob[ii] > prob_th]
                    pred_centerLines_prob = [prob for prob in pred_centerLines_prob if prob > prob_th]
                    pred_centerlines_copy = copy.deepcopy(pred_centerlines)
                    # N to N matching of lanelines
                    r_lane, p_lane, cnt_gt, cnt_pred = self.bench_PR(pred_centerlines_copy,
                                                                     gt_centerlines,
                                                                     gt_visibility)
                    r_lane_vec.append(r_lane)
                    p_lane_vec.append(p_lane)
                    cnt_gt_vec.append(cnt_gt)
                    cnt_pred_vec.append(cnt_pred)
                centerline_r_all.append(r_lane_vec)
                centerline_p_all.append(p_lane_vec)
                centerline_gt_cnt_all.append(cnt_gt_vec)
                centerline_pred_cnt_all.append(cnt_pred_vec)

        output_stats = []
        # compute precision, recall
        laneline_r_all = np.array(laneline_r_all)
        laneline_p_all = np.array(laneline_p_all)
        laneline_gt_cnt_all = np.array(laneline_gt_cnt_all)
        laneline_pred_cnt_all = np.array(laneline_pred_cnt_all)

        R_lane = np.sum(laneline_r_all, axis=0) / (np.sum(laneline_gt_cnt_all, axis=0) + 1e-6)
        P_lane = np.sum(laneline_p_all, axis=0) / (np.sum(laneline_pred_cnt_all, axis=0) + 1e-6)
        F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)

        output_stats.append(F_lane)
        output_stats.append(R_lane)
        output_stats.append(P_lane)

        if not self.no_centerline:
            centerline_r_all = np.array(centerline_r_all)
            centerline_p_all = np.array(centerline_p_all)
            centerline_gt_cnt_all = np.array(centerline_gt_cnt_all)
            centerline_pred_cnt_all = np.array(centerline_pred_cnt_all)

            R_lane = np.sum(centerline_r_all, axis=0) / (np.sum(centerline_gt_cnt_all, axis=0) + 1e-6)
            P_lane = np.sum(centerline_p_all, axis=0) / (np.sum(centerline_pred_cnt_all, axis=0) + 1e-6)
            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)

            output_stats.append(F_lane)
            output_stats.append(R_lane)
            output_stats.append(P_lane)

        # calculate metrics
        laneline_F = output_stats[0]
        laneline_F_max = np.max(laneline_F)
        laneline_max_i = np.argmax(laneline_F)
        laneline_R = output_stats[1]
        laneline_P = output_stats[2]
        centerline_F = output_stats[3]
        centerline_F_max = centerline_F[laneline_max_i]
        centerline_max_i = laneline_max_i
        centerline_R = output_stats[4]
        centerline_P = output_stats[5]

        laneline_R = np.array([1.] + laneline_R.tolist() + [0.])
        laneline_P = np.array([0.] + laneline_P.tolist() + [1.])
        centerline_R = np.array([1.] + centerline_R.tolist() + [0.])
        centerline_P = np.array([0.] + centerline_P.tolist() + [1.])
        f_laneline = interp1d(laneline_R, laneline_P)
        f_centerline = interp1d(centerline_R, centerline_P)
        r_range = np.linspace(0.05, 0.95, 19)
        laneline_AP = np.mean(f_laneline(r_range))
        centerline_AP = np.mean(f_centerline(r_range))

        if eval_fig_file is not None:
            # plot PR curve
            fig = plt.figure()
            ax1 = fig.add_subplot(121)
            ax2 = fig.add_subplot(122)
            ax1.plot(laneline_R, laneline_P, '-s')
            ax2.plot(centerline_R, centerline_P, '-s')

            ax1.set_xlim(0, 1)
            ax1.set_ylim(0, 1)
            ax1.set_title('Lane Line')
            ax1.set_xlabel('Recall')
            ax1.set_ylabel('Precision')
            ax1.set_aspect('equal')
            ax1.legend('Max F-measure {:.3}'.format(laneline_F_max))

            ax2.set_xlim(0, 1)
            ax2.set_ylim(0, 1)
            ax2.set_title('Center Line')
            ax2.set_xlabel('Recall')
            ax2.set_ylabel('Precision')
            ax2.set_aspect('equal')
            ax2.legend('Max F-measure {:.3}'.format(centerline_F_max))

            # fig.subplots_adjust(wspace=0.1, hspace=0.01)
            fig.savefig(eval_fig_file)
            plt.close(fig)

        # print("===> Evaluation on validation set: \n"
        #       "laneline max F-measure {:.3} at Recall {:.3}, Precision {:.3} \n"
        #       "laneline AP: {:.3}\n"
        #       "centerline max F-measure {:.3} at Recall {:.3}, Precision {:.3} \n"
        #       "centerline AP: {:.3} \n".format(laneline_F_max,
        #                                        laneline_R[laneline_max_i + 1],
        #                                        laneline_P[laneline_max_i + 1],
        #                                        laneline_AP,
        #                                        centerline_F_max,
        #                                        centerline_R[centerline_max_i + 1],
        #                                        centerline_P[centerline_max_i + 1],
        #                                        centerline_AP))

        json_out = {}
        json_out['laneline_R'] = laneline_R[1:-1].astype(np.float32).tolist()
        json_out['laneline_P'] = laneline_P[1:-1].astype(np.float32).tolist()
        json_out['laneline_F_max'] = laneline_F_max
        json_out['laneline_max_i'] = laneline_max_i.tolist()
        json_out['laneline_AP'] = laneline_AP

        json_out['centerline_R'] = centerline_R[1:-1].astype(np.float32).tolist()
        json_out['centerline_P'] = centerline_P[1:-1].astype(np.float32).tolist()
        json_out['centerline_F_max'] = centerline_F_max
        json_out['centerline_max_i'] = centerline_max_i.tolist()
        json_out['centerline_AP'] = centerline_AP

        json_out['max_F_prob_th'] = varying_th[laneline_max_i]

        if eval_out_file is not None:
            with open(eval_out_file, 'w') as jsonFile:
                jsonFile.write(json.dumps(json_out))
                jsonFile.write('\n')
                jsonFile.close()
        return json_out


if __name__ == '__main__':
    vis = False
    parser = define_args()
    args = parser.parse_args()

    # two method are compared: '3D_LaneNet' and 'Gen_LaneNet'
    method_name = 'Gen_LaneNet_ext'

    # Three different splits of datasets: 'standard', 'rare_subsit', 'illus_chg'
    data_split = 'illus_chg'

    # location where the original dataset is saved. Image will be loaded in case of visualization
    args.dataset_dir = '~/Datasets/Apollo_Sim_3D_Lane_Release/'

    # load configuration for certain dataset
    sim3d_config(args)

    # auto-file in dependent paths
    gt_file = 'data_splits/' + data_split + '/test.json'
    pred_folder = 'data_splits/' + data_split + '/' + method_name
    pred_file = pred_folder + '/test_pred_file.json'

    # Initialize evaluator
    evaluator = LaneEval(args)

    # evaluation at varying thresholds
    eval_stats_pr = evaluator.bench_one_submit_varying_probs(pred_file, gt_file)
    max_f_prob = eval_stats_pr['max_F_prob_th']

    # evaluate at the point with max F-measure. Additional eval of position error. Option to visualize matching result
    eval_stats = evaluator.bench_one_submit(pred_file, gt_file, prob_th=max_f_prob, vis=vis)

    print("Metrics: AP, F-score, x error (close), x error (far), z error (close), z error (far)")
    print(
        "Laneline:  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}".format(eval_stats_pr['laneline_AP'], eval_stats[0],
                                                                     eval_stats[3], eval_stats[4],
                                                                     eval_stats[5], eval_stats[6]))
    print("Centerline:  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}".format(eval_stats_pr['centerline_AP'], eval_stats[7],
                                                                         eval_stats[10], eval_stats[11],
                                                                         eval_stats[12], eval_stats[13]))