# Lane Detection using Vision Transformer

## Abstract

In this notebook, we perform Lane Detection by fine tuning a Vision Transformer Model on the OpenLaneV2 dataset. This task is essential for many downstream tasks like Bird's Eye View Representation, Motion Planning, 3D Occupancy Detection etc. 

## Dataset Description

The OpenLane dataset consists of multi-view images taken by a cameras mounted on an ego vehicle. The dataset provides annotations for Lane lines, Traffic elements,Area elements(Pedestrian Crossing, Sidewalk etc)

Dataset Link - [here](https://github.com/OpenDriveLab/OpenLane-V2/blob/master/data/README.md#download)

### RTDETR for Object Detection

In [2]:
import torch
import requests
from PIL import Image
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor

In [3]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg' 
image = Image.open(requests.get(url, stream=True).raw)

image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")

inputs = image_processor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)

for result in results:
    for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
        score, label = score.item(), label_id.item()
        box = [round(i, 2) for i in box.tolist()]
        print(f"{model.config.id2label[label]}: {score:.2f} {box}")


preprocessor_config.json:   0%|          | 0.00/841 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/5.11k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/172M [00:00<?, ?B/s]

sofa: 0.97 [0.14, 0.38, 640.13, 476.21]
cat: 0.96 [343.38, 24.28, 640.14, 371.5]
cat: 0.96 [13.23, 54.18, 318.98, 472.22]
remote: 0.95 [40.11, 73.44, 175.96, 118.48]
remote: 0.92 [333.73, 76.58, 369.97, 186.99]


### ResNet for Laneline Detection

In [47]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset
import os
import sys

In [61]:
import pandas as pd

data_dict_subset_A = pd.read_pickle('../../OpenLane-V2/data/OpenLane-V2/data_dict_subset_A.pkl')

In [62]:
data_dict_train = {}
data_dict_val = {}

for i in data_dict_subset_A.keys():
    if i[0] == 'train':
        data_dict_train[i] = data_dict_subset_A[i]

dict_keys([('train', '00000', '315967376899927209'), ('train', '00000', '315967377349927211'), ('train', '00000', '315967377849927217'), ('train', '00000', '315967378349927223'), ('train', '00000', '315967378849927221'), ('train', '00000', '315967379349927219'), ('train', '00000', '315967379849927218'), ('train', '00000', '315967380349927220'), ('train', '00000', '315967380849927218'), ('train', '00000', '315967381349927208'), ('train', '00000', '315967381849927221'), ('train', '00000', '315967382349927219'), ('train', '00000', '315967382849927217'), ('train', '00000', '315967383349927223'), ('train', '00000', '315967383849927213'), ('train', '00000', '315967384349927219'), ('train', '00000', '315967384849927217'), ('train', '00000', '315967385349927219'), ('train', '00000', '315967385849927217'), ('train', '00000', '315967386349927219'), ('train', '00000', '315967386849927220'), ('train', '00000', '315967387349927213'), ('train', '00000', '315967387849927213'), ('train', '00000', '315

In [42]:
# Resnet 50
from torchvision.models import resnet50, ResNet50_Weights

model = resnet50(weights=ResNet50_Weights.DEFAULT)


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/akshaybharadwaj/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:29<00:00, 3.45MB/s]


### Project 3d to 2d 

In [36]:
# ==============================================================================
# Binaries and/or source for the following packages or projects 
# are presented under one or more of the following open source licenses:
# utils.py    The OpenLane-V2 Dataset Authors    Apache License, Version 2.0
#
# Contact wanghuijie@pjlab.org.cn if you have any issue.
#
# Copyright (c) 2023 The OpenLane-V2 Dataset Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np


THICKNESS = 4

COLOR_DEFAULT = (0, 0, 255)
COLOR_DICT = {
    0:  COLOR_DEFAULT,
    1:  (255, 0, 0),
    2:  (0, 255, 0),
    3:  (255, 255, 0),
    4:  (255, 0, 255),
    5:  (0, 128, 128),
    6:  (0, 128, 0),
    7:  (128, 0, 0),
    8:  (128, 0, 128),
    9:  (128, 128, 0),
    10: (0, 0, 128),
    11: (64, 64, 64),
    12: (192, 192, 192),
}


def interp_arc(points, t=1000):
    r'''
    Linearly interpolate equally-spaced points along a polyline, either in 2d or 3d.

    Parameters
    ----------
    points : List
        List of shape (N,2) or (N,3), representing 2d or 3d-coordinates.
    t : array_like
        Number of points that will be uniformly interpolated and returned.

    Returns
    -------
    array_like  
        Numpy array of shape (N,2) or (N,3)

    Notes
    -----
    Adapted from https://github.com/johnwlambert/argoverse2-api/blob/main/src/av2/geometry/interpolate.py#L120

    '''
    
    # filter consecutive points with same coordinate
    temp = []
    for point in points:
        point = point.tolist()
        if temp == [] or point != temp[-1]:
            temp.append(point)
    if len(temp) <= 1:
        return None
    points = np.array(temp, dtype=points.dtype)

    assert points.ndim == 2

    # the number of points on the curve itself
    n, _ = points.shape

    # equally spaced in arclength -- the number of points that will be uniformly interpolated
    eq_spaced_points = np.linspace(0, 1, t)

    # Compute the chordal arclength of each segment.
    # Compute differences between each x coord, to get the dx's
    # Do the same to get dy's. Then the hypotenuse length is computed as a norm.
    chordlen = np.linalg.norm(np.diff(points, axis=0), axis=1)  # type: ignore
    # Normalize the arclengths to a unit total
    chordlen = chordlen / np.sum(chordlen)
    # cumulative arclength

    cumarc = np.zeros(len(chordlen) + 1)
    cumarc[1:] = np.cumsum(chordlen)

    # which interval did each point fall in, in terms of eq_spaced_points? (bin index)
    tbins = np.digitize(eq_spaced_points, bins=cumarc).astype(int)  # type: ignore

    # #catch any problems at the ends
    tbins[np.where((tbins <= 0) | (eq_spaced_points <= 0))] = 1  # type: ignore
    tbins[np.where((tbins >= n) | (eq_spaced_points >= 1))] = n - 1

    s = np.divide((eq_spaced_points - cumarc[tbins - 1]), chordlen[tbins - 1])
    anchors = points[tbins - 1, :]
    # broadcast to scale each row of `points` by a different row of s
    offsets = (points[tbins, :] - points[tbins - 1, :]) * s.reshape(-1, 1)
    points_interp = anchors + offsets
#     print(f"points interp : {points_interp}")
    return points_interp



In [77]:
def _project(points, intrinsic, extrinsic):
    if points is None:
        return points
    
    points_in_cam_cor = np.linalg.pinv(np.array(extrinsic['rotation'])) \
        @ (points.T - np.array(extrinsic['translation']).reshape(3, -1))
    
#     print("points cam 1 : ", points_in_cam_cor)
    
    points_in_cam_cor = points_in_cam_cor[:, points_in_cam_cor[2, :] > 0]
#     print("points cam : ", points_in_cam_cor)
    
    if points_in_cam_cor.shape[1] > 1:
        points_on_image_cor = np.array(intrinsic['K']) @ points_in_cam_cor
        points_on_image_cor = points_on_image_cor / (points_on_image_cor[-1, :].reshape(1, -1))
        points_on_image_cor = points_on_image_cor[:2, :].T
    else:
        points_on_image_cor = None
#     print(" points_on_image_cor : ", points_on_image_cor)
    return points_on_image_cor

### Custom Dataset

In [82]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os

class Lane3DDataset(Dataset):
    def __init__(self, image_dir, data_dict, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.annotations = self.load_annotations(data_dict)

    def load_annotations(self, data_dict):
        annotations = []
        try:
            for ann in data_dict:
    #             print(data_dict[ann].keys())
                frame = data_dict[ann]['sensor']['ring_front_center']
                image_name = frame['image_path']
                intrinsics = frame['intrinsic']
                extrinsics = frame['extrinsic']
                lane_points_3d = data_dict[ann]['annotation']['lane_segment'][0]['centerline']
                lane_points_2d = _project(interp_arc(np.array(lane_points_3d)), frame['intrinsic'],frame['extrinsic'])
                if type(center_points) == np.ndarray:
                    annotations.append((image_name, intrinsics, extrinsics, lane_points_2d))
            print(f"len of annotations : {len(annotations)}")
            return annotations
        except:
            print(data_dict[ann].keys())

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        image_name, intrinsics, extrinsics, lane_points_2d = self.annotations[idx]
        img_path = os.path.join(self.image_dir, image_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        lane_points_2d = torch.tensor(lane_points_2d, dtype=torch.float32)

        return image, intrinsics, extrinsics, lane_points_2d


In [83]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = Lane3DDataset(image_dir='../../OpenLane-V2/data/OpenLane-V2/mapless_driving/', data_dict=data_dict_subset_A, transform=transform)
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)


dict_keys(['version', 'segment_id', 'meta_data', 'timestamp', 'sensor', 'pose'])


In [70]:
dataset.annotations

[]

### 3D Lane Detection Model

In [46]:
class ResNet3DLaneDetection(nn.Module):
    def __init__(self, num_points):
        super(ResNet3DLaneDetection, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        
        # Remove the fully connected layer
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
        
        # Add a global average pooling layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Add a fully connected layer to predict 3D lane points
        self.fc = nn.Linear(2048, num_points * 3)  # num_points * 3 for (x, y, z) coordinates

    def forward(self, x):
        x = self.resnet(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x.view(x.size(0), -1, 3)  # Reshape to (batch_size, num_points, 3)

# Number of points to detect per lane
num_points = 20
model = ResNet3DLaneDetection(num_points=num_points)


NameError: name 'nn' is not defined

In [None]:
criterion = nn.MSELoss()  # Suitable for regression
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


### Sample Analysis

In [41]:
# for a in annotations:
# #     print(a)
#     frame = annotations[a]['sensor']['ring_front_center']
#     lane_ann = annotations[a]['annotation']['lane_segment'][0]['centerline']
#     center_points = _project(interp_arc(np.array(lane_ann)), frame['intrinsic'],frame['extrinsic'])
#     print(center_points)
#     print(type(center_points))
#     try:
#         if type(center_points) == np.ndarray:
#             print("Break from loop")
#             break
#     except:
#         print("No center line ")
#     print("-------------------------------------------------------------")

 points_on_image_cor :  None
None
<class 'NoneType'>
-------------------------------------------------------------
 points_on_image_cor :  None
None
<class 'NoneType'>
-------------------------------------------------------------
 points_on_image_cor :  None
None
<class 'NoneType'>
-------------------------------------------------------------
 points_on_image_cor :  None
None
<class 'NoneType'>
-------------------------------------------------------------
 points_on_image_cor :  None
None
<class 'NoneType'>
-------------------------------------------------------------
 points_on_image_cor :  None
None
<class 'NoneType'>
-------------------------------------------------------------
 points_on_image_cor :  None
None
<class 'NoneType'>
-------------------------------------------------------------
 points_on_image_cor :  None
None
<class 'NoneType'>
-------------------------------------------------------------
 points_on_image_cor :  None
None
<class 'NoneType'>
---------------------------

In [21]:
# annotations[('train', '00000', '315967376899927209')]

{'version': 'OpenLaneV2_V2.0',
 'segment_id': '00000',
 'meta_data': {'source': 'ArgoverseV2',
  'source_id': '00a6ffc1-6ce9-3bc3-a060-6006e9893a1a'},
 'timestamp': 315967376899927209,
 'sensor': {'ring_front_center': {'image_path': 'train/00000/image/ring_front_center/315967376899927209.jpg',
   'extrinsic': {'rotation': array([[-9.57540534e-04,  5.28159325e-03,  9.99985594e-01],
           [-9.99999076e-01, -9.69974438e-04, -9.52430359e-04],
           [ 9.64930115e-04, -9.99985582e-01,  5.28251716e-03]]),
    'translation': array([1.63685016, 0.00213953, 1.40590799])},
   'intrinsic': {'K': array([[1.77733411e+03, 0.00000000e+00, 7.78222656e+02],
           [0.00000000e+00, 1.77733411e+03, 1.01621350e+03],
           [0.00000000e+00, 0.00000000e+00, 1.00000000e+00]]),
    'distortion': array([-0.24358925, -0.20579766,  0.32205908])}},
  'ring_front_left': {'image_path': 'train/00000/image/ring_front_left/315967376887425437.jpg',
   'extrinsic': {'rotation': array([[ 0.70601219, -0.0

In [None]:
# center_points = _project(interp_arc(lane_segment['centerline']), frame.get_intrinsic(cam),frame.get_extrinsic(cam))

In [71]:
annotations = []
for ann in data_dict_subset_A:
#     print(ann)
    frame = data_dict_subset_A[ann]['sensor']['ring_front_center']
    image_name = frame['image_path']
    intrinsics = frame['intrinsic']
    extrinsics = frame['extrinsic']
    lane_points_3d = data_dict_subset_A[ann]['annotation']['lane_segment'][0]['centerline']
    lane_points_2d = _project(interp_arc(np.array(lane_points_3d)), frame['intrinsic'],frame['extrinsic'])
    if type(center_points) == np.ndarray:
        annotations.append((image_name, intrinsics, extrinsics, lane_points_2d))

('train', '00000', '315967376899927209')
 points_on_image_cor :  None
('train', '00000', '315967377349927211')
 points_on_image_cor :  None
('train', '00000', '315967377849927217')
 points_on_image_cor :  None
('train', '00000', '315967378349927223')
 points_on_image_cor :  None
('train', '00000', '315967378849927221')
 points_on_image_cor :  None
('train', '00000', '315967379349927219')
 points_on_image_cor :  None
('train', '00000', '315967379849927218')
 points_on_image_cor :  None
('train', '00000', '315967380349927220')
 points_on_image_cor :  None
('train', '00000', '315967380849927218')
 points_on_image_cor :  None
('train', '00000', '315967381349927208')
 points_on_image_cor :  None
('train', '00000', '315967381849927221')
 points_on_image_cor :  None
('train', '00000', '315967382349927219')
 points_on_image_cor :  None
('train', '00000', '315967382849927217')
 points_on_image_cor :  None
('train', '00000', '315967383349927223')
 points_on_image_cor :  None
('train', '00000', '

 points_on_image_cor :  [[1982.65517147 1100.22307692]
 [1981.870632   1100.24113435]
 [1981.08610863 1100.25919142]
 ...
 [1208.54939046 1124.98504496]
 [1207.78096616 1125.01127509]
 [1207.01255761 1125.03750468]]
('train', '00019', '315971467449927222')
 points_on_image_cor :  [[2163.43223624 1111.4730386 ]
 [2162.53933712 1111.49705183]
 [2161.64647587 1111.52106403]
 ...
 [1291.91983773 1142.38882262]
 [1291.06382908 1142.42098678]
 [1290.20785586 1142.45314961]]
('train', '00019', '315971467949927217')
 points_on_image_cor :  [[2349.3910888  1094.92796346]
 [2348.3757829  1094.9661651 ]
 [2347.3605319  1095.00436467]
 ...
 [1364.10234461 1140.65121044]
 [1363.1401176  1140.69789351]
 [1362.17794113 1140.74457412]]
('train', '00019', '315971468449927219')
 points_on_image_cor :  [[2581.05154102 1069.62690986]
 [2579.87836272 1069.68370414]
 [2578.70525937 1069.74049479]
 ...
 [1447.97475047 1134.35962666]
 [1446.87362119 1134.42486415]
 [1445.77256028 1134.49009758]]
('train', '00

 points_on_image_cor :  [[   -997.48617101    1417.90528188]
 [  -1003.59767155    1419.16131763]
 [  -1009.75161564    1420.42607637]
 [  -1015.94844697    1421.69964931]
 [  -1022.18861543    1422.9821289 ]
 [  -1028.47257724    1424.27360889]
 [  -1034.80079502    1425.57418437]
 [  -1041.17373796    1426.88395177]
 [  -1047.59188187    1428.20300888]
 [  -1054.05570937    1429.53145487]
 [  -1060.56570994    1430.86939036]
 [  -1067.12238012    1432.21691739]
 [  -1073.72622355    1433.57413946]
 [  -1080.37775117    1434.94116159]
 [  -1087.07748132    1436.3180903 ]
 [  -1093.82593987    1437.70503368]
 [  -1100.6236604     1439.10210138]
 [  -1107.47118426    1440.50940466]
 [  -1114.3690608     1441.92705643]
 [  -1121.31784747    1443.35517125]
 [  -1128.31810996    1444.79386538]
 [  -1135.3704224     1446.24325681]
 [  -1142.47536747    1447.70346531]
 [  -1149.63353658    1449.17461241]
 [  -1156.84553006    1450.6568215 ]
 [  -1164.11195727    1452.15021781]
 [  -1171.4334

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



 points_on_image_cor :  [[ 10029.86069499 501707.02430207]
 [  1246.14668434  23585.97034186]
 [  1043.7013718   12566.33034858]
 [   974.12980377   8779.36376564]
 [   938.93847838   6863.80577227]
 [   917.69209015   5707.30801737]
 [   903.47257311   4933.30166085]
 [   893.28857405   4378.95934556]
 [   885.63564679   3962.39003985]
 [   879.63591738   3637.85500054]
 [   874.46386326   3377.41215921]
 [   870.22823003   3164.1235264 ]
 [   866.69581478   2986.24598644]
 [   863.70491397   2835.63683094]
 [   861.13987237   2706.47214939]
 [   858.91578094   2594.47628171]
 [   856.96889652   2496.43938995]
 [   855.25043346   2409.90483496]
 [   853.72241948   2332.9604929 ]
 [   852.35485699   2264.09581095]
 [   851.12373399   2202.10164715]
 [   850.00960214   2145.9986642 ]
 [   848.9965417    2094.9852111 ]
 [   848.07139599   2048.39877358]
 [   847.2231969    2005.68704317]
 [   846.44272795   1966.38591697]
 [   845.72218805   1930.10256522]
 [   845.05492985   1896.502255

 points_on_image_cor :  None
('val', '10138', '315970341849927217')
 points_on_image_cor :  None
('val', '10138', '315970342349927217')
 points_on_image_cor :  None
('val', '10138', '315970342849927215')
 points_on_image_cor :  None
('val', '10139', '315972300149927213')
 points_on_image_cor :  [[81556.29056564 55884.8368884 ]
 [39627.24804744 27407.16744406]
 [26352.88902486 18391.3925628 ]
 [19839.94137562 13967.88181407]
 [15971.03738988 11340.17170669]
 [13407.76083421  9599.22705203]
 [11584.61986492  8360.973022  ]
 [10221.49018306  7435.15277526]
 [ 9163.77222187  6716.76426326]
 [ 8319.16409584  6143.11722582]
 [ 7629.15400416  5674.47119615]
 [ 7054.8542686   5284.41416053]
 [ 6569.41068083  4954.70708409]
 [ 6153.68215436  4672.3495857 ]
 [ 5793.65702481  4427.8251221 ]
 [ 5478.84270846  4214.00726462]
 [ 5201.22615771  4025.45365249]
 [ 4954.58235107  3857.93633208]
 [ 4734.00186696  3708.12088715]
 [ 4535.56078765  3573.34224293]
 [ 4356.08523325  3451.44474077]
 [ 4192.980

KeyError: 'annotation'

In [72]:
len(annotations)

8005

### Model Training

In [None]:
import torch.optim as optim

# Model
num_points = 20
model = ResNet3DLaneDetection(num_points=num_points)
model = model.cuda()

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, intrinsics, extrinsics, lane_points_2d in dataloader:
        images = images.cuda()
        lane_points_2d = lane_points_2d.cuda()

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, lane_points_2d)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader)}")

# Save the trained model
torch.save(model.state_dict(), 'lane_detection_model.pth')
