## Lane Shape Modeling
1) load the trained Instance Segmentation model
2) predict lane masks
3) fit mathematical equations (Polynomials/Splines) to the detected lanes.

lane shape modeling improvements:
- better filtering
- applying to ground truth  
- more metrics and comparisons
- maybe also we can add spline method to compare

In [None]:
import os
import json
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt


# Configuration
BASE_DIR = os.getcwd()
DATA_DIR = r'C:\ADAS_Project\TUSimple_Small'
TRAIN_SET_DIR = os.path.join(DATA_DIR, 'train_set')
PROCESSED_DATA_DIR = os.path.join(DATA_DIR, r'processed\instance')
CHECKPOINT_DIR = os.path.join(BASE_DIR, r'checkpoints\instance')
eval_filename = f'best_model_instance_e20.pth'
EVAL_MODEL_PATH = os.path.join(CHECKPOINT_DIR, eval_filename)
# Polynomial degree to fit
polynomial_degree = 2 

NUM_CLASSES = 6 
IMG_HEIGHT = 288
IMG_WIDTH = 512
EPOCHS = 40 
BATCH_SIZE = 8
LEARNING_RATE = 5*1e-4
model_filename = f'best_model_instance_e{EPOCHS}.pth'
BEST_MODEL_PATH = os.path.join(CHECKPOINT_DIR, model_filename)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Model & Dataset Definitions
Same as lane detection

In [None]:
class TuSimpleDataset(Dataset):
    def __init__(self, root_dir, processed_dir, json_files, transform=None):
        self.root_dir = root_dir
        self.processed_dir = processed_dir
        self.transform = transform
        self.samples = []

        for json_file in json_files:
            json_path = os.path.join(root_dir, json_file)
            if not os.path.exists(json_path):
                continue
                
            with open(json_path, 'r') as f:
                lines = f.readlines()
            
            for line in lines:
                info = json.loads(line)
                raw_file = info['raw_file']
                mask_file = raw_file.replace('.jpg', '.png')
                img_path = os.path.join(self.root_dir, raw_file)
                mask_path = os.path.join(self.processed_dir, mask_file)
                self.samples.append((raw_file, mask_file))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_rel_path, mask_rel_path = self.samples[idx]
        img_path = os.path.join(self.root_dir, img_rel_path)
        mask_path = os.path.join(self.processed_dir, mask_rel_path)
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        image = cv2.resize(image, (IMG_WIDTH, IMG_HEIGHT))
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)

        image = image.astype(np.float32) / 255.0
        image = np.transpose(image, (2, 0, 1))
        image = torch.from_numpy(image).float()
        mask = torch.from_numpy(mask).long()
        return image, mask

In [None]:
# U-Net Architecture
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        out = self.outc(x)
        return out

## 3. Execution & Visualization

In [None]:
def evaluate_shapes(model_path, data_dir, processed_dir, json_files, polydegree):
    # Number of random samples to visualize (this function is for qualitative inspection)
    num_visualize = 20
    dataset = TuSimpleDataset(data_dir, processed_dir, json_files)
    loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
    
    model = UNet(n_channels=3, n_classes=NUM_CLASSES).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))

    model.eval()
    
    # Colors for visualization (up to 5 lanes)
    colors = [
        (0, 0, 0),       # Background
        (0, 255, 0),     # Lane 1
        (255, 0, 0),     # Lane 2
        (0, 0, 255),     # Lane 3
        (255, 255, 0),   # Lane 4
        (255, 0, 255)    # Lane 5
    ]
    
    count = 0
    
    with torch.no_grad():
        for images, masks in tqdm(loader, desc="Fitting Shapes"):
            if count >= num_visualize: break
            
            images = images.to(device)
            
            output = model(images)
            probs = torch.softmax(output, dim=1)
            pred_mask = torch.argmax(probs, dim=1).cpu().numpy()[0]
            
            # Convert image tensor back to uint8 image for OpenCV drawing
            img_np = images[0].cpu().permute(1, 2, 0).numpy()
            img_np = (img_np * 255).astype(np.uint8)
            img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
            
            viz_img = img_np.copy()
            
            # Fitting logic
            height, width = pred_mask.shape
            plot_y = np.linspace(0, height-1, height)
            
            pred_polys = {}
            gt_polys = {}
            
             # Fit a curve independently for each lane class (1..5)
            for cls_idx in range(1, NUM_CLASSES):
                # Binary mask for current lane class
                lane_mask = (pred_mask == cls_idx).astype(np.uint8)
                
                # Blob filtering
                # Keep only the largest blob if it is "lane-like" (large enough and tall enough)
                lane_mask_uint8 = lane_mask.astype(np.uint8)
                num_labels, labels_im, stats, centroids = cv2.connectedComponentsWithStats(lane_mask_uint8, connectivity=8)
                valid_lane = False
                if num_labels > 1:

                    # Find component with maximum area 
                    largest_idx = np.argmax(stats[1:, cv2.CC_STAT_AREA]) + 1
                    max_area = stats[largest_idx, cv2.CC_STAT_AREA]
                    max_height = stats[largest_idx, cv2.CC_STAT_HEIGHT]
                    
                    # Thresholds to reject small or short components
                    if max_area >= 1000 and max_height >= 50:
                        lane_mask = (labels_im == largest_idx).astype(np.uint8)
                        valid_lane = True
                
                if valid_lane:
                    # Extract lane pixels as x, y coordinates
                    y_coords, x_coords = np.where(lane_mask == 1)
                    
                    # Fit Polynomial
                    lane_fit = np.polyfit(y_coords, x_coords, polydegree) 
                    print(f'Lane {cls_idx} Equation: x = {lane_fit[0]:.6f}*y^2 + {lane_fit[1]:.6f}*y + {lane_fit[2]:.6f}')

                    # Generate points for plotting
                    # Generate points for plotting (Constrained to detected range)
                    min_y = int(np.min(y_coords))
                    plot_y_fit = np.linspace(min_y, height-1, num=(height-min_y))
                    # Store for comparison
                    pred_polys[cls_idx] = {'fit': lane_fit, 'y_range': plot_y_fit}
                    plot_x = np.polyval(lane_fit, plot_y_fit)
                    
                    # Draw onto the visualization image
                    pts = np.array([np.transpose(np.vstack([plot_x, plot_y_fit]))])
                    pts = pts.astype(np.int32)
                    
                    color = colors[cls_idx if cls_idx < len(colors) else 1]
                    cv2.polylines(viz_img, pts, isClosed=False, color=(color[2], color[1], color[0]), thickness=5)
                    
            # Plot
            fig, ax = plt.subplots(1, 4, figsize=(40, 8))
            
            # 1. Ground Truth Overlay & Fitting
            gt_viz = img_np.copy()
            gt_mask_val = masks[0].cpu().numpy()
            
            print(f"\n--- Ground Truth Fits for Sample {count+1} ---")
            
            for cls_idx in range(1, NUM_CLASSES):
                # 1. Colorize the mask pixels
                lane_bool = (gt_mask_val == cls_idx)
                if np.any(lane_bool):
                    color_rgb = colors[cls_idx if cls_idx < len(colors) else 1]
                    gt_viz[lane_bool] = (color_rgb[2], color_rgb[1], color_rgb[0])
                    
                    # 2. Fit Polynomial to Ground Truth
                    y_coords_gt, x_coords_gt = np.where(lane_bool)
                    if len(y_coords_gt) > 20:
                        gt_fit = np.polyfit(y_coords_gt, x_coords_gt, polydegree)
                        print(f'GT Lane {cls_idx}: x = {gt_fit[0]:.6f}*y^2 + {gt_fit[1]:.6f}*y + {gt_fit[2]:.6f}')
                        
                        # Draw fit line (White)
                        min_y_gt = int(np.min(y_coords_gt))
                        plot_y_gt = np.linspace(min_y_gt, height-1, num=(height-min_y_gt))
                        # Store for comparison
                        gt_polys[cls_idx] = {'fit': gt_fit, 'y_range': plot_y_gt}
                        plot_x_gt = np.polyval(gt_fit, plot_y_gt)
                        
                        pts_gt = np.array([np.transpose(np.vstack([plot_x_gt, plot_y_gt]))])
                        pts_gt = pts_gt.astype(np.int32)
                        cv2.polylines(gt_viz, pts_gt, isClosed=False, color=(255, 255, 255), thickness=2)

            ax[0].set_title("Ground Truth (Mask + Poly Fits)")
            ax[0].imshow(cv2.cvtColor(gt_viz, cv2.COLOR_BGR2RGB))
            ax[0].axis('off')

            # Compare Fits & Calculate Overlap
            comp_viz = img_np.copy() # Background image
            print(f"--- Overlap Metrics (IoU with 5px Tolerance) ---")
            for cls_idx in range(1, NUM_CLASSES):
                has_pred = cls_idx in pred_polys
                has_gt = cls_idx in gt_polys
                
                # Visualization Drawing (Thin lines for visual check)
                if has_gt:
                    g = gt_polys[cls_idx]
                    px_g = np.polyval(g['fit'], g['y_range'])
                    pts_g = np.array([np.transpose(np.vstack([px_g, g['y_range']]))]).astype(np.int32)
                    cv2.polylines(comp_viz, pts_g, isClosed=False, color=(255, 255, 255), thickness=2)
                
                if has_pred:
                    p = pred_polys[cls_idx]
                    px_p = np.polyval(p['fit'], p['y_range'])
                    pts_p = np.array([np.transpose(np.vstack([px_p, p['y_range']]))]).astype(np.int32)
                    color_rgb = colors[cls_idx if cls_idx < len(colors) else 1]
                    cv2.polylines(comp_viz, pts_p, isClosed=False, color=(color_rgb[2], color_rgb[1], color_rgb[0]), thickness=2)

                # Metric Calculation
                if has_gt:
                    if has_pred:
                        # Calculate IoU with Tolerance
                        # Tolerance of 5px radius approx = line thickness of 10px
                        mask_g = np.zeros((height, width), dtype=np.uint8)
                        mask_p = np.zeros((height, width), dtype=np.uint8)
                        
                        # Ground Truth Mask
                        g = gt_polys[cls_idx]
                        px_g = np.polyval(g['fit'], g['y_range'])
                        pts_g = np.array([np.transpose(np.vstack([px_g, g['y_range']]))]).astype(np.int32)
                        cv2.polylines(mask_g, pts_g, isClosed=False, color=1, thickness=10)
                        
                        # Prediction Mask
                        p = pred_polys[cls_idx]
                        px_p = np.polyval(p['fit'], p['y_range'])
                        pts_p = np.array([np.transpose(np.vstack([px_p, p['y_range']]))]).astype(np.int32)
                        cv2.polylines(mask_p, pts_p, isClosed=False, color=1, thickness=10)
                        
                        intersection = np.sum((mask_g & mask_p))
                        union = np.sum((mask_g | mask_p))
                        iou = intersection / union if union > 0 else 0.0
                        print(f"Lane {cls_idx} Poly Fit Overlap (IoU): {iou*100:.2f}%")
                    else:
                        # Present in GT, Missing in Pred
                        print(f"Lane {cls_idx} Poly Fit Overlap (IoU): 0.00% (Missed Detection)")

            # 2. Instance Segmentation Prediction (Projected)
            pred_viz = img_np.copy()
            for cls_idx in range(1, NUM_CLASSES):
                # Use pred_mask which causes the drawing logic above
                lane_bool = (pred_mask == cls_idx)
                if np.any(lane_bool):
                    color_rgb = colors[cls_idx if cls_idx < len(colors) else 1]
                    pred_viz[lane_bool] = (color_rgb[2], color_rgb[1], color_rgb[0])
            
            ax[1].set_title("Instance Segmentation Prediction")
            ax[1].imshow(cv2.cvtColor(pred_viz, cv2.COLOR_BGR2RGB))
            ax[1].axis('off')

            # 3. Prediction + Polynomials
            ax[2].set_title("Lane Shape Modeling (Polynomial Fit)")
            ax[2].imshow(cv2.cvtColor(viz_img, cv2.COLOR_BGR2RGB))
            ax[2].axis('off')

            # 4. Comparison Overlay
            ax[3].set_title("Fit Comparison (White=GT, Color=Pred)")
            ax[3].imshow(cv2.cvtColor(comp_viz, cv2.COLOR_BGR2RGB))
            ax[3].axis('off')
            
            plt.show()
            
            count += 1


In [None]:

# Run
json_files = ['label_data_0313.json', 'label_data_0531.json', 'label_data_0601.json']
evaluate_shapes(EVAL_MODEL_PATH, TRAIN_SET_DIR, PROCESSED_DATA_DIR, json_files, polynomial_degree)