<a href="https://colab.research.google.com/github/LimSeunghyeon1/Lanenet_with_Colab/blob/master/Lanenet_for_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#! unzip -d /content/drive/Shareddrives/colab/ /content/drive/MyDrive/data_splits.zip

In [None]:
#! cd /content/drive/Shareddrives/colab/data_splits/illus_chg/Gen_LaneNet_ext/ 

In [None]:
#! tar -xvf /content/drive/Shareddrives/colab/data_splits/illus_chg/Gen_LaneNet_ext/model_best_epoch_29.pth.tar

In [None]:
! #unzip -d /content/drive/Shareddrives/colab/ /content/drive/MyDrive/Apollo_Sim_3D_Lane_Release.zip

In [2]:
"""
Utility functions and default settings
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import argparse
import errno
import os
import sys
import easydict

import cv2
import matplotlib
import numpy as np
import torch
import torch.nn.init as init
import torch.optim
from torch.optim import lr_scheduler
import os.path as ops
from mpl_toolkits.mplot3d import Axes3D
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
plt.rcParams['figure.figsize'] = (35, 30)


def define_args():
    parser=easydict.EasyDict({
       "dataset_name":str(),
       "data_dir":str(),
       "dataset_dir":str(),
       "save_path":'/content/drive/Shareddrives/colab/data_splits',
       "org_h":1080,
       "org_w":1920,
       "crop_y":0,
       "cam_height":1.55,
       "pitch":float(3),
       "fix_cam":False,
       "no_3d": False,
       "no_centerline":False,
       "mod":'3DLaneNet',
       "pretrained":True,
       "batch_norm":True,
       "pred_cam":False,
       "ipm_h":208,
       "ipm_w":128,
       "resize_h":360,
       "resize_w":480,
       "y_ref":20.0,
       "prob_th":0.5,
       "batch_size":8,
       "nepochs":30,
       "learning_rate":5*1e-4,
       "no_cuda":False,
       "nworkers":0,
       "no_dropout":False,
       "pretrain_epochs":20,
       "channels_in":3,
       "flip_on":False,
       "test_mode":False,
       "start_epoch":0,
       "evaluate":True,  ##only show evaluation?
       "resume":str(),
       "vgg_mean":[0.485,0.456,0.406],
       "vgg_std":[0.229,0.224,0.225],
       "optimizer":'adam',
       "weight_init":"normal",
       "weight_decay":float(0),
       "lr_decay":False,
       "niter":50,
       "niter_decay":400,
       "lr_policy":None,
       "lr_decay_iters":30,
       "clip_grad_norm":0,
       "cudnn":True,
       "no_tb":False,
       "print_freq":500,
       "save_freq":500,
       "list":[954,2789]
    })
    '''
    parser = argparse.ArgumentParser(description='Lane_detection_all_objectives')
    # Paths settings
    parser.add_argument('--dataset_name', type=str, help='the dataset name to be used in saving model names')
    parser.add_argument('--data_dir', type=str, help='The path saving train.json and val.json files')
    parser.add_argument('--dataset_dir', type=str, help='The path saving actual data')
    parser.add_argument('--save_path', type=str, default='data_splits/', help='directory to save output')
    # Dataset settings
    parser.add_argument('--org_h', type=int, default=1080, help='height of the original image')
    parser.add_argument('--org_w', type=int, default=1920, help='width of the original image')
    parser.add_argument('--crop_y', type=int, default=0, help='crop from image')
    parser.add_argument('--cam_height', type=float, default=1.55, help='height of camera in meters')
    parser.add_argument('--pitch', type=float, default=3, help='pitch angle of camera to ground in centi degree')
    parser.add_argument('--fix_cam', type=str2bool, nargs='?', const=True, default=False, help='if to use fix camera')
    parser.add_argument('--no_3d', action='store_true', help='if a dataset include laneline 3D attributes')
    parser.add_argument('--no_centerline', action='store_true', help='if a dataset include centerline')
    # 3DLaneNet settings
    parser.add_argument('--mod', type=str, default='3DLaneNet', help='model to train')
    parser.add_argument("--pretrained", type=str2bool, nargs='?', const=True, default=True, help="use pretrained vgg model")
    parser.add_argument("--batch_norm", type=str2bool, nargs='?', const=True, default=True, help="apply batch norm")
    parser.add_argument("--pred_cam", type=str2bool, nargs='?', const=True, default=False, help="use network to predict camera online?")
    parser.add_argument('--ipm_h', type=int, default=208, help='height of inverse projective map (IPM)')
    parser.add_argument('--ipm_w', type=int, default=128, help='width of inverse projective map (IPM)')
    parser.add_argument('--resize_h', type=int, default=360, help='height of the original image')
    parser.add_argument('--resize_w', type=int, default=480, help='width of the original image')
    parser.add_argument('--y_ref', type=float, default=20.0, help='the reference Y distance in meters from where lane association is determined')
    parser.add_argument('--prob_th', type=float, default=0.5, help='probability threshold for selecting output lanes')
    # General model settings
    parser.add_argument('--batch_size', type=int, default=8, help='batch size')
    parser.add_argument('--nepochs', type=int, default=30, help='total numbers of epochs')
    parser.add_argument('--learning_rate', type=float, default=5*1e-4, help='learning rate')
    parser.add_argument('--no_cuda', action='store_true', help='if gpu available')
    parser.add_argument('--nworkers', type=int, default=0, help='num of threads')
    parser.add_argument('--no_dropout', action='store_true', help='no dropout in network')
    parser.add_argument('--pretrain_epochs', type=int, default=20, help='Number of epochs to perform segmentation pretraining')
    parser.add_argument('--channels_in', type=int, default=3, help='num channels of input image')
    parser.add_argument('--flip_on', action='store_true', help='Random flip input images on?')
    parser.add_argument('--test_mode', action='store_true', help='prevents loading latest saved model')
    parser.add_argument('--start_epoch', type=int, default=0, help='prevents loading latest saved model')
    parser.add_argument('--evaluate', action='store_true', help='only perform evaluation')
    parser.add_argument('--resume', type=str, default='', help='resume latest saved run')
    parser.add_argument('--vgg_mean', type=float, default=[0.485, 0.456, 0.406], help='Mean of rgb used in pretrained model on ImageNet')
    parser.add_argument('--vgg_std', type=float, default=[0.229, 0.224, 0.225], help='Std of rgb used in pretrained model on ImageNet')
    # Optimizer settings
    parser.add_argument('--optimizer', type=str, default='adam', help='adam or sgd')
    parser.add_argument('--weight_init', type=str, default='normal', help='normal, xavier, kaiming, orhtogonal weights initialisation')
    parser.add_argument('--weight_decay', type=float, default=0, help='L2 weight decay/regularisation on?')
    parser.add_argument('--lr_decay', action='store_true', help='decay learning rate with rule')
    parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate')
    parser.add_argument('--niter_decay', type=int, default=400, help='# of iter to linearly decay learning rate to zero')
    parser.add_argument('--lr_policy', default=None, help='learning rate policy: lambda|step|plateau')
    parser.add_argument('--lr_decay_iters', type=int, default=30, help='multiply by a gamma every lr_decay_iters iterations')
    parser.add_argument('--clip_grad_norm', type=int, default=0, help='performs gradient clipping')
    # CUDNN usage
    parser.add_argument("--cudnn", type=str2bool, nargs='?', const=True, default=True, help="cudnn optimization active")
    # Tensorboard settings
    parser.add_argument("--no_tb", type=str2bool, nargs='?', const=True, default=False, help="Use tensorboard logging by tensorflow")
    # Print settings
    parser.add_argument('--print_freq', type=int, default=500, help='padding')
    parser.add_argument('--save_freq', type=int, default=500, help='padding')
    # Skip batch
    parser.add_argument('--list', type=int, nargs='+', default=[954, 2789], help='Images you want to skip')
'''
    return parser


def tusimple_config(args):

    # set dataset parameters
    args.org_h = 720
    args.org_w = 1280
    args.crop_y = 80
    args.no_centerline = True
    args.no_3d = True
    args.fix_cam = True
    args.pred_cam = False

    # set camera parameters for the test dataset
    args.K = np.array([[1000, 0, 640],
                       [0, 1000, 400],
                       [0, 0, 1]])
    args.cam_height = 1.6
    args.pitch = 9

    # specify model settings
    """
    paper presented params:
        args.top_view_region = np.array([[-10, 85], [10, 85], [-10, 5], [10, 5]])
        args.anchor_y_steps = np.array([5, 20, 40, 60, 80, 100])
    """
    # args.top_view_region = np.array([[-10, 82], [10, 82], [-10, 2], [10, 2]])
    # args.anchor_y_steps = np.array([2, 3, 5, 10, 15, 20, 30, 40, 60, 80])
    args.top_view_region = np.array([[-10, 103], [10, 103], [-10, 3], [10, 3]])
    args.anchor_y_steps = np.array([5, 10, 15, 20, 30, 40, 50, 60, 80, 100])
    args.num_y_steps = len(args.anchor_y_steps)

    # initialize with pre-trained vgg weights
    args.pretrained = False
    # apply batch norm in network
    args.batch_norm = True


def sim3d_config(args):

    # set dataset parameters
    args.org_h = 1080
    args.org_w = 1920
    args.crop_y = 0
    args.no_centerline = False
    args.no_3d = False
    args.fix_cam = False
    args.pred_cam = False

    # set camera parameters for the test datasets
    args.K = np.array([[2015., 0., 960.],
                       [0., 2015., 540.],
                       [0., 0., 1.]])

    # specify model settings
    """
    paper presented params:
        args.top_view_region = np.array([[-10, 85], [10, 85], [-10, 5], [10, 5]])
        args.anchor_y_steps = np.array([5, 20, 40, 60, 80, 100])
    """
    # args.top_view_region = np.array([[-10, 83], [10, 83], [-10, 3], [10, 3]])
    # args.anchor_y_steps = np.array([3, 5, 10, 20, 40, 60, 80, 100])
    args.top_view_region = np.array([[-10, 103], [10, 103], [-10, 3], [10, 3]])
    args.anchor_y_steps = np.array([5, 10, 15, 20, 30, 40, 50, 60, 80, 100])
    args.num_y_steps = len(args.anchor_y_steps)

    # initialize with pre-trained vgg weights
    args.pretrained = False
    # apply batch norm in network
    args.batch_norm = True


class Visualizer:
    def __init__(self, args, vis_folder='val_vis'):
        self.save_path = args.save_path
        self.vis_folder = vis_folder
        self.no_3d = args.no_3d
        self.no_centerline = args.no_centerline
        self.vgg_mean = args.vgg_mean
        self.vgg_std = args.vgg_std
        self.ipm_w = args.ipm_w
        self.ipm_h = args.ipm_h
        self.num_y_steps = args.num_y_steps

        if args.no_3d:
            self.anchor_dim = args.num_y_steps + 1
        else:
            if 'ext' in args.mod:
                self.anchor_dim = 3 * args.num_y_steps + 1
            else:
                self.anchor_dim = 2 * args.num_y_steps + 1

        x_min = args.top_view_region[0, 0]
        x_max = args.top_view_region[1, 0]
        self.anchor_x_steps = np.linspace(x_min, x_max, np.int(args.ipm_w / 8), endpoint=True)
        self.anchor_y_steps = args.anchor_y_steps

        # transformation from ipm to ground region
        H_ipm2g = cv2.getPerspectiveTransform(np.float32([[0, 0],
                                                          [self.ipm_w-1, 0],
                                                          [0, self.ipm_h-1],
                                                          [self.ipm_w-1, self.ipm_h-1]]),
                                              np.float32(args.top_view_region))
        self.H_g2ipm = np.linalg.inv(H_ipm2g)

        # probability threshold for choosing visualize lanes
        self.prob_th = args.prob_th

    def draw_on_img(self, img, lane_anchor, P_g2im, draw_type='laneline', color=[0, 0, 1]):
        """
        :param img: image in numpy array, each pixel in [0, 1] range
        :param lane_anchor: lane anchor in N X C numpy ndarray, dimension in agree with dataloader
        :param P_g2im: projection from ground 3D coordinates to image 2D coordinates
        :param draw_type: 'laneline' or 'centerline' deciding which to draw
        :param color: [r, g, b] color for line,  each range in [0, 1]
        :return:
        """

        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                else:
                    z_3d = lane_anchor[j, self.num_y_steps:self.anchor_dim - 1]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                else:
                    z_3d = lane_anchor[j, self.anchor_dim + self.num_y_steps:2 * self.anchor_dim - 1]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2 * self.anchor_dim:2 * self.anchor_dim + self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                else:
                    z_3d = lane_anchor[j, 2 * self.anchor_dim + self.num_y_steps:3 * self.anchor_dim - 1]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)
        return img

    def draw_on_img_new(self, img, lane_anchor, P_g2im, draw_type='laneline', color=[0, 0, 1]):
        """
        :param img: image in numpy array, each pixel in [0, 1] range
        :param lane_anchor: lane anchor in N X C numpy ndarray, dimension in agree with dataloader
        :param P_g2im: projection from ground 3D coordinates to image 2D coordinates
        :param draw_type: 'laneline' or 'centerline' deciding which to draw
        :param color: [r, g, b] color for line,  each range in [0, 1]
        :return:
        """
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                    visibility = np.ones_like(x_2d)
                else:
                    z_3d = lane_anchor[j, self.num_y_steps:2*self.num_y_steps]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                    visibility = lane_anchor[j, 2 * self.num_y_steps:3 * self.num_y_steps]
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    if visibility[k] > self.prob_th:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)
                    else:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), [0, 0, 0], 2)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                    visibility = np.ones_like(x_2d)
                else:
                    z_3d = lane_anchor[j, self.anchor_dim + self.num_y_steps:self.anchor_dim + 2*self.num_y_steps]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                    visibility = lane_anchor[j, self.anchor_dim + 2*self.num_y_steps:self.anchor_dim + 3*self.num_y_steps]
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    if visibility[k] > self.prob_th:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)
                    else:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), [0, 0, 0], 2)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2*self.anchor_dim:2*self.anchor_dim + self.num_y_steps]
                x_3d = x_offsets + self.anchor_x_steps[j]
                if P_g2im.shape[1] is 3:
                    x_2d, y_2d = homographic_transformation(P_g2im, x_3d, self.anchor_y_steps)
                    visibility = np.ones_like(x_2d)
                else:
                    z_3d = lane_anchor[j, 2*self.anchor_dim + self.num_y_steps:2*self.anchor_dim + 2*self.num_y_steps]
                    x_2d, y_2d = projective_transformation(P_g2im, x_3d, self.anchor_y_steps, z_3d)
                    visibility = lane_anchor[j,
                                 2 * self.anchor_dim + 2 * self.num_y_steps:2 * self.anchor_dim + 3 * self.num_y_steps]
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)
                for k in range(1, x_2d.shape[0]):
                    if visibility[k] > self.prob_th:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color, 2)
                    else:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), [0, 0, 0], 2)
        return img

    def draw_on_ipm(self, im_ipm, lane_anchor, draw_type='laneline', color=[0, 0, 1]):
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                      (x_ipm[k], y_ipm[k]), color, 1)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                      (x_ipm[k], y_ipm[k]), color, 1)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3 * self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2 * self.anchor_dim:2 * self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                      (x_ipm[k], y_ipm[k]), color, 1)
        return im_ipm

    def draw_on_ipm_new(self, im_ipm, lane_anchor, draw_type='laneline', color=[0, 0, 1], width=1):
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    visibility = np.ones_like(x_g)
                else:
                    visibility = lane_anchor[j, 2*self.num_y_steps:3*self.num_y_steps]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    if visibility[k] > self.prob_th:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), color, width)
                    else:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), [0, 0, 0], width)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    visibility = np.ones_like(x_g)
                else:
                    visibility = lane_anchor[j, self.anchor_dim + 2*self.num_y_steps:self.anchor_dim + 3*self.num_y_steps]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    if visibility[k] > self.prob_th:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), color, width)
                    else:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), [0, 0, 0], width)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2*self.anchor_dim:2*self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    visibility = np.ones_like(x_g)
                else:
                    visibility = lane_anchor[j, 2*self.anchor_dim + 2*self.num_y_steps:2*self.anchor_dim + 3*self.num_y_steps]

                # compute lanelines in ipm view
                x_ipm, y_ipm = homographic_transformation(self.H_g2ipm, x_g, self.anchor_y_steps)
                x_ipm = x_ipm.astype(np.int)
                y_ipm = y_ipm.astype(np.int)
                for k in range(1, x_g.shape[0]):
                    if visibility[k] > self.prob_th:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), color, width)
                    else:
                        im_ipm = cv2.line(im_ipm, (x_ipm[k - 1], y_ipm[k - 1]),
                                          (x_ipm[k], y_ipm[k]), [0, 0, 0], width)
        return im_ipm

    def draw_3d_curves(self, ax, lane_anchor, draw_type='laneline', color=[0, 0, 1]):
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_g)
                else:
                    z_g = lane_anchor[j, self.num_y_steps:2*self.num_y_steps]
                ax.plot(x_g, self.anchor_y_steps, z_g, color=color)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_g)
                else:
                    z_g = lane_anchor[j, self.anchor_dim + self.num_y_steps:self.anchor_dim + 2*self.num_y_steps]
                ax.plot(x_g, self.anchor_y_steps, z_g, color=color)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2*self.anchor_dim:2*self.anchor_dim + self.num_y_steps]
                x_g = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_g)
                else:
                    z_g = lane_anchor[j, 2*self.anchor_dim + self.num_y_steps:2*self.anchor_dim + 2*self.num_y_steps]
                ax.plot(x_g, self.anchor_y_steps, z_g, color=color)

    def draw_3d_curves_new(self, ax, lane_anchor, h_cam, draw_type='laneline', color=[0, 0, 1]):
        for j in range(lane_anchor.shape[0]):
            # draw laneline
            if draw_type is 'laneline' and lane_anchor[j, self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, :self.num_y_steps]
                x_gflat = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_gflat)
                    visibility = np.ones_like(x_gflat)
                else:
                    z_g = lane_anchor[j, self.num_y_steps:2*self.num_y_steps]
                    visibility = lane_anchor[j, 2*self.num_y_steps:3*self.num_y_steps]
                x_gflat = x_gflat[np.where(visibility > self.prob_th)]
                z_g = z_g[np.where(visibility > self.prob_th)]
                if len(x_gflat) > 0:
                    # transform lane detected in flat ground space to 3d ground space
                    x_g, y_g = transform_lane_gflat2g(h_cam,
                                                      x_gflat,
                                                      self.anchor_y_steps[np.where(visibility > self.prob_th)],
                                                      z_g)
                    ax.plot(x_g, y_g, z_g, color=color)

            # draw centerline
            if draw_type is 'centerline' and lane_anchor[j, 2*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, self.anchor_dim:self.anchor_dim + self.num_y_steps]
                x_gflat = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_gflat)
                    visibility = np.ones_like(x_gflat)
                else:
                    z_g = lane_anchor[j, self.anchor_dim + self.num_y_steps:self.anchor_dim + 2*self.num_y_steps]
                    visibility = lane_anchor[j, self.anchor_dim + 2*self.num_y_steps:self.anchor_dim + 3*self.num_y_steps]
                x_gflat = x_gflat[np.where(visibility > self.prob_th)]
                z_g = z_g[np.where(visibility > self.prob_th)]
                if len(x_gflat) > 0:
                    # transform lane detected in flat ground space to 3d ground space
                    x_g, y_g = transform_lane_gflat2g(h_cam,
                                                      x_gflat,
                                                      self.anchor_y_steps[np.where(visibility > self.prob_th)],
                                                      z_g)
                    ax.plot(x_g, y_g, z_g, color=color)

            # draw the additional centerline for the merging case
            if draw_type is 'centerline' and lane_anchor[j, 3*self.anchor_dim - 1] > self.prob_th:
                x_offsets = lane_anchor[j, 2*self.anchor_dim:2*self.anchor_dim + self.num_y_steps]
                x_gflat = x_offsets + self.anchor_x_steps[j]
                if self.no_3d:
                    z_g = np.zeros_like(x_gflat)
                    visibility = np.ones_like(x_gflat)
                else:
                    z_g = lane_anchor[j, 2*self.anchor_dim + self.num_y_steps:2*self.anchor_dim + 2*self.num_y_steps]
                    visibility = lane_anchor[j, 2*self.anchor_dim + 2*self.num_y_steps:2*self.anchor_dim + 3*self.num_y_steps]
                x_gflat = x_gflat[np.where(visibility > self.prob_th)]
                z_g = z_g[np.where(visibility > self.prob_th)]
                if len(x_gflat) > 0:
                    # transform lane detected in flat ground space to 3d ground space
                    x_g, y_g = transform_lane_gflat2g(h_cam,
                                                      x_gflat,
                                                      self.anchor_y_steps[np.where(visibility > self.prob_th)],
                                                      z_g)
                    ax.plot(x_g, y_g, z_g, color=color)

    def save_result(self, dataset, train_or_val, epoch, batch_i, idx, images, gt, pred, pred_cam_pitch, pred_cam_height, aug_mat=np.identity(3, dtype=np.float), evaluate=False):
        if not dataset.data_aug:
            aug_mat = np.repeat(np.expand_dims(aug_mat, axis=0), idx.shape[0], axis=0)

        for i in range(idx.shape[0]):
            # during training, only visualize the first sample of this batch
            if i > 0 and not evaluate:
                break
            im = images.permute(0, 2, 3, 1).data.cpu().numpy()[i]
            # the vgg_std and vgg_mean are for images in [0, 1] range
            im = im * np.array(self.vgg_std)
            im = im + np.array(self.vgg_mean)
            im = np.clip(im, 0, 1)

            gt_anchors = gt[i]
            pred_anchors = pred[i]

            # apply nms to avoid output directly neighbored lanes
            # consider w/o centerline cases
            if self.no_centerline:
                pred_anchors[:, -1] = nms_1d(pred_anchors[:, -1])
            else:
                pred_anchors[:, self.anchor_dim - 1] = nms_1d(pred_anchors[:, self.anchor_dim - 1])
                pred_anchors[:, 2 * self.anchor_dim - 1] = nms_1d(pred_anchors[:, 2 * self.anchor_dim - 1])
                pred_anchors[:, 3 * self.anchor_dim - 1] = nms_1d(pred_anchors[:, 3 * self.anchor_dim - 1])

            H_g2im, P_g2im, H_crop, H_im2ipm = dataset.transform_mats(idx[i])
            if self.no_3d:
                P_gt = np.matmul(H_crop, H_g2im)
                H_g2im_pred = homograpthy_g2im(pred_cam_pitch[i],
                                               pred_cam_height[i], dataset.K)
                P_pred = np.matmul(H_crop, H_g2im_pred)

                # consider data augmentation
                P_gt = np.matmul(aug_mat[i, :, :], P_gt)
                P_pred = np.matmul(aug_mat[i, :, :], P_pred)
            else:
                P_gt = np.matmul(H_crop, P_g2im)
                P_g2im_pred = projection_g2im(pred_cam_pitch[i],
                                              pred_cam_height[i], dataset.K)
                P_pred = np.matmul(H_crop, P_g2im_pred)

                # consider data augmentation
                P_gt = np.matmul(aug_mat[i, :, :], P_gt)
                P_pred = np.matmul(aug_mat[i, :, :], P_pred)

            # update transformation with image augmentation
            H_im2ipm = np.matmul(H_im2ipm, np.linalg.inv(aug_mat[i, :, :]))
            im_ipm = cv2.warpPerspective(im, H_im2ipm, (self.ipm_w, self.ipm_h))
            im_ipm = np.clip(im_ipm, 0, 1)

            # draw lanes on image
            im_laneline = im.copy()
            im_laneline = self.draw_on_img(im_laneline, gt_anchors, P_gt, 'laneline', [0, 0, 1])
            im_laneline = self.draw_on_img(im_laneline, pred_anchors, P_pred, 'laneline', [1, 0, 0])
            if not self.no_centerline:
                im_centerline = im.copy()
                im_centerline = self.draw_on_img(im_centerline, gt_anchors, P_gt, 'centerline', [0, 0, 1])
                im_centerline = self.draw_on_img(im_centerline, pred_anchors, P_pred, 'centerline', [1, 0, 0])

            # draw lanes on ipm
            ipm_laneline = im_ipm.copy()
            ipm_laneline = self.draw_on_ipm(ipm_laneline, gt_anchors, 'laneline', [0, 0, 1])
            ipm_laneline = self.draw_on_ipm(ipm_laneline, pred_anchors, 'laneline', [1, 0, 0])
            if not self.no_centerline:
                ipm_centerline = im_ipm.copy()
                ipm_centerline = self.draw_on_ipm(ipm_centerline, gt_anchors, 'centerline', [0, 0, 1])
                ipm_centerline = self.draw_on_ipm(ipm_centerline, pred_anchors, 'centerline', [1, 0, 0])

            # plot on a single figure
            if self.no_centerline and self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(121)
                ax2 = fig.add_subplot(122)
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
            elif not self.no_centerline and self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(221)
                ax2 = fig.add_subplot(222)
                ax3 = fig.add_subplot(223)
                ax4 = fig.add_subplot(224)
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
                ax3.imshow(im_centerline)
                ax4.imshow(ipm_centerline)
            elif not self.no_centerline and not self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(231)
                ax2 = fig.add_subplot(232)
                ax3 = fig.add_subplot(233, projection='3d')
                ax4 = fig.add_subplot(234)
                ax5 = fig.add_subplot(235)
                ax6 = fig.add_subplot(236, projection='3d')
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
                self.draw_3d_curves(ax3, gt_anchors, 'laneline', [0, 0, 1])
                self.draw_3d_curves(ax3, pred_anchors, 'laneline', [1, 0, 0])
                ax3.set_xlabel('x axis')
                ax3.set_ylabel('y axis')
                ax3.set_zlabel('z axis')
                bottom, top = ax3.get_zlim()
                ax3.set_zlim(min(bottom, -1), max(top, 1))
                ax3.set_xlim(-20, 20)
                ax3.set_ylim(0, 100)
                ax4.imshow(im_centerline)
                ax5.imshow(ipm_centerline)
                self.draw_3d_curves(ax6, gt_anchors, 'centerline', [0, 0, 1])
                self.draw_3d_curves(ax6, pred_anchors, 'centerline', [1, 0, 0])
                ax6.set_xlabel('x axis')
                ax6.set_ylabel('y axis')
                ax6.set_zlabel('z axis')
                bottom, top = ax6.get_zlim()
                ax6.set_zlim(min(bottom, -1), max(top, 1))
                ax6.set_xlim(-20, 20)
                ax6.set_ylim(0, 100)

            if evaluate:
                fig.savefig(self.save_path + '/example/' + self.vis_folder + '/infer_{}'.format(idx[i]))
            else:
                fig.savefig(self.save_path + '/example/{}/epoch-{}_batch-{}_idx-{}'.format(train_or_val,
                                                                                           epoch, batch_i, idx[i]))
            plt.clf()
            plt.close(fig)

    def save_result_new(self, dataset, train_or_val, epoch, batch_i, idx, images, gt, pred, pred_cam_pitch, pred_cam_height, aug_mat=np.identity(3, dtype=np.float), evaluate=False):
        if not dataset.data_aug:
            aug_mat = np.repeat(np.expand_dims(aug_mat, axis=0), idx.shape[0], axis=0)

        for i in range(idx.shape[0]):
            # during training, only visualize the first sample of this batch
            if i > 0 and not evaluate:
                break
            im = images.permute(0, 2, 3, 1).data.cpu().numpy()[i]
            # the vgg_std and vgg_mean are for images in [0, 1] range
            im = im * np.array(self.vgg_std)
            im = im + np.array(self.vgg_mean)
            im = np.clip(im, 0, 1)

            gt_anchors = gt[i]
            pred_anchors = pred[i]

            # apply nms to avoid output directly neighbored lanes
            # consider w/o centerline cases
            if self.no_centerline:
                pred_anchors[:, -1] = nms_1d(pred_anchors[:, -1])
            else:
                pred_anchors[:, self.anchor_dim - 1] = nms_1d(pred_anchors[:, self.anchor_dim - 1])
                pred_anchors[:, 2 * self.anchor_dim - 1] = nms_1d(pred_anchors[:, 2 * self.anchor_dim - 1])
                pred_anchors[:, 3 * self.anchor_dim - 1] = nms_1d(pred_anchors[:, 3 * self.anchor_dim - 1])

            H_g2im, P_g2im, H_crop, H_im2ipm = dataset.transform_mats(idx[i])
            P_gt = np.matmul(H_crop, H_g2im)
            H_g2im_pred = homograpthy_g2im(pred_cam_pitch[i],
                                           pred_cam_height[i], dataset.K)
            P_pred = np.matmul(H_crop, H_g2im_pred)

            # consider data augmentation
            P_gt = np.matmul(aug_mat[i, :, :], P_gt)
            P_pred = np.matmul(aug_mat[i, :, :], P_pred)

            # update transformation with image augmentation
            H_im2ipm = np.matmul(H_im2ipm, np.linalg.inv(aug_mat[i, :, :]))
            im_ipm = cv2.warpPerspective(im, H_im2ipm, (self.ipm_w, self.ipm_h))
            im_ipm = np.clip(im_ipm, 0, 1)

            # draw lanes on image
            im_laneline = im.copy()
            im_laneline = self.draw_on_img_new(im_laneline, gt_anchors, P_gt, 'laneline', [0, 0, 1])
            im_laneline = self.draw_on_img_new(im_laneline, pred_anchors, P_pred, 'laneline', [1, 0, 0])
            if not self.no_centerline:
                im_centerline = im.copy()
                im_centerline = self.draw_on_img_new(im_centerline, gt_anchors, P_gt, 'centerline', [0, 0, 1])
                im_centerline = self.draw_on_img_new(im_centerline, pred_anchors, P_pred, 'centerline', [1, 0, 0])

            # draw lanes on ipm
            ipm_laneline = im_ipm.copy()
            ipm_laneline = self.draw_on_ipm_new(ipm_laneline, gt_anchors, 'laneline', [0, 0, 1])
            ipm_laneline = self.draw_on_ipm_new(ipm_laneline, pred_anchors, 'laneline', [1, 0, 0])
            if not self.no_centerline:
                ipm_centerline = im_ipm.copy()
                ipm_centerline = self.draw_on_ipm_new(ipm_centerline, gt_anchors, 'centerline', [0, 0, 1])
                ipm_centerline = self.draw_on_ipm_new(ipm_centerline, pred_anchors, 'centerline', [1, 0, 0])

            # plot on a single figure
            if self.no_centerline and self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(121)
                ax2 = fig.add_subplot(122)
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
            elif not self.no_centerline and self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(221)
                ax2 = fig.add_subplot(222)
                ax3 = fig.add_subplot(223)
                ax4 = fig.add_subplot(224)
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
                ax3.imshow(im_centerline)
                ax4.imshow(ipm_centerline)
            elif not self.no_centerline and not self.no_3d:
                fig = plt.figure()
                ax1 = fig.add_subplot(231)
                ax2 = fig.add_subplot(232)
                ax3 = fig.add_subplot(233, projection='3d')
                ax4 = fig.add_subplot(234)
                ax5 = fig.add_subplot(235)
                ax6 = fig.add_subplot(236, projection='3d')
                ax1.imshow(im_laneline)
                ax2.imshow(ipm_laneline)
                # TODO:use separate gt_cam_height when ready
                self.draw_3d_curves_new(ax3, gt_anchors, pred_cam_height[i], 'laneline', [0, 0, 1])
                self.draw_3d_curves_new(ax3, pred_anchors, pred_cam_height[i], 'laneline', [1, 0, 0])
                ax3.set_xlabel('x axis')
                ax3.set_ylabel('y axis')
                ax3.set_zlabel('z axis')
                bottom, top = ax3.get_zlim()
                ax3.set_xlim(-20, 20)
                ax3.set_ylim(0, 100)
                ax3.set_zlim(min(bottom, -1), max(top, 1))
                ax4.imshow(im_centerline)
                ax5.imshow(ipm_centerline)
                # TODO:use separate gt_cam_height when ready
                self.draw_3d_curves_new(ax6, gt_anchors, pred_cam_height[i], 'centerline', [0, 0, 1])
                self.draw_3d_curves_new(ax6, pred_anchors, pred_cam_height[i], 'centerline', [1, 0, 0])
                ax6.set_xlabel('x axis')
                ax6.set_ylabel('y axis')
                ax6.set_zlabel('z axis')
                bottom, top = ax6.get_zlim()
                ax6.set_xlim(-20, 20)
                ax6.set_ylim(0, 100)
                ax6.set_zlim(min(bottom, -1), max(top, 1))

            if evaluate:
                fig.savefig(self.save_path + '/example/' + self.vis_folder + '/infer_{}'.format(idx[i]))
            else:
                fig.savefig(self.save_path + '/example/{}/epoch-{}_batch-{}_idx-{}'.format(train_or_val,
                                                                                           epoch, batch_i, idx[i]))
            plt.clf()
            plt.close(fig)


def prune_3d_lane_by_visibility(lane_3d, visibility):
    lane_3d = lane_3d[visibility > 0, ...]
    return lane_3d


def prune_3d_lane_by_range(lane_3d, x_min, x_max):
    # TODO: solve hard coded range later
    # remove points with y out of range
    # 3D label may miss super long straight-line with only two points: Not have to be 200, gt need a min-step
    # 2D dataset requires this to rule out those points projected to ground, but out of meaningful range
    lane_3d = lane_3d[np.logical_and(lane_3d[:, 1] > 0, lane_3d[:, 1] < 200), ...]

    # remove lane points out of x range
    lane_3d = lane_3d[np.logical_and(lane_3d[:, 0] > x_min,
                                     lane_3d[:, 0] < x_max), ...]
    return lane_3d


def resample_laneline_in_y(input_lane, y_steps, out_vis=False):
    """
        Interpolate x, z values at each anchor grid, including those beyond the range of input lnae y range
    :param input_lane: N x 2 or N x 3 ndarray, one row for a point (x, y, z-optional).
                       It requires y values of input lane in ascending order
    :param y_steps: a vector of steps in y
    :param out_vis: whether to output visibility indicator which only depends on input y range
    :return:
    """

    # at least two points are included
    assert(input_lane.shape[0] >= 2)

    y_min = np.min(input_lane[:, 1])-5
    y_max = np.max(input_lane[:, 1])+5

    if input_lane.shape[1] < 3:
        input_lane = np.concatenate([input_lane, np.zeros([input_lane.shape[0], 1], dtype=np.float32)], axis=1)

    f_x = interp1d(input_lane[:, 1], input_lane[:, 0], fill_value="extrapolate")
    f_z = interp1d(input_lane[:, 1], input_lane[:, 2], fill_value="extrapolate")

    x_values = f_x(y_steps)
    z_values = f_z(y_steps)

    if out_vis:
        output_visibility = np.logical_and(y_steps >= y_min, y_steps <= y_max)
        return x_values, z_values, output_visibility.astype(np.float32) + 1e-9
    return x_values, z_values


def resample_laneline_in_y_with_vis(input_lane, y_steps, vis_vec):
    """
        Interpolate x, z values at each anchor grid, including those beyond the range of input lnae y range
    :param input_lane: N x 2 or N x 3 ndarray, one row for a point (x, y, z-optional).
                       It requires y values of input lane in ascending order
    :param y_steps: a vector of steps in y
    :param out_vis: whether to output visibility indicator which only depends on input y range
    :return:
    """

    # at least two points are included
    assert(input_lane.shape[0] >= 2)

    if input_lane.shape[1] < 3:
        input_lane = np.concatenate([input_lane, np.zeros([input_lane.shape[0], 1], dtype=np.float32)], axis=1)

    f_x = interp1d(input_lane[:, 1], input_lane[:, 0], fill_value="extrapolate")
    f_z = interp1d(input_lane[:, 1], input_lane[:, 2], fill_value="extrapolate")
    f_vis = interp1d(input_lane[:, 1], vis_vec, fill_value="extrapolate")

    x_values = f_x(y_steps)
    z_values = f_z(y_steps)
    vis_values = f_vis(y_steps)

    x_values = x_values[vis_values > 0.5]
    y_values = y_steps[vis_values > 0.5]
    z_values = z_values[vis_values > 0.5]
    return np.array([x_values, y_values, z_values]).T


def homography_im2ipm_norm(top_view_region, org_img_size, crop_y, resize_img_size, cam_pitch, cam_height, K):
    """
        Compute the normalized transformation such that image region are mapped to top_view region maps to
        the top view image's 4 corners
        Ground coordinates: x-right, y-forward, z-up
        The purpose of applying normalized transformation: 1. invariance in scale change
                                                           2.Torch grid sample is based on normalized grids
    :param top_view_region: a 4 X 2 list of (X, Y) indicating the top-view region corners in order:
                            top-left, top-right, bottom-left, bottom-right
    :param org_img_size: the size of original image size: [h, w]
    :param crop_y: pixels croped from original img
    :param resize_img_size: the size of image as network input: [h, w]
    :param cam_pitch: camera pitch angle wrt ground plane
    :param cam_height: camera height wrt ground plane in meters
    :param K: camera intrinsic parameters
    :return: H_im2ipm_norm: the normalized transformation from image to IPM image
    """

    # compute homography transformation from ground to image (only this depends on cam_pitch and cam height)
    H_g2im = homograpthy_g2im(cam_pitch, cam_height, K)
    # transform original image region to network input region
    H_c = homography_crop_resize(org_img_size, crop_y, resize_img_size)
    H_g2im = np.matmul(H_c, H_g2im)

    # compute top-view corners' coordinates in image
    x_2d, y_2d = homographic_transformation(H_g2im, top_view_region[:, 0], top_view_region[:, 1])
    border_im = np.concatenate([x_2d.reshape(-1, 1), y_2d.reshape(-1, 1)], axis=1)

    # compute the normalized transformation
    border_im[:, 0] = border_im[:, 0] / resize_img_size[1]
    border_im[:, 1] = border_im[:, 1] / resize_img_size[0]
    border_im = np.float32(border_im)
    dst = np.float32([[0, 0], [1, 0], [0, 1], [1, 1]])
    # img to ipm
    H_im2ipm_norm = cv2.getPerspectiveTransform(border_im, dst)
    # ipm to im
    H_ipm2im_norm = cv2.getPerspectiveTransform(dst, border_im)
    return H_im2ipm_norm, H_ipm2im_norm


def homography_ipmnorm2g(top_view_region):
    src = np.float32([[0, 0], [1, 0], [0, 1], [1, 1]])
    H_ipmnorm2g = cv2.getPerspectiveTransform(src, np.float32(top_view_region))
    return H_ipmnorm2g


def homograpthy_g2im(cam_pitch, cam_height, K):
    # transform top-view region to original image region
    R_g2c = np.array([[1, 0, 0],
                      [0, np.cos(np.pi / 2 + cam_pitch), -np.sin(np.pi / 2 + cam_pitch)],
                      [0, np.sin(np.pi / 2 + cam_pitch), np.cos(np.pi / 2 + cam_pitch)]])
    H_g2im = np.matmul(K, np.concatenate([R_g2c[:, 0:2], [[0], [cam_height], [0]]], 1))
    return H_g2im


def projection_g2im(cam_pitch, cam_height, K):
    P_g2c = np.array([[1,                             0,                              0,          0],
                      [0, np.cos(np.pi / 2 + cam_pitch), -np.sin(np.pi / 2 + cam_pitch), cam_height],
                      [0, np.sin(np.pi / 2 + cam_pitch),  np.cos(np.pi / 2 + cam_pitch),          0]])
    P_g2im = np.matmul(K, P_g2c)
    return P_g2im


def homography_crop_resize(org_img_size, crop_y, resize_img_size):
    """
        compute the homography matrix transform original image to cropped and resized image
    :param org_img_size: [org_h, org_w]
    :param crop_y:
    :param resize_img_size: [resize_h, resize_w]
    :return:
    """
    # transform original image region to network input region
    ratio_x = resize_img_size[1] / org_img_size[1]
    ratio_y = resize_img_size[0] / (org_img_size[0] - crop_y)
    H_c = np.array([[ratio_x, 0, 0],
                    [0, ratio_y, -ratio_y*crop_y],
                    [0, 0, 1]])
    return H_c


def homographic_transformation(Matrix, x, y):
    """
    Helper function to transform coordinates defined by transformation matrix
    Args:
            Matrix (multi dim - array): 3x3 homography matrix
            x (array): original x coordinates
            y (array): original y coordinates
    """
    ones = np.ones((1, len(y)))
    coordinates = np.vstack((x, y, ones))
    trans = np.matmul(Matrix, coordinates)

    x_vals = trans[0, :]/trans[2, :]
    y_vals = trans[1, :]/trans[2, :]
    return x_vals, y_vals


def projective_transformation(Matrix, x, y, z):
    """
    Helper function to transform coordinates defined by transformation matrix
    Args:
            Matrix (multi dim - array): 3x4 projection matrix
            x (array): original x coordinates
            y (array): original y coordinates
            z (array): original z coordinates
    """
    ones = np.ones((1, len(z)))
    coordinates = np.vstack((x, y, z, ones))
    trans = np.matmul(Matrix, coordinates)

    x_vals = trans[0, :]/trans[2, :]
    y_vals = trans[1, :]/trans[2, :]
    return x_vals, y_vals


def transform_lane_gflat2g(h_cam, X_gflat, Y_gflat, Z_g):
    """
        Given X coordinates in flat ground space, Y coordinates in flat ground space, and Z coordinates in real 3D ground space
        with projection matrix from 3D ground to flat ground, compute real 3D coordinates X, Y in 3D ground space.
    :param P_g2gflat: a 3 X 4 matrix transforms lane form 3d ground x,y,z to flat ground x, y
    :param X_gflat: X coordinates in flat ground space
    :param Y_gflat: Y coordinates in flat ground space
    :param Z_g: Z coordinates in real 3D ground space
    :return:
    """

    X_g = X_gflat - X_gflat * Z_g / h_cam
    Y_g = Y_gflat - Y_gflat * Z_g / h_cam

    return X_g, Y_g


def transform_lane_g2gflat(h_cam, X_g, Y_g, Z_g):
    """
        Given X coordinates in flat ground space, Y coordinates in flat ground space, and Z coordinates in real 3D ground space
        with projection matrix from 3D ground to flat ground, compute real 3D coordinates X, Y in 3D ground space.
    :param P_g2gflat: a 3 X 4 matrix transforms lane form 3d ground x,y,z to flat ground x, y
    :param X_gflat: X coordinates in flat ground space
    :param Y_gflat: Y coordinates in flat ground space
    :param Z_g: Z coordinates in real 3D ground space
    :return:
    """

    X_gflat = X_g * h_cam / (h_cam - Z_g)
    Y_gflat = Y_g * h_cam / (h_cam - Z_g)

    return X_gflat, Y_gflat


def nms_1d(v):
    """
    :param v: a 1D numpy array
    :return:
    """
    v_out = v.copy()
    len = v.shape[0]
    if len < 2:
        return v
    for i in range(len):
        if i is not 0 and v[i - 1] > v[i]:
            v_out[i] = 0.
        elif i is not len-1 and v[i+1] > v[i]:
            v_out[i] = 0.
    return v_out


def first_run(save_path):
    txt_file = os.path.join(save_path,'first_run.txt')
    if not os.path.exists(txt_file):
        open(txt_file, 'w').close()
    else:
        saved_epoch = open(txt_file).read()
        if saved_epoch is None:
            print('You forgot to delete [first run file]')
            return '' 
        return saved_epoch
    return ''


def mkdir_if_missing(directory):
    if not os.path.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise


# trick from stackoverflow
def str2bool(argument):
    if argument.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif argument.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Wrong argument in argparse, should be a boolean')


class Logger(object):
    """
    Source https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
    """
    def __init__(self, fpath=None):
        self.console = sys.stdout
        self.file = None
        self.fpath = fpath
        if fpath is not None:
            mkdir_if_missing(os.path.dirname(fpath))
            self.file = open(fpath, 'w')

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, *args):
        self.close()

    def write(self, msg):
        self.console.write(msg)
        if self.file is not None:
            self.file.write(msg)

    def flush(self):
        self.console.flush()
        if self.file is not None:
            self.file.flush()
            os.fsync(self.file.fileno())

    def close(self):
        self.console.close()
        if self.file is not None:
            self.file.close()


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def define_optim(optim, params, lr, weight_decay):
    if optim == 'adam':
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    elif optim == 'sgd':
        optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
    elif optim == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
    else:
        raise KeyError("The requested optimizer: {} is not implemented".format(optim))
    return optimizer


def define_scheduler(optimizer, args):
    if args.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 - args.niter) / float(args.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif args.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=args.lr_decay_iters, gamma=args.gamma)
    elif args.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                   factor=args.gamma,
                                                   threshold=0.0001,
                                                   patience=args.lr_decay_iters)
    elif args.lr_policy == 'none':
        scheduler = None
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy)
    return scheduler


def define_init_weights(model, init_w='normal', activation='relu'):
    print('Init weights in network with [{}]'.format(init_w))
    if init_w == 'normal':
        model.apply(weights_init_normal)
    elif init_w == 'xavier':
        model.apply(weights_init_xavier)
    elif init_w == 'kaiming':
        model.apply(weights_init_kaiming)
    elif init_w == 'orthogonal':
        model.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError('initialization method [{}] is not implemented'.format(init_w))


def weights_init_normal(m):
    classname = m.__class__.__name__
#    print(classname)
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_xavier(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        init.xavier_normal_(m.weight.data, gain=0.02)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.xavier_normal_(m.weight.data, gain=0.02)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
#    print(classname)
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        init.orthogonal(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.orthogonal(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)

In [3]:
! pip install ortools

Collecting ortools
[?25l  Downloading https://files.pythonhosted.org/packages/6a/bd/75277072925d687aa35a6ea9e23e81a7f6b7c980b2a80949c5b9a3f98c79/ortools-9.0.9048-cp37-cp37m-manylinux1_x86_64.whl (14.4MB)
[K     |████████████████████████████████| 14.4MB 196kB/s 
[?25hCollecting protobuf>=3.15.8
[?25l  Downloading https://files.pythonhosted.org/packages/4c/53/ddcef00219f2a3c863b24288e24a20c3070bd086a1e77706f22994a7f6db/protobuf-3.17.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 45.6MB/s 
Installing collected packages: protobuf, ortools
  Found existing installation: protobuf 3.12.4
    Uninstalling protobuf-3.12.4:
      Successfully uninstalled protobuf-3.12.4
Successfully installed ortools-9.0.9048 protobuf-3.17.3


In [4]:
"""
MinCostFow solver adapted for matching two set of contours. The implementation is based on google-ortools.
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

from __future__ import print_function
import numpy as np
from ortools.graph import pywrapgraph
import time


def SolveMinCostFlow(adj_mat, cost_mat):
    """
        Solving an Assignment Problem with MinCostFlow"
    :param adj_mat: adjacency matrix with binary values indicating possible matchings between two sets
    :param cost_mat: cost matrix recording the matching cost of every possible pair of items from two sets
    :return:
    """

    # Instantiate a SimpleMinCostFlow solver.
    min_cost_flow = pywrapgraph.SimpleMinCostFlow()
    # Define the directed graph for the flow.

    cnt_1, cnt_2 = adj_mat.shape
    cnt_nonzero_row = int(np.sum(np.sum(adj_mat, axis=1) > 0))
    cnt_nonzero_col = int(np.sum(np.sum(adj_mat, axis=0) > 0))

    # prepare directed graph for the flow
    start_nodes = np.zeros(cnt_1, dtype=np.int).tolist() +\
                  np.repeat(np.array(range(1, cnt_1+1)), cnt_2).tolist() + \
                  [i for i in range(cnt_1+1, cnt_1 + cnt_2 + 1)]
    end_nodes = [i for i in range(1, cnt_1+1)] + \
                np.repeat(np.array([i for i in range(cnt_1+1, cnt_1 + cnt_2 + 1)]).reshape([1, -1]), cnt_1, axis=0).flatten().tolist() + \
                [cnt_1 + cnt_2 + 1 for i in range(cnt_2)]
    capacities = np.ones(cnt_1, dtype=np.int).tolist() + adj_mat.flatten().astype(np.int).tolist() + np.ones(cnt_2, dtype=np.int).tolist()
    costs = (np.zeros(cnt_1, dtype=np.int).tolist() + cost_mat.flatten().astype(np.int).tolist() + np.zeros(cnt_2, dtype=np.int).tolist())
    # Define an array of supplies at each node.
    supplies = [min(cnt_nonzero_row, cnt_nonzero_col)] + np.zeros(cnt_1 + cnt_2, dtype=np.int).tolist() + [-min(cnt_nonzero_row, cnt_nonzero_col)]
    # supplies = [min(cnt_1, cnt_2)] + np.zeros(cnt_1 + cnt_2, dtype=np.int).tolist() + [-min(cnt_1, cnt_2)]
    source = 0
    sink = cnt_1 + cnt_2 + 1

    # Add each arc.
    for i in range(len(start_nodes)):
        min_cost_flow.AddArcWithCapacityAndUnitCost(start_nodes[i], end_nodes[i],
                                                    capacities[i], costs[i])

    # Add node supplies.
    for i in range(len(supplies)):
        min_cost_flow.SetNodeSupply(i, supplies[i])

    match_results = []
    # Find the minimum cost flow between node 0 and node 10.
    if min_cost_flow.Solve() == min_cost_flow.OPTIMAL:
        # print('Total cost = ', min_cost_flow.OptimalCost())
        # print()
        for arc in range(min_cost_flow.NumArcs()):

            # Can ignore arcs leading out of source or into sink.
            if min_cost_flow.Tail(arc)!=source and min_cost_flow.Head(arc)!=sink:

                # Arcs in the solution have a flow value of 1. Their start and end nodes
                # give an assignment of worker to task.

                if min_cost_flow.Flow(arc) > 0:
                    # print('set A item %d assigned to set B item %d.  Cost = %d' % (
                    #     min_cost_flow.Tail(arc)-1,
                    #     min_cost_flow.Head(arc)-cnt_1-1,
                    #     min_cost_flow.UnitCost(arc)))
                    match_results.append([min_cost_flow.Tail(arc)-1,
                                          min_cost_flow.Head(arc)-cnt_1-1,
                                          min_cost_flow.UnitCost(arc)])
    else:
        print('There was an issue with the min cost flow input.')

    return match_results


def main():
    """Solving an Assignment Problem with MinCostFlow"""

    # Instantiate a SimpleMinCostFlow solver.
    min_cost_flow = pywrapgraph.SimpleMinCostFlow()
    # Define the directed graph for the flow.

    start_nodes = [0, 0, 0, 0] + [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4] + [5, 6, 7, 8]
    end_nodes = [1, 2, 3, 4] + [5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8] + [9, 9, 9, 9]
    capacities = [1, 1, 1, 1] + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + [1, 1, 1, 1]
    costs = ([0, 0, 0, 0] + [90, 76, 75, 70, 35, 85, 55, 65, 125, 95, 90, 105, 45, 110, 95, 115] + [0, 0, 0, 0])
    # Define an array of supplies at each node.
    supplies = [4, 0, 0, 0, 0, 0, 0, 0, 0, -4]
    source = 0
    sink = 9
    tasks = 4

    # Add each arc.
    for i in range(len(start_nodes)):
        min_cost_flow.AddArcWithCapacityAndUnitCost(start_nodes[i], end_nodes[i],
                                                    capacities[i], costs[i])

    # Add node supplies.

    for i in range(len(supplies)):
        min_cost_flow.SetNodeSupply(i, supplies[i])
    # Find the minimum cost flow between node 0 and node 10.
    if min_cost_flow.Solve() == min_cost_flow.OPTIMAL:
        print('Total cost = ', min_cost_flow.OptimalCost())
        print()
        for arc in range(min_cost_flow.NumArcs()):

            # Can ignore arcs leading out of source or into sink.
            if min_cost_flow.Tail(arc)!=source and min_cost_flow.Head(arc)!=sink:

                # Arcs in the solution have a flow value of 1. Their start and end nodes
                # give an assignment of worker to task.

                if min_cost_flow.Flow(arc) > 0:
                    print('Worker %d assigned to task %d.  Cost = %d' % (
                        min_cost_flow.Tail(arc),
                        min_cost_flow.Head(arc),
                        min_cost_flow.UnitCost(arc)))
    else:
        print('There was an issue with the min cost flow input.')

'''
if __name__ == '__main__':
    start_time = time.clock()
    main()
    print()
    print("Time =", time.clock() - start_time, "seconds")
    '''

'\nif __name__ == \'__main__\':\n    start_time = time.clock()\n    main()\n    print()\n    print("Time =", time.clock() - start_time, "seconds")\n    '

In [5]:
! pip install ujson

Collecting ujson
[?25l  Downloading https://files.pythonhosted.org/packages/17/4e/50e8e4cf5f00b537095711c2c86ac4d7191aed2b4fffd5a19f06898f6929/ujson-4.0.2-cp37-cp37m-manylinux1_x86_64.whl (179kB)
[K     |█▉                              | 10kB 24.7MB/s eta 0:00:01[K     |███▋                            | 20kB 17.8MB/s eta 0:00:01[K     |█████▌                          | 30kB 15.0MB/s eta 0:00:01[K     |███████▎                        | 40kB 13.5MB/s eta 0:00:01[K     |█████████▏                      | 51kB 7.4MB/s eta 0:00:01[K     |███████████                     | 61kB 7.3MB/s eta 0:00:01[K     |████████████▉                   | 71kB 8.1MB/s eta 0:00:01[K     |██████████████▋                 | 81kB 8.8MB/s eta 0:00:01[K     |████████████████▌               | 92kB 9.1MB/s eta 0:00:01[K     |██████████████████▎             | 102kB 7.4MB/s eta 0:00:01[K     |████████████████████▏           | 112kB 7.4MB/s eta 0:00:01[K     |██████████████████████          | 122kB

In [6]:
"""
eval_3D_Lane.py
Description: This code is to evaluate 3D lane detection. The optimal matching between ground-truth set and predicted
set of lanes are sought via solving a min cost flow.
Evaluation metrics includes:
    Average Precision (AP)
    Max F-scores
    x error close (0 - 40 m)
    x error far (0 - 100 m)
    z error close (0 - 40 m)
    z error far (0 - 100 m)
Reference: "Gen-LaneNet: Generalized and Scalable Approach for 3D Lane Detection". Y. Guo. etal. 2020
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import numpy as np
import cv2
import os
import os.path as ops
import copy
import math
import ujson as json
from scipy.interpolate import interp1d
import matplotlib
#from tools.utils import *
#from tools.MinCostFlow import SolveMinCostFlow
from mpl_toolkits.mplot3d import Axes3D

matplotlib.use('Agg')
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (35, 30)
plt.rcParams.update({'font.size': 25})
plt.rcParams.update({'font.weight': 'semibold'})

color = [[0, 0, 255],  # red
         [0, 255, 0],  # green
         [255, 0, 255],  # purple
         [255, 255, 0]]  # cyan

vis_min_y = 5
vis_max_y = 80


class LaneEval(object):
    def __init__(self, args):
        self.dataset_dir = args.dataset_dir
        self.K = args.K
        self.no_centerline = args.no_centerline
        self.resize_h = args.resize_h
        self.resize_w = args.resize_w
        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [args.resize_h, args.resize_w])

        self.x_min = args.top_view_region[0, 0]
        self.x_max = args.top_view_region[1, 0]
        self.y_min = args.top_view_region[2, 1]
        self.y_max = args.top_view_region[0, 1]
        self.y_samples = np.linspace(self.y_min, self.y_max, num=100, endpoint=False)
        # self.y_samples = np.linspace(min_y, max_y, num=100, endpoint=False)
        self.dist_th = 1.5
        self.ratio_th = 0.75
        self.close_range = 40

    def bench(self, pred_lanes, gt_lanes, gt_visibility, raw_file, gt_cam_height, gt_cam_pitch, vis, ax1, ax2):
        """
            Matching predicted lanes and ground-truth lanes in their IPM projection, ignoring z attributes.
            x error, y_error, and z error are all considered, although the matching does not rely on z
            The input of prediction and ground-truth lanes are in ground coordinate, x-right, y-forward, z-up
            The fundamental assumption is: 1. there are no two points from different lanes with identical x, y
                                              but different z's
                                           2. there are no two points from a single lane having identical x, y
                                              but different z's
            If the interest area is within the current drivable road, the above assumptions are almost always valid.
        :param pred_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :param gt_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :param raw_file: file path rooted in dataset folder
        :param gt_cam_height: camera height given in ground-truth data
        :param gt_cam_pitch: camera pitch given in ground-truth data
        :return:
        """

        # change this properly
        close_range_idx = np.where(self.y_samples > self.close_range)[0][0]

        r_lane, p_lane = 0., 0.
        x_error_close = []
        x_error_far = []
        z_error_close = []
        z_error_far = []

        # only keep the visible portion
        gt_lanes = [prune_3d_lane_by_visibility(np.array(gt_lane), np.array(gt_visibility[k])) for k, gt_lane in
                    enumerate(gt_lanes)]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        # only consider those gt lanes overlapping with sampling range
        gt_lanes = [lane for lane in gt_lanes if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]
        gt_lanes = [prune_3d_lane_by_range(np.array(gt_lane), 3 * self.x_min, 3 * self.x_max) for gt_lane in gt_lanes]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        cnt_gt = len(gt_lanes)
        cnt_pred = len(pred_lanes)

        gt_visibility_mat = np.zeros((cnt_gt, 100))
        pred_visibility_mat = np.zeros((cnt_pred, 100))
        # resample gt and pred at y_samples
        for i in range(cnt_gt):
            min_y = np.min(np.array(gt_lanes[i])[:, 1])
            max_y = np.max(np.array(gt_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(gt_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            gt_lanes[i] = np.vstack([x_values, z_values]).T
            gt_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                     np.logical_and(x_values <= self.x_max,
                                                                    np.logical_and(self.y_samples >= min_y,
                                                                                   self.y_samples <= max_y)))
            gt_visibility_mat[i, :] = np.logical_and(gt_visibility_mat[i, :], visibility_vec)

        for i in range(cnt_pred):
            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size
            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))
            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)
            min_y = np.min(np.array(pred_lanes[i])[:, 1])
            max_y = np.max(np.array(pred_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            pred_lanes[i] = np.vstack([x_values, z_values]).T
            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                       np.logical_and(x_values <= self.x_max,
                                                                      np.logical_and(self.y_samples >= min_y,
                                                                                     self.y_samples <= max_y)))
            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)
            # pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, x_values <= self.x_max)

        adj_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat.fill(1000)
        num_match_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        x_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        x_dist_mat_close.fill(1000.)
        x_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        x_dist_mat_far.fill(1000.)
        z_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        z_dist_mat_close.fill(1000.)
        z_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        z_dist_mat_far.fill(1000.)
        # compute curve to curve distance
        for i in range(cnt_gt):
            for j in range(cnt_pred):
                x_dist = np.abs(gt_lanes[i][:, 0] - pred_lanes[j][:, 0])
                z_dist = np.abs(gt_lanes[i][:, 1] - pred_lanes[j][:, 1])
                euclidean_dist = np.sqrt(x_dist ** 2 + z_dist ** 2)

                # apply visibility to penalize different partial matching accordingly
                euclidean_dist[
                    np.logical_or(gt_visibility_mat[i, :] < 0.5, pred_visibility_mat[j, :] < 0.5)] = self.dist_th

                # if np.average(euclidean_dist) < 2*self.dist_th: # don't prune here to encourage finding perfect match
                num_match_mat[i, j] = np.sum(euclidean_dist < self.dist_th)
                adj_mat[i, j] = 1
                # ATTENTION: use the sum as int type to meet the requirements of min cost flow optimization (int type)
                # using num_match_mat as cost does not work?
                cost_mat[i, j] = np.sum(euclidean_dist).astype(np.int)
                # cost_mat[i, j] = num_match_mat[i, j]

                # use the both visible portion to calculate distance error
                both_visible_indices = np.logical_and(gt_visibility_mat[i, :] > 0.5, pred_visibility_mat[j, :] > 0.5)
                if np.sum(both_visible_indices[:close_range_idx]) > 0:
                    x_dist_mat_close[i, j] = np.sum(
                        x_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(
                        both_visible_indices[:close_range_idx])
                    z_dist_mat_close[i, j] = np.sum(
                        z_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(
                        both_visible_indices[:close_range_idx])
                else:
                    x_dist_mat_close[i, j] = self.dist_th
                    z_dist_mat_close[i, j] = self.dist_th

                if np.sum(both_visible_indices[close_range_idx:]) > 0:
                    x_dist_mat_far[i, j] = np.sum(
                        x_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(
                        both_visible_indices[close_range_idx:])
                    z_dist_mat_far[i, j] = np.sum(
                        z_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(
                        both_visible_indices[close_range_idx:])
                else:
                    x_dist_mat_far[i, j] = self.dist_th
                    z_dist_mat_far[i, j] = self.dist_th

        # solve bipartite matching vis min cost flow solver
        match_results = SolveMinCostFlow(adj_mat, cost_mat)
        match_results = np.array(match_results)

        # only a match with avg cost < self.dist_th is consider valid one
        match_gt_ids = []
        match_pred_ids = []
        if match_results.shape[0] > 0:
            for i in range(len(match_results)):
                if match_results[i, 2] < self.dist_th * self.y_samples.shape[0]:
                    gt_i = match_results[i, 0]
                    pred_i = match_results[i, 1]
                    # consider match when the matched points is above a ratio
                    if num_match_mat[gt_i, pred_i] / np.sum(gt_visibility_mat[gt_i, :]) >= self.ratio_th:
                        r_lane += 1
                        match_gt_ids.append(gt_i)
                    if num_match_mat[gt_i, pred_i] / np.sum(pred_visibility_mat[pred_i, :]) >= self.ratio_th:
                        p_lane += 1
                        match_pred_ids.append(pred_i)
                    x_error_close.append(x_dist_mat_close[gt_i, pred_i])
                    x_error_far.append(x_dist_mat_far[gt_i, pred_i])
                    z_error_close.append(z_dist_mat_close[gt_i, pred_i])
                    z_error_far.append(z_dist_mat_far[gt_i, pred_i])

        # visualize lanelines and matching results both in image and 3D
        if vis:
            P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, self.K)
            P_gt = np.matmul(self.H_crop, P_g2im)
            img = cv2.imread(ops.join(self.dataset_dir, raw_file))
            img = cv2.warpPerspective(img, self.H_crop, (self.resize_w, self.resize_h))
            img = img.astype(np.float) / 255

            for i in range(cnt_gt):
                x_values = gt_lanes[i][:, 0]
                z_values = gt_lanes[i][:, 1]
                x_2d, y_2d = projective_transformation(P_gt, x_values, self.y_samples, z_values)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)

                if i in match_gt_ids:
                    color = [0, 0, 1]
                else:
                    color = [0, 1, 1]
                for k in range(1, x_2d.shape[0]):
                    # only draw the visible portion
                    if gt_visibility_mat[i, k - 1] and gt_visibility_mat[i, k]:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color[-1::-1], 3)
                ax2.plot(x_values[np.where(gt_visibility_mat[i, :])],
                         self.y_samples[np.where(gt_visibility_mat[i, :])],
                         z_values[np.where(gt_visibility_mat[i, :])], color=color, linewidth=5)

            for i in range(cnt_pred):
                x_values = pred_lanes[i][:, 0]
                z_values = pred_lanes[i][:, 1]
                x_2d, y_2d = projective_transformation(P_gt, x_values, self.y_samples, z_values)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)

                if i in match_pred_ids:
                    color = [1, 0, 0]
                else:
                    color = [1, 0, 1]
                for k in range(1, x_2d.shape[0]):
                    # only draw the visible portion
                    if pred_visibility_mat[i, k - 1] and pred_visibility_mat[i, k]:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color[-1::-1], 2)
                ax2.plot(x_values[np.where(pred_visibility_mat[i, :])],
                         self.y_samples[np.where(pred_visibility_mat[i, :])],
                         z_values[np.where(pred_visibility_mat[i, :])], color=color, linewidth=5)

            cv2.putText(img, 'Recall: {:.3f}'.format(r_lane / (cnt_gt + 1e-6)),
                        (5, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)
            cv2.putText(img, 'Precision: {:.3f}'.format(p_lane / (cnt_pred + 1e-6)),
                        (5, 60), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)
            ax1.imshow(img[:, :, [2, 1, 0]])

        return r_lane, p_lane, cnt_gt, cnt_pred, x_error_close, x_error_far, z_error_close, z_error_far

    # compare predicted set and ground-truth set using a fixed lane probability threshold
    def bench_one_submit(self, pred_file, gt_file, prob_th=0.5, vis=False):
        if vis:
            save_path = pred_file[:pred_file.rfind('/')]
            save_path += '/vis'
            if vis and not os.path.exists(save_path):
                try:
                    os.makedirs(save_path)
                except OSError as e:
                    print(e.message)
        # try:
        pred_lines = open(pred_file).readlines()
        json_pred = [json.loads(line) for line in pred_lines]
        # except BaseException as e:
        #     raise Exception('Fail to load json file of the prediction.')
        json_gt = [json.loads(line) for line in open(gt_file).readlines()]
        if len(json_gt) != len(json_pred):
            raise Exception('We do not get the predictions of all the test tasks')
        gts = {l['raw_file']: l for l in json_gt}

        laneline_stats = []
        laneline_x_error_close = []
        laneline_x_error_far = []
        laneline_z_error_close = []
        laneline_z_error_far = []
        centerline_stats = []
        centerline_x_error_close = []
        centerline_x_error_far = []
        centerline_z_error_close = []
        centerline_z_error_far = []
        for i, pred in enumerate(json_pred):
            if 'raw_file' not in pred or 'laneLines' not in pred:
                raise Exception('raw_file or lanelines not in some predictions.')
            raw_file = pred['raw_file']

            # if raw_file != 'images/05/0000347.jpg':
            #     continue
            pred_lanelines = pred['laneLines']
            pred_laneLines_prob = pred['laneLines_prob']
            pred_lanelines = [pred_lanelines[ii] for ii in range(len(pred_laneLines_prob)) if
                              pred_laneLines_prob[ii] > prob_th]

            if raw_file not in gts:
                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')
            gt = gts[raw_file]
            gt_cam_height = gt['cam_height']
            gt_cam_pitch = gt['cam_pitch']

            if vis:
                fig = plt.figure()
                ax1 = fig.add_subplot(221)
                ax2 = fig.add_subplot(222, projection='3d')
                ax3 = fig.add_subplot(223)
                ax4 = fig.add_subplot(224, projection='3d')
            else:
                ax1 = 0
                ax2 = 0
                ax3 = 0
                ax4 = 0

            # evaluate lanelines
            gt_lanelines = gt['laneLines']
            gt_visibility = gt['laneLines_visibility']
            # N to N matching of lanelines
            r_lane, p_lane, cnt_gt, cnt_pred, \
            x_error_close, x_error_far, \
            z_error_close, z_error_far = self.bench(pred_lanelines,
                                                    gt_lanelines,
                                                    gt_visibility,
                                                    raw_file,
                                                    gt_cam_height,
                                                    gt_cam_pitch,
                                                    vis, ax1, ax2)
            laneline_stats.append(np.array([r_lane, p_lane, cnt_gt, cnt_pred]))
            # consider x_error z_error only for the matched lanes
            # if r_lane > 0 and p_lane > 0:
            laneline_x_error_close.extend(x_error_close)
            laneline_x_error_far.extend(x_error_far)
            laneline_z_error_close.extend(z_error_close)
            laneline_z_error_far.extend(z_error_far)

            # evaluate centerlines
            if not self.no_centerline:
                pred_centerlines = pred['centerLines']
                pred_centerlines_prob = pred['centerLines_prob']
                pred_centerlines = [pred_centerlines[ii] for ii in range(len(pred_centerlines_prob)) if
                                    pred_centerlines_prob[ii] > prob_th]

                gt_centerlines = gt['centerLines']
                gt_visibility = gt['centerLines_visibility']

                # N to N matching of lanelines
                r_lane, p_lane, cnt_gt, cnt_pred, \
                x_error_close, x_error_far, \
                z_error_close, z_error_far = self.bench(pred_centerlines,
                                                        gt_centerlines,
                                                        gt_visibility,
                                                        raw_file,
                                                        gt_cam_height,
                                                        gt_cam_pitch,
                                                        vis, ax3, ax4)
                centerline_stats.append(np.array([r_lane, p_lane, cnt_gt, cnt_pred]))
                # consider x_error z_error only for the matched lanes
                # if r_lane > 0 and p_lane > 0:
                centerline_x_error_close.extend(x_error_close)
                centerline_x_error_far.extend(x_error_far)
                centerline_z_error_close.extend(z_error_close)
                centerline_z_error_far.extend(z_error_far)

            if vis:
                ax1.set_xticks([])
                ax1.set_yticks([])
                # ax2.set_xlabel('x axis')
                # ax2.set_ylabel('y axis')
                # ax2.set_zlabel('z axis')
                bottom, top = ax2.get_zlim()
                left, right = ax2.get_xlim()
                ax2.set_zlim(min(bottom, -0.1), max(top, 0.1))
                ax2.set_xlim(left, right)
                ax2.set_ylim(vis_min_y, vis_max_y)
                ax2.locator_params(nbins=5, axis='x')
                ax2.locator_params(nbins=5, axis='z')
                ax2.tick_params(pad=18)

                ax3.set_xticks([])
                ax3.set_yticks([])
                # ax4.set_xlabel('x axis')
                # ax4.set_ylabel('y axis')
                # ax4.set_zlabel('z axis')
                bottom, top = ax4.get_zlim()
                left, right = ax4.get_xlim()
                ax4.set_zlim(min(bottom, -0.1), max(top, 0.1))
                ax4.set_xlim(left, right)
                ax4.set_ylim(vis_min_y, vis_max_y)
                ax4.locator_params(nbins=5, axis='x')
                ax4.locator_params(nbins=5, axis='z')
                ax4.tick_params(pad=18)

                fig.subplots_adjust(wspace=0, hspace=0.01)
                fig.savefig(ops.join(save_path, raw_file.replace("/", "_")))
                plt.close(fig)
                print('processed sample: {}  {}'.format(i, raw_file))

        output_stats = []
        laneline_stats = np.array(laneline_stats)
        laneline_x_error_close = np.array(laneline_x_error_close)
        laneline_x_error_far = np.array(laneline_x_error_far)
        laneline_z_error_close = np.array(laneline_z_error_close)
        laneline_z_error_far = np.array(laneline_z_error_far)

        R_lane = np.sum(laneline_stats[:, 0]) / (np.sum(laneline_stats[:, 2]) + 1e-6)
        P_lane = np.sum(laneline_stats[:, 1]) / (np.sum(laneline_stats[:, 3]) + 1e-6)
        F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)
        x_error_close_avg = np.average(laneline_x_error_close)
        x_error_far_avg = np.average(laneline_x_error_far)
        z_error_close_avg = np.average(laneline_z_error_close)
        z_error_far_avg = np.average(laneline_z_error_far)

        output_stats.append(F_lane)
        output_stats.append(R_lane)
        output_stats.append(P_lane)
        output_stats.append(x_error_close_avg)
        output_stats.append(x_error_far_avg)
        output_stats.append(z_error_close_avg)
        output_stats.append(z_error_far_avg)

        if not self.no_centerline:
            centerline_stats = np.array(centerline_stats)
            centerline_x_error_close = np.array(centerline_x_error_close)
            centerline_x_error_far = np.array(centerline_x_error_far)
            centerline_z_error_close = np.array(centerline_z_error_close)
            centerline_z_error_far = np.array(centerline_z_error_far)

            R_lane = np.sum(centerline_stats[:, 0]) / (np.sum(centerline_stats[:, 2]) + 1e-6)
            P_lane = np.sum(centerline_stats[:, 1]) / (np.sum(centerline_stats[:, 3]) + 1e-6)
            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)
            x_error_close_avg = np.average(centerline_x_error_close)
            x_error_far_avg = np.average(centerline_x_error_far)
            z_error_close_avg = np.average(centerline_z_error_close)
            z_error_far_avg = np.average(centerline_z_error_far)

            output_stats.append(F_lane)
            output_stats.append(R_lane)
            output_stats.append(P_lane)
            output_stats.append(x_error_close_avg)
            output_stats.append(x_error_far_avg)
            output_stats.append(z_error_close_avg)
            output_stats.append(z_error_far_avg)

        return output_stats

    def bench_PR(self, pred_lanes, gt_lanes, gt_visibility):
        """
            Matching predicted lanes and ground-truth lanes in their IPM projection, ignoring z attributes.
            x error, y_error, and z error are all considered, although the matching does not rely on z
            The input of prediction and ground-truth lanes are in ground coordinate, x-right, y-forward, z-up
            The fundamental assumption is: 1. there are no two points from different lanes with identical x, y
                                              but different z's
                                           2. there are no two points from a single lane having identical x, y
                                              but different z's
            If the interest area is within the current drivable road, the above assumptions are almost always valid.
        :param pred_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :param gt_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :return:
        """

        r_lane, p_lane = 0., 0.

        # only keep the visible portion
        gt_lanes = [prune_3d_lane_by_visibility(np.array(gt_lane), np.array(gt_visibility[k])) for k, gt_lane in
                    enumerate(gt_lanes)]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        # only consider those gt lanes overlapping with sampling range
        gt_lanes = [lane for lane in gt_lanes if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]
        gt_lanes = [prune_3d_lane_by_range(np.array(gt_lane), 3 * self.x_min, 3 * self.x_max) for gt_lane in gt_lanes]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        cnt_gt = len(gt_lanes)
        cnt_pred = len(pred_lanes)

        gt_visibility_mat = np.zeros((cnt_gt, 100))
        pred_visibility_mat = np.zeros((cnt_pred, 100))
        # resample gt and pred at y_samples
        for i in range(cnt_gt):
            min_y = np.min(np.array(gt_lanes[i])[:, 1])
            max_y = np.max(np.array(gt_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(gt_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            gt_lanes[i] = np.vstack([x_values, z_values]).T
            gt_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                     np.logical_and(x_values <= self.x_max,
                                                                    np.logical_and(self.y_samples >= min_y,
                                                                                   self.y_samples <= max_y)))
            gt_visibility_mat[i, :] = np.logical_and(gt_visibility_mat[i, :], visibility_vec)

        for i in range(cnt_pred):
            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size
            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))
            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)
            min_y = np.min(np.array(pred_lanes[i])[:, 1])
            max_y = np.max(np.array(pred_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            pred_lanes[i] = np.vstack([x_values, z_values]).T
            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                       np.logical_and(x_values <= self.x_max,
                                                                      np.logical_and(self.y_samples >= min_y,
                                                                                     self.y_samples <= max_y)))
            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)
            # pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, x_values <= self.x_max)

        adj_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat.fill(1000)
        num_match_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        # compute curve to curve distance
        for i in range(cnt_gt):
            for j in range(cnt_pred):
                x_dist = np.abs(gt_lanes[i][:, 0] - pred_lanes[j][:, 0])
                z_dist = np.abs(gt_lanes[i][:, 1] - pred_lanes[j][:, 1])
                euclidean_dist = np.sqrt(x_dist ** 2 + z_dist ** 2)

                # apply visibility to penalize different partial matching accordingly
                euclidean_dist[
                    np.logical_or(gt_visibility_mat[i, :] < 0.5, pred_visibility_mat[j, :] < 0.5)] = self.dist_th

                # if np.average(euclidean_dist) < 2*self.dist_th: # don't prune here to encourage finding perfect match
                num_match_mat[i, j] = np.sum(euclidean_dist < self.dist_th)
                adj_mat[i, j] = 1
                # ATTENTION: use the sum as int type to meet the requirements of min cost flow optimization (int type)
                # why using num_match_mat as cost does not work?
                cost_mat[i, j] = np.sum(euclidean_dist).astype(np.int)
                # cost_mat[i, j] = num_match_mat[i, j]

        # solve bipartite matching vis min cost flow solver
        match_results = SolveMinCostFlow(adj_mat, cost_mat)
        match_results = np.array(match_results)

        # only a match with avg cost < self.dist_th is consider valid one
        match_gt_ids = []
        match_pred_ids = []
        if match_results.shape[0] > 0:
            for i in range(len(match_results)):
                if match_results[i, 2] < self.dist_th * self.y_samples.shape[0]:
                    gt_i = match_results[i, 0]
                    pred_i = match_results[i, 1]
                    # consider match when the matched points is above a ratio
                    if num_match_mat[gt_i, pred_i] / np.sum(gt_visibility_mat[gt_i, :]) >= self.ratio_th:
                        r_lane += 1
                        match_gt_ids.append(gt_i)
                    if num_match_mat[gt_i, pred_i] / np.sum(pred_visibility_mat[pred_i, :]) >= self.ratio_th:
                        p_lane += 1
                        match_pred_ids.append(pred_i)

        return r_lane, p_lane, cnt_gt, cnt_pred

    # evaluate two dataset at varying lane probability threshold to calculate AP
    def bench_one_submit_varying_probs(self, pred_file, gt_file, eval_out_file=None, eval_fig_file=None):
        varying_th = np.linspace(0.05, 0.95, 19)
        # try:
        pred_lines = open(pred_file).readlines()
        json_pred = [json.loads(line) for line in pred_lines]
        # except BaseException as e:
        #     raise Exception('Fail to load json file of the prediction.')
        json_gt = [json.loads(line) for line in open(gt_file).readlines()]
        if len(json_gt) != len(json_pred):
            raise Exception('We do not get the predictions of all the test tasks')
        gts = {l['raw_file']: l for l in json_gt}

        laneline_r_all = []
        laneline_p_all = []
        laneline_gt_cnt_all = []
        laneline_pred_cnt_all = []
        centerline_r_all = []
        centerline_p_all = []
        centerline_gt_cnt_all = []
        centerline_pred_cnt_all = []
        for i, pred in enumerate(json_pred):
            print('Evaluating sample {} / {}'.format(i, len(json_pred)))
            if 'raw_file' not in pred or 'laneLines' not in pred:
                raise Exception('raw_file or lanelines not in some predictions.')
            raw_file = pred['raw_file']

            pred_lanelines = pred['laneLines']
            pred_laneLines_prob = pred['laneLines_prob']
            if raw_file not in gts:
                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')
            gt = gts[raw_file]
            gt_cam_height = gt['cam_height']
            gt_cam_pitch = gt['cam_pitch']

            # evaluate lanelines
            gt_lanelines = gt['laneLines']
            gt_visibility = gt['laneLines_visibility']
            r_lane_vec = []
            p_lane_vec = []
            cnt_gt_vec = []
            cnt_pred_vec = []

            for prob_th in varying_th:
                pred_lanelines = [pred_lanelines[ii] for ii in range(len(pred_laneLines_prob)) if
                                  pred_laneLines_prob[ii] > prob_th]
                pred_laneLines_prob = [prob for prob in pred_laneLines_prob if prob > prob_th]
                pred_lanelines_copy = copy.deepcopy(pred_lanelines)
                # N to N matching of lanelines
                r_lane, p_lane, cnt_gt, cnt_pred = self.bench_PR(pred_lanelines_copy,
                                                                 gt_lanelines,
                                                                 gt_visibility)
                r_lane_vec.append(r_lane)
                p_lane_vec.append(p_lane)
                cnt_gt_vec.append(cnt_gt)
                cnt_pred_vec.append(cnt_pred)

            laneline_r_all.append(r_lane_vec)
            laneline_p_all.append(p_lane_vec)
            laneline_gt_cnt_all.append(cnt_gt_vec)
            laneline_pred_cnt_all.append(cnt_pred_vec)

            # evaluate centerlines
            if not self.no_centerline:
                pred_centerlines = pred['centerLines']
                pred_centerLines_prob = pred['centerLines_prob']
                gt_centerlines = gt['centerLines']
                gt_visibility = gt['centerLines_visibility']
                r_lane_vec = []
                p_lane_vec = []
                cnt_gt_vec = []
                cnt_pred_vec = []

                for prob_th in varying_th:
                    pred_centerlines = [pred_centerlines[ii] for ii in range(len(pred_centerLines_prob)) if
                                        pred_centerLines_prob[ii] > prob_th]
                    pred_centerLines_prob = [prob for prob in pred_centerLines_prob if prob > prob_th]
                    pred_centerlines_copy = copy.deepcopy(pred_centerlines)
                    # N to N matching of lanelines
                    r_lane, p_lane, cnt_gt, cnt_pred = self.bench_PR(pred_centerlines_copy,
                                                                     gt_centerlines,
                                                                     gt_visibility)
                    r_lane_vec.append(r_lane)
                    p_lane_vec.append(p_lane)
                    cnt_gt_vec.append(cnt_gt)
                    cnt_pred_vec.append(cnt_pred)
                centerline_r_all.append(r_lane_vec)
                centerline_p_all.append(p_lane_vec)
                centerline_gt_cnt_all.append(cnt_gt_vec)
                centerline_pred_cnt_all.append(cnt_pred_vec)

        output_stats = []
        # compute precision, recall
        laneline_r_all = np.array(laneline_r_all)
        laneline_p_all = np.array(laneline_p_all)
        laneline_gt_cnt_all = np.array(laneline_gt_cnt_all)
        laneline_pred_cnt_all = np.array(laneline_pred_cnt_all)

        R_lane = np.sum(laneline_r_all, axis=0) / (np.sum(laneline_gt_cnt_all, axis=0) + 1e-6)
        P_lane = np.sum(laneline_p_all, axis=0) / (np.sum(laneline_pred_cnt_all, axis=0) + 1e-6)
        F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)

        output_stats.append(F_lane)
        output_stats.append(R_lane)
        output_stats.append(P_lane)

        if not self.no_centerline:
            centerline_r_all = np.array(centerline_r_all)
            centerline_p_all = np.array(centerline_p_all)
            centerline_gt_cnt_all = np.array(centerline_gt_cnt_all)
            centerline_pred_cnt_all = np.array(centerline_pred_cnt_all)

            R_lane = np.sum(centerline_r_all, axis=0) / (np.sum(centerline_gt_cnt_all, axis=0) + 1e-6)
            P_lane = np.sum(centerline_p_all, axis=0) / (np.sum(centerline_pred_cnt_all, axis=0) + 1e-6)
            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)

            output_stats.append(F_lane)
            output_stats.append(R_lane)
            output_stats.append(P_lane)

        # calculate metrics
        laneline_F = output_stats[0]
        laneline_F_max = np.max(laneline_F)
        laneline_max_i = np.argmax(laneline_F)
        laneline_R = output_stats[1]
        laneline_P = output_stats[2]
        centerline_F = output_stats[3]
        centerline_F_max = centerline_F[laneline_max_i]
        centerline_max_i = laneline_max_i
        centerline_R = output_stats[4]
        centerline_P = output_stats[5]

        laneline_R = np.array([1.] + laneline_R.tolist() + [0.])
        laneline_P = np.array([0.] + laneline_P.tolist() + [1.])
        centerline_R = np.array([1.] + centerline_R.tolist() + [0.])
        centerline_P = np.array([0.] + centerline_P.tolist() + [1.])
        f_laneline = interp1d(laneline_R, laneline_P)
        f_centerline = interp1d(centerline_R, centerline_P)
        r_range = np.linspace(0.05, 0.95, 19)
        laneline_AP = np.mean(f_laneline(r_range))
        centerline_AP = np.mean(f_centerline(r_range))

        if eval_fig_file is not None:
            # plot PR curve
            fig = plt.figure()
            ax1 = fig.add_subplot(121)
            ax2 = fig.add_subplot(122)
            ax1.plot(laneline_R, laneline_P, '-s')
            ax2.plot(centerline_R, centerline_P, '-s')

            ax1.set_xlim(0, 1)
            ax1.set_ylim(0, 1)
            ax1.set_title('Lane Line')
            ax1.set_xlabel('Recall')
            ax1.set_ylabel('Precision')
            ax1.set_aspect('equal')
            ax1.legend('Max F-measure {:.3}'.format(laneline_F_max))

            ax2.set_xlim(0, 1)
            ax2.set_ylim(0, 1)
            ax2.set_title('Center Line')
            ax2.set_xlabel('Recall')
            ax2.set_ylabel('Precision')
            ax2.set_aspect('equal')
            ax2.legend('Max F-measure {:.3}'.format(centerline_F_max))

            # fig.subplots_adjust(wspace=0.1, hspace=0.01)
            fig.savefig(eval_fig_file)
            plt.close(fig)

        # print("===> Evaluation on validation set: \n"
        #       "laneline max F-measure {:.3} at Recall {:.3}, Precision {:.3} \n"
        #       "laneline AP: {:.3}\n"
        #       "centerline max F-measure {:.3} at Recall {:.3}, Precision {:.3} \n"
        #       "centerline AP: {:.3} \n".format(laneline_F_max,
        #                                        laneline_R[laneline_max_i + 1],
        #                                        laneline_P[laneline_max_i + 1],
        #                                        laneline_AP,
        #                                        centerline_F_max,
        #                                        centerline_R[centerline_max_i + 1],
        #                                        centerline_P[centerline_max_i + 1],
        #                                        centerline_AP))

        json_out = {}
        json_out['laneline_R'] = laneline_R[1:-1].astype(np.float32).tolist()
        json_out['laneline_P'] = laneline_P[1:-1].astype(np.float32).tolist()
        json_out['laneline_F_max'] = laneline_F_max
        json_out['laneline_max_i'] = laneline_max_i.tolist()
        json_out['laneline_AP'] = laneline_AP

        json_out['centerline_R'] = centerline_R[1:-1].astype(np.float32).tolist()
        json_out['centerline_P'] = centerline_P[1:-1].astype(np.float32).tolist()
        json_out['centerline_F_max'] = centerline_F_max
        json_out['centerline_max_i'] = centerline_max_i.tolist()
        json_out['centerline_AP'] = centerline_AP

        json_out['max_F_prob_th'] = varying_th[laneline_max_i]

        if eval_out_file is not None:
            with open(eval_out_file, 'w') as jsonFile:
                jsonFile.write(json.dumps(json_out))
                jsonFile.write('\n')
                jsonFile.close()
        return json_out

'''
if __name__ == '__main__':
    vis = False
    args = define_args()
    #args = parser.parse_args()

    # two method are compared: '3D_LaneNet' and 'Gen_LaneNet'
    method_name = 'Gen_LaneNet_ext'

    # Three different splits of datasets: 'standard', 'rare_subsit', 'illus_chg'
    data_split = 'illus_chg'

    # location where the original dataset is saved. Image will be loaded in case of visualization
    args.dataset_dir = '/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release' ##YOU NEED TO EDIT THIS DIRECTORY

    # load configuration for certain dataset
    sim3d_config(args)

    # auto-file in dependent paths
    base_dir='/content/drive/Shareddrives/colab/'
    gt_file = base_dir+'data_splits/' + data_split + '/test.json'
    pred_folder = base_dir+'data_splits/' + data_split + '/' + method_name
    pred_file = pred_folder + '/test_pred_file.json'

    # Initialize evaluator
    evaluator = LaneEval(args)

    # evaluation at varying thresholds
    eval_stats_pr = evaluator.bench_one_submit_varying_probs(pred_file, gt_file)
    max_f_prob = eval_stats_pr['max_F_prob_th']

    # evaluate at the point with max F-measure. Additional eval of position error. Option to visualize matching result
    eval_stats = evaluator.bench_one_submit(pred_file, gt_file, prob_th=max_f_prob, vis=vis)

    print("Metrics: AP, F-score, x error (close), x error (far), z error (close), z error (far)")
    print(
        "Laneline:  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}".format(eval_stats_pr['laneline_AP'], eval_stats[0],
                                                                     eval_stats[3], eval_stats[4],
                                                                     eval_stats[5], eval_stats[6]))
    print("Centerline:  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}".format(eval_stats_pr['centerline_AP'], eval_stats[7],
                                                                         eval_stats[10], eval_stats[11],
                                                                         eval_stats[12], eval_stats[13]))
                                                                         '''

'\nif __name__ == \'__main__\':\n    vis = False\n    args = define_args()\n    #args = parser.parse_args()\n\n    # two method are compared: \'3D_LaneNet\' and \'Gen_LaneNet\'\n    method_name = \'Gen_LaneNet_ext\'\n\n    # Three different splits of datasets: \'standard\', \'rare_subsit\', \'illus_chg\'\n    data_split = \'illus_chg\'\n\n    # location where the original dataset is saved. Image will be loaded in case of visualization\n    args.dataset_dir = \'/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release\' ##YOU NEED TO EDIT THIS DIRECTORY\n\n    # load configuration for certain dataset\n    sim3d_config(args)\n\n    # auto-file in dependent paths\n    base_dir=\'/content/drive/Shareddrives/colab/\'\n    gt_file = base_dir+\'data_splits/\' + data_split + \'/test.json\'\n    pred_folder = base_dir+\'data_splits/\' + data_split + \'/\' + method_name\n    pred_file = pred_folder + \'/test_pred_file.json\'\n\n    # Initialize evaluator\n    evaluator = LaneEval(args)\n\n   

In [7]:
"""
/dataloader/Load_Data_3DLane_ext.py /

Dataloader for networks integrated with the new geometry-guided anchor design proposed in Gen-LaneNet:
    "Gen-laneNet: a generalized and scalable approach for 3D lane detection", Y.Guo, etal., arxiv 2020
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import copy
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image, ImageOps
import json
import random
import warnings
import torchvision.transforms.functional as Q
#from tools.utils import *
warnings.simplefilter('ignore', np.RankWarning)
matplotlib.use('Agg')


class LaneDataset(Dataset):
    """
    Dataset with labeled lanes
        This implementation considers:
        w/o laneline 3D attributes
        w/o centerline annotations
        default considers 3D laneline, including centerlines
        This new version of data loader prepare ground-truth anchor tensor in flat ground space.
        It is assumed the dataset provides accurate visibility labels. Preparing ground-truth tensor depends on it.
    """
    def __init__(self, dataset_base_dir, json_file_path, args, data_aug=False, save_std=False):
        """
        :param dataset_info_file: json file list
        """
        # define image pre-processor
        self.totensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(args.vgg_mean, args.vgg_std)
        self.data_aug = data_aug

        # dataset parameters
        self.dataset_name = args.dataset_name
        self.no_3d = args.no_3d
        self.no_centerline = args.no_centerline

        self.h_org = args.org_h
        self.w_org = args.org_w
        self.h_crop = args.crop_y

        # parameters related to service network
        self.h_net = args.resize_h
        self.w_net = args.resize_w
        self.ipm_h = args.ipm_h
        self.ipm_w = args.ipm_w
        # self.x_ratio = float(self.w_net) / float(self.w_org)
        # self.y_ratio = float(self.h_net) / float(self.h_org - self.h_crop)
        self.top_view_region = args.top_view_region

        self.K = args.K
        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [args.resize_h, args.resize_w])
        # transformation from ipm to ground region
        self.H_ipm2g = cv2.getPerspectiveTransform(np.float32([[0, 0],
                                                               [self.ipm_w-1, 0],
                                                               [0, self.ipm_h-1],
                                                               [self.ipm_w-1, self.ipm_h-1]]),
                                                   np.float32(args.top_view_region))
        # self.H_g2ipm = np.linalg.inv(H_ipm2g)

        if args.fix_cam:
            self.fix_cam = True
            # compute the homography between image and IPM, and crop transformation
            self.cam_height = args.cam_height
            self.cam_pitch = np.pi / 180 * args.pitch
            self.P_g2im = projection_g2im(self.cam_pitch, self.cam_height, args.K)
            self.H_g2im = homograpthy_g2im(self.cam_pitch, self.cam_height, args.K)
            self.H_im2g = np.linalg.inv(self.H_g2im)
            self.H_im2ipm = np.linalg.inv(np.matmul(self.H_crop, np.matmul(self.H_g2im, self.H_ipm2g)))
        else:
            self.fix_cam = False

        # compute anchor steps
        x_min = self.top_view_region[0, 0]
        x_max = self.top_view_region[1, 0]
        self.x_min = x_min
        self.x_max = x_max
        self.anchor_x_steps = np.linspace(x_min, x_max, np.int(args.ipm_w/8), endpoint=True)
        self.anchor_y_steps = args.anchor_y_steps
        self.num_y_steps = len(self.anchor_y_steps)

        if self.no_centerline:
            self.num_types = 1
        else:
            self.num_types = 3

        if self.no_3d:
            self.anchor_dim = self.num_y_steps + 1
        else:
            self.anchor_dim = 3 * args.num_y_steps + 1

        self.y_ref = args.y_ref
        self.ref_id = np.argmin(np.abs(self.num_y_steps - self.y_ref))

        # parse ground-truth file
        if 'tusimple' in self.dataset_name:
            self._label_image_path,\
                self._label_laneline_all_org, \
                self._label_laneline_all, \
                self._laneline_ass_ids, \
                self._x_off_std,\
                self._gt_laneline_visibility_all = self.init_dataset_tusimple(dataset_base_dir, json_file_path)
        else:  # assume loading apollo sim 3D lane
            self._label_image_path, \
                self._label_laneline_all_org, \
                self._label_laneline_all, \
                self._label_centerline_all, \
                self._label_cam_height_all, \
                self._label_cam_pitch_all, \
                self._laneline_ass_ids, \
                self._centerline_ass_ids, \
                self._x_off_std, \
                self._y_off_std, \
                self._z_std, \
                self._gt_laneline_visibility_all, \
                self._gt_centerline_visibility_all = self.init_dataset_3D(dataset_base_dir, json_file_path)
        self.n_samples = self._label_image_path.shape[0]

        if save_std is True:
            with open(ops.join(args.data_dir, 'geo_anchor_std.json'), 'w') as jsonFile:
                json_out = {}
                json_out["x_off_std"] = self._x_off_std.tolist()
                json_out["z_std"] = self._z_std.tolist()
                json.dump(json_out, jsonFile)
                jsonFile.write('\n')
        # # normalize label values: manual execute in main function, in case overwriting stds is needed
        # self.normalize_lane_label()

    def __len__(self):
        """
        Conventional len method
        """
        return self.n_samples

    def __getitem__(self, idx):
        """
        Args: idx (int): Index in list to load image
        """

        # fetch camera height and pitch
        if not self.fix_cam:
            gt_cam_height = self._label_cam_height_all[idx]
            gt_cam_pitch = self._label_cam_pitch_all[idx]
        else:
            gt_cam_height = self.cam_height
            gt_cam_pitch = self.cam_pitch

        img_name = self._label_image_path[idx]

        with open(img_name, 'rb') as f:
            image = (Image.open(f).convert('RGB'))

        # image preprocess with crop and resize
        image = Q.crop(image, self.h_crop, 0, self.h_org-self.h_crop, self.w_org)
        image = Q.resize(image, size=(self.h_net, self.w_net), interpolation=Image.BILINEAR)

        gt_anchor = np.zeros([np.int32(self.ipm_w / 8), self.num_types, self.anchor_dim], dtype=np.float32)
        gt_lanes = self._label_laneline_all[idx]
        gt_vis_inds = self._gt_laneline_visibility_all[idx]
        for i in range(len(gt_lanes)):

            # if ass_id >= 0:
            ass_id = self._laneline_ass_ids[idx][i]
            x_off_values = gt_lanes[i][:, 0]
            z_values = gt_lanes[i][:, 1]
            visibility = gt_vis_inds[i]
            # assign anchor tensor values
            gt_anchor[ass_id, 0, 0: self.num_y_steps] = x_off_values
            if not self.no_3d:
                gt_anchor[ass_id, 0, self.num_y_steps:2*self.num_y_steps] = z_values
                gt_anchor[ass_id, 0, 2*self.num_y_steps:3*self.num_y_steps] = visibility

            gt_anchor[ass_id, 0, -1] = 1.0

        # fetch centerlines when available
        if not self.no_centerline:
            gt_lanes = self._label_centerline_all[idx]
            gt_vis_inds = self._gt_centerline_visibility_all[idx]
            for i in range(len(gt_lanes)):

                # if ass_id >= 0:
                ass_id = self._centerline_ass_ids[idx][i]
                x_off_values = gt_lanes[i][:, 0]
                z_values = gt_lanes[i][:, 1]
                visibility = gt_vis_inds[i]

                # assign anchor tensor values
                # if ass_id >= 0:
                if gt_anchor[ass_id, 1, -1] > 0:  # the case one splitting lane has been assigned
                    gt_anchor[ass_id, 2, 0: self.num_y_steps] = x_off_values
                    if not self.no_3d:
                        gt_anchor[ass_id, 2, self.num_y_steps:2*self.num_y_steps] = z_values
                        gt_anchor[ass_id, 2, 2*self.num_y_steps:3*self.num_y_steps] = visibility
                    gt_anchor[ass_id, 2, -1] = 1.0
                else:
                    gt_anchor[ass_id, 1, 0: self.num_y_steps] = x_off_values
                    if not self.no_3d:
                        gt_anchor[ass_id, 1, self.num_y_steps:2*self.num_y_steps] = z_values
                        gt_anchor[ass_id, 1, 2*self.num_y_steps:3*self.num_y_steps] = visibility
                    gt_anchor[ass_id, 1, -1] = 1.0

        if self.data_aug:
            img_rot, aug_mat = data_aug_rotate(image)
            image = Image.fromarray(img_rot)
        image = self.totensor(image).float()
        image = self.normalize(image)
        gt_anchor = gt_anchor.reshape([np.int32(self.ipm_w / 8), -1])
        gt_anchor = torch.from_numpy(gt_anchor)
        gt_cam_height = torch.tensor(gt_cam_height, dtype=torch.float32)
        gt_cam_pitch = torch.tensor(gt_cam_pitch, dtype=torch.float32)

        # prepare binary segmentation label map
        seg_label = np.zeros((self.h_net, self.w_net), dtype=np.int8)
        gt_lanes = self._label_laneline_all_org[idx]
        for i, lane in enumerate(gt_lanes):
            # project lane3d to image
            if self.no_3d:
                x_2d = lane[:, 0]
                y_2d = lane[:, 1]
                # update transformation with image augmentation
                if self.data_aug:
                    x_2d, y_2d = homographic_transformation(aug_mat, x_2d, y_2d)
            else:
                H_g2im, P_g2im, H_crop, H_im2ipm = self.transform_mats(idx)
                M = np.matmul(H_crop, P_g2im)
                # update transformation with image augmentation
                if self.data_aug:
                    M = np.matmul(aug_mat, M)
                x_2d, y_2d = projective_transformation(M, lane[:, 0],
                                                       lane[:, 1], lane[:, 2])
            for j in range(len(x_2d) - 1):
                seg_label = cv2.line(seg_label,
                                     (int(x_2d[j]), int(y_2d[j])), (int(x_2d[j+1]), int(y_2d[j+1])),
                                     color=np.asscalar(np.array([1])))
        seg_label = torch.from_numpy(seg_label.astype(np.float32))
        seg_label.unsqueeze_(0)

        if self.data_aug:
            aug_mat = torch.from_numpy(aug_mat.astype(np.float32))
            return image, seg_label, gt_anchor, idx, gt_cam_height, gt_cam_pitch, aug_mat
        return image, seg_label, gt_anchor, idx, gt_cam_height, gt_cam_pitch

    def init_dataset_3D(self, dataset_base_dir, json_file_path):
        """
        :param dataset_info_file:
        :return: image paths, labels in unormalized net input coordinates
        data processing:
        ground truth labels map are scaled wrt network input sizes
        """

        # load image path, and lane pts
        label_image_path = []
        gt_laneline_pts_all = []
        gt_centerline_pts_all = []
        gt_laneline_visibility_all = []
        gt_centerline_visibility_all = []
        gt_cam_height_all = []
        gt_cam_pitch_all = []

        assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)

        with open(json_file_path, 'r') as file:
            for line in file:
                info_dict = json.loads(line)

                image_path = ops.join(dataset_base_dir, info_dict['raw_file'])
                assert ops.exists(image_path), '{:s} not exist'.format(image_path)

                label_image_path.append(image_path)

                gt_lane_pts = info_dict['laneLines']
                gt_lane_visibility = info_dict['laneLines_visibility']
                for i, lane in enumerate(gt_lane_pts):
                    # A GT lane can be either 2D or 3D
                    # if a GT lane is 3D, the height is intact from 3D GT, so keep it intact here too
                    lane = np.array(lane)
                    gt_lane_pts[i] = lane
                    gt_lane_visibility[i] = np.array(gt_lane_visibility[i])
                gt_laneline_pts_all.append(gt_lane_pts)
                gt_laneline_visibility_all.append(gt_lane_visibility)

                if not self.no_centerline:
                    gt_lane_pts = info_dict['centerLines']
                    gt_lane_visibility = info_dict['centerLines_visibility']
                    for i, lane in enumerate(gt_lane_pts):
                        # A GT lane can be either 2D or 3D
                        # if a GT lane is 3D, the height is intact from 3D GT, so keep it intact here too
                        lane = np.array(lane)
                        gt_lane_pts[i] = lane
                        gt_lane_visibility[i] = np.array(gt_lane_visibility[i])
                    gt_centerline_pts_all.append(gt_lane_pts)
                    gt_centerline_visibility_all.append(gt_lane_visibility)

                if not self.fix_cam:
                    gt_cam_height = info_dict['cam_height']
                    gt_cam_height_all.append(gt_cam_height)
                    gt_cam_pitch = info_dict['cam_pitch']
                    gt_cam_pitch_all.append(gt_cam_pitch)

        label_image_path = np.array(label_image_path)
        gt_cam_height_all = np.array(gt_cam_height_all)
        gt_cam_pitch_all = np.array(gt_cam_pitch_all)
        gt_laneline_pts_all_org = copy.deepcopy(gt_laneline_pts_all)

        # convert labeled laneline to anchor format
        gt_laneline_ass_ids = []
        gt_centerline_ass_ids = []
        lane_x_off_all = []
        lane_z_all = []
        lane_y_off_all = []  # this is the offset of y when transformed back 3 3D
        visibility_all_flat = []
        for idx in range(len(gt_laneline_pts_all)):
            # if idx == 936:
            #     print(label_image_path[idx])
            # fetch camera height and pitch
            gt_cam_height = gt_cam_height_all[idx]
            gt_cam_pitch = gt_cam_pitch_all[idx]
            if not self.fix_cam:
                P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, self.K)
                H_g2im = homograpthy_g2im(gt_cam_pitch, gt_cam_height, self.K)
                H_im2g = np.linalg.inv(H_g2im)
            else:
                P_g2im = self.P_g2im
                H_im2g = self.H_im2g
            P_g2gflat = np.matmul(H_im2g, P_g2im)

            gt_lanes = gt_laneline_pts_all[idx]
            gt_visibility = gt_laneline_visibility_all[idx]

            # prune gt lanes by visibility labels
            gt_lanes = [prune_3d_lane_by_visibility(gt_lane, gt_visibility[k]) for k, gt_lane in enumerate(gt_lanes)]
            gt_laneline_pts_all_org[idx] = gt_lanes
            # prune out-of-range points are necessary before transformation
            gt_lanes = [prune_3d_lane_by_range(gt_lane, 3*self.x_min, 3*self.x_max) for gt_lane in gt_lanes]
            gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]

            # convert 3d lanes to flat ground space
            self.convert_lanes_3d_to_gflat(gt_lanes, P_g2gflat)

            gt_anchors = []
            ass_ids = []
            visibility_vectors = []
            for i in range(len(gt_lanes)):

                # convert gt label to anchor label
                # consider individual out-of-range interpolation still visible
                ass_id, x_off_values, z_values, visibility_vec = self.convert_label_to_anchor(gt_lanes[i], H_im2g)
                if ass_id >= 0:
                    gt_anchors.append(np.vstack([x_off_values, z_values]).T)
                    ass_ids.append(ass_id)
                    visibility_vectors.append(visibility_vec)

            for i in range(len(gt_anchors)):
                lane_x_off_all.append(gt_anchors[i][:, 0])
                lane_z_all.append(gt_anchors[i][:, 1])
                # compute y offset when transformed back to 3D space
                lane_y_off_all.append(-gt_anchors[i][:, 1]*self.anchor_y_steps/gt_cam_height)
            visibility_all_flat.extend(visibility_vectors)
            gt_laneline_ass_ids.append(ass_ids)
            gt_laneline_pts_all[idx] = gt_anchors
            gt_laneline_visibility_all[idx] = visibility_vectors

            if not self.no_centerline:
                gt_lanes = gt_centerline_pts_all[idx]
                gt_visibility = gt_centerline_visibility_all[idx]

                # prune gt lanes by visibility labels
                gt_lanes = [prune_3d_lane_by_visibility(gt_lane, gt_visibility[k]) for k, gt_lane in
                            enumerate(gt_lanes)]
                # prune out-of-range points are necessary before transformation
                gt_lanes = [prune_3d_lane_by_range(gt_lane, 3 * self.x_min, 3 * self.x_max) for gt_lane in gt_lanes]
                gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]

                # convert 3d lanes to flat ground space
                self.convert_lanes_3d_to_gflat(gt_lanes, P_g2gflat)

                gt_anchors = []
                ass_ids = []
                visibility_vectors = []
                for i in range(len(gt_lanes)):
                    # convert gt label to anchor label
                    # consider individual out-of-range interpolation still visible
                    ass_id, x_off_values, z_values, visibility_vec = self.convert_label_to_anchor(gt_lanes[i], H_im2g)
                    if ass_id >= 0:
                        gt_anchors.append(np.vstack([x_off_values, z_values]).T)
                        ass_ids.append(ass_id)
                        visibility_vectors.append(visibility_vec)

                for i in range(len(gt_anchors)):
                    lane_x_off_all.append(gt_anchors[i][:, 0])
                    lane_z_all.append(gt_anchors[i][:, 1])
                    # compute y offset when transformed back to 3D space
                    lane_y_off_all.append(-gt_anchors[i][:, 1] * self.anchor_y_steps / gt_cam_height)
                visibility_all_flat.extend(visibility_vectors)
                gt_centerline_ass_ids.append(ass_ids)
                gt_centerline_pts_all[idx] = gt_anchors
                gt_centerline_visibility_all[idx] = visibility_vectors

        lane_x_off_all = np.array(lane_x_off_all)
        lane_y_off_all = np.array(lane_y_off_all)
        lane_z_all = np.array(lane_z_all)
        visibility_all_flat = np.array(visibility_all_flat)

        # computed weighted std based on visibility
        lane_x_off_std = np.sqrt(np.average(lane_x_off_all**2, weights=visibility_all_flat, axis=0))
        lane_y_off_std = np.sqrt(np.average(lane_y_off_all**2, weights=visibility_all_flat, axis=0))
        lane_z_std = np.sqrt(np.average(lane_z_all**2, weights=visibility_all_flat, axis=0))
        return label_image_path, gt_laneline_pts_all_org,\
               gt_laneline_pts_all, gt_centerline_pts_all, gt_cam_height_all, gt_cam_pitch_all,\
               gt_laneline_ass_ids, gt_centerline_ass_ids, lane_x_off_std, lane_y_off_std, lane_z_std,\
               gt_laneline_visibility_all, gt_centerline_visibility_all

    def init_dataset_tusimple(self, dataset_base_dir, json_file_path):
        """
        :param json_file_path:
        :return: image paths, labels in unormalized net input coordinates
        data processing:
        ground truth labels map are scaled wrt network input sizes
        """

        # load image path, and lane pts
        label_image_path = []
        gt_laneline_pts_all = []
        gt_laneline_visibility_all = []

        assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)

        with open(json_file_path, 'r') as file:
            for line in file:
                info_dict = json.loads(line)

                image_path = ops.join(dataset_base_dir, info_dict['raw_file'])
                assert ops.exists(image_path), '{:s} not exist'.format(image_path)

                label_image_path.append(image_path)

                gt_lane_pts_X = info_dict['lanes']
                gt_y_steps = np.array(info_dict['h_samples'])
                gt_lane_pts = []

                for i, lane_x in enumerate(gt_lane_pts_X):
                    lane = np.zeros([gt_y_steps.shape[0], 2], dtype=np.float32)

                    lane_x = np.array(lane_x)
                    lane[:, 0] = lane_x
                    lane[:, 1] = gt_y_steps
                    # remove invalid samples
                    lane = lane[lane_x >= 0, :]

                    if lane.shape[0] < 2:
                        continue

                    gt_lane_pts.append(lane)
                gt_laneline_pts_all.append(gt_lane_pts)
        label_image_path = np.array(label_image_path)
        gt_laneline_pts_all_org = copy.deepcopy(gt_laneline_pts_all)

        # convert labeled laneline to anchor format
        H_im2g = self.H_im2g
        gt_laneline_ass_ids = []
        lane_x_off_all = []
        for idx in range(len(gt_laneline_pts_all)):
            gt_lanes = gt_laneline_pts_all[idx]
            gt_anchors = []
            ass_ids = []
            visibility_vectors = []
            for i in range(len(gt_lanes)):
                # convert gt label to anchor label
                ass_id, x_off_values, z_values, visibility_vec = self.convert_label_to_anchor(gt_lanes[i], H_im2g)
                if ass_id >= 0:
                    gt_anchors.append(np.vstack([x_off_values, z_values]).T)
                    ass_ids.append(ass_id)
                    lane_x_off_all.append(x_off_values)
                    visibility_vectors.append(visibility_vec)
            gt_laneline_ass_ids.append(ass_ids)
            gt_laneline_pts_all[idx] = gt_anchors
            gt_laneline_visibility_all.append(visibility_vectors)

        lane_x_off_all = np.array(lane_x_off_all)
        lane_x_off_std = np.std(lane_x_off_all, axis=0)

        return label_image_path, gt_laneline_pts_all_org, gt_laneline_pts_all, gt_laneline_ass_ids,\
               lane_x_off_std, gt_laneline_visibility_all

    def set_x_off_std(self, x_off_std):
        self._x_off_std = x_off_std

    def set_y_off_std(self, y_off_std):
        self._y_off_std = y_off_std

    def set_z_std(self, z_std):
        self._z_std = z_std

    def normalize_lane_label(self):
        for lanes in self._label_laneline_all:
            for lane in lanes:
                lane[:, 0] = np.divide(lane[:, 0], self._x_off_std)
                if not self.no_3d:
                    lane[:, 1] = np.divide(lane[:, 1], self._z_std)

        if not self.no_centerline:
            for lanes in self._label_centerline_all:
                for lane in lanes:
                    lane[:, 0] = np.divide(lane[:, 0], self._x_off_std)
                    if not self.no_3d:
                        lane[:, 1] = np.divide(lane[:, 1], self._z_std)

    def convert_lanes_3d_to_gflat(self, lanes, P_g2gflat):
        """
            Convert a set of lanes from 3D ground coordinates [X, Y, Z], to IPM-based
            flat ground coordinates [x_gflat, y_gflat, Z]
        :param lanes: a list of N x 3 numpy arrays recording a set of 3d lanes
        :param P_g2gflat: projection matrix from 3D ground coordinates to frat ground coordinates
        :return:
        """
        # TODO: this function can be simplified with the derived formula
        for lane in lanes:
            # convert gt label to anchor label
            lane_gflat_x, lane_gflat_y = projective_transformation(P_g2gflat, lane[:, 0], lane[:, 1], lane[:, 2])
            lane[:, 0] = lane_gflat_x
            lane[:, 1] = lane_gflat_y

    def compute_visibility_lanes_gflat(self, lane_anchors, ass_ids):
        """
            Compute the visibility of each anchor point in flat ground space. The reasoning requires all the considering
            lanes globally.
        :param lane_anchors: A list of N x 2 numpy arrays where N equals to number of Y steps in anchor representation
                             x offset and z values are recorded for each lane
               ass_ids: the associated id determine the base x value
        :return:
        """
        if len(lane_anchors) is 0:
            return [], [], []

        vis_inds_lanes = []
        # sort the lane_anchors such that lanes are recorded from left to right
        # sort the lane_anchors based on the x value at the closed anchor
        # do NOT sort the lane_anchors by the order of ass_ids because there could be identical ass_ids

        x_refs = [lane_anchors[i][0, 0] + self.anchor_x_steps[ass_ids[i]] for i in range(len(lane_anchors))]
        sort_idx = np.argsort(x_refs)
        lane_anchors = [lane_anchors[i] for i in sort_idx]
        ass_ids = [ass_ids[i] for i in sort_idx]

        min_x_vec = lane_anchors[0][:, 0] + self.anchor_x_steps[ass_ids[0]]
        max_x_vec = lane_anchors[-1][:, 0] + self.anchor_x_steps[ass_ids[-1]]
        for i, lane in enumerate(lane_anchors):
            vis_inds = np.ones(lane.shape[0])
            for j in range(lane.shape[0]):
                x_value = lane[j, 0] + self.anchor_x_steps[ass_ids[i]]
                if x_value < 3*self.x_min or x_value > 3*self.x_max:
                    vis_inds[j:] = 0
                # A point with x < the left most lane's current x is considered invisible
                # A point with x > the right most lane's current x is considered invisible
                if x_value < min_x_vec[j] - 0.01 or x_value > max_x_vec[j] + 0.01:
                    vis_inds[j:] = 0
                    break
                # A point with orientation close enough to horizontal is considered as invisible
                if j > 0:
                    dx = lane[j, 0] - lane[j-1, 0]
                    dy = self.anchor_y_steps[j] - self.anchor_y_steps[j-1]
                    if abs(dx/dy) > 10:
                        vis_inds[j:] = 0
                        break
            vis_inds_lanes.append(vis_inds)
        return vis_inds_lanes, lane_anchors, ass_ids

    def convert_label_to_anchor(self, laneline_gt, H_im2g):
        """
            Convert a set of ground-truth lane points to the format of network anchor representation.
            All the given laneline only include visible points. The interpolated points will be marked invisible
        :param laneline_gt: a list of arrays where each array is a set of point coordinates in [x, y, z]
        :param H_im2g: homographic transformation only used for tusimple dataset
        :return: ass_id: the column id of current lane in anchor representation
                 x_off_values: current lane's x offset from it associated anchor column
                 z_values: current lane's z value in ground coordinates
        """
        if self.no_3d:  # For ground-truth in 2D image coordinates (TuSimple)
            gt_lane_2d = laneline_gt
            # project to ground coordinates
            gt_lane_grd_x, gt_lane_grd_y = homographic_transformation(H_im2g, gt_lane_2d[:, 0], gt_lane_2d[:, 1])
            gt_lane_3d = np.zeros_like(gt_lane_2d, dtype=np.float32)
            gt_lane_3d[:, 0] = gt_lane_grd_x
            gt_lane_3d[:, 1] = gt_lane_grd_y
        else:  # For ground-truth in ground coordinates (Apollo Sim)
            gt_lane_3d = laneline_gt

        # prune out points not in valid range, requires additional points to interpolate better
        # prune out-of-range points after transforming to flat ground space, update visibility vector
        valid_indices = np.logical_and(np.logical_and(gt_lane_3d[:, 1] > 0, gt_lane_3d[:, 1] < 200),
                                       np.logical_and(gt_lane_3d[:, 0] > 3 * self.x_min,
                                                      gt_lane_3d[:, 0] < 3 * self.x_max))
        gt_lane_3d = gt_lane_3d[valid_indices, ...]
        # use more restricted range to determine deletion or not
        if gt_lane_3d.shape[0] < 2 or np.sum(np.logical_and(gt_lane_3d[:, 0] > self.x_min,
                                                            gt_lane_3d[:, 0] < self.x_max)) < 2:
            return -1, np.array([]), np.array([]), np.array([])

        if self.dataset_name is 'tusimple':
            # reverse the order of 3d pints to make the first point the closest
            gt_lane_3d = gt_lane_3d[::-1, :]

        # only keep the portion y is monotonically increasing above a threshold, to prune those super close points
        gt_lane_3d = make_lane_y_mono_inc(gt_lane_3d)
        if gt_lane_3d.shape[0] < 2:
            return -1, np.array([]), np.array([]), np.array([])

        # ignore GT ends before y_ref, for those start at y > y_ref, use its interpolated value at y_ref for association
        # if gt_lane_3d[0, 1] > self.y_ref or gt_lane_3d[-1, 1] < self.y_ref:
        if gt_lane_3d[-1, 1] < self.y_ref:
            return -1, np.array([]), np.array([]), np.array([])

        # resample ground-truth laneline at anchor y steps
        x_values, z_values, visibility_vec = resample_laneline_in_y(gt_lane_3d, self.anchor_y_steps, out_vis=True)

        if np.sum(visibility_vec) < 2:
            return -1, np.array([]), np.array([]), np.array([])

        # decide association at r_ref
        ass_id = np.argmin((self.anchor_x_steps - x_values[self.ref_id]) ** 2)
        # compute offset values
        x_off_values = x_values - self.anchor_x_steps[ass_id]

        return ass_id, x_off_values, z_values, visibility_vec

    def transform_mats(self, idx):
        """
            return the transform matrices associated with sample idx
        :param idx:
        :return:
        """
        if not self.fix_cam:
            H_g2im = homograpthy_g2im(self._label_cam_pitch_all[idx],
                                      self._label_cam_height_all[idx], self.K)
            P_g2im = projection_g2im(self._label_cam_pitch_all[idx],
                                     self._label_cam_height_all[idx], self.K)

            H_im2ipm = np.linalg.inv(np.matmul(self.H_crop, np.matmul(H_g2im, self.H_ipm2g)))
            return H_g2im, P_g2im, self.H_crop, H_im2ipm
        else:
            return self.H_g2im, self.P_g2im, self.H_crop, self.H_im2ipm


def make_lane_y_mono_inc(lane):
    """
        Due to lose of height dim, projected lanes to flat ground plane may not have monotonically increasing y.
        This function trace the y with monotonically increasing y, and output a pruned lane
    :param lane:
    :return:
    """
    idx2del = []
    max_y = lane[0, 1]
    for i in range(1, lane.shape[0]):
        # hard-coded a smallest step, so the far-away near horizontal tail can be pruned
        if lane[i, 1] <= max_y + 3:
            idx2del.append(i)
        else:
            max_y = lane[i, 1]
    lane = np.delete(lane, idx2del, 0)
    return lane


"""
    Data Augmentation: 
        idea 1: (currently in use)
            when initializing dataset, all labels will be prepared in 3D which do not need to be changed in image augmenting
            Image data augmentation would change the spatial transform matrix integrated in the network, provide 
            the transformation matrix related to random cropping, scaling and rotation
        idea 2:
            Introduce random sampling of cam_h, cam_pitch and their associated transformed image
            img2 = [R2[:, 0:2], T2] [R1[:, 0:2], T1]^-1 img1
            output augmented hcam, pitch, and img2 and untouched 3D anchor label value, Before forward pass, update spatial
            transform in network. However, However, image rotation is not considered, additional cropping is still needed
"""


def data_aug_rotate(img):
    # assume img in PIL image format
    rot = random.uniform(-np.pi/18, np.pi/18)
    # rot = random.uniform(-10, 10)
    center_x = img.width / 2
    center_y = img.height / 2
    rot_mat = cv2.getRotationMatrix2D((center_x, center_y), rot, 1.0)
    img_rot = np.array(img)
    img_rot = cv2.warpAffine(img_rot, rot_mat, (img.width, img.height), flags=cv2.INTER_LINEAR)
    # img_rot = img.rotate(rot)
    # rot = rot / 180 * np.pi
    rot_mat = np.vstack([rot_mat, [0, 0, 1]])
    return img_rot, rot_mat


def get_loader(transformed_dataset, args):
    """
        create dataset from ground-truth
        return a batch sampler based ont the dataset
    """

    # transformed_dataset = LaneDataset(dataset_base_dir, json_file_path, args)
    sample_idx = range(transformed_dataset.n_samples)
    sample_idx = sample_idx[0:len(sample_idx)//args.batch_size*args.batch_size]
    data_sampler = torch.utils.data.sampler.SubsetRandomSampler(sample_idx)
    data_loader = DataLoader(transformed_dataset,
                             batch_size=args.batch_size, sampler=data_sampler,
                             num_workers=args.nworkers, pin_memory=True)

    return data_loader


def compute_2d_lanes(pred_anchor, h_samples, H_g2im, anchor_x_steps, anchor_y_steps, x_min, x_max, prob_th=0.5):
    """
        convert anchor lanes to image lanes in tusimple format
    :return: x values at h_samples in image coordinates
    """
    lanes_out = []

    # apply nms to output lanes
    pred_anchor[:, -1] = nms_1d(pred_anchor[:, -1])

    # need to resample network lane results at h_samples
    for j in range(pred_anchor.shape[0]):
        if pred_anchor[j, -1] > prob_th:
            x_offsets = pred_anchor[j, :-1]
            x_3d = x_offsets + anchor_x_steps[j]
            # compute x, y in original image coordinates
            x_2d, y_2d = homographic_transformation(H_g2im, x_3d, anchor_y_steps)
            # reverse the order such that y_2d is ascending
            x_2d = x_2d[::-1]
            y_2d = y_2d[::-1]
            # resample at h_samples
            x_values, z_values = resample_laneline_in_y(np.vstack([x_2d, y_2d]).T, h_samples)
            # assign out-of-range x values to be -2
            x_values = x_values.astype(np.int)
            x_values[np.where(np.logical_or(x_values < x_min, x_values >= x_max))] = -2
            # assign far side y values to be -2
            x_values[np.where(h_samples < y_2d[0])] = -2

            lanes_out.append(x_values.data.tolist())
    return lanes_out


def compute_3d_lanes(pred_anchor, anchor_dim, anchor_x_steps, anchor_y_steps, h_cam, prob_th=0.5):
    lanelines_out = []
    centerlines_out = []
    num_y_steps = anchor_y_steps.shape[0]

    # apply nms to output lanes probabilities
    # consider w/o centerline cases
    pred_anchor[:, anchor_dim - 1] = nms_1d(pred_anchor[:, anchor_dim - 1])
    pred_anchor[:, 2*anchor_dim - 1] = nms_1d(pred_anchor[:, 2*anchor_dim - 1])
    pred_anchor[:, 3*anchor_dim - 1] = nms_1d(pred_anchor[:, 3*anchor_dim - 1])

    # output only the visible portion of lane
    """
        An important process is output lanes in the considered y-range. Interpolate the visibility attributes to 
        automatically determine whether to extend the lanes.
    """
    for j in range(pred_anchor.shape[0]):
        # draw laneline
        if pred_anchor[j, anchor_dim - 1] > prob_th:
            x_offsets = pred_anchor[j, :num_y_steps]
            x_g = x_offsets + anchor_x_steps[j]
            z_g = pred_anchor[j, num_y_steps:2*num_y_steps]
            visibility = pred_anchor[j, 2*num_y_steps:3*num_y_steps]
            line = np.vstack([x_g, anchor_y_steps, z_g]).T
            # line = line[visibility > prob_th, :]
            # convert to 3D ground space
            x_g, y_g = transform_lane_gflat2g(h_cam, line[:, 0], line[:, 1], line[:, 2])
            line[:, 0] = x_g
            line[:, 1] = y_g
            line = resample_laneline_in_y_with_vis(line, anchor_y_steps, visibility)
            if line.shape[0] >= 2:
                lanelines_out.append(line.data.tolist())

        # draw centerline
        if pred_anchor[j, 2*anchor_dim - 1] > prob_th:
            x_offsets = pred_anchor[j, anchor_dim:anchor_dim + num_y_steps]
            x_g = x_offsets + anchor_x_steps[j]
            z_g = pred_anchor[j, anchor_dim + num_y_steps:anchor_dim + 2*num_y_steps]
            visibility = pred_anchor[j, anchor_dim + 2*num_y_steps:anchor_dim + 3*num_y_steps]
            line = np.vstack([x_g, anchor_y_steps, z_g]).T
            # line = line[visibility > prob_th, :]
            # convert to 3D ground space
            x_g, y_g = transform_lane_gflat2g(h_cam, line[:, 0], line[:, 1], line[:, 2])
            line[:, 0] = x_g
            line[:, 1] = y_g
            line = resample_laneline_in_y_with_vis(line, anchor_y_steps, visibility)
            if line.shape[0] >= 2:
                centerlines_out.append(line.data.tolist())

        # draw the additional centerline for the merging case
        if pred_anchor[j, 3*anchor_dim - 1] > prob_th:
            x_offsets = pred_anchor[j, 2*anchor_dim:2*anchor_dim + num_y_steps]
            x_g = x_offsets + anchor_x_steps[j]
            z_g = pred_anchor[j, 2*anchor_dim + num_y_steps:2*anchor_dim + 2*num_y_steps]
            visibility = pred_anchor[j, 2*anchor_dim + 2*num_y_steps:2*anchor_dim + 3*num_y_steps]
            line = np.vstack([x_g, anchor_y_steps, z_g]).T
            # line = line[visibility > prob_th, :]
            # convert to 3D ground space
            x_g, y_g = transform_lane_gflat2g(h_cam, line[:, 0], line[:, 1], line[:, 2])
            line[:, 0] = x_g
            line[:, 1] = y_g
            line = resample_laneline_in_y_with_vis(line, anchor_y_steps, visibility)
            if line.shape[0] >= 2:
                centerlines_out.append(line.data.tolist())

    return lanelines_out, centerlines_out


def compute_3d_lanes_all_prob(pred_anchor, anchor_dim, anchor_x_steps, anchor_y_steps, h_cam):
    lanelines_out = []
    lanelines_prob = []
    centerlines_out = []
    centerlines_prob = []
    num_y_steps = anchor_y_steps.shape[0]

    # apply nms to output lanes probabilities
    # consider w/o centerline cases
    pred_anchor[:, anchor_dim - 1] = nms_1d(pred_anchor[:, anchor_dim - 1])
    pred_anchor[:, 2*anchor_dim - 1] = nms_1d(pred_anchor[:, 2*anchor_dim - 1])
    pred_anchor[:, 3*anchor_dim - 1] = nms_1d(pred_anchor[:, 3*anchor_dim - 1])

    # output only the visible portion of lane
    """
        An important process is output lanes in the considered y-range. Interpolate the visibility attributes to 
        automatically determine whether to extend the lanes.
    """
    for j in range(pred_anchor.shape[0]):
        # draw laneline
        x_offsets = pred_anchor[j, :num_y_steps]
        x_g = x_offsets + anchor_x_steps[j]
        z_g = pred_anchor[j, num_y_steps:2*num_y_steps]
        visibility = pred_anchor[j, 2*num_y_steps:3*num_y_steps]
        line = np.vstack([x_g, anchor_y_steps, z_g]).T
        # line = line[visibility > prob_th, :]
        # convert to 3D ground space
        x_g, y_g = transform_lane_gflat2g(h_cam, line[:, 0], line[:, 1], line[:, 2])
        line[:, 0] = x_g
        line[:, 1] = y_g
        line = resample_laneline_in_y_with_vis(line, anchor_y_steps, visibility)
        if line.shape[0] >= 2:
            lanelines_out.append(line.data.tolist())
            lanelines_prob.append(pred_anchor[j, anchor_dim - 1].tolist())

        # draw centerline
        x_offsets = pred_anchor[j, anchor_dim:anchor_dim + num_y_steps]
        x_g = x_offsets + anchor_x_steps[j]
        z_g = pred_anchor[j, anchor_dim + num_y_steps:anchor_dim + 2*num_y_steps]
        visibility = pred_anchor[j, anchor_dim + 2*num_y_steps:anchor_dim + 3*num_y_steps]
        line = np.vstack([x_g, anchor_y_steps, z_g]).T
        # line = line[visibility > prob_th, :]
        # convert to 3D ground space
        x_g, y_g = transform_lane_gflat2g(h_cam, line[:, 0], line[:, 1], line[:, 2])
        line[:, 0] = x_g
        line[:, 1] = y_g
        line = resample_laneline_in_y_with_vis(line, anchor_y_steps, visibility)
        if line.shape[0] >= 2:
            centerlines_out.append(line.data.tolist())
            centerlines_prob.append(pred_anchor[j, 2 * anchor_dim - 1].tolist())

        # draw the additional centerline for the merging case
        x_offsets = pred_anchor[j, 2*anchor_dim:2*anchor_dim + num_y_steps]
        x_g = x_offsets + anchor_x_steps[j]
        z_g = pred_anchor[j, 2*anchor_dim + num_y_steps:2*anchor_dim + 2*num_y_steps]
        visibility = pred_anchor[j, 2*anchor_dim + 2*num_y_steps:2*anchor_dim + 3*num_y_steps]
        line = np.vstack([x_g, anchor_y_steps, z_g]).T
        # line = line[visibility > prob_th, :]
        # convert to 3D ground space
        x_g, y_g = transform_lane_gflat2g(h_cam, line[:, 0], line[:, 1], line[:, 2])
        line[:, 0] = x_g
        line[:, 1] = y_g
        line = resample_laneline_in_y_with_vis(line, anchor_y_steps, visibility)
        if line.shape[0] >= 2:
            centerlines_out.append(line.data.tolist())
            centerlines_prob.append(pred_anchor[j, 3*anchor_dim - 1].tolist())

    return lanelines_out, centerlines_out, lanelines_prob, centerlines_prob


def unormalize_lane_anchor(anchor, dataset):
    num_y_steps = dataset.num_y_steps
    anchor_dim = dataset.anchor_dim
    for i in range(dataset.num_types):
        anchor[:, i*anchor_dim:i*anchor_dim + num_y_steps] = \
            np.multiply(anchor[:, i*anchor_dim: i*anchor_dim + num_y_steps], dataset._x_off_std)
        if not dataset.no_3d:
            anchor[:, i*anchor_dim + num_y_steps: i*anchor_dim + 2*num_y_steps] = \
                np.multiply(anchor[:, i*anchor_dim + num_y_steps: i*anchor_dim + 2*num_y_steps], dataset._z_std)


# unit testR
'''
if __name__ == '__main__':
    import sys
    #from tools.utils import define_args

    args = define_args()
    #args = parser.parse_args()

    # dataset_name: 'standard' / 'rare_subset' / 'illus_chg'
    args.dataset_name = 'illus_chg'
    args.dataset_dir = '/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release'
    args.test_dataset_dir = '/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release'
   # args.data_dir = ops.join('data_splits', args.dataset_name)
    args.data_dir='/content/drive/Shareddrives/colab/data_splits/illus_chg'
    # load configuration for certain dataset
    if 'tusimple' in args.dataset_name:
        tusimple_config(args)
    else:
        sim3d_config(args)
    args.y_ref = 5.0

    # set 3D ground area for visualization
    vis_border_3d = np.array([[-1.75, 100.], [1.75, 100.], [-1.75, 5.], [1.75, 5.]])
    print('visual area border:')
    print(vis_border_3d)

    # load data
    dataset = LaneDataset(args.dataset_dir, ops.join(args.data_dir, 'train.json'), args, data_aug=True, save_std=True)
    dataset.normalize_lane_label()
    loader = get_loader(dataset, args)
    anchor_x_steps = dataset.anchor_x_steps

    # initialize visualizer
    args.mod = 'ext'
    visualizer = Visualizer(args)
    Visualizer.anchor_dim = dataset.anchor_dim

    # get a batch of data/label pairs from loader
    for batch_ndx, (image_tensor, seg_labels, gt_tensor, idx, gt_cam_height, gt_cam_pitch, aug_mat) in enumerate(loader):
        print('batch id: {:d}, image tensor shape:'.format(batch_ndx))
        print(image_tensor.shape)
        print('batch id: {:d}, gt tensor shape:'.format(batch_ndx))
        print(gt_tensor.shape)

        # convert to BGR and numpy for visualization in opencv
        images = image_tensor.permute(0, 2, 3, 1).data.cpu().numpy()
        seg_labels = seg_labels.data.cpu().numpy()
        gt_anchors = gt_tensor.numpy()
        idx = idx.numpy()
        gt_cam_height = gt_cam_height.numpy()
        gt_cam_pitch = gt_cam_pitch.numpy()
        aug_mat = aug_mat.numpy()
        for i in range(args.batch_size):
            img = images[i]
            seg_label = seg_labels[i][0]
            img = img * np.array(args.vgg_std).astype(np.float32)
            img = img + np.array(args.vgg_mean).astype(np.float32)
            if img.min() < 0. or img.max() > 1.0:
                print('found an invalid normalized sample')
            img = np.clip(img, 0, 1)

            # if args.no_3d:
            H_g2im, P_g2im, H_crop, H_im2ipm = dataset.transform_mats(idx[i])
            M = np.matmul(H_crop, H_g2im)
            # update transformation with image augmentation
            M = np.matmul(aug_mat[i], M)
            x_2d, y_2d = homographic_transformation(M, vis_border_3d[:, 0], vis_border_3d[:, 1])

            # update transformation with image augmentation
            H_im2ipm = np.matmul(H_im2ipm, np.linalg.inv(aug_mat[i]))
            im_ipm = cv2.warpPerspective(img, H_im2ipm, (args.ipm_w, args.ipm_h))
            im_ipm = np.clip(im_ipm, 0, 1)

            # draw visual border on image to confirm calibration
            x_2d = x_2d.astype(np.int)
            y_2d = y_2d.astype(np.int)
            img = cv2.line(img, (x_2d[0], y_2d[0]), (x_2d[1], y_2d[1]), [1, 0, 0], 2)
            img = cv2.line(img, (x_2d[2], y_2d[2]), (x_2d[3], y_2d[3]), [1, 0, 0], 2)
            img = cv2.line(img, (x_2d[0], y_2d[0]), (x_2d[2], y_2d[2]), [1, 0, 0], 2)
            img = cv2.line(img, (x_2d[1], y_2d[1]), (x_2d[3], y_2d[3]), [1, 0, 0], 2)
            gt_anchor = gt_anchors[i, :, :]

            # un-normalize
            unormalize_lane_anchor(gt_anchor, dataset)

            # visualize ground-truth anchor lanelines by projecting them on the image
            img = visualizer.draw_on_img_new(img, gt_anchor, M, 'laneline', color=[0, 0, 1])
            if not args.no_centerline:
                img = visualizer.draw_on_img_new(img, gt_anchor, M, 'centerline', color=[0, 1, 0])

            cv2.putText(img, 'camara pitch: {:.3f}'.format(gt_cam_pitch[i]/np.pi*180),
                        (5, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)
            cv2.putText(img, 'camara height: {:.3f}'.format(gt_cam_height[i]),
                        (5, 60), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)

            # visualize on ipm
            im_ipm = visualizer.draw_on_ipm_new(im_ipm, gt_anchor, 'laneline', color=[0, 0, 1])
            if not args.no_centerline:
                im_ipm = visualizer.draw_on_ipm_new(im_ipm, gt_anchor, 'centerline', color=[0, 1, 0])
            
            # convert image to BGR for opencv imshow
            cv2.imshow('image gt check', np.flip(img, axis=2))
            cv2.imshow('ipm gt check', np.flip(im_ipm, axis=2))
            cv2.imshow('seg label check', seg_label)
            cv2.waitKey()
            
            print('image: {:d} in batch: {:d}'.format(idx[i], batch_ndx))

    print('done')
    '''

"\nif __name__ == '__main__':\n    import sys\n    #from tools.utils import define_args\n\n    args = define_args()\n    #args = parser.parse_args()\n\n    # dataset_name: 'standard' / 'rare_subset' / 'illus_chg'\n    args.dataset_name = 'illus_chg'\n    args.dataset_dir = '/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release'\n    args.test_dataset_dir = '/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release'\n   # args.data_dir = ops.join('data_splits', args.dataset_name)\n    args.data_dir='/content/drive/Shareddrives/colab/data_splits/illus_chg'\n    # load configuration for certain dataset\n    if 'tusimple' in args.dataset_name:\n        tusimple_config(args)\n    else:\n        sim3d_config(args)\n    args.y_ref = 5.0\n\n    # set 3D ground area for visualization\n    vis_border_3d = np.array([[-1.75, 100.], [1.75, 100.], [-1.75, 5.], [1.75, 5.]])\n    print('visual area border:')\n    print(vis_border_3d)\n\n    # load data\n    dataset = LaneDataset(args.dataset

In [8]:
#Pytorch_Generalized_3D_Lane_Detection/networks/Loss_crit.py /
"""
Loss functions
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import numpy as np
import torch
import torch.nn as nn


class Laneline_loss_3D(nn.Module):
    """
    Compute the loss between predicted lanelines and ground-truth laneline in anchor representation.
    The anchor representation is based on real 3D X, Y, Z.
    loss = loss1 + loss2 + loss2
    loss1: cross entropy loss for lane type classification
    loss2: sum of geometric distance betwen 3D lane anchor points in X and Z offsets
    loss3: error in estimating pitch and camera heights
    """
    def __init__(self, num_types, anchor_dim, pred_cam):
        super(Laneline_loss_3D, self).__init__()
        self.num_types = num_types
        self.anchor_dim = anchor_dim
        self.pred_cam = pred_cam

    def forward(self, pred_3D_lanes, gt_3D_lanes, pred_hcam, gt_hcam, pred_pitch, gt_pitch):
        """
        :param pred_3D_lanes: predicted tensor with size N x (ipm_w/8) x 3*(2*K+1)
        :param gt_3D_lanes: ground-truth tensor with size N x (ipm_w/8) x 3*(2*K+1)
        :param pred_pitch: predicted pitch with size N
        :param gt_pitch: ground-truth pitch with size N
        :param pred_hcam: predicted camera height with size N
        :param gt_hcam: ground-truth camera height with size N
        :return:
        """
        sizes = pred_3D_lanes.shape
        # reshape to N x ipm_w/8 x 3 x (2K+1)
        pred_3D_lanes = pred_3D_lanes.reshape(sizes[0], sizes[1], self.num_types, self.anchor_dim)
        gt_3D_lanes = gt_3D_lanes.reshape(sizes[0], sizes[1], self.num_types, self.anchor_dim)
        # class prob N x ipm_w/8 x 3 x 1, anchor value N x ipm_w/8 x 3 x 2K
        pred_class = pred_3D_lanes[:, :, :, -1].unsqueeze(-1)
        pred_anchors = pred_3D_lanes[:, :, :, :-1]
        gt_class = gt_3D_lanes[:, :, :, -1].unsqueeze(-1)
        gt_anchors = gt_3D_lanes[:, :, :, :-1]

        loss1 = -torch.sum(gt_class*torch.log(pred_class + torch.tensor(1e-9)) +
                           (torch.ones_like(gt_class)-gt_class) *
                           torch.log(torch.ones_like(pred_class)-pred_class + torch.tensor(1e-9)))
        # applying L1 norm does not need to separate X and Z
        loss2 = torch.sum(torch.norm(gt_class*(pred_anchors-gt_anchors), p=1, dim=3))
        if not self.pred_cam:
            return loss1+loss2
        loss3 = torch.sum(torch.abs(gt_pitch-pred_pitch))+torch.sum(torch.abs(gt_hcam-pred_hcam))
        return loss1+loss2+loss3


class Laneline_loss_gflat(nn.Module):
    """
    Compute the loss between predicted lanelines and ground-truth laneline in anchor representation.
    The anchor representation is in flat ground space X', Y' and real 3D Z. Visibility estimation is also included.
    loss = loss0 + loss1 + loss2 + loss2
    loss0: cross entropy loss for lane point visibility
    loss1: cross entropy loss for lane type classification
    loss2: sum of geometric distance betwen 3D lane anchor points in X and Z offsets
    loss3: error in estimating pitch and camera heights
    """
    def __init__(self, num_types, num_y_steps, pred_cam):
        super(Laneline_loss_gflat, self).__init__()
        self.num_types = num_types
        self.num_y_steps = num_y_steps
        self.anchor_dim = 3*self.num_y_steps + 1
        self.pred_cam = pred_cam

    def forward(self, pred_3D_lanes, gt_3D_lanes, pred_hcam, gt_hcam, pred_pitch, gt_pitch):
        """
        :param pred_3D_lanes: predicted tensor with size N x (ipm_w/8) x 3*(2*K+1)
        :param gt_3D_lanes: ground-truth tensor with size N x (ipm_w/8) x 3*(2*K+1)
        :param pred_pitch: predicted pitch with size N
        :param gt_pitch: ground-truth pitch with size N
        :param pred_hcam: predicted camera height with size N
        :param gt_hcam: ground-truth camera height with size N
        :return:
        """
        sizes = pred_3D_lanes.shape
        # reshape to N x ipm_w/8 x 3 x (3K+1)
        pred_3D_lanes = pred_3D_lanes.reshape(sizes[0], sizes[1], self.num_types, self.anchor_dim)
        gt_3D_lanes = gt_3D_lanes.reshape(sizes[0], sizes[1], self.num_types, self.anchor_dim)
        # class prob N x ipm_w/8 x 3 x 1, anchor value N x ipm_w/8 x 3 x 2K
        pred_class = pred_3D_lanes[:, :, :, -1].unsqueeze(-1)
        pred_anchors = pred_3D_lanes[:, :, :, :2*self.num_y_steps]
        pred_visibility = pred_3D_lanes[:, :, :, 2*self.num_y_steps:3*self.num_y_steps]
        gt_class = gt_3D_lanes[:, :, :, -1].unsqueeze(-1)
        gt_anchors = gt_3D_lanes[:, :, :, :2*self.num_y_steps]
        gt_visibility = gt_3D_lanes[:, :, :, 2*self.num_y_steps:3*self.num_y_steps]

        # cross-entropy loss for visibility
        loss0 = -torch.sum(
            gt_visibility*torch.log(pred_visibility + torch.tensor(1e-9)) +
            (torch.ones_like(gt_visibility) - gt_visibility + torch.tensor(1e-9)) *
            torch.log(torch.ones_like(pred_visibility) - pred_visibility + torch.tensor(1e-9)))/self.num_y_steps
        # cross-entropy loss for lane probability
        loss1 = -torch.sum(
            gt_class*torch.log(pred_class + torch.tensor(1e-9)) +
            (torch.ones_like(gt_class)-gt_class) *
            torch.log(torch.ones_like(pred_class) - pred_class + torch.tensor(1e-9)))
        # applying L1 norm does not need to separate X and Z
        loss2 = torch.sum(torch.norm(gt_class*torch.cat((gt_visibility, gt_visibility), 3) *
                                     (pred_anchors-gt_anchors), p=1, dim=3))
        if not self.pred_cam:
            return loss0+loss1+loss2
        loss3 = torch.sum(torch.abs(gt_pitch-pred_pitch))+torch.sum(torch.abs(gt_hcam-pred_hcam))
        return loss0+loss1+loss2+loss3


class Laneline_loss_gflat_3D(nn.Module):
    """
    Compute the loss between predicted lanelines and ground-truth laneline in anchor representation.
    The anchor representation is in flat ground space X', Y' and real 3D Z. Visibility estimation is also included.
    The X' Y' and Z estimation will be transformed to real X, Y to compare with ground truth. An additional loss in
    X, Y space is expected to guide the learning of features to satisfy the geometry constraints between two spaces
    loss = loss0 + loss1 + loss2 + loss2
    loss0: cross entropy loss for lane point visibility
    loss1: cross entropy loss for lane type classification
    loss2: sum of geometric distance betwen 3D lane anchor points in X and Z offsets
    loss3: error in estimating pitch and camera heights
    """
    def __init__(self, batch_size, num_types, anchor_x_steps, anchor_y_steps, x_off_std, y_off_std, z_std, pred_cam=False, no_cuda=False):
        super(Laneline_loss_gflat_3D, self).__init__()
        self.batch_size = batch_size
        self.num_types = num_types
        self.num_x_steps = anchor_x_steps.shape[0]
        self.num_y_steps = anchor_y_steps.shape[0]
        self.anchor_dim = 3*self.num_y_steps + 1
        self.pred_cam = pred_cam

        # prepare broadcast anchor_x_tensor, anchor_y_tensor, std_X, std_Y, std_Z
        tmp_zeros = torch.zeros(self.batch_size, self.num_x_steps, self.num_types, self.num_y_steps)
        self.x_off_std = torch.tensor(x_off_std.astype(np.float32)).reshape(1, 1, 1, self.num_y_steps) + tmp_zeros
        self.y_off_std = torch.tensor(y_off_std.astype(np.float32)).reshape(1, 1, 1, self.num_y_steps) + tmp_zeros
        self.z_std = torch.tensor(z_std.astype(np.float32)).reshape(1, 1, 1, self.num_y_steps) + tmp_zeros
        self.anchor_x_tensor = torch.tensor(anchor_x_steps.astype(np.float32)).reshape(1, self.num_x_steps, 1, 1) + tmp_zeros
        self.anchor_y_tensor = torch.tensor(anchor_y_steps.astype(np.float32)).reshape(1, 1, 1, self.num_y_steps) + tmp_zeros
        self.anchor_x_tensor = self.anchor_x_tensor/self.x_off_std
        self.anchor_y_tensor = self.anchor_y_tensor/self.y_off_std

        if not no_cuda:
            self.z_std = self.z_std.cuda()
            self.anchor_x_tensor = self.anchor_x_tensor.cuda()
            self.anchor_y_tensor = self.anchor_y_tensor.cuda()

    def forward(self, pred_3D_lanes, gt_3D_lanes, pred_hcam, gt_hcam, pred_pitch, gt_pitch):
        """
        :param pred_3D_lanes: predicted tensor with size N x (ipm_w/8) x 3*(2*K+1)
        :param gt_3D_lanes: ground-truth tensor with size N x (ipm_w/8) x 3*(2*K+1)
        :param pred_pitch: predicted pitch with size N
        :param gt_pitch: ground-truth pitch with size N
        :param pred_hcam: predicted camera height with size N
        :param gt_hcam: ground-truth camera height with size N
        :return:
        """
        sizes = pred_3D_lanes.shape
        # reshape to N x ipm_w/8 x 3 x (3K+1)
        pred_3D_lanes = pred_3D_lanes.reshape(sizes[0], sizes[1], self.num_types, self.anchor_dim)
        gt_3D_lanes = gt_3D_lanes.reshape(sizes[0], sizes[1], self.num_types, self.anchor_dim)
        # class prob N x ipm_w/8 x 3 x 1, anchor values N x ipm_w/8 x 3 x 2K, visibility N x ipm_w/8 x 3 x K
        pred_class = pred_3D_lanes[:, :, :, -1].unsqueeze(-1)
        pred_anchors = pred_3D_lanes[:, :, :, :2*self.num_y_steps]
        pred_visibility = pred_3D_lanes[:, :, :, 2*self.num_y_steps:3*self.num_y_steps]
        gt_class = gt_3D_lanes[:, :, :, -1].unsqueeze(-1)
        gt_anchors = gt_3D_lanes[:, :, :, :2*self.num_y_steps]
        gt_visibility = gt_3D_lanes[:, :, :, 2*self.num_y_steps:3*self.num_y_steps]

        # cross-entropy loss for visibility
        loss0 = -torch.sum(
            gt_visibility*torch.log(pred_visibility + torch.tensor(1e-9)) +
            (torch.ones_like(gt_visibility) - gt_visibility + torch.tensor(1e-9)) *
            torch.log(torch.ones_like(pred_visibility) - pred_visibility + torch.tensor(1e-9)))/self.num_y_steps
        # cross-entropy loss for lane probability
        loss1 = -torch.sum(
            gt_class*torch.log(pred_class + torch.tensor(1e-9)) +
            (torch.ones_like(gt_class) - gt_class) *
            torch.log(torch.ones_like(pred_class) - pred_class + torch.tensor(1e-9)))
        # applying L1 norm does not need to separate X and Z
        loss2 = torch.sum(
            torch.norm(gt_class*torch.cat((gt_visibility, gt_visibility), 3)*(pred_anchors-gt_anchors), p=1, dim=3))

        # compute loss in real 3D X, Y space, the transformation considers offset to anchor and normalization by std
        pred_Xoff_g = pred_anchors[:, :, :, :self.num_y_steps]
        pred_Z = pred_anchors[:, :, :, self.num_y_steps:2*self.num_y_steps]
        gt_Xoff_g = gt_anchors[:, :, :, :self.num_y_steps]
        gt_Z = gt_anchors[:, :, :, self.num_y_steps:2*self.num_y_steps]
        pred_hcam = pred_hcam.reshape(self.batch_size, 1, 1, 1)
        gt_hcam = gt_hcam.reshape(self.batch_size, 1, 1, 1)

        pred_Xoff = (1 - pred_Z * self.z_std / pred_hcam) * pred_Xoff_g - pred_Z * self.z_std / pred_hcam * self.anchor_x_tensor
        pred_Yoff = -pred_Z * self.z_std / pred_hcam * self.anchor_y_tensor
        gt_Xoff = (1 - gt_Z * self.z_std / gt_hcam) * gt_Xoff_g - gt_Z * self.z_std / gt_hcam * self.anchor_x_tensor
        gt_Yoff = -gt_Z * self.z_std / gt_hcam * self.anchor_y_tensor
        loss3 = torch.sum(
            torch.norm(
                gt_class * torch.cat((gt_visibility, gt_visibility), 3) *
                (torch.cat((pred_Xoff, pred_Yoff), 3) - torch.cat((gt_Xoff, gt_Yoff), 3)), p=1, dim=3))

        if not self.pred_cam:
            return loss0+loss1+loss2+loss3
        loss4 = torch.sum(torch.abs(gt_pitch-pred_pitch)) + torch.sum(torch.abs(gt_hcam-pred_hcam))
        return loss0+loss1+loss2+loss3+loss4


# unit test
'''
if __name__ == '__main__':
    num_types = 3

    # for Laneline_loss_3D
    print('Test Laneline_loss_3D')
    anchor_dim = 2*6 + 1
    pred_cam = True
    criterion = Laneline_loss_3D(num_types, anchor_dim, pred_cam)
    criterion = criterion.cuda()

    pred_3D_lanes = torch.rand(8, 26, num_types*anchor_dim).cuda()
    gt_3D_lanes = torch.rand(8, 26, num_types*anchor_dim).cuda()
    pred_pitch = torch.ones(8).float().cuda()
    gt_pitch = torch.ones(8).float().cuda()
    pred_hcam = torch.ones(8).float().cuda()
    gt_hcam = torch.ones(8).float().cuda()

    loss = criterion(pred_3D_lanes, gt_3D_lanes, pred_pitch, gt_pitch, pred_hcam, gt_hcam)
    print(loss)

    # for Laneline_loss_gflat
    print('Test Laneline_loss_gflat')
    num_y_steps = 6
    anchor_dim = 3*num_y_steps + 1
    pred_cam = True
    criterion = Laneline_loss_gflat(num_types, num_y_steps, pred_cam)
    criterion = criterion.cuda()

    pred_3D_lanes = torch.rand(8, 26, num_types*anchor_dim).cuda()
    gt_3D_lanes = torch.rand(8, 26, num_types*anchor_dim).cuda()
    pred_pitch = torch.ones(8).float().cuda()
    gt_pitch = torch.ones(8).float().cuda()
    pred_hcam = torch.ones(8).float().cuda()
    gt_hcam = torch.ones(8).float().cuda()

    loss = criterion(pred_3D_lanes, gt_3D_lanes, pred_pitch, gt_pitch, pred_hcam, gt_hcam)

    print(loss)

    # for Laneline_loss_gflat_3D
    print('Test Laneline_loss_gflat_3D')
    batch_size = 8
    anchor_x_steps = np.linspace(-10, 10, 26, endpoint=True)
    anchor_y_steps = np.array([3, 5, 10, 20, 30, 40, 50, 60, 80, 100])
    num_y_steps = anchor_y_steps.shape[0]
    x_off_std = np.ones(num_y_steps)
    y_off_std = np.ones(num_y_steps)
    z_std = np.ones(num_y_steps)
    pred_cam = True
    criterion = Laneline_loss_gflat_3D(batch_size, num_types, anchor_x_steps, anchor_y_steps, x_off_std, y_off_std, z_std, pred_cam, no_cuda=False)
    # criterion = criterion.cuda()

    anchor_dim = 3*num_y_steps + 1
    pred_3D_lanes = torch.rand(batch_size, 26, num_types*anchor_dim).cuda()
    gt_3D_lanes = torch.rand(batch_size, 26, num_types*anchor_dim).cuda()
    pred_pitch = torch.ones(batch_size).float().cuda()
    gt_pitch = torch.ones(batch_size).float().cuda()
    pred_hcam = torch.ones(batch_size).float().cuda()*1.5
    gt_hcam = torch.ones(batch_size).float().cuda()*1.5

    loss = criterion(pred_3D_lanes, gt_3D_lanes, pred_pitch, gt_pitch, pred_hcam, gt_hcam)

    print(loss)
    '''

"\nif __name__ == '__main__':\n    num_types = 3\n\n    # for Laneline_loss_3D\n    print('Test Laneline_loss_3D')\n    anchor_dim = 2*6 + 1\n    pred_cam = True\n    criterion = Laneline_loss_3D(num_types, anchor_dim, pred_cam)\n    criterion = criterion.cuda()\n\n    pred_3D_lanes = torch.rand(8, 26, num_types*anchor_dim).cuda()\n    gt_3D_lanes = torch.rand(8, 26, num_types*anchor_dim).cuda()\n    pred_pitch = torch.ones(8).float().cuda()\n    gt_pitch = torch.ones(8).float().cuda()\n    pred_hcam = torch.ones(8).float().cuda()\n    gt_hcam = torch.ones(8).float().cuda()\n\n    loss = criterion(pred_3D_lanes, gt_3D_lanes, pred_pitch, gt_pitch, pred_hcam, gt_hcam)\n    print(loss)\n\n    # for Laneline_loss_gflat\n    print('Test Laneline_loss_gflat')\n    num_y_steps = 6\n    anchor_dim = 3*num_y_steps + 1\n    pred_cam = True\n    criterion = Laneline_loss_gflat(num_types, num_y_steps, pred_cam)\n    criterion = criterion.cuda()\n\n    pred_3D_lanes = torch.rand(8, 26, num_types*an

In [9]:
"""
Pytorch_Generalized_3D_Lane_Detection/networks/GeoNet3D_ext.py /

3D-GeoNet with new anchor: predict 3D lanes from segmentation input. The geometry-guided anchor design is based on:
    "Gen-laneNet: a generalized and scalable approach for 3D lane detection"
New Anchor:
    1. Prediction head's lane representation is in X_g, Y_g in flat ground space and Z in real 3D ground space.
    Y_g is sampled equally, X_g, Z is regressed from network output.
    2. In addition, visibility of each point is added into the anchor representation and regressed from network.
Overall dimension of the output tensor would be: N * W * 3 *(3 * K + 1), where
    K          : number of y samples.
    (3 * K + 1): Each lane includes K attributes for X_g offset + K attributes for Z + K attributes for visibility + 1 lane probability
    3          : Each anchor column include one laneline and two centerlines --> 3
    W          : Number of columns for the output tensor each corresponds to a IPM X_g location
    N          : batch size
Use of this network requires to use its corresponding data loader and loss criterion.
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import numpy as np
import torch
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
#from tools.utils import define_args, define_init_weights, homography_im2ipm_norm, homography_crop_resize, homography_ipmnorm2g, tusimple_config, sim3d_config


def make_layers(cfg, in_channels=3, batch_norm=False):
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


def make_one_layer(in_channels, out_channels, kernel_size=3, padding=1, stride=1, batch_norm=False):
    conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride)
    if batch_norm:
        layers = [conv2d, nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)]
    else:
        layers = [conv2d, nn.ReLU(inplace=True)]
    return layers


# initialize base_grid with different sizes can adapt to different sizes
class ProjectiveGridGenerator(nn.Module):
    def __init__(self, size_ipm, M, no_cuda):
        """
        :param size_ipm: size of ipm tensor NCHW
        :param im_h: height of image tensor
        :param im_w: width of image tensor
        :param M: normalized transformation matrix between image view and IPM
        :param no_cuda:
        """
        super().__init__()
        self.N, self.H, self.W = size_ipm
        # self.im_h = im_h
        # self.im_w = im_w
        linear_points_W = torch.linspace(0, 1 - 1/self.W, self.W)
        linear_points_H = torch.linspace(0, 1 - 1/self.H, self.H)

        # use M only to decide the type not value
        self.base_grid = M.new(self.N, self.H, self.W, 3)
        self.base_grid[:, :, :, 0] = torch.ger(
                torch.ones(self.H), linear_points_W).expand_as(self.base_grid[:, :, :, 0])
        self.base_grid[:, :, :, 1] = torch.ger(
                linear_points_H, torch.ones(self.W)).expand_as(self.base_grid[:, :, :, 1])
        self.base_grid[:, :, :, 2] = 1

        self.base_grid = Variable(self.base_grid)
        if not no_cuda:
            self.base_grid = self.base_grid.cuda()
            # self.im_h = self.im_h.cuda()
            # self.im_w = self.im_w.cuda()

    def forward(self, M):
        # compute the grid mapping based on the input transformation matrix M
        # if base_grid is top-view, M should be ipm-to-img homography transformation, and vice versa
        grid = torch.bmm(self.base_grid.view(self.N, self.H * self.W, 3), M.transpose(1, 2))
        grid = torch.div(grid[:, :, 0:2], grid[:, :, 2:]).reshape((self.N, self.H, self.W, 2))
        #
        """
        output grid to be used for grid_sample. 
            1. grid specifies the sampling pixel locations normalized by the input spatial dimensions.
            2. pixel locations need to be converted to the range (-1, 1)
        """
        grid = (grid - 0.5) * 2
        return grid


# Sub-network corresponding to the top view pathway
class TopViewPathway(nn.Module):
    def __init__(self, batch_norm=False, init_weights=True):
        super(TopViewPathway, self).__init__()
        self.features1 = make_layers(['M', 128, 128, 128], 128, batch_norm)
        self.features2 = make_layers(['M', 256, 256, 256], 256, batch_norm)
        self.features3 = make_layers(['M', 256, 256, 256], 512, batch_norm)

        if init_weights:
            self._initialize_weights()

    def forward(self, a, b, c, d):
        x = self.features1(a)
        feat_1 = x
        x = torch.cat((x, b), 1)
        x = self.features2(x)
        feat_2 = x
        x = torch.cat((x, c), 1)
        x = self.features3(x)
        feat_3 = x
        x = torch.cat((x, d), 1)
        return x, feat_1, feat_2, feat_3

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                # if m.bias is not None:
                #     nn.init.constant_(m.bias, 0)
                nn.init.normal_(m.weight.data, 0.0, 0.02)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


#  Lane Prediction Head: through a series of convolutions with no padding in the y dimension, the feature maps are
#  reduced in height, and finally the prediction layer size is N × 1 × 3 ·(3 · K + 1)
class LanePredictionHead(nn.Module):
    def __init__(self, num_lane_type, num_y_steps, batch_norm=False):
        super(LanePredictionHead, self).__init__()
        self.num_lane_type = num_lane_type
        self.num_y_steps = num_y_steps
        self.anchor_dim = 3*self.num_y_steps + 1
        layers = []
        layers += make_one_layer(64, 64, kernel_size=3, padding=(0, 1), batch_norm=batch_norm)
        layers += make_one_layer(64, 64, kernel_size=3, padding=(0, 1), batch_norm=batch_norm)
        layers += make_one_layer(64, 64, kernel_size=3, padding=(0, 1), batch_norm=batch_norm)

        layers += make_one_layer(64, 64, kernel_size=5, padding=(0, 2), batch_norm=batch_norm)
        layers += make_one_layer(64, 64, kernel_size=5, padding=(0, 2), batch_norm=batch_norm)
        layers += make_one_layer(64, 64, kernel_size=5, padding=(0, 2), batch_norm=batch_norm)
        layers += make_one_layer(64, 64, kernel_size=5, padding=(0, 2), batch_norm=batch_norm)
        self.features = nn.Sequential(*layers)

        # x suppose to be N X 64 X 4 X ipm_w/8, need to be reshaped to N X 256 X ipm_w/8 X 1
        # TODO: use large kernel_size in x or fc layer to estimate z with global parallelism
        dim_rt_layers = []
        dim_rt_layers += make_one_layer(256, 128, kernel_size=(5, 1), padding=(2, 0), batch_norm=batch_norm)
        dim_rt_layers += [nn.Conv2d(128, self.num_lane_type*self.anchor_dim, kernel_size=(5, 1), padding=(2, 0))]
        self.dim_rt = nn.Sequential(*dim_rt_layers)

    def forward(self, x):
        x = self.features(x)
        # x suppose to be N X 64 X 4 X ipm_w/8, reshape to N X 256 X ipm_w/8 X 1
        sizes = x.shape
        x = x.reshape(sizes[0], sizes[1]*sizes[2], sizes[3], 1)
        x = self.dim_rt(x)
        x = x.squeeze(-1).transpose(1, 2)
        # apply sigmoid to the probability terms to make it in (0, 1)
        for i in range(self.num_lane_type):
            x[:, :, i*self.anchor_dim + 2*self.num_y_steps:(i+1)*self.anchor_dim] = \
                torch.sigmoid(x[:, :, i*self.anchor_dim + 2*self.num_y_steps:(i+1)*self.anchor_dim])
        return x


# The 3D-lanenet composed of image encode, top view pathway, and lane predication head
class Net(nn.Module):
    def __init__(self, args, input_dim=1, debug=False):
        super().__init__()

        self.no_cuda = args.no_cuda
        self.debug = debug
        self.pred_cam = args.pred_cam
        self.batch_size = args.batch_size
        if args.no_centerline:
            self.num_lane_type = 1
        else:
            self.num_lane_type = 3

        self.num_y_steps = args.num_y_steps
        if args.no_3d:
            self.anchor_dim = args.num_y_steps + 1
        else:
            self.anchor_dim = 3*args.num_y_steps + 1

        # define required transformation matrices
        # define homographic transformation between image and ipm
        org_img_size = np.array([args.org_h, args.org_w])
        resize_img_size = np.array([args.resize_h, args.resize_w])
        cam_pitch = np.pi / 180 * args.pitch

        self.cam_height = torch.tensor(args.cam_height).unsqueeze_(0).expand([self.batch_size, 1]).type(torch.FloatTensor)
        self.cam_pitch = torch.tensor(cam_pitch).unsqueeze_(0).expand([self.batch_size, 1]).type(torch.FloatTensor)
        self.cam_height_default = torch.tensor(args.cam_height).unsqueeze_(0).expand(self.batch_size).type(torch.FloatTensor)
        self.cam_pitch_default = torch.tensor(cam_pitch).unsqueeze_(0).expand(self.batch_size).type(torch.FloatTensor)

        # image scale matrix
        self.S_im = torch.from_numpy(np.array([[args.resize_w,              0, 0],
                                               [            0,  args.resize_h, 0],
                                               [            0,              0, 1]], dtype=np.float32))
        self.S_im_inv = torch.from_numpy(np.array([[1/np.float(args.resize_w),                         0, 0],
                                                   [                        0, 1/np.float(args.resize_h), 0],
                                                   [                        0,                         0, 1]], dtype=np.float32))
        self.S_im_inv_batch = self.S_im_inv.unsqueeze_(0).expand([self.batch_size, 3, 3]).type(torch.FloatTensor)

        # image transform matrix
        H_c = homography_crop_resize(org_img_size, args.crop_y, resize_img_size)
        self.H_c = torch.from_numpy(H_c).unsqueeze_(0).expand([self.batch_size, 3, 3]).type(torch.FloatTensor)

        # camera intrinsic matrix
        self.K = torch.from_numpy(args.K).unsqueeze_(0).expand([self.batch_size, 3, 3]).type(torch.FloatTensor)

        # homograph ground to camera
        # H_g2cam = np.array([[1,                             0,               0],
        #                     [0, np.cos(np.pi / 2 + cam_pitch), args.cam_height],
        #                     [0, np.sin(np.pi / 2 + cam_pitch),               0]])
        H_g2cam = np.array([[1,                             0,               0],
                            [0, np.sin(-cam_pitch), args.cam_height],
                            [0, np.cos(-cam_pitch),               0]])
        self.H_g2cam = torch.from_numpy(H_g2cam).unsqueeze_(0).expand([self.batch_size, 3, 3]).type(torch.FloatTensor)

        # transform from ipm normalized coordinates to ground coordinates
        H_ipmnorm2g = homography_ipmnorm2g(args.top_view_region)
        self.H_ipmnorm2g = torch.from_numpy(H_ipmnorm2g).unsqueeze_(0).expand([self.batch_size, 3, 3]).type(torch.FloatTensor)

        # compute the tranformation from ipm norm coords to image norm coords
        M_ipm2im = torch.bmm(self.H_g2cam, self.H_ipmnorm2g)
        M_ipm2im = torch.bmm(self.K, M_ipm2im)
        M_ipm2im = torch.bmm(self.H_c, M_ipm2im)
        M_ipm2im = torch.bmm(self.S_im_inv_batch, M_ipm2im)
        M_ipm2im = torch.div(M_ipm2im,  M_ipm2im[:, 2, 2].reshape([self.batch_size, 1, 1]).expand([self.batch_size, 3, 3]))
        self.M_inv = M_ipm2im

        if not self.no_cuda:
            self.M_inv = self.M_inv.cuda()
            self.S_im = self.S_im.cuda()
            self.S_im_inv = self.S_im_inv.cuda()
            self.S_im_inv_batch = self.S_im_inv_batch.cuda()
            self.H_c = self.H_c.cuda()
            self.K = self.K.cuda()
            self.H_g2cam = self.H_g2cam.cuda()
            self.H_ipmnorm2g = self.H_ipmnorm2g.cuda()
            self.cam_height_default = self.cam_height_default.cuda()
            self.cam_pitch_default = self.cam_pitch_default.cuda()

            # Define network
            # the grid considers both src and dst grid normalized
            size_top = torch.Size([self.batch_size, np.int(args.ipm_h), np.int(args.ipm_w)])
            self.project_layer = ProjectiveGridGenerator(size_top, self.M_inv, args.no_cuda)

            # Conv layers to convert original resolution binary map to target resolution with high-dimension
            self.encoder = make_layers([8, 'M', 16, 'M', 32, 'M', 64], input_dim, batch_norm=args.batch_norm)

            self.lane_out = LanePredictionHead(self.num_lane_type, self.num_y_steps, args.batch_norm)

    def forward(self, input):
        # compute image features from multiple layers

        cam_height = self.cam_height
        cam_pitch = self.cam_pitch

        # spatial transfer image features to IPM features
        grid = self.project_layer(self.M_inv)
        x_proj = F.grid_sample(input, grid)

        # conv layers to convert original resolution binary map to target resolution with high-dimension
        x_feat = self.encoder(x_proj)

        # convert top-view features to anchor output
        out = self.lane_out(x_feat)

        if self.debug:
            return out, cam_height, cam_pitch, x_proj, x_feat

        return out, cam_height, cam_pitch

    def update_projection(self, args, cam_height, cam_pitch):
        """
            Update transformation matrix based on ground-truth cam_height and cam_pitch
            This function is "Mutually Exclusive" to the updates of M_inv from network prediction
        :param args:
        :param cam_height:
        :param cam_pitch:
        :return:
        """
        for i in range(self.batch_size):
            M, M_inv = homography_im2ipm_norm(args.top_view_region, np.array([args.org_h, args.org_w]),
                                              args.crop_y, np.array([args.resize_h, args.resize_w]),
                                              cam_pitch[i].data.cpu().numpy(), cam_height[i].data.cpu().numpy(), args.K)
            self.M_inv[i] = torch.from_numpy(M_inv).type(torch.FloatTensor)
        self.cam_height = cam_height
        self.cam_pitch = cam_pitch

    def update_projection_for_data_aug(self, aug_mats):
        """
            update transformation matrix when data augmentation have been applied, and the image augmentation matrix are provided
            Need to consider both the cases of 1. when using ground-truth cam_height, cam_pitch, update M_inv
                                               2. when cam_height, cam_pitch are online estimated, update H_c for later use
        """
        if not self.no_cuda:
            aug_mats = aug_mats.cuda()

        for i in range(aug_mats.shape[0]):
            # update H_c directly
            self.H_c[i] = torch.matmul(aug_mats[i], self.H_c[i])
            # augmentation need to be applied in unnormalized image coords for M_inv
            aug_mats[i] = torch.matmul(torch.matmul(self.S_im_inv, aug_mats[i]), self.S_im)
            self.M_inv[i] = torch.matmul(aug_mats[i], self.M_inv[i])


# unit test
'''
if __name__ == '__main__':
    import os
    from PIL import Image
    from torchvision import transforms
    import torchvision.transforms.functional as F2
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    global args
    parser = define_args()
    args = parser.parse_args()

    # args.dataset_name = 'tusimple'
    # tusimple_config(args)
    args.dataset_name = 'sim3d'
    sim3d_config(args)
    args.pred_cam = True
    args.batch_size = 1

    # construct model
    model = Net(args)
    print(model)

    # initialize model weights
    define_init_weights(model, args.weight_init)

    # load in vgg pretrained weights on ImageNet
    if args.pretrained:
        model.load_pretrained_vgg(args.batch_norm)
        print('vgg weights pretrained on ImageNet loaded!')
    model = model.cuda()

    # prepare input
    image = torch.randn(1, 1, args.resize_h, args.resize_w)
    image = image.cuda()

    # test update of camera height and pitch
    cam_height = torch.tensor(1.65).unsqueeze_(0).expand([args.batch_size, 1]).type(torch.FloatTensor)
    cam_pitch = torch.tensor(0.1).unsqueeze_(0).expand([args.batch_size, 1]).type(torch.FloatTensor)
    # model.update_projection(args, cam_height, cam_pitch)

    # inference the model
    output_net, pred_height, pred_pitch = model(image)

    print(output_net.shape)
    print(pred_height)
    print(pred_pitch)
    '''

'\nif __name__ == \'__main__\':\n    import os\n    from PIL import Image\n    from torchvision import transforms\n    import torchvision.transforms.functional as F2\n    os.environ["CUDA_VISIBLE_DEVICES"] = "0"\n\n    global args\n    parser = define_args()\n    args = parser.parse_args()\n\n    # args.dataset_name = \'tusimple\'\n    # tusimple_config(args)\n    args.dataset_name = \'sim3d\'\n    sim3d_config(args)\n    args.pred_cam = True\n    args.batch_size = 1\n\n    # construct model\n    model = Net(args)\n    print(model)\n\n    # initialize model weights\n    define_init_weights(model, args.weight_init)\n\n    # load in vgg pretrained weights on ImageNet\n    if args.pretrained:\n        model.load_pretrained_vgg(args.batch_norm)\n        print(\'vgg weights pretrained on ImageNet loaded!\')\n    model = model.cuda()\n\n    # prepare input\n    image = torch.randn(1, 1, args.resize_h, args.resize_w)\n    image = image.cuda()\n\n    # test update of camera height and pitch\n 

In [10]:
# ERFNET full network definition for Pytorch
# Sept 2017
# Eduardo Romera
#######################
#Pytorch_Generalized_3D_Lane_Detection/networks/erfnet.py /

"""
This code is modified from pytorch ERFNET implementation:
https://github.com/cardwing/Codes-for-Lane-Detection/tree/master/ERFNet-CULane-PyTorch
"""

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F


class DownsamplerBlock (nn.Module):
    def __init__(self, ninput, noutput):
        super().__init__()

        self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True)
        self.pool = nn.MaxPool2d(2, stride=2)
        self.bn = nn.BatchNorm2d(noutput, eps=1e-3)

    def forward(self, input):
        output = torch.cat([self.conv(input), self.pool(input)], 1)
        output = self.bn(output)
        return F.relu(output)
    

class non_bottleneck_1d (nn.Module):
    def __init__(self, chann, dropprob, dilated):        
        super().__init__()

        self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True)

        self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True)

        self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)

        self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))

        self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated))

        self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)

        self.dropout = nn.Dropout2d(dropprob)
        

    def forward(self, input):

        output = self.conv3x1_1(input)
        output = F.relu(output)
        output = self.conv1x3_1(output)
        output = self.bn1(output)
        output = F.relu(output)

        output = self.conv3x1_2(output)
        output = F.relu(output)
        output = self.conv1x3_2(output)
        output = self.bn2(output)

        if (self.dropout.p != 0):
            output = self.dropout(output)
        
        return F.relu(output+input)    #+input = identity (residual connection)


class Encoder(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.initial_block = DownsamplerBlock(3,16)

        self.layers = nn.ModuleList()

        self.layers.append(DownsamplerBlock(16,64))

        for x in range(0, 5):    #5 times
           self.layers.append(non_bottleneck_1d(64, 0.1, 1))  

        self.layers.append(DownsamplerBlock(64,128))

        for x in range(0, 2):    #2 times
            self.layers.append(non_bottleneck_1d(128, 0.1, 2))
            self.layers.append(non_bottleneck_1d(128, 0.1, 4))
            self.layers.append(non_bottleneck_1d(128, 0.1, 8))
            self.layers.append(non_bottleneck_1d(128, 0.1, 16))

        #only for encoder mode:
        self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True)

    def forward(self, input, predict=False):
        output = self.initial_block(input)

        for layer in self.layers:
            output = layer(output)

        if predict:
            output = self.output_conv(output)

        return output


class UpsamplerBlock (nn.Module):
    def __init__(self, ninput, noutput):
        super().__init__()
        self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
        self.bn = nn.BatchNorm2d(noutput, eps=1e-3)

    def forward(self, input):
        output = self.conv(input)
        output = self.bn(output)
        return F.relu(output)

class Decoder (nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.layers = nn.ModuleList()

        self.layers.append(UpsamplerBlock(128,64))
        self.layers.append(non_bottleneck_1d(64, 0, 1))
        self.layers.append(non_bottleneck_1d(64, 0, 1))

        self.layers.append(UpsamplerBlock(64,16))
        self.layers.append(non_bottleneck_1d(16, 0, 1))
        self.layers.append(non_bottleneck_1d(16, 0, 1))

        self.output_conv = nn.ConvTranspose2d( 16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True)

    def forward(self, input):
        output = input

        for layer in self.layers:
            output = layer(output)

        output = self.output_conv(output)

        return output

class Lane_exist (nn.Module):
    def __init__(self, num_output):
        super().__init__()

        self.layers = nn.ModuleList()

        self.layers.append(nn.Conv2d(128, 32, (3, 3), stride=1, padding=(4,4), bias=False, dilation = (4,4)))
        self.layers.append(nn.BatchNorm2d(32, eps=1e-03))

        self.layers_final = nn.ModuleList()

        self.layers_final.append(nn.Dropout2d(0.1))
        self.layers_final.append(nn.Conv2d(32, 5, (1, 1), stride=1, padding=(0,0), bias=True))

        self.maxpool = nn.MaxPool2d(2, stride=2)
        self.linear1 = nn.Linear(3965, 128)
        self.linear2 = nn.Linear(128, 4)

    def forward(self, input):
        output = input

        for layer in self.layers:
            output = layer(output)
       
        output = F.relu(output)

        for layer in self.layers_final:
            output = layer(output)

        output = F.softmax(output, dim=1)
        output = self.maxpool(output)
        # print(output.shape)
        output = output.view(-1, 3965)
        output = self.linear1(output)
        output = F.relu(output)
        output = self.linear2(output)
        output = F.sigmoid(output)

        return output

class ERFNet(nn.Module):
    def __init__(self, num_classes, partial_bn=False, encoder=None):  #use encoder to pass pretrained encoder
        super().__init__()

        if (encoder == None):
            self.encoder = Encoder(num_classes)
        else:
            self.encoder = encoder
        self.decoder = Decoder(num_classes)
        self.lane_exist = Lane_exist(4) # num_output
        self.input_mean = [103.939, 116.779, 123.68] # [0, 0, 0]
        self.input_std = [1, 1, 1]
        self._enable_pbn = partial_bn

        if partial_bn:
            self.partialBN(True)

    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        :return:
        """
        super(ERFNet, self).train(mode)
        if self._enable_pbn:
            print("Freezing BatchNorm2D.")
            for m in self.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    # shutdown update in frozen mode
                    m.weight.requires_grad = False
                    m.bias.requires_grad = False

    def partialBN(self, enable):
        self._enable_pbn = enable

    def get_optim_policies(self):
        base_weight = []
        base_bias = []
        base_bn = []

        addtional_weight = []
        addtional_bias = []
        addtional_bn = []

        # print(self.modules())

        for m in self.encoder.modules(): # self.base_model.modules()
            if isinstance(m, nn.Conv2d):
                # print(1)
                ps = list(m.parameters())
                base_weight.append(ps[0])
                if len(ps) == 2:
                    base_bias.append(ps[1])
            elif isinstance(m, nn.BatchNorm2d):
                # print(2)
                base_bn.extend(list(m.parameters()))

        for m in self.decoder.modules(): # self.base_model.modules()
            if isinstance(m, nn.Conv2d):
                # print(1)
                ps = list(m.parameters())
                base_weight.append(ps[0])
                if len(ps) == 2:
                    base_bias.append(ps[1])
            elif isinstance(m, nn.BatchNorm2d):
                # print(2)
                base_bn.extend(list(m.parameters()))


        return [
            {
                'params': addtional_weight,
                'lr_mult': 10,
                'decay_mult': 1,
                'name': "addtional weight"
            },
            {
                'params': addtional_bias,
                'lr_mult': 20,
                'decay_mult': 1,
                'name': "addtional bias"
            },
            {
                'params': addtional_bn,
                'lr_mult': 10,
                'decay_mult': 0,
                'name': "addtional BN scale/shift"
            },
            {
                'params': base_weight,
                'lr_mult': 1,
                'decay_mult': 1,
                'name': "base weight"
            },
            {
                'params': base_bias,
                'lr_mult': 2,
                'decay_mult': 0,
                'name': "base bias"
            },
            {
                'params': base_bn,
                'lr_mult': 1,
                'decay_mult': 0,
                'name': "base BN scale/shift"
            },
        ]

    def forward(self, input, only_encode=False, no_lane_exist=False):
        '''if only_encode:
            return self.encoder.forward(input, predict=True)
        else:'''
        output = self.encoder(input)    #predict=False by default
        if no_lane_exist:
            return self.decoder.forward(output)
        return self.decoder.forward(output), self.lane_exist(output)

In [11]:
! pip install tensorboardX

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/07/84/46421bd3e0e89a92682b1a38b40efc22dafb6d8e3d947e4ceefd4a5fabc7/tensorboardX-2.2-py2.py3-none-any.whl (120kB)
[K     |██▊                             | 10kB 18.9MB/s eta 0:00:01[K     |█████▍                          | 20kB 17.9MB/s eta 0:00:01[K     |████████▏                       | 30kB 14.9MB/s eta 0:00:01[K     |██████████▉                     | 40kB 13.6MB/s eta 0:00:01[K     |█████████████▋                  | 51kB 7.7MB/s eta 0:00:01[K     |████████████████▎               | 61kB 8.9MB/s eta 0:00:01[K     |███████████████████             | 71kB 8.5MB/s eta 0:00:01[K     |█████████████████████▊          | 81kB 9.2MB/s eta 0:00:01[K     |████████████████████████▌       | 92kB 8.7MB/s eta 0:00:01[K     |███████████████████████████▏    | 102kB 7.4MB/s eta 0:00:01[K     |██████████████████████████████  | 112kB 7.4MB/s eta 0:00:01[K     |████████████████████████████████| 122kB

In [12]:
##Training 코드->  'pretrained/erfnet_model_sim3d.tar'가 있어야함
"""
Pytorch_Generalized_3D_Lane_Detection/main_train_GenLaneNet_ext.py /

The training code for 'Gen-LaneNet' which is a two-stage framework composed of segmentation subnetwork (erfnet)
and 3D lane prediction subnetwork (3D-GeoNet). A new lane anchor is integrated in the 3D-GeoNet. The architecture and
new anchor design are based on:
    "Gen-laneNet: a generalized and scalable approach for 3D lane detection", Y.Guo, etal., arxiv 2020
The training of Gen-LaneNet is based on a pretrained ERFNet saved in ./pretrained folder. The training is on a
synthetic dataset for 3D lane detection proposed in the above paper.
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import numpy as np
import torch
import torch.optim
import torch.nn as nn
import glob
import time
import shutil
import torch.nn.functional as F
from tqdm import tqdm
from tensorboardX import SummaryWriter
#from dataloader.Load_Data_3DLane_ext import *
#from networks import Loss_crit, GeoNet3D_ext, erfnet
#from tools.utils import *
#from tools import eval_3D_lane


def load_my_state_dict(model, state_dict):  # custom function to load model when not all dict elements
    own_state = model.state_dict()
    ckpt_name = []
    cnt = 0
    for name, param in state_dict.items():
        # TODO: why the trained model do not have modules in name?
        if name[7:] not in list(own_state.keys()) or 'output_conv' in name:
            ckpt_name.append(name)
            # continue
        own_state[name[7:]].copy_(param)
        cnt += 1
    print('#reused param: {}'.format(cnt))
    return model


def train_net():

    # Check GPU availability
    if not args.no_cuda and not torch.cuda.is_available():
        raise Exception("No gpu available for usage")
    torch.backends.cudnn.benchmark = args.cudnn

    # Define save path
    # save_id = 'Model_{}_crit_{}_opt_{}_lr_{}_batch_{}_{}X{}_pretrain_{}_batchnorm_{}_predcam_{}' \
    #           .format(args.mod,
    #                   crit_string,
    #                   args.optimizer,
    #                   args.learning_rate,
    #                   args.batch_size,
    #                   args.resize_h,
    #                   args.resize_w,
    #                   args.pretrained,
    #                   args.batch_norm,
    #                   args.pred_cam)
    save_id = args.mod
    args.save_path = os.path.join(args.save_path, save_id)
    mkdir_if_missing(args.save_path)
    mkdir_if_missing(os.path.join(args.save_path, 'example/'))
    mkdir_if_missing(os.path.join(args.save_path, 'example/train'))
    mkdir_if_missing(os.path.join(args.save_path, 'example/valid'))

    # dataloader for training and validation set
    val_gt_file = ops.join(args.data_dir, 'test.json')
    train_dataset = LaneDataset(args.dataset_dir, ops.join(args.data_dir, 'train.json'), args, data_aug=True, save_std=True)
    train_dataset.normalize_lane_label()
    train_loader = get_loader(train_dataset, args)
    valid_dataset = LaneDataset(args.dataset_dir, val_gt_file, args)
    # assign std of valid dataset to be consistent with train dataset
    valid_dataset.set_x_off_std(train_dataset._x_off_std)
    if not args.no_3d:
        valid_dataset.set_z_std(train_dataset._z_std)
    valid_dataset.normalize_lane_label()
    valid_loader = get_loader(valid_dataset, args)

    # extract valid set labels for evaluation later
    global valid_set_labels
    valid_set_labels = [json.loads(line) for line in open(val_gt_file).readlines()]

    # Define network
    model1 = ERFNet(args.num_class)
    model2 = Net(args, input_dim=args.num_class - 1)
    define_init_weights(model2, args.weight_init)

    if not args.no_cuda:
        # Load model on gpu before passing params to optimizer
        model1 = model1.cuda()
        model2 = model2.cuda()

    # load in vgg pretrained weights
    checkpoint = torch.load(args.pretrained_feat_model)
    # args.start_epoch = checkpoint['epoch']
    model1 = load_my_state_dict(model1, checkpoint['state_dict'])
    model1.eval()  # do not back propagate to model1

    # Define optimizer and scheduler
    optimizer = define_optim(args.optimizer, model2.parameters(),
                             args.learning_rate, args.weight_decay)
    scheduler = define_scheduler(optimizer, args)

    # Define loss criteria
    if crit_string == 'loss_gflat_3D':
        criterion = Laneline_loss_gflat_3D(args.batch_size, train_dataset.num_types,
                                                     train_dataset.anchor_x_steps, train_dataset.anchor_y_steps,
                                                     train_dataset._x_off_std, train_dataset._y_off_std,
                                                     train_dataset._z_std, args.pred_cam, args.no_cuda)
    else:
        criterion = Laneline_loss_gflat(train_dataset.num_types, args.num_y_steps, args.pred_cam)

    if not args.no_cuda:
        criterion = criterion.cuda()

    # Logging setup
    best_epoch = 0
    lowest_loss = np.inf
    log_file_name = 'log_train_start_0.txt'

    # Tensorboard writer
    if not args.no_tb:
        global writer
        writer = SummaryWriter(os.path.join(args.save_path, 'Tensorboard/'))

    # initialize visual saver
    vs_saver = Visualizer(args)

    # Train, evaluate or resume
    args.resume = first_run(args.save_path)
    if args.resume and not args.test_mode and not args.evaluate:
        path = os.path.join(args.save_path, 'checkpoint_model_epoch_{}.pth.tar'.format(
            int(args.resume)))
        if os.path.isfile(path):
            log_file_name = 'log_train_start_{}.txt'.format(args.resume)
            # Redirect stdout
            sys.stdout = Logger(os.path.join(args.save_path, log_file_name))
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(path)
            args.start_epoch = checkpoint['epoch']
            lowest_loss = checkpoint['loss']
            best_epoch = checkpoint['best epoch']
            model2.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            log_file_name = 'log_train_start_0.txt'
            # Redirect stdout
            sys.stdout = Logger(os.path.join(args.save_path, log_file_name))
            print("=> no checkpoint found at '{}'".format(path))

    # Only evaluate
    elif args.evaluate:
        best_file_name = glob.glob(os.path.join(args.save_path, 'model_best*'))[0]
        if os.path.isfile(best_file_name):
            sys.stdout = Logger(os.path.join(args.save_path, 'Evaluate.txt'))
            print("=> loading checkpoint '{}'".format(best_file_name))
            checkpoint = torch.load(best_file_name)
            model2.load_state_dict(checkpoint['state_dict'])
        else:
            print("=> no checkpoint found at '{}'".format(best_file_name))
        mkdir_if_missing(os.path.join(args.save_path, 'example/val_vis'))
        losses_valid, eval_stats = validate(valid_loader, valid_dataset, model1, model2, criterion, vs_saver, val_gt_file)
        return

    # Start training from clean slate
    else:
        # Redirect stdout
        sys.stdout = Logger(os.path.join(args.save_path, log_file_name))

    # INIT MODEL
    print(40*"="+"\nArgs:{}\n".format(args)+40*"=")
    print("Init model: '{}'".format(args.mod))
    print("Number of parameters in model {} is {:.3f}M".format(
        args.mod, sum(tensor.numel() for tensor in model2.parameters())/1e6))

    # Start training and validation for nepochs
    for epoch in range(args.start_epoch, args.nepochs):
        print("\n => Start train set for EPOCH {}".format(epoch + 1))
        # Adjust learning rate
        if args.lr_policy is not None and args.lr_policy != 'plateau':
            scheduler.step()
            lr = optimizer.param_groups[0]['lr']
            print('lr is set to {}'.format(lr))

        # Define container objects to keep track of multiple losses/metrics
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()

        # Specify operation modules
        model2.train()

        # compute timing
        end = time.time()

        # Start training loop
        for i, (input, seg_maps, gt, idx, gt_hcam, gt_pitch, aug_mat) in tqdm(enumerate(train_loader)):

            # Time dataloader
            data_time.update(time.time() - end)

            # Put inputs on gpu if possible
            if not args.no_cuda:
                input, gt = input.cuda(non_blocking=True), gt.cuda(non_blocking=True)
                seg_maps = seg_maps.cuda(non_blocking=True)
                gt_hcam = gt_hcam.cuda()
                gt_pitch = gt_pitch.cuda()
            input = input.contiguous().float()

            if not args.fix_cam and not args.pred_cam:
                model2.update_projection(args, gt_hcam, gt_pitch)

            # update transformation for data augmentation (only for training)
            model2.update_projection_for_data_aug(aug_mat)

            # Run model
            optimizer.zero_grad()
            # Inference model
            try:
                output1 = model1(input, no_lane_exist=True)
                with torch.no_grad():
                    # output1 = F.softmax(output1, dim=1)
                    output1 = output1.softmax(dim=1)
                    output1 = output1 / torch.max(torch.max(output1, dim=2, keepdim=True)[0], dim=3, keepdim=True)[0]
                # pred = output1.data.cpu().numpy()[0, 1:, :, :]
                # pred = np.max(pred, axis=0)
                # cv2.imshow('check probmap', pred)
                # cv2.waitKey()
                output1 = output1[:, 1:, :, :]
                output_net, pred_hcam, pred_pitch = model2(output1)
            except RuntimeError as e:
                print("Batch with idx {} skipped due to inference error".format(idx.numpy()))
                print(e)
                continue

            # Compute losses on
            loss = criterion(output_net, gt, pred_hcam, gt_hcam, pred_pitch, gt_pitch)
            losses.update(loss.item(), input.size(0))

            # Clip gradients (usefull for instabilities or mistakes in ground truth)
            if args.clip_grad_norm != 0:
                nn.utils.clip_grad_norm(model2.parameters(), args.clip_grad_norm)

            # Setup backward pass
            loss.backward()
            optimizer.step()

            # Time trainig iteration
            batch_time.update(time.time() - end)
            end = time.time()

            pred_pitch = pred_pitch.data.cpu().numpy().flatten()
            pred_hcam = pred_hcam.data.cpu().numpy().flatten()
            aug_mat = aug_mat.data.cpu().numpy()
            output_net = output_net.data.cpu().numpy()
            gt = gt.data.cpu().numpy()

            # unormalize lane outputs
            num_el = input.size(0)
            for j in range(num_el):
                unormalize_lane_anchor(output_net[j], train_dataset)
                unormalize_lane_anchor(gt[j], train_dataset)

            # Print info
            if (i + 1) % args.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.8f} ({loss.avg:.8f})'.format(
                       epoch+1, i+1, len(train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses))

            # Plot curves in two views
            if (i + 1) % args.save_freq == 0:
                vs_saver.save_result_new(train_dataset, 'train', epoch, i, idx,
                                         input, gt, output_net, pred_pitch, pred_hcam, aug_mat)

        losses_valid, eval_stats = validate(valid_loader, valid_dataset, model1, model2, criterion, vs_saver, val_gt_file, epoch)

        print("===> Average {}-loss on training set is {:.8f}".format(crit_string, losses.avg))
        print("===> Average {}-loss on validation set is {:.8f}".format(crit_string, losses_valid))
        print("===> Evaluation laneline F-measure: {:3f}".format(eval_stats[0]))
        print("===> Evaluation laneline Recall: {:3f}".format(eval_stats[1]))
        print("===> Evaluation laneline Precision: {:3f}".format(eval_stats[2]))
        print("===> Evaluation centerline F-measure: {:3f}".format(eval_stats[7]))
        print("===> Evaluation centerline Recall: {:3f}".format(eval_stats[8]))
        print("===> Evaluation centerline Precision: {:3f}".format(eval_stats[9]))

        print("===> Last best {}-loss was {:.8f} in epoch {}".format(crit_string, lowest_loss, best_epoch))

        if not args.no_tb:
            writer.add_scalars('3D-Lane-Loss', {'Training': losses.avg}, epoch)
            writer.add_scalars('3D-Lane-Loss', {'Validation': losses_valid}, epoch)
            writer.add_scalars('Evaluation', {'laneline F-measure': eval_stats[0]}, epoch)
            writer.add_scalars('Evaluation', {'centerline F-measure': eval_stats[7]}, epoch)
        total_score = losses.avg

        # Adjust learning_rate if loss plateaued
        if args.lr_policy == 'plateau':
            scheduler.step(total_score)
            lr = optimizer.param_groups[0]['lr']
            print('LR plateaued, hence is set to {}'.format(lr))

        # File to keep latest epoch
        with open(os.path.join(args.save_path, 'first_run.txt'), 'w') as f:
            f.write(str(epoch))
        # Save model
        to_save = False
        if total_score < lowest_loss:
            to_save = True
            best_epoch = epoch+1
            lowest_loss = total_score
        save_checkpoint({
            'epoch': epoch + 1,
            'best epoch': best_epoch,
            'arch': args.mod,
            'state_dict': model2.state_dict(),
            'loss': lowest_loss,
            'optimizer': optimizer.state_dict()}, to_save, epoch)
    if not args.no_tb:
        writer.close()


def validate(loader, dataset, model1, model2, criterion, vs_saver, val_gt_file, epoch=0):

    # Define container to keep track of metric and loss
    losses = AverageMeter()
    lane_pred_file = ops.join(args.save_path, 'test_pred_file.json')

    # Evaluate model
    model2.eval()

    # Only forward pass, hence no gradients needed
    with torch.no_grad():
        with open(lane_pred_file, 'w') as jsonFile:
            # Start validation loop
            for i, (input, seg_maps, gt, idx, gt_hcam, gt_pitch) in tqdm(enumerate(loader)):
                if not args.no_cuda:
                    input, gt = input.cuda(non_blocking=True), gt.cuda(non_blocking=True)
                    seg_maps = seg_maps.cuda(non_blocking=True)
                    gt_hcam = gt_hcam.cuda()
                    gt_pitch = gt_pitch.cuda()
                input = input.contiguous().float()

                if not args.fix_cam and not args.pred_cam:
                    model2.update_projection(args, gt_hcam, gt_pitch)
                # Inference model
                try:
                    output1 = model1(input, no_lane_exist=True)
                    # output1 = F.softmax(output1, dim=1)
                    output1 = output1.softmax(dim=1)
                    output1 = output1 / torch.max(torch.max(output1, dim=2, keepdim=True)[0], dim=3, keepdim=True)[0]
                    output1 = output1[:, 1:, :, :]
                    output_net, pred_hcam, pred_pitch = model2(output1)
                except RuntimeError as e:
                    print("Batch with idx {} skipped due to inference error".format(idx.numpy()))
                    print(e)
                    continue

                # Compute losses on parameters or segmentation
                loss = criterion(output_net, gt, pred_hcam, gt_hcam, pred_pitch, gt_pitch)
                losses.update(loss.item(), input.size(0))

                pred_pitch = pred_pitch.data.cpu().numpy().flatten()
                pred_hcam = pred_hcam.data.cpu().numpy().flatten()
                output_net = output_net.data.cpu().numpy()
                gt = gt.data.cpu().numpy()

                # unormalize lane outputs
                num_el = input.size(0)
                for j in range(num_el):
                    unormalize_lane_anchor(output_net[j], dataset)
                    unormalize_lane_anchor(gt[j], dataset)

                # Print info
                if (i + 1) % args.print_freq == 0:
                        print('Test: [{0}/{1}]\t'
                              'Loss {loss.val:.8f} ({loss.avg:.8f})'.format(
                               i+1, len(loader), loss=losses))

                # Plot curves in two views
                if (i + 1) % args.save_freq == 0 or args.evaluate:
                    vs_saver.save_result_new(dataset, 'valid', epoch, i, idx,
                                             input, gt, output_net, pred_pitch, pred_hcam, evaluate=args.evaluate)

                # write results and evaluate
                for j in range(num_el):
                    im_id = idx[j]
                    H_g2im, P_g2im, H_crop, H_im2ipm = dataset.transform_mats(idx[j])
                    json_line = valid_set_labels[im_id]
                    lane_anchors = output_net[j]
                    # convert to json output format
                    # P_g2gflat = np.matmul(np.linalg.inv(H_g2im), P_g2im)
                    lanelines_pred, centerlines_pred, lanelines_prob, centerlines_prob = \
                        compute_3d_lanes_all_prob(lane_anchors, dataset.anchor_dim,
                                                  dataset.anchor_x_steps, args.anchor_y_steps, pred_hcam[j])
                    json_line["laneLines"] = lanelines_pred
                    json_line["centerLines"] = centerlines_pred
                    json_line["laneLines_prob"] = lanelines_prob
                    json_line["centerLines_prob"] = centerlines_prob
                    json.dump(json_line, jsonFile)
                    jsonFile.write('\n')
        eval_stats = evaluator.bench_one_submit(lane_pred_file, val_gt_file)

        if args.evaluate:
            print("===> Average {}-loss on validation set is {:.8}".format(crit_string, losses.avg))
            print("===> Evaluation on validation set: \n"
                  "laneline F-measure {:.8} \n"
                  "laneline Recall  {:.8} \n"
                  "laneline Precision  {:.8} \n"
                  "laneline x error (close)  {:.8} m\n"
                  "laneline x error (far)  {:.8} m\n"
                  "laneline z error (close)  {:.8} m\n"
                  "laneline z error (far)  {:.8} m\n\n"
                  "centerline F-measure {:.8} \n"
                  "centerline Recall  {:.8} \n"
                  "centerline Precision  {:.8} \n"
                  "centerline x error (close)  {:.8} m\n"
                  "centerline x error (far)  {:.8} m\n"
                  "centerline z error (close)  {:.8} m\n"
                  "centerline z error (far)  {:.8} m\n".format(eval_stats[0], eval_stats[1],
                                                               eval_stats[2], eval_stats[3],
                                                               eval_stats[4], eval_stats[5],
                                                               eval_stats[6], eval_stats[7],
                                                               eval_stats[8], eval_stats[9],
                                                               eval_stats[10], eval_stats[11],
                                                               eval_stats[12], eval_stats[13]))

        return losses.avg, eval_stats


def save_checkpoint(state, to_copy, epoch):
    filepath = os.path.join(args.save_path, 'checkpoint_model_epoch_{}.pth.tar'.format(epoch))
    torch.save(state, filepath)
    if to_copy:
        if epoch > 0:
            lst = glob.glob(os.path.join(args.save_path, 'model_best*'))
            if len(lst) != 0:
                os.remove(lst[0])
        shutil.copyfile(filepath, os.path.join(args.save_path, 
            'model_best_epoch_{}.pth.tar'.format(epoch)))
        print("Best model copied")
    if epoch > 0:
        prev_checkpoint_filename = os.path.join(args.save_path, 
                'checkpoint_model_epoch_{}.pth.tar'.format(epoch-1))
        if os.path.exists(prev_checkpoint_filename):
            os.remove(prev_checkpoint_filename)

'''
if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    global args
    args = define_args()
    #args = parser.parse_args()

    # dataset_name: 'standard' / 'rare_subset' / 'illus_chg'
    args.dataset_name = 'illus_chg'
    args.dataset_dir = '/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release'
    #args.data_dir = ops.join('data_splits', args.dataset_name)
    #args.save_path = ops.join('data_splits', args.dataset_name)
    args.data_dir='/content/drive/Shareddrives/colab/data_splits/illus_chg'
    args.save_path='/content/drive/Shareddrives/colab/data_splits/illus_chg'
    # load configuration for certain dataset
    global evaluator
    sim3d_config(args)
    # define evaluator
    #evaluator = eval_3D_lane.LaneEval(args)
    evaluator=LaneEval(args)
    args.prob_th = 0.5

    # define the network model
    args.num_class = 2  # 1 background + n lane labels
    args.pretrained_feat_model = '/content/drive/Shareddrives/colab/erfnet_model_sim3d.tar'
    args.mod = 'Gen_LaneNet_ext'
    args.y_ref = 5  # new anchor prefer closer range gt assign
    global crit_string
    crit_string = 'loss_gflat'

    # for the case only running evaluation
    args.evaluate = False

    # settings for save and visualize
    args.print_freq = 50
    args.save_freq = 50

    # run the training
    train_net()
    '''

'\nif __name__ == \'__main__\':\n    os.environ["CUDA_VISIBLE_DEVICES"] = "0"\n\n    global args\n    args = define_args()\n    #args = parser.parse_args()\n\n    # dataset_name: \'standard\' / \'rare_subset\' / \'illus_chg\'\n    args.dataset_name = \'illus_chg\'\n    args.dataset_dir = \'/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release\'\n    #args.data_dir = ops.join(\'data_splits\', args.dataset_name)\n    #args.save_path = ops.join(\'data_splits\', args.dataset_name)\n    args.data_dir=\'/content/drive/Shareddrives/colab/data_splits/illus_chg\'\n    args.save_path=\'/content/drive/Shareddrives/colab/data_splits/illus_chg\'\n    # load configuration for certain dataset\n    global evaluator\n    sim3d_config(args)\n    # define evaluator\n    #evaluator = eval_3D_lane.LaneEval(args)\n    evaluator=LaneEval(args)\n    args.prob_th = 0.5\n\n    # define the network model\n    args.num_class = 2  # 1 background + n lane labels\n    args.pretrained_feat_model = \'/content/d

In [13]:
# __main__ 은 batch testing용
"""
Pytorch_Generalized_3D_Lane_Detection/main_test_GenLaneNet_ext.py 
Batch test code for Gen-LaneNet with new anchor extension. It predicts 3D lanes per image.
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import numpy as np
import torch
import torch.optim
import glob
from tqdm import tqdm
#from dataloader.Load_Data_3DLane_ext import *
#from networks import GeoNet3D_ext, erfnet
#from tools.utils import *
#from tools import eval_3D_lane


def load_my_state_dict(model, state_dict):  # custom function to load model when not all dict elements
    own_state = model.state_dict()
    ckpt_name = []
    cnt = 0
    for name, param in state_dict.items():
        if name[7:] not in list(own_state.keys()) or 'output_conv' in name:
            ckpt_name.append(name)
            # continue
        own_state[name[7:]].copy_(param)
        cnt += 1
    print('#reused param: {}'.format(cnt))
    return model


def deploy(args, loader, dataset, model_seg, model_geo, vs_saver, test_gt_file, vis=False, epoch=0):

    # model deploy mode
    model_geo.eval()

    # read ground-truth lanes for later evaluation
    test_set_labels = [json.loads(line) for line in open(test_gt_file).readlines()]

    # Only forward pass, hence no gradients needed
    with torch.no_grad():
        with open(lane_pred_file, 'w') as jsonFile:
            # Start validation loop
            for i, (input, _, gt, idx, gt_hcam, gt_pitch) in tqdm(enumerate(loader)):
                if not args.no_cuda:
                    input, gt = input.cuda(non_blocking=True), gt.cuda(non_blocking=True)
                    input = input.float()
                input = input.contiguous()
                input = torch.autograd.Variable(input)

                # if not args.fix_cam and not args.pred_cam:
                # ATTENTION: here requires to update with test dataset args
                model_geo.update_projection(args, gt_hcam, gt_pitch)

                # Evaluate model
                try:
                    output_seg = model_seg(input, no_lane_exist=True)
                    # output1 = F.softmax(output1, dim=1)
                    output_seg = output_seg.softmax(dim=1)
                    output_seg = output_seg / torch.max(torch.max(output_seg, dim=2, keepdim=True)[0], dim=3, keepdim=True)[0]
                    output_seg = output_seg[:, 1:, :, :]
                    output_geo, pred_hcam, pred_pitch = model_geo(output_seg)
                except RuntimeError as e:
                    print("Batch with idx {} skipped due to singular matrix".format(idx.numpy()))
                    print(e)
                    continue

                gt = gt.data.cpu().numpy()
                output_geo = output_geo.data.cpu().numpy()
                pred_pitch = pred_pitch.data.cpu().numpy().flatten()
                pred_hcam = pred_hcam.data.cpu().numpy().flatten()

                # unormalize lane outputs
                num_el = input.size(0)
                for j in range(num_el):
                    unormalize_lane_anchor(gt[j], dataset)
                    unormalize_lane_anchor(output_geo[j], dataset)

                if vis:
                    # Plot curves in two views
                    vs_saver.save_result_new(dataset, args.vis_folder, epoch, i, idx,
                                             input, gt, output_geo, pred_pitch, pred_hcam, evaluate=vis)

                # visualize and write results
                for j in range(num_el):
                    im_id = idx[j]
                    H_g2im, P_g2im, H_crop, H_im2ipm = dataset.transform_mats(idx[j])
                    """
                        save results in test dataset format
                    """
                    json_line = test_set_labels[im_id]
                    lane_anchors = output_geo[j]
                    # convert to json output format
                    lanelines_pred, centerlines_pred, lanelines_prob, centerlines_prob =\
                        compute_3d_lanes_all_prob(lane_anchors, dataset.anchor_dim,
                                                  dataset.anchor_x_steps, args.anchor_y_steps, pred_hcam[j])
                    json_line["laneLines"] = lanelines_pred
                    json_line["centerLines"] = centerlines_pred
                    json_line["laneLines_prob"] = lanelines_prob
                    json_line["centerLines_prob"] = centerlines_prob
                    json.dump(json_line, jsonFile)
                    jsonFile.write('\n')

        # evaluation at varying thresholds
        eval_stats_pr = evaluator.bench_one_submit_varying_probs(lane_pred_file, test_gt_file)
        max_f_prob = eval_stats_pr['max_F_prob_th']

        # evaluate at the point with max F-measure. Additional eval of position error.
        eval_stats = evaluator.bench_one_submit(lane_pred_file, test_gt_file, prob_th=max_f_prob)

        print("Metrics: AP, F-score, x error (close), x error (far), z error (close), z error (far)")
        print(
            "Laneline:  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}".format(eval_stats_pr['laneline_AP'], eval_stats[0],
                                                                         eval_stats[3], eval_stats[4],
                                                                         eval_stats[5], eval_stats[6]))
        print("Centerline:  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}".format(eval_stats_pr['centerline_AP'], eval_stats[7],
                                                                             eval_stats[10], eval_stats[11],
                                                                             eval_stats[12], eval_stats[13]))

    return eval_stats

'''
if __name__ == '__main__':
  print("hi")
  
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    args = define_args()
    #args = parser.parse_args()

    # manual settings
    args.dataset_dir = '/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release/'  # raw data dir
    args.dataset_name = 'illus_chg'  # choose a data split 'standard' / 'rare_subset' / 'illus_chg'
    args.mod = 'Gen_LaneNet_ext'  # model name
    test_name = 'test'  # test set name
    #pretrained_feat_model = 'pretrained/erfnet_model_sim3d.tar'
    pretrained_feat_model='/content/drive/Shareddrives/colab/erfnet_model_sim3d.tar'
    vis = False  # choose to save visualization result

    # generate relative paths
    args.data_dir = '/content/drive/Shareddrives/colab/data_splits/illus_chg'
   # args.save_path = os.path.join(ops.join('data_splits', args.dataset_name), args.mod)
    args.save_path ='/content/drive/Shareddrives/colab/data_splits/illus_chg/Gen_LaneNet_ext'
    args.vis_folder = test_name + '_vis'
    if vis:
        mkdir_if_missing(os.path.join(args.save_path, 'example/' + args.vis_folder))
    test_gt_file = ops.join(args.data_dir, test_name + '.json')
    lane_pred_file = ops.join(args.save_path, test_name + '_pred_file.json')

    # load configuration for certain dataset
    sim3d_config(args)
    args.y_ref = 5
    # define evaluator
    #evaluator = eval_3D_lane.LaneEval(args)
    evaluator = LaneEval(args)
    args.prob_th = 0.5

    # Check GPU availability
    if not args.no_cuda and not torch.cuda.is_available():
        raise Exception("No gpu available for usage")
    torch.backends.cudnn.benchmark = args.cudnn

    # Define network
    #model_seg = erfnet.ERFNet(2)  # 2-class model
    model_seg = ERFNet(2)
    #model_geo = GeoNet3D_ext.Net(args)
    model_geo = Net(args)
    define_init_weights(model_geo, args.weight_init)

    if not args.no_cuda:
        # Load model on gpu before passing params to optimizer
        model_seg = model_seg.cuda()
        model_geo = model_geo.cuda()

    # load segmentation model
    checkpoint = torch.load(pretrained_feat_model)
    model_seg = load_my_state_dict(model_seg, checkpoint['state_dict'])
    model_seg.eval()  # do not back propagate to model1

    # load geometry model
    best_test_name = glob.glob(os.path.join(args.save_path, 'model_best*'))[0]
    if os.path.isfile(best_test_name):
        sys.stdout = Logger(os.path.join(args.save_path, 'Evaluate.txt'))
        print("=> loading checkpoint '{}'".format(best_test_name))
        checkpoint = torch.load(best_test_name)
        model_geo.load_state_dict(checkpoint['state_dict'])
    else:
        print("=> no checkpoint found at '{}'".format(best_test_name))

    # Data loader
    test_dataset = LaneDataset(args.dataset_dir, test_gt_file, args)
    # assign std of valid dataset to be consistent with train dataset
    with open(ops.join(args.data_dir, 'geo_anchor_std.json')) as f:
        anchor_std = json.load(f)
    test_dataset.set_x_off_std(anchor_std['x_off_std'])
    if not args.no_3d:
        test_dataset.set_z_std(anchor_std['z_std'])
    test_dataset.normalize_lane_label()
    test_loader = get_loader(test_dataset, args)

    # initialize visual saver
    vs_saver = Visualizer(args, args.vis_folder)

    mkdir_if_missing(os.path.join(args.save_path, 'example/' + args.vis_folder))
    eval_stats = deploy(args, test_loader, test_dataset, model_seg, model_geo, vs_saver, test_gt_file, vis)
'''


'\nif __name__ == \'__main__\':\n  print("hi")\n  \n    os.environ["CUDA_VISIBLE_DEVICES"] = "0"\n\n    args = define_args()\n    #args = parser.parse_args()\n\n    # manual settings\n    args.dataset_dir = \'/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release/\'  # raw data dir\n    args.dataset_name = \'illus_chg\'  # choose a data split \'standard\' / \'rare_subset\' / \'illus_chg\'\n    args.mod = \'Gen_LaneNet_ext\'  # model name\n    test_name = \'test\'  # test set name\n    #pretrained_feat_model = \'pretrained/erfnet_model_sim3d.tar\'\n    pretrained_feat_model=\'/content/drive/Shareddrives/colab/erfnet_model_sim3d.tar\'\n    vis = False  # choose to save visualization result\n\n    # generate relative paths\n    args.data_dir = \'/content/drive/Shareddrives/colab/data_splits/illus_chg\'\n   # args.save_path = os.path.join(ops.join(\'data_splits\', args.dataset_name), args.mod)\n    args.save_path =\'/content/drive/Shareddrives/colab/data_splits/illus_chg/Gen_LaneNet_

In [None]:
! whereis datalab

datalab:


In [14]:
#Pytorch_Generalized_3D_Lane_Detection/tools/eval_3D_lane.py /
"""
Description: This code is to evaluate 3D lane detection. The optimal matching between ground-truth set and predicted
set of lanes are sought via solving a min cost flow.
Evaluation metrics includes:
    Average Precision (AP)
    Max F-scores
    x error close (0 - 40 m)
    x error far (0 - 100 m)
    z error close (0 - 40 m)
    z error far (0 - 100 m)
Reference: "Gen-LaneNet: Generalized and Scalable Approach for 3D Lane Detection". Y. Guo. etal. 2020
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import numpy as np
import cv2
import os
import os.path as ops
import copy
import math
import ujson as json
from scipy.interpolate import interp1d
import matplotlib
#from tools.utils import *
#from tools.MinCostFlow import SolveMinCostFlow
from mpl_toolkits.mplot3d import Axes3D

matplotlib.use('Agg')
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (35, 30)
plt.rcParams.update({'font.size': 25})
plt.rcParams.update({'font.weight': 'semibold'})

color = [[0, 0, 255],  # red
         [0, 255, 0],  # green
         [255, 0, 255],  # purple
         [255, 255, 0]]  # cyan

vis_min_y = 5
vis_max_y = 80


class LaneEval(object):
    def __init__(self, args):
        self.dataset_dir = args.dataset_dir
        self.K = args.K
        self.no_centerline = args.no_centerline
        self.resize_h = args.resize_h
        self.resize_w = args.resize_w
        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [args.resize_h, args.resize_w])

        self.x_min = args.top_view_region[0, 0]
        self.x_max = args.top_view_region[1, 0]
        self.y_min = args.top_view_region[2, 1]
        self.y_max = args.top_view_region[0, 1]
        self.y_samples = np.linspace(self.y_min, self.y_max, num=100, endpoint=False)
        # self.y_samples = np.linspace(min_y, max_y, num=100, endpoint=False)
        self.dist_th = 1.5
        self.ratio_th = 0.75
        self.close_range = 40

    def bench(self, pred_lanes, gt_lanes, gt_visibility, raw_file, gt_cam_height, gt_cam_pitch, vis, ax1, ax2):
        """
            Matching predicted lanes and ground-truth lanes in their IPM projection, ignoring z attributes.
            x error, y_error, and z error are all considered, although the matching does not rely on z
            The input of prediction and ground-truth lanes are in ground coordinate, x-right, y-forward, z-up
            The fundamental assumption is: 1. there are no two points from different lanes with identical x, y
                                              but different z's
                                           2. there are no two points from a single lane having identical x, y
                                              but different z's
            If the interest area is within the current drivable road, the above assumptions are almost always valid.
        :param pred_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :param gt_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :param raw_file: file path rooted in dataset folder
        :param gt_cam_height: camera height given in ground-truth data
        :param gt_cam_pitch: camera pitch given in ground-truth data
        :return:
        """

        # change this properly
        close_range_idx = np.where(self.y_samples > self.close_range)[0][0]

        r_lane, p_lane = 0., 0.
        x_error_close = []
        x_error_far = []
        z_error_close = []
        z_error_far = []

        # only keep the visible portion
        gt_lanes = [prune_3d_lane_by_visibility(np.array(gt_lane), np.array(gt_visibility[k])) for k, gt_lane in
                    enumerate(gt_lanes)]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        # only consider those gt lanes overlapping with sampling range
        gt_lanes = [lane for lane in gt_lanes if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]
        gt_lanes = [prune_3d_lane_by_range(np.array(gt_lane), 3 * self.x_min, 3 * self.x_max) for gt_lane in gt_lanes]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        cnt_gt = len(gt_lanes)
        cnt_pred = len(pred_lanes)

        gt_visibility_mat = np.zeros((cnt_gt, 100))
        pred_visibility_mat = np.zeros((cnt_pred, 100))
        # resample gt and pred at y_samples
        for i in range(cnt_gt):
            min_y = np.min(np.array(gt_lanes[i])[:, 1])
            max_y = np.max(np.array(gt_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(gt_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            gt_lanes[i] = np.vstack([x_values, z_values]).T
            gt_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                     np.logical_and(x_values <= self.x_max,
                                                                    np.logical_and(self.y_samples >= min_y,
                                                                                   self.y_samples <= max_y)))
            gt_visibility_mat[i, :] = np.logical_and(gt_visibility_mat[i, :], visibility_vec)

        for i in range(cnt_pred):
            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size
            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))
            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)
            min_y = np.min(np.array(pred_lanes[i])[:, 1])
            max_y = np.max(np.array(pred_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            pred_lanes[i] = np.vstack([x_values, z_values]).T
            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                       np.logical_and(x_values <= self.x_max,
                                                                      np.logical_and(self.y_samples >= min_y,
                                                                                     self.y_samples <= max_y)))
            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)
            # pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, x_values <= self.x_max)

        adj_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat.fill(1000)
        num_match_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        x_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        x_dist_mat_close.fill(1000.)
        x_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        x_dist_mat_far.fill(1000.)
        z_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        z_dist_mat_close.fill(1000.)
        z_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        z_dist_mat_far.fill(1000.)
        # compute curve to curve distance
        for i in range(cnt_gt):
            for j in range(cnt_pred):
                x_dist = np.abs(gt_lanes[i][:, 0] - pred_lanes[j][:, 0])
                z_dist = np.abs(gt_lanes[i][:, 1] - pred_lanes[j][:, 1])
                euclidean_dist = np.sqrt(x_dist ** 2 + z_dist ** 2)

                # apply visibility to penalize different partial matching accordingly
                euclidean_dist[
                    np.logical_or(gt_visibility_mat[i, :] < 0.5, pred_visibility_mat[j, :] < 0.5)] = self.dist_th

                # if np.average(euclidean_dist) < 2*self.dist_th: # don't prune here to encourage finding perfect match
                num_match_mat[i, j] = np.sum(euclidean_dist < self.dist_th)
                adj_mat[i, j] = 1
                # ATTENTION: use the sum as int type to meet the requirements of min cost flow optimization (int type)
                # using num_match_mat as cost does not work?
                cost_mat[i, j] = np.sum(euclidean_dist).astype(np.int)
                # cost_mat[i, j] = num_match_mat[i, j]

                # use the both visible portion to calculate distance error
                both_visible_indices = np.logical_and(gt_visibility_mat[i, :] > 0.5, pred_visibility_mat[j, :] > 0.5)
                if np.sum(both_visible_indices[:close_range_idx]) > 0:
                    x_dist_mat_close[i, j] = np.sum(
                        x_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(
                        both_visible_indices[:close_range_idx])
                    z_dist_mat_close[i, j] = np.sum(
                        z_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(
                        both_visible_indices[:close_range_idx])
                else:
                    x_dist_mat_close[i, j] = self.dist_th
                    z_dist_mat_close[i, j] = self.dist_th

                if np.sum(both_visible_indices[close_range_idx:]) > 0:
                    x_dist_mat_far[i, j] = np.sum(
                        x_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(
                        both_visible_indices[close_range_idx:])
                    z_dist_mat_far[i, j] = np.sum(
                        z_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(
                        both_visible_indices[close_range_idx:])
                else:
                    x_dist_mat_far[i, j] = self.dist_th
                    z_dist_mat_far[i, j] = self.dist_th

        # solve bipartite matching vis min cost flow solver
        match_results = SolveMinCostFlow(adj_mat, cost_mat)
        match_results = np.array(match_results)

        # only a match with avg cost < self.dist_th is consider valid one
        match_gt_ids = []
        match_pred_ids = []
        if match_results.shape[0] > 0:
            for i in range(len(match_results)):
                if match_results[i, 2] < self.dist_th * self.y_samples.shape[0]:
                    gt_i = match_results[i, 0]
                    pred_i = match_results[i, 1]
                    # consider match when the matched points is above a ratio
                    if num_match_mat[gt_i, pred_i] / np.sum(gt_visibility_mat[gt_i, :]) >= self.ratio_th:
                        r_lane += 1
                        match_gt_ids.append(gt_i)
                    if num_match_mat[gt_i, pred_i] / np.sum(pred_visibility_mat[pred_i, :]) >= self.ratio_th:
                        p_lane += 1
                        match_pred_ids.append(pred_i)
                    x_error_close.append(x_dist_mat_close[gt_i, pred_i])
                    x_error_far.append(x_dist_mat_far[gt_i, pred_i])
                    z_error_close.append(z_dist_mat_close[gt_i, pred_i])
                    z_error_far.append(z_dist_mat_far[gt_i, pred_i])

        # visualize lanelines and matching results both in image and 3D
        if vis:
            P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, self.K)
            P_gt = np.matmul(self.H_crop, P_g2im)
            img = cv2.imread(ops.join(self.dataset_dir, raw_file))
            img = cv2.warpPerspective(img, self.H_crop, (self.resize_w, self.resize_h))
            img = img.astype(np.float) / 255

            for i in range(cnt_gt):
                x_values = gt_lanes[i][:, 0]
                z_values = gt_lanes[i][:, 1]
                x_2d, y_2d = projective_transformation(P_gt, x_values, self.y_samples, z_values)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)

                if i in match_gt_ids:
                    color = [0, 0, 1]
                else:
                    color = [0, 1, 1]
                for k in range(1, x_2d.shape[0]):
                    # only draw the visible portion
                    if gt_visibility_mat[i, k - 1] and gt_visibility_mat[i, k]:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color[-1::-1], 3)
                ax2.plot(x_values[np.where(gt_visibility_mat[i, :])],
                         self.y_samples[np.where(gt_visibility_mat[i, :])],
                         z_values[np.where(gt_visibility_mat[i, :])], color=color, linewidth=5)

            for i in range(cnt_pred):
                x_values = pred_lanes[i][:, 0]
                z_values = pred_lanes[i][:, 1]
                x_2d, y_2d = projective_transformation(P_gt, x_values, self.y_samples, z_values)
                x_2d = x_2d.astype(np.int)
                y_2d = y_2d.astype(np.int)

                if i in match_pred_ids:
                    color = [1, 0, 0]
                else:
                    color = [1, 0, 1]
                for k in range(1, x_2d.shape[0]):
                    # only draw the visible portion
                    if pred_visibility_mat[i, k - 1] and pred_visibility_mat[i, k]:
                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color[-1::-1], 2)
                ax2.plot(x_values[np.where(pred_visibility_mat[i, :])],
                         self.y_samples[np.where(pred_visibility_mat[i, :])],
                         z_values[np.where(pred_visibility_mat[i, :])], color=color, linewidth=5)

            cv2.putText(img, 'Recall: {:.3f}'.format(r_lane / (cnt_gt + 1e-6)),
                        (5, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)
            cv2.putText(img, 'Precision: {:.3f}'.format(p_lane / (cnt_pred + 1e-6)),
                        (5, 60), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)
            ax1.imshow(img[:, :, [2, 1, 0]])

        return r_lane, p_lane, cnt_gt, cnt_pred, x_error_close, x_error_far, z_error_close, z_error_far

    # compare predicted set and ground-truth set using a fixed lane probability threshold
    def bench_one_submit(self, pred_file, gt_file, prob_th=0.5, vis=False):
        if vis:
            save_path = pred_file[:pred_file.rfind('/')]
            save_path += '/vis'
            if vis and not os.path.exists(save_path):
                try:
                    os.makedirs(save_path)
                except OSError as e:
                    print(e.message)
        # try:
        pred_lines = open(pred_file).readlines()
        json_pred = [json.loads(line) for line in pred_lines]
        # except BaseException as e:
        #     raise Exception('Fail to load json file of the prediction.')
        json_gt = [json.loads(line) for line in open(gt_file).readlines()]
        if len(json_gt) != len(json_pred):
            raise Exception('We do not get the predictions of all the test tasks')
        gts = {l['raw_file']: l for l in json_gt}

        laneline_stats = []
        laneline_x_error_close = []
        laneline_x_error_far = []
        laneline_z_error_close = []
        laneline_z_error_far = []
        centerline_stats = []
        centerline_x_error_close = []
        centerline_x_error_far = []
        centerline_z_error_close = []
        centerline_z_error_far = []
        for i, pred in enumerate(json_pred):
            if 'raw_file' not in pred or 'laneLines' not in pred:
                raise Exception('raw_file or lanelines not in some predictions.')
            raw_file = pred['raw_file']

            # if raw_file != 'images/05/0000347.jpg':
            #     continue
            pred_lanelines = pred['laneLines']
            pred_laneLines_prob = pred['laneLines_prob']
            pred_lanelines = [pred_lanelines[ii] for ii in range(len(pred_laneLines_prob)) if
                              pred_laneLines_prob[ii] > prob_th]

            if raw_file not in gts:
                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')
            gt = gts[raw_file]
            gt_cam_height = gt['cam_height']
            gt_cam_pitch = gt['cam_pitch']

            if vis:
                fig = plt.figure()
                ax1 = fig.add_subplot(221)
                ax2 = fig.add_subplot(222, projection='3d')
                ax3 = fig.add_subplot(223)
                ax4 = fig.add_subplot(224, projection='3d')
            else:
                ax1 = 0
                ax2 = 0
                ax3 = 0
                ax4 = 0

            # evaluate lanelines
            gt_lanelines = gt['laneLines']
            gt_visibility = gt['laneLines_visibility']
            # N to N matching of lanelines
            r_lane, p_lane, cnt_gt, cnt_pred, \
            x_error_close, x_error_far, \
            z_error_close, z_error_far = self.bench(pred_lanelines,
                                                    gt_lanelines,
                                                    gt_visibility,
                                                    raw_file,
                                                    gt_cam_height,
                                                    gt_cam_pitch,
                                                    vis, ax1, ax2)
            laneline_stats.append(np.array([r_lane, p_lane, cnt_gt, cnt_pred]))
            # consider x_error z_error only for the matched lanes
            # if r_lane > 0 and p_lane > 0:
            laneline_x_error_close.extend(x_error_close)
            laneline_x_error_far.extend(x_error_far)
            laneline_z_error_close.extend(z_error_close)
            laneline_z_error_far.extend(z_error_far)

            # evaluate centerlines
            if not self.no_centerline:
                pred_centerlines = pred['centerLines']
                pred_centerlines_prob = pred['centerLines_prob']
                pred_centerlines = [pred_centerlines[ii] for ii in range(len(pred_centerlines_prob)) if
                                    pred_centerlines_prob[ii] > prob_th]

                gt_centerlines = gt['centerLines']
                gt_visibility = gt['centerLines_visibility']

                # N to N matching of lanelines
                r_lane, p_lane, cnt_gt, cnt_pred, \
                x_error_close, x_error_far, \
                z_error_close, z_error_far = self.bench(pred_centerlines,
                                                        gt_centerlines,
                                                        gt_visibility,
                                                        raw_file,
                                                        gt_cam_height,
                                                        gt_cam_pitch,
                                                        vis, ax3, ax4)
                centerline_stats.append(np.array([r_lane, p_lane, cnt_gt, cnt_pred]))
                # consider x_error z_error only for the matched lanes
                # if r_lane > 0 and p_lane > 0:
                centerline_x_error_close.extend(x_error_close)
                centerline_x_error_far.extend(x_error_far)
                centerline_z_error_close.extend(z_error_close)
                centerline_z_error_far.extend(z_error_far)

            if vis:
                ax1.set_xticks([])
                ax1.set_yticks([])
                # ax2.set_xlabel('x axis')
                # ax2.set_ylabel('y axis')
                # ax2.set_zlabel('z axis')
                bottom, top = ax2.get_zlim()
                left, right = ax2.get_xlim()
                ax2.set_zlim(min(bottom, -0.1), max(top, 0.1))
                ax2.set_xlim(left, right)
                ax2.set_ylim(vis_min_y, vis_max_y)
                ax2.locator_params(nbins=5, axis='x')
                ax2.locator_params(nbins=5, axis='z')
                ax2.tick_params(pad=18)

                ax3.set_xticks([])
                ax3.set_yticks([])
                # ax4.set_xlabel('x axis')
                # ax4.set_ylabel('y axis')
                # ax4.set_zlabel('z axis')
                bottom, top = ax4.get_zlim()
                left, right = ax4.get_xlim()
                ax4.set_zlim(min(bottom, -0.1), max(top, 0.1))
                ax4.set_xlim(left, right)
                ax4.set_ylim(vis_min_y, vis_max_y)
                ax4.locator_params(nbins=5, axis='x')
                ax4.locator_params(nbins=5, axis='z')
                ax4.tick_params(pad=18)

                fig.subplots_adjust(wspace=0, hspace=0.01)
                fig.savefig(ops.join(save_path, raw_file.replace("/", "_")))
                plt.close(fig)
                print('processed sample: {}  {}'.format(i, raw_file))

        output_stats = []
        laneline_stats = np.array(laneline_stats)
        laneline_x_error_close = np.array(laneline_x_error_close)
        laneline_x_error_far = np.array(laneline_x_error_far)
        laneline_z_error_close = np.array(laneline_z_error_close)
        laneline_z_error_far = np.array(laneline_z_error_far)

        R_lane = np.sum(laneline_stats[:, 0]) / (np.sum(laneline_stats[:, 2]) + 1e-6)
        P_lane = np.sum(laneline_stats[:, 1]) / (np.sum(laneline_stats[:, 3]) + 1e-6)
        F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)
        x_error_close_avg = np.average(laneline_x_error_close)
        x_error_far_avg = np.average(laneline_x_error_far)
        z_error_close_avg = np.average(laneline_z_error_close)
        z_error_far_avg = np.average(laneline_z_error_far)

        output_stats.append(F_lane)
        output_stats.append(R_lane)
        output_stats.append(P_lane)
        output_stats.append(x_error_close_avg)
        output_stats.append(x_error_far_avg)
        output_stats.append(z_error_close_avg)
        output_stats.append(z_error_far_avg)

        if not self.no_centerline:
            centerline_stats = np.array(centerline_stats)
            centerline_x_error_close = np.array(centerline_x_error_close)
            centerline_x_error_far = np.array(centerline_x_error_far)
            centerline_z_error_close = np.array(centerline_z_error_close)
            centerline_z_error_far = np.array(centerline_z_error_far)

            R_lane = np.sum(centerline_stats[:, 0]) / (np.sum(centerline_stats[:, 2]) + 1e-6)
            P_lane = np.sum(centerline_stats[:, 1]) / (np.sum(centerline_stats[:, 3]) + 1e-6)
            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)
            x_error_close_avg = np.average(centerline_x_error_close)
            x_error_far_avg = np.average(centerline_x_error_far)
            z_error_close_avg = np.average(centerline_z_error_close)
            z_error_far_avg = np.average(centerline_z_error_far)

            output_stats.append(F_lane)
            output_stats.append(R_lane)
            output_stats.append(P_lane)
            output_stats.append(x_error_close_avg)
            output_stats.append(x_error_far_avg)
            output_stats.append(z_error_close_avg)
            output_stats.append(z_error_far_avg)

        return output_stats

    def bench_PR(self, pred_lanes, gt_lanes, gt_visibility):
        """
            Matching predicted lanes and ground-truth lanes in their IPM projection, ignoring z attributes.
            x error, y_error, and z error are all considered, although the matching does not rely on z
            The input of prediction and ground-truth lanes are in ground coordinate, x-right, y-forward, z-up
            The fundamental assumption is: 1. there are no two points from different lanes with identical x, y
                                              but different z's
                                           2. there are no two points from a single lane having identical x, y
                                              but different z's
            If the interest area is within the current drivable road, the above assumptions are almost always valid.
        :param pred_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :param gt_lanes: N X 2 or N X 3 lists depending on 2D or 3D
        :return:
        """

        r_lane, p_lane = 0., 0.

        # only keep the visible portion
        gt_lanes = [prune_3d_lane_by_visibility(np.array(gt_lane), np.array(gt_visibility[k])) for k, gt_lane in
                    enumerate(gt_lanes)]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        # only consider those gt lanes overlapping with sampling range
        gt_lanes = [lane for lane in gt_lanes if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]
        gt_lanes = [prune_3d_lane_by_range(np.array(gt_lane), 3 * self.x_min, 3 * self.x_max) for gt_lane in gt_lanes]
        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]
        cnt_gt = len(gt_lanes)
        cnt_pred = len(pred_lanes)

        gt_visibility_mat = np.zeros((cnt_gt, 100))
        pred_visibility_mat = np.zeros((cnt_pred, 100))
        # resample gt and pred at y_samples
        for i in range(cnt_gt):
            min_y = np.min(np.array(gt_lanes[i])[:, 1])
            max_y = np.max(np.array(gt_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(gt_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            gt_lanes[i] = np.vstack([x_values, z_values]).T
            gt_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                     np.logical_and(x_values <= self.x_max,
                                                                    np.logical_and(self.y_samples >= min_y,
                                                                                   self.y_samples <= max_y)))
            gt_visibility_mat[i, :] = np.logical_and(gt_visibility_mat[i, :], visibility_vec)

        for i in range(cnt_pred):
            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size
            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))
            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)
            min_y = np.min(np.array(pred_lanes[i])[:, 1])
            max_y = np.max(np.array(pred_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples,
                                                                        out_vis=True)
            pred_lanes[i] = np.vstack([x_values, z_values]).T
            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                       np.logical_and(x_values <= self.x_max,
                                                                      np.logical_and(self.y_samples >= min_y,
                                                                                     self.y_samples <= max_y)))
            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)
            # pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, x_values <= self.x_max)

        adj_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)
        cost_mat.fill(1000)
        num_match_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.float)
        # compute curve to curve distance
        for i in range(cnt_gt):
            for j in range(cnt_pred):
                x_dist = np.abs(gt_lanes[i][:, 0] - pred_lanes[j][:, 0])
                z_dist = np.abs(gt_lanes[i][:, 1] - pred_lanes[j][:, 1])
                euclidean_dist = np.sqrt(x_dist ** 2 + z_dist ** 2)

                # apply visibility to penalize different partial matching accordingly
                euclidean_dist[
                    np.logical_or(gt_visibility_mat[i, :] < 0.5, pred_visibility_mat[j, :] < 0.5)] = self.dist_th

                # if np.average(euclidean_dist) < 2*self.dist_th: # don't prune here to encourage finding perfect match
                num_match_mat[i, j] = np.sum(euclidean_dist < self.dist_th)
                adj_mat[i, j] = 1
                # ATTENTION: use the sum as int type to meet the requirements of min cost flow optimization (int type)
                # why using num_match_mat as cost does not work?
                cost_mat[i, j] = np.sum(euclidean_dist).astype(np.int)
                # cost_mat[i, j] = num_match_mat[i, j]

        # solve bipartite matching vis min cost flow solver
        match_results = SolveMinCostFlow(adj_mat, cost_mat)
        match_results = np.array(match_results)

        # only a match with avg cost < self.dist_th is consider valid one
        match_gt_ids = []
        match_pred_ids = []
        if match_results.shape[0] > 0:
            for i in range(len(match_results)):
                if match_results[i, 2] < self.dist_th * self.y_samples.shape[0]:
                    gt_i = match_results[i, 0]
                    pred_i = match_results[i, 1]
                    # consider match when the matched points is above a ratio
                    if num_match_mat[gt_i, pred_i] / np.sum(gt_visibility_mat[gt_i, :]) >= self.ratio_th:
                        r_lane += 1
                        match_gt_ids.append(gt_i)
                    if num_match_mat[gt_i, pred_i] / np.sum(pred_visibility_mat[pred_i, :]) >= self.ratio_th:
                        p_lane += 1
                        match_pred_ids.append(pred_i)

        return r_lane, p_lane, cnt_gt, cnt_pred

    # evaluate two dataset at varying lane probability threshold to calculate AP
    def bench_one_submit_varying_probs(self, pred_file, gt_file, eval_out_file=None, eval_fig_file=None):
        varying_th = np.linspace(0.05, 0.95, 19)
        # try:
        pred_lines = open(pred_file).readlines()
        json_pred = [json.loads(line) for line in pred_lines]
        # except BaseException as e:
        #     raise Exception('Fail to load json file of the prediction.')
        json_gt = [json.loads(line) for line in open(gt_file).readlines()]
        if len(json_gt) != len(json_pred):
            raise Exception('We do not get the predictions of all the test tasks')
        gts = {l['raw_file']: l for l in json_gt}

        laneline_r_all = []
        laneline_p_all = []
        laneline_gt_cnt_all = []
        laneline_pred_cnt_all = []
        centerline_r_all = []
        centerline_p_all = []
        centerline_gt_cnt_all = []
        centerline_pred_cnt_all = []
        for i, pred in enumerate(json_pred):
            print('Evaluating sample {} / {}'.format(i, len(json_pred)))
            if 'raw_file' not in pred or 'laneLines' not in pred:
                raise Exception('raw_file or lanelines not in some predictions.')
            raw_file = pred['raw_file']

            pred_lanelines = pred['laneLines']
            pred_laneLines_prob = pred['laneLines_prob']
            if raw_file not in gts:
                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')
            gt = gts[raw_file]
            gt_cam_height = gt['cam_height']
            gt_cam_pitch = gt['cam_pitch']

            # evaluate lanelines
            gt_lanelines = gt['laneLines']
            gt_visibility = gt['laneLines_visibility']
            r_lane_vec = []
            p_lane_vec = []
            cnt_gt_vec = []
            cnt_pred_vec = []

            for prob_th in varying_th:
                pred_lanelines = [pred_lanelines[ii] for ii in range(len(pred_laneLines_prob)) if
                                  pred_laneLines_prob[ii] > prob_th]
                pred_laneLines_prob = [prob for prob in pred_laneLines_prob if prob > prob_th]
                pred_lanelines_copy = copy.deepcopy(pred_lanelines)
                # N to N matching of lanelines
                r_lane, p_lane, cnt_gt, cnt_pred = self.bench_PR(pred_lanelines_copy,
                                                                 gt_lanelines,
                                                                 gt_visibility)
                r_lane_vec.append(r_lane)
                p_lane_vec.append(p_lane)
                cnt_gt_vec.append(cnt_gt)
                cnt_pred_vec.append(cnt_pred)

            laneline_r_all.append(r_lane_vec)
            laneline_p_all.append(p_lane_vec)
            laneline_gt_cnt_all.append(cnt_gt_vec)
            laneline_pred_cnt_all.append(cnt_pred_vec)

            # evaluate centerlines
            if not self.no_centerline:
                pred_centerlines = pred['centerLines']
                pred_centerLines_prob = pred['centerLines_prob']
                gt_centerlines = gt['centerLines']
                gt_visibility = gt['centerLines_visibility']
                r_lane_vec = []
                p_lane_vec = []
                cnt_gt_vec = []
                cnt_pred_vec = []

                for prob_th in varying_th:
                    pred_centerlines = [pred_centerlines[ii] for ii in range(len(pred_centerLines_prob)) if
                                        pred_centerLines_prob[ii] > prob_th]
                    pred_centerLines_prob = [prob for prob in pred_centerLines_prob if prob > prob_th]
                    pred_centerlines_copy = copy.deepcopy(pred_centerlines)
                    # N to N matching of lanelines
                    r_lane, p_lane, cnt_gt, cnt_pred = self.bench_PR(pred_centerlines_copy,
                                                                     gt_centerlines,
                                                                     gt_visibility)
                    r_lane_vec.append(r_lane)
                    p_lane_vec.append(p_lane)
                    cnt_gt_vec.append(cnt_gt)
                    cnt_pred_vec.append(cnt_pred)
                centerline_r_all.append(r_lane_vec)
                centerline_p_all.append(p_lane_vec)
                centerline_gt_cnt_all.append(cnt_gt_vec)
                centerline_pred_cnt_all.append(cnt_pred_vec)

        output_stats = []
        # compute precision, recall
        laneline_r_all = np.array(laneline_r_all)
        laneline_p_all = np.array(laneline_p_all)
        laneline_gt_cnt_all = np.array(laneline_gt_cnt_all)
        laneline_pred_cnt_all = np.array(laneline_pred_cnt_all)

        R_lane = np.sum(laneline_r_all, axis=0) / (np.sum(laneline_gt_cnt_all, axis=0) + 1e-6)
        P_lane = np.sum(laneline_p_all, axis=0) / (np.sum(laneline_pred_cnt_all, axis=0) + 1e-6)
        F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)

        output_stats.append(F_lane)
        output_stats.append(R_lane)
        output_stats.append(P_lane)

        if not self.no_centerline:
            centerline_r_all = np.array(centerline_r_all)
            centerline_p_all = np.array(centerline_p_all)
            centerline_gt_cnt_all = np.array(centerline_gt_cnt_all)
            centerline_pred_cnt_all = np.array(centerline_pred_cnt_all)

            R_lane = np.sum(centerline_r_all, axis=0) / (np.sum(centerline_gt_cnt_all, axis=0) + 1e-6)
            P_lane = np.sum(centerline_p_all, axis=0) / (np.sum(centerline_pred_cnt_all, axis=0) + 1e-6)
            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)

            output_stats.append(F_lane)
            output_stats.append(R_lane)
            output_stats.append(P_lane)

        # calculate metrics
        laneline_F = output_stats[0]
        laneline_F_max = np.max(laneline_F)
        laneline_max_i = np.argmax(laneline_F)
        laneline_R = output_stats[1]
        laneline_P = output_stats[2]
        centerline_F = output_stats[3]
        centerline_F_max = centerline_F[laneline_max_i]
        centerline_max_i = laneline_max_i
        centerline_R = output_stats[4]
        centerline_P = output_stats[5]

        laneline_R = np.array([1.] + laneline_R.tolist() + [0.])
        laneline_P = np.array([0.] + laneline_P.tolist() + [1.])
        centerline_R = np.array([1.] + centerline_R.tolist() + [0.])
        centerline_P = np.array([0.] + centerline_P.tolist() + [1.])
        f_laneline = interp1d(laneline_R, laneline_P)
        f_centerline = interp1d(centerline_R, centerline_P)
        r_range = np.linspace(0.05, 0.95, 19)
        laneline_AP = np.mean(f_laneline(r_range))
        centerline_AP = np.mean(f_centerline(r_range))

        if eval_fig_file is not None:
            # plot PR curve
            fig = plt.figure()
            ax1 = fig.add_subplot(121)
            ax2 = fig.add_subplot(122)
            ax1.plot(laneline_R, laneline_P, '-s')
            ax2.plot(centerline_R, centerline_P, '-s')

            ax1.set_xlim(0, 1)
            ax1.set_ylim(0, 1)
            ax1.set_title('Lane Line')
            ax1.set_xlabel('Recall')
            ax1.set_ylabel('Precision')
            ax1.set_aspect('equal')
            ax1.legend('Max F-measure {:.3}'.format(laneline_F_max))

            ax2.set_xlim(0, 1)
            ax2.set_ylim(0, 1)
            ax2.set_title('Center Line')
            ax2.set_xlabel('Recall')
            ax2.set_ylabel('Precision')
            ax2.set_aspect('equal')
            ax2.legend('Max F-measure {:.3}'.format(centerline_F_max))

            # fig.subplots_adjust(wspace=0.1, hspace=0.01)
            fig.savefig(eval_fig_file)
            plt.close(fig)

        # print("===> Evaluation on validation set: \n"
        #       "laneline max F-measure {:.3} at Recall {:.3}, Precision {:.3} \n"
        #       "laneline AP: {:.3}\n"
        #       "centerline max F-measure {:.3} at Recall {:.3}, Precision {:.3} \n"
        #       "centerline AP: {:.3} \n".format(laneline_F_max,
        #                                        laneline_R[laneline_max_i + 1],
        #                                        laneline_P[laneline_max_i + 1],
        #                                        laneline_AP,
        #                                        centerline_F_max,
        #                                        centerline_R[centerline_max_i + 1],
        #                                        centerline_P[centerline_max_i + 1],
        #                                        centerline_AP))

        json_out = {}
        json_out['laneline_R'] = laneline_R[1:-1].astype(np.float32).tolist()
        json_out['laneline_P'] = laneline_P[1:-1].astype(np.float32).tolist()
        json_out['laneline_F_max'] = laneline_F_max
        json_out['laneline_max_i'] = laneline_max_i.tolist()
        json_out['laneline_AP'] = laneline_AP

        json_out['centerline_R'] = centerline_R[1:-1].astype(np.float32).tolist()
        json_out['centerline_P'] = centerline_P[1:-1].astype(np.float32).tolist()
        json_out['centerline_F_max'] = centerline_F_max
        json_out['centerline_max_i'] = centerline_max_i.tolist()
        json_out['centerline_AP'] = centerline_AP

        json_out['max_F_prob_th'] = varying_th[laneline_max_i]

        if eval_out_file is not None:
            with open(eval_out_file, 'w') as jsonFile:
                jsonFile.write(json.dumps(json_out))
                jsonFile.write('\n')
                jsonFile.close()
        return json_out

'''
if __name__ == '__main__':
    import sys
    print("start")
    vis = True
    args = define_args()
    #args = parser.parse_args()

    # two method are compared: '3D_LaneNet' and 'Gen_LaneNet'
    method_name = 'Gen_LaneNet_ext'

    # Three different splits of datasets: 'standard', 'rare_subsit', 'illus_chg'
    data_split = 'illus_chg'

    # location where the original dataset is saved. Image will be loaded in case of visualization
    args.dataset_dir = '/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release'

    # load configuration for certain dataset
    sim3d_config(args)

    # auto-file in dependent paths
    gt_file = '/content/drive/Shareddrives/colab/data_splits/' + data_split + '/test.json'
    pred_folder = '/content/drive/Shareddrives/colab/data_splits/' + data_split + '/' + method_name
    pred_file = pred_folder + '/test_pred_file.json'

    # Initialize evaluator
    evaluator = LaneEval(args)
    print("here")
    # evaluation at varying thresholds
    eval_stats_pr = evaluator.bench_one_submit_varying_probs(pred_file, gt_file)
    max_f_prob = eval_stats_pr['max_F_prob_th']
    print("done")
    # evaluate at the point with max F-measure. Additional eval of position error. Option to visualize matching result
    eval_stats = evaluator.bench_one_submit(pred_file, gt_file, prob_th=max_f_prob, vis=vis)

    print("Metrics: AP, F-score, x error (close), x error (far), z error (close), z error (far)")
    print(
        "Laneline:  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}".format(eval_stats_pr['laneline_AP'], eval_stats[0],
                                                                     eval_stats[3], eval_stats[4],
                                                                     eval_stats[5], eval_stats[6]))
    print("Centerline:  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}  {:.3}".format(eval_stats_pr['centerline_AP'], eval_stats[7],
                                                                         eval_stats[10], eval_stats[11],
                                                                         eval_stats[12], eval_stats[13]))
                                                                         '''

'\nif __name__ == \'__main__\':\n    import sys\n    print("start")\n    vis = True\n    args = define_args()\n    #args = parser.parse_args()\n\n    # two method are compared: \'3D_LaneNet\' and \'Gen_LaneNet\'\n    method_name = \'Gen_LaneNet_ext\'\n\n    # Three different splits of datasets: \'standard\', \'rare_subsit\', \'illus_chg\'\n    data_split = \'illus_chg\'\n\n    # location where the original dataset is saved. Image will be loaded in case of visualization\n    args.dataset_dir = \'/content/drive/Shareddrives/colab/Apollo_Sim_3D_Lane_Release\'\n\n    # load configuration for certain dataset\n    sim3d_config(args)\n\n    # auto-file in dependent paths\n    gt_file = \'/content/drive/Shareddrives/colab/data_splits/\' + data_split + \'/test.json\'\n    pred_folder = \'/content/drive/Shareddrives/colab/data_splits/\' + data_split + \'/\' + method_name\n    pred_file = pred_folder + \'/test_pred_file.json\'\n\n    # Initialize evaluator\n    evaluator = LaneEval(args)\n    pri

In [17]:
"""
Description: Visualization code to draw predicted lane-lines and center-lines in three views: image, virtual top,
             3D ego-car. Respectively, lane-lines are shown in the top row, center-lines are drawn in the bottom.
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import numpy as np
import cv2
import os
import os.path as ops
import math
import ujson as json
import matplotlib
#from tools.utils import *
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (35, 30)
plt.rcParams.update({'font.size': 25})
plt.rcParams.update({'font.weight': 'semibold'})

min_y = 0
max_y = 80

colors = [[1, 0, 0],  # red
          [0, 1, 0],  # green
          [0, 0, 1],  # blue
          [1, 0, 1],  # purple
          [0, 1, 1],  # cyan
          [1, 0.7, 0]]  # orange


class lane_visualizer(object):
    def __init__(self, args):
        self.dataset_dir = args.dataset_dir
        self.K = args.K
        self.no_centerline = args.no_centerline
        """
            this visualizer use higher resolution than network input for better look
        """
        self.resize_h = args.org_h
        self.resize_w = args.org_w
        # self.resize_h = args.resize_h
        # self.resize_w = args.resize_w
        self.ipm_w = 2*args.ipm_w
        self.ipm_h = 2*args.ipm_h
        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [self.resize_h, self.resize_w])
        # transformation from ipm to ground region
        self.H_ipm2g = cv2.getPerspectiveTransform(np.float32([[0, 0],
                                                              [self.ipm_w-1, 0],
                                                              [0, self.ipm_h-1],
                                                              [self.ipm_w-1, self.ipm_h-1]]),
                                                   np.float32(args.top_view_region))
        self.H_g2ipm = np.linalg.inv(self.H_ipm2g)

        self.x_min = args.top_view_region[0, 0]
        self.x_max = args.top_view_region[1, 0]
        # self.y_samples = np.linspace(args.anchor_y_steps[0], args.anchor_y_steps[-1], num=100, endpoint=False)
        self.y_samples = np.linspace(min_y, max_y, num=100, endpoint=False)

    def visualize_lanes(self, pred_lanes, raw_file, gt_cam_height, gt_cam_pitch, ax1, ax2, ax3):
        P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, self.K)
        # P_gt = P_g2im
        P_gt = np.matmul(self.H_crop, P_g2im)
        H_g2im = homograpthy_g2im(gt_cam_pitch, gt_cam_height, self.K)
        H_im2ipm = np.linalg.inv(np.matmul(self.H_crop, np.matmul(H_g2im, self.H_ipm2g)))

        img = cv2.imread(ops.join(self.dataset_dir, raw_file))
        img = cv2.warpPerspective(img, self.H_crop, (self.resize_w, self.resize_h))
        img = img.astype(np.float) / 255
        im_ipm = cv2.warpPerspective(img, H_im2ipm, (self.ipm_w, self.ipm_h))
        im_ipm = np.clip(im_ipm, 0, 1)

        cnt_pred = len(pred_lanes)
        pred_visibility_mat = np.zeros((cnt_pred, 100))
        for i in range(cnt_pred):
            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size
            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))
            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)
            min_y = np.min(np.array(pred_lanes[i])[:, 1])
            max_y = np.max(np.array(pred_lanes[i])[:, 1])
            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples, out_vis=True)
            pred_lanes[i] = np.vstack([x_values, z_values]).T
            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,
                                                       np.logical_and(x_values <= self.x_max,
                                                                      np.logical_and(self.y_samples >= min_y,
                                                                                     self.y_samples <= max_y)))
            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)

        # draw lanes in multiple color
        for i in range(cnt_pred):
            x_values = pred_lanes[i][:, 0]
            z_values = pred_lanes[i][:, 1]
            # if 'gflat' in pred_file or 'ext' in pred_file:
            x_ipm_values, y_ipm_values = transform_lane_g2gflat(gt_cam_height, x_values, self.y_samples, z_values)
            # remove those points with z_values > gt_cam_height, this is only for visualization on top-view
            x_ipm_values = x_ipm_values[np.where(z_values < gt_cam_height)]
            y_ipm_values = y_ipm_values[np.where(z_values < gt_cam_height)]
            # else:  # mean to visualize original anchor's preparation
            #     x_ipm_values = x_values
            #     y_ipm_values = self.y_samples
            x_ipm_values, y_ipm_values = homographic_transformation(self.H_g2ipm, x_ipm_values, y_ipm_values)
            x_ipm_values = x_ipm_values.astype(np.int)
            y_ipm_values = y_ipm_values.astype(np.int)
            x_2d, y_2d = projective_transformation(P_gt, x_values, self.y_samples, z_values)
            x_2d = x_2d.astype(np.int)
            y_2d = y_2d.astype(np.int)

            color = colors[np.mod(i, len(colors))]
            # draw on image
            for k in range(1, x_2d.shape[0]):
                # only draw the visible portion
                if pred_visibility_mat[i, k - 1] and pred_visibility_mat[i, k]:
                    img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color[-1::-1], 10)

            # draw on ipm
            for k in range(1, x_ipm_values.shape[0]):
                # only draw the visible portion
                if pred_visibility_mat[i, k - 1] and pred_visibility_mat[i, k]:
                    im_ipm = cv2.line(im_ipm, (x_ipm_values[k - 1], y_ipm_values[k - 1]), (x_ipm_values[k], y_ipm_values[k]), color[-1::-1], 3)

            # draw in 3d
            ax3.plot(x_values[np.where(pred_visibility_mat[i, :])],
                     self.y_samples[np.where(pred_visibility_mat[i, :])],
                     z_values[np.where(pred_visibility_mat[i, :])], color=color, linewidth=5)
        ax1.imshow(img[:, :, [2, 1, 0]])
        ax2.imshow(im_ipm[:, :, [2, 1, 0]])

'''
if __name__ == '__main__':
    parser = define_args()
    args = parser.parse_args()

    # dataset_name: 'standard' / 'rare_subset' / 'illus_chg'
    args.dataset_name = 'illus_chg'
    args.dataset_dir = '/media/yuliangguo/DATA1/Datasets/Apollo_Sim_3D_Lane_Release/'

    # model name: 'Gen_LaneNet_ext' / '3D_LaneNet'
    model_name = 'Gen_LaneNet_ext'

    # load configuration for certain dataset
    sim3d_config(args)
    args.top_view_region = np.array([[-10, max_y], [10, max_y], [-10, 3], [10, 3]])
    vs = lane_visualizer(args)

    pred_file = '../data_splits/' + args.dataset_name + '/' + model_name + '/test_pred_file.json'
    gt_file = '../data_splits/' + args.dataset_name + '/test.json'

    save_path = pred_file[:pred_file.rfind('/')]
    save_path += '/example/test_vis_pred'
    if not os.path.exists(save_path):
        try:
            os.makedirs(save_path)
        except OSError as e:
            print(e.message)

    pred_lines = open(pred_file).readlines()
    json_pred = [json.loads(line) for line in pred_lines]
    # except BaseException as e:
    #     raise Exception('Fail to load json file of the prediction.')
    json_gt = [json.loads(line) for line in open(gt_file).readlines()]
    if len(json_gt) != len(json_pred):
        raise Exception('We do not get the predictions of all the test tasks')
    gts = {l['raw_file']: l for l in json_gt}

    for i, pred in enumerate(json_pred):
        raw_file = pred['raw_file']

        pred_lanelines = pred['laneLines']
        pred_centerlines = pred['centerLines']

        if raw_file not in gts:
            continue
        gt = gts[raw_file]
        gt_cam_height = gt['cam_height']
        gt_cam_pitch = gt['cam_pitch']

        fig = plt.figure()
        ax1 = fig.add_subplot(231)
        ax2 = fig.add_subplot(232)
        ax3 = fig.add_subplot(233, projection='3d')
        ax4 = fig.add_subplot(234)
        ax5 = fig.add_subplot(235)
        ax6 = fig.add_subplot(236, projection='3d')

        # draw lanes
        vs.visualize_lanes(pred_lanelines, raw_file, gt_cam_height, gt_cam_pitch, ax1, ax2, ax3)
        vs.visualize_lanes(pred_centerlines, raw_file, gt_cam_height, gt_cam_pitch, ax4, ax5, ax6)
        ax1.set_xticks([])
        ax1.set_yticks([])
        ax2.set_xticks([])
        ax2.set_yticks([])
        # ax2.set_xlabel('x axis')
        # ax2.set_ylabel('y axis')
        # ax2.set_zlabel('z axis')
        bottom, top = ax3.get_zlim()
        left, right = ax3.get_xlim()
        ax3.set_zlim(min(bottom, -0.1), max(top, 0.1))
        ax3.set_xlim(left, right)
        ax3.set_ylim(min_y, max_y)
        ax3.locator_params(nbins=5, axis='x')
        ax3.locator_params(nbins=5, axis='z')
        ax3.tick_params(pad=18)

        ax4.set_xticks([])
        ax4.set_yticks([])
        ax5.set_xticks([])
        ax5.set_yticks([])
        # ax4.set_xlabel('x axis')
        # ax4.set_ylabel('y axis')
        # ax4.set_zlabel('z axis')
        bottom, top = ax6.get_zlim()
        left, right = ax6.get_xlim()
        ax6.set_zlim(min(bottom, -0.1), max(top, 0.1))
        ax6.set_xlim(left, right)
        ax6.set_ylim(min_y, max_y)
        ax6.locator_params(nbins=5, axis='x')
        ax6.locator_params(nbins=5, axis='z')
        ax6.tick_params(pad=18)

        fig.subplots_adjust(wspace=0, hspace=0.01)
        fig.savefig(ops.join(save_path, raw_file.replace("/", "_")))
        plt.close(fig)
        print('processed sample: {}  {}'.format(i, raw_file))
        '''

'\nif __name__ == \'__main__\':\n    parser = define_args()\n    args = parser.parse_args()\n\n    # dataset_name: \'standard\' / \'rare_subset\' / \'illus_chg\'\n    args.dataset_name = \'illus_chg\'\n    args.dataset_dir = \'/media/yuliangguo/DATA1/Datasets/Apollo_Sim_3D_Lane_Release/\'\n\n    # model name: \'Gen_LaneNet_ext\' / \'3D_LaneNet\'\n    model_name = \'Gen_LaneNet_ext\'\n\n    # load configuration for certain dataset\n    sim3d_config(args)\n    args.top_view_region = np.array([[-10, max_y], [10, max_y], [-10, 3], [10, 3]])\n    vs = lane_visualizer(args)\n\n    pred_file = \'../data_splits/\' + args.dataset_name + \'/\' + model_name + \'/test_pred_file.json\'\n    gt_file = \'../data_splits/\' + args.dataset_name + \'/test.json\'\n\n    save_path = pred_file[:pred_file.rfind(\'/\')]\n    save_path += \'/example/test_vis_pred\'\n    if not os.path.exists(save_path):\n        try:\n            os.makedirs(save_path)\n        except OSError as e:\n            print(e.message

In [24]:

"""
A demo for Gen-LaneNet with new anchor extension. It predicts 3D lanes from a single image.
Author: Yuliang Guo (33yuliangguo@gmail.com)
Date: March, 2020
"""

import numpy as np
import torch
import torch.optim
import glob
from tqdm import tqdm
import torchvision.transforms.functional as Q
#from dataloader.Load_Data_3DLane_ext import *
#from networks import GeoNet3D_ext, erfnet
#from tools.utils import *
#from tools.visualize_pred import lane_visualizer


def unormalize_lane_anchor(anchor, num_y_steps, anchor_dim, x_off_std, z_std, num_types=3):
    for i in range(num_types):
        anchor[:, i*anchor_dim:i*anchor_dim + num_y_steps] = \
            np.multiply(anchor[:, i*anchor_dim: i*anchor_dim + num_y_steps], x_off_std)
        anchor[:, i*anchor_dim + num_y_steps: i*anchor_dim + 2*num_y_steps] = \
            np.multiply(anchor[:, i*anchor_dim + num_y_steps: i*anchor_dim + 2*num_y_steps], z_std)


def load_my_state_dict(model, state_dict):  # custom function to load model when not all dict elements
    own_state = model.state_dict()
    ckpt_name = []
    cnt = 0
    for name, param in state_dict.items():
        if name[7:] not in list(own_state.keys()) or 'output_conv' in name:
            ckpt_name.append(name)
            # continue
        own_state[name[7:]].copy_(param)
        cnt += 1
    print('#reused param: {}'.format(cnt))
    return model


if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    args = define_args()
    #args = parser.parse_args()

    # manual settings
    image_file = '/content/drive/Shareddrives/colab/sample_image.jpg'
    cam_file = '/content/drive/Shareddrives/colab/sample_image_cam.json'
    args.mod = 'Gen_LaneNet_ext'  # model name
    pretrained_feat_model = '/content/drive/Shareddrives/colab/erfnet_model_sim3d.tar'
    trained_geo_model = '/content/drive/Shareddrives/colab/gen_lanenet_geo_model.tar'
    anchor_std_file = '/content/drive/Shareddrives/colab/geo_anchor_std.json'

    # load configuration for the model
    sim3d_config(args)
    args.y_ref = 5
    args.batch_size = 1
    anchor_y_steps = args.anchor_y_steps
    num_y_steps = len(anchor_y_steps)
    anchor_dim = 3 * num_y_steps + 1
    x_min = args.top_view_region[0, 0]
    x_max = args.top_view_region[1, 0]
    anchor_x_steps = np.linspace(x_min, x_max, np.int(args.ipm_w / 8), endpoint=True)

    # Check GPU availability
    if not args.no_cuda and not torch.cuda.is_available():
        raise Exception("No gpu available for usage")
    torch.backends.cudnn.benchmark = args.cudnn

    # Define network
    model_seg = ERFNet(2)  # 2-class model
    model_geo = Net(args)
    define_init_weights(model_geo, args.weight_init)

    if not args.no_cuda:
        # Load model on gpu before passing params to optimizer
        model_seg = model_seg.cuda()
        model_geo = model_geo.cuda()

    # load segmentation model
    checkpoint = torch.load(pretrained_feat_model)
    model_seg = load_my_state_dict(model_seg, checkpoint['state_dict'])
    model_seg.eval()  # do not back propagate to model1

    # load geometry model
    if os.path.isfile(trained_geo_model):
        print("=> loading checkpoint '{}'".format(trained_geo_model))
        checkpoint = torch.load(trained_geo_model)
        model_geo.load_state_dict(checkpoint['state_dict'])
        model_geo.eval()
    else:
        print("=> no checkpoint found at '{}'".format(trained_geo_model))

    # load anchor std saved from training
    with open(anchor_std_file) as f:
        anchor_std = json.load(f)
    x_off_std = np.array(anchor_std['x_off_std'])
    z_std = np.array(anchor_std['z_std'])

    #  load image
    with open(image_file, 'rb') as f:
        image = (Image.open(f).convert('RGB'))
    # image preprocess
    w, h = image.size
    image = Q.crop(image, args.crop_y, 0, args.org_h - args.crop_y, w)
    image = Q.resize(image, size=(args.resize_h, args.resize_w), interpolation=Image.BILINEAR)
    image = transforms.ToTensor()(image).float()
    image = transforms.Normalize(args.vgg_mean, args.vgg_std)(image)
    image.unsqueeze_(0)
    image = torch.cat(list(torch.split(image, 1, dim=0)) * args.batch_size)

    if not args.no_cuda:
        image = image.cuda()
    # image = image.contiguous()
    # image = torch.autograd.Variable(image)

    # update camera setting os the model
    with open(cam_file) as f:
        cam_params = json.load(f)
    gt_pitch = torch.tensor([cam_params['cameraPitch']], dtype=torch.float32)
    gt_hcam = torch.tensor([cam_params['cameraHeight']], dtype=torch.float32)
    model_geo.update_projection(args, gt_hcam, gt_pitch)

    with torch.no_grad():
        # deploy model
        try:
            output_seg = model_seg(image, no_lane_exist=True)
            # output1 = F.softmax(output1, dim=1)
            output_seg = output_seg.softmax(dim=1)
            output_seg = output_seg / torch.max(torch.max(output_seg, dim=2, keepdim=True)[0], dim=3, keepdim=True)[0]
            output_seg = output_seg[:, 1:, :, :]
            output_geo, pred_hcam, pred_pitch = model_geo(output_seg)
        except RuntimeError as e:
            print(e)

    output_geo = output_geo[0].data.cpu().numpy()

    # unormalize lane outputs
    unormalize_lane_anchor(output_geo, num_y_steps, anchor_dim, x_off_std, z_std, num_types=3)

    # compute 3D lanes from network output, geometric transformation is involved
    lanelines_pred, centerlines_pred, lanelines_prob, centerlines_prob = \
        compute_3d_lanes_all_prob(output_geo, anchor_dim, anchor_x_steps, anchor_y_steps, cam_params['cameraHeight'])

    # visualize predicted lanes
    # args.top_view_region = np.array([[-10, 80], [10, 80], [-10, 3], [10, 3]])
    vs = lane_visualizer(args)
    vs.dataset_dir = './'

    fig = plt.figure()
    ax1 = fig.add_subplot(231)
    ax2 = fig.add_subplot(232)
    ax3 = fig.add_subplot(233, projection='3d')
    ax4 = fig.add_subplot(234)
    ax5 = fig.add_subplot(235)
    ax6 = fig.add_subplot(236, projection='3d')

    # draw lanes
    vs.visualize_lanes(lanelines_pred, image_file, cam_params['cameraHeight'], cam_params['cameraPitch'], ax1, ax2, ax3)
    vs.visualize_lanes(centerlines_pred, image_file, cam_params['cameraHeight'], cam_params['cameraPitch'], ax4, ax5, ax6)
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax2.set_xticks([])
    ax2.set_yticks([])
    bottom, top = ax3.get_zlim()
    left, right = ax3.get_xlim()
    ax3.set_zlim(min(bottom, -0.1), max(top, 0.1))
    ax3.set_xlim(left, right)
    ax3.set_ylim(0, 80)
    ax3.locator_params(nbins=5, axis='x')
    ax3.locator_params(nbins=5, axis='z')
    ax3.tick_params(pad=18)

    ax4.set_xticks([])
    ax4.set_yticks([])
    ax5.set_xticks([])
    ax5.set_yticks([])

    bottom, top = ax6.get_zlim()
    left, right = ax6.get_xlim()
    ax6.set_zlim(min(bottom, -0.1), max(top, 0.1))
    ax6.set_xlim(left, right)
    ax6.set_ylim(0, 80)
    ax6.locator_params(nbins=5, axis='x')
    ax6.locator_params(nbins=5, axis='z')
    ax6.tick_params(pad=18)

    fig.subplots_adjust(wspace=0, hspace=0.01)
    fig.savefig('test.png')
    plt.close(fig)
    print("done!")

Init weights in network with [normal]
#reused param: 357
=> loading checkpoint '/content/drive/Shareddrives/colab/gen_lanenet_geo_model.tar'


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Default grid_sample and affine_grid behavior has changed "


done!
