## Lane Shape Modeling
1) load the trained Instance Segmentation model
2) predict lane masks
3) fit mathematical equations (Polynomials/Splines) to the detected lanes.

lane shape modeling improvements:
- better filtering
- applying to ground truth  
- more metrics and comparisons
- maybe also we can add spline method to compare

## 1. Configuration

In [None]:
import os
import json
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.linear_model import RANSACRegressor, LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline


# Configuration
BASE_DIR = os.getcwd()
DATA_DIR = r'C:\Users\Alex\Documents\Clase\Italia\Segundo_ano\ADAS\Project\TU_Simple_folder\TUSimple'
TRAIN_SET_DIR = os.path.join(DATA_DIR, 'train_set')
PROCESSED_DATA_DIR = os.path.join(DATA_DIR, r'processed\instance')
CHECKPOINT_DIR = os.path.join(BASE_DIR, r'checkpoints\instance')
eval_filename = f'best_model_instance_e40_dropout_filter.pth'
EVAL_MODEL_PATH = os.path.join(CHECKPOINT_DIR, eval_filename)

NUM_CLASSES = 6 
IMG_HEIGHT = 288
IMG_WIDTH = 512
EPOCHS = 40 
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
model_filename = f'best_model_instance_e{EPOCHS}_dropout_filter.pth'
BEST_MODEL_PATH = os.path.join(CHECKPOINT_DIR, model_filename)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Model & Dataset Definitions
Same as lane detection

In [None]:
class TuSimpleDataset(Dataset):
    def __init__(self, root_dir, processed_dir, json_files, transform=None):
        self.root_dir = root_dir
        self.processed_dir = processed_dir
        self.transform = transform
        self.samples = []

        for json_file in json_files:
            json_path = os.path.join(root_dir, json_file)
            if not os.path.exists(json_path):
                continue
                
            with open(json_path, 'r') as f:
                lines = f.readlines()
            
            for line in lines:
                info = json.loads(line)
                raw_file = info['raw_file']
                mask_file = raw_file.replace('.jpg', '.png')
                img_path = os.path.join(self.root_dir, raw_file)
                mask_path = os.path.join(self.processed_dir, mask_file)
                self.samples.append((raw_file, mask_file))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_rel_path, mask_rel_path = self.samples[idx]
        img_path = os.path.join(self.root_dir, img_rel_path)
        mask_path = os.path.join(self.processed_dir, mask_rel_path)
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        image = cv2.resize(image, (IMG_WIDTH, IMG_HEIGHT))
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)

        image = image.astype(np.float32) / 255.0
        image = np.transpose(image, (2, 0, 1))
        image = torch.from_numpy(image).float()
        mask = torch.from_numpy(mask).long()
        return image, mask

In [None]:
# U-Net Architecture
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        out = self.outc(x)
        return out

## 3. Evaluation function

In [None]:
def evaluate_shapes(model_path, data_dir, processed_dir, json_files, polydegree, method='poly'):
    dataset = TuSimpleDataset(data_dir, processed_dir, json_files)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
    
    model = UNet(n_channels=3, n_classes=NUM_CLASSES).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    colors = [(0, 0, 0), (0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
    
    all_gt_mask_vs_fit = []
    all_pred_mask_vs_fit = []
    all_gt_fit_vs_pred_fit = []
    
    # Number of sample pictures to visualize
    num_visualize = 20
    
    print(f"Starting evaluation... (Visualizing first {num_visualize} samples)")
    with torch.no_grad():
        for i, (images, masks) in enumerate(tqdm(loader, desc="Evaluation")):
            images = images.to(device)
            output = model(images)
            probs = torch.softmax(output, dim=1)
            pred_mask = torch.argmax(probs, dim=1).cpu().numpy()[0]
            gt_mask_val = masks[0].cpu().numpy()
            
            img_np = (images[0].cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
            img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
            height, width = pred_mask.shape
            
            iou_thickness = 3
            tolerance = 2
            
            gt_polys = {}
            pred_polys = {}
            
            # Fitting Loop
            for cls_idx in range(1, NUM_CLASSES):
                # GT Fitting
                gt_lane_bool = (gt_mask_val == cls_idx)
                if np.any(gt_lane_bool):
                    y_gt, x_gt = np.where(gt_lane_bool)
                    if len(y_gt) > 20:
                        gt_fit = np.polyfit(y_gt, x_gt, polydegree)
                        gt_polys[cls_idx] = {'fit': gt_fit, 'y_min': np.min(y_gt), 'y_max': np.max(y_gt),
                                             'y_range': np.linspace(np.min(y_gt), height-1, num=(height-int(np.min(y_gt))))}
                
                # Pred Fitting
                pred_lane_bool = (pred_mask == cls_idx)
                if np.any(pred_lane_bool):
                    y_p, x_p = np.where(pred_lane_bool)
                    if len(y_p) > 20:
                        if method == 'ransac':
                            from sklearn.linear_model import RANSACRegressor
                            from sklearn.preprocessing import PolynomialFeatures
                            try:
                                poly = PolynomialFeatures(degree=polydegree, include_bias=False)
                                y_p_poly = poly.fit_transform(y_p.reshape(-1, 1))
                                ransac = RANSACRegressor(random_state=42)
                                ransac.fit(y_p_poly, x_p)
                                plot_y_range = np.linspace(np.min(y_p), height-1, num=(height-int(np.min(y_p))))
                                pred_polys[cls_idx] = {'fit': ransac, 'poly_transformer': poly, 'y_range': plot_y_range, 'type': 'ransac'}
                            except:
                                p_fit = np.polyfit(y_p, x_p, 2)
                                plot_y_range = np.linspace(np.min(y_p), height-1, num=(height-int(np.min(y_p))))
                                pred_polys[cls_idx] = {'fit': p_fit, 'y_range': plot_y_range, 'type': 'poly'}
                        else:
                            deg = 1 if (np.max(y_p) - np.min(y_p)) < 150 else polydegree
                            p_fit = np.polyfit(y_p, x_p, deg)
                            plot_y_range = np.linspace(np.min(y_p), height-1, num=(height-int(np.min(y_p))))
                            pred_polys[cls_idx] = {'fit': p_fit, 'y_range': plot_y_range, 'type': 'poly'}

            # Metrics Calculation
            sample_metrics = {"gt_fit": [], "pred_fit": [], "comparison": []}
            for cls_idx in range(1, NUM_CLASSES):
                # GT Mask vs GT Fit
                if cls_idx in gt_polys:
                    g = gt_polys[cls_idx]; gt_lane_mask = (gt_mask_val == cls_idx)
                    mask_gt_fit = np.zeros((height, width), dtype=np.uint8); mask_gt_fit_tol = np.zeros((height, width), dtype=np.uint8)
                    y_g = np.linspace(g['y_min'], g['y_max'], num=int(g['y_max'] - g['y_min'] + 1))
                    pts_g = np.array([np.transpose(np.vstack([np.polyval(g['fit'], y_g), y_g]))]).astype(np.int32)
                    cv2.polylines(mask_gt_fit, pts_g, False, 1, iou_thickness)
                    cv2.polylines(mask_gt_fit_tol, pts_g, False, 1, iou_thickness + 2*tolerance)
                    inter = np.sum(np.logical_and(gt_lane_mask > 0, mask_gt_fit_tol > 0))
                    uni = np.sum(np.logical_or(gt_lane_mask > 0, mask_gt_fit > 0))
                    iou = (inter / uni * 100) if uni > 0 else 0
                    all_gt_mask_vs_fit.append(iou); sample_metrics["gt_fit"].append((cls_idx, iou))
                
                # Pred Mask vs Pred Fit
                if cls_idx in pred_polys:
                    p = pred_polys[cls_idx]; pr_mask = (pred_mask == cls_idx)
                    mask_pr_fit = np.zeros((height, width), dtype=np.uint8); mask_pr_fit_tol = np.zeros((height, width), dtype=np.uint8)
                    px = p['fit'].predict(p['poly_transformer'].transform(p['y_range'].reshape(-1, 1))) if p['type'] == 'ransac' else np.polyval(p['fit'], p['y_range'])
                    pts_p = np.array([np.transpose(np.vstack([px, p['y_range']]))]).astype(np.int32)
                    cv2.polylines(mask_pr_fit, pts_p, False, 1, iou_thickness); cv2.polylines(mask_pr_fit_tol, pts_p, False, 1, iou_thickness + 2*tolerance)
                    inter = np.sum(np.logical_and(pr_mask > 0, mask_pr_fit_tol > 0))
                    uni = np.sum(np.logical_or(pr_mask > 0, mask_pr_fit > 0))
                    iou = (inter / uni * 100) if uni > 0 else 0
                    all_pred_mask_vs_fit.append(iou); sample_metrics["pred_fit"].append((cls_idx, iou))

                # GT Fit vs Pred Fit
                if cls_idx in gt_polys:
                    g = gt_polys[cls_idx]; m_g = np.zeros((height, width), dtype=np.uint8)
                    y_g = np.linspace(g['y_min'], g['y_max'], num=int(g['y_max'] - g['y_min'] + 1))
                    pts_g = np.array([np.transpose(np.vstack([np.polyval(g['fit'], y_g), y_g]))]).astype(np.int32)
                    cv2.polylines(m_g, pts_g, False, 1, iou_thickness)
                    if cls_idx in pred_polys:
                        p = pred_polys[cls_idx]; m_p = np.zeros((height, width), dtype=np.uint8); m_p_t = np.zeros((height, width), dtype=np.uint8)
                        px = p['fit'].predict(p['poly_transformer'].transform(p['y_range'].reshape(-1, 1))) if p['type'] == 'ransac' else np.polyval(p['fit'], p['y_range'])
                        pts_p = np.array([np.transpose(np.vstack([px, p['y_range']]))]).astype(np.int32)
                        cv2.polylines(m_p, pts_p, False, 1, iou_thickness); cv2.polylines(m_p_t, pts_p, False, 1, iou_thickness + 2*tolerance)
                        inter = np.sum(np.logical_and(m_g > 0, m_p_t > 0))
                        uni = np.sum(np.logical_or(m_g > 0, m_p > 0))
                        iou = (inter / uni * 100) if uni > 0 else 0
                        all_gt_fit_vs_pred_fit.append(iou); sample_metrics["comparison"].append((cls_idx, iou))
                    else:
                        all_gt_fit_vs_pred_fit.append(0.0); sample_metrics["comparison"].append((cls_idx, 0.0))

            # Plotting
            if i < num_visualize:
                print(f"\n--- Sample {i+1} Overlap Metrics ---")
                v_gt_m = img_np.copy(); v_pr_m = img_np.copy(); v_gt_f = img_np.copy(); v_pr_f = img_np.copy(); v_comp = img_np.copy()
                gt_eqs = []; pr_eqs = []
                for cls_idx in range(1, NUM_CLASSES):
                    color_rgb = colors[cls_idx]; color_bgr = (color_rgb[2], color_rgb[1], color_rgb[0])
                    v_gt_m[gt_mask_val == cls_idx] = color_bgr; v_pr_m[pred_mask == cls_idx] = color_bgr
                    if cls_idx in gt_polys:
                        g = gt_polys[cls_idx]; pxg = np.polyval(g['fit'], g['y_range']); ptsg = np.array([np.transpose(np.vstack([pxg, g['y_range']]))]).astype(np.int32)
                        cv2.polylines(v_gt_f, ptsg, False, (255, 255, 255), 4); cv2.polylines(v_comp, ptsg, False, (255, 255, 255), 4)
                        gt_eqs.append(f"GT L{cls_idx}: x={g['fit'][0]:.4f}y^2+{g['fit'][1]:.4f}y+{g['fit'][2]:.4f}")
                    if cls_idx in pred_polys:
                        p = res_p = pred_polys[cls_idx]
                        if p['type'] == 'ransac':
                            pxp = p['fit'].predict(p['poly_transformer'].transform(p['y_range'].reshape(-1, 1)))
                            c = p['fit'].estimator_.coef_; intercept = p['fit'].estimator_.intercept_
                            pr_eqs.append(f"L{cls_idx} (RSC): x={c[1]:.4f}y^2+{c[0]:.4f}y+{intercept:.4f}")
                        else:
                            pxp = np.polyval(p['fit'], p['y_range'])
                            if len(p['fit']) == 3: pr_eqs.append(f"L{cls_idx} (Q): x={p['fit'][0]:.4f}y^2+{p['fit'][1]:.4f}y+{p['fit'][2]:.4f}")
                            else: pr_eqs.append(f"L{cls_idx} (L): x={p['fit'][0]:.4f}y+{p['fit'][1]:.4f}")
                        ptsp = np.array([np.transpose(np.vstack([pxp, p['y_range']]))]).astype(np.int32)
                        cv2.polylines(v_pr_f, ptsp, False, color_bgr, 4); cv2.polylines(v_comp, ptsp, False, color_bgr, 4)

                # Console Printing
                comp_parts = []
                for cid, val in sample_metrics["comparison"]: 
                    comp_parts.append(f"L{cid} Pred v GT: {val:.1f}%")
                    print(f"Lane {cid} GT Fit vs Pred Fit IoU: {val:.2f}%")
                
                comp_title = "Fit Comparison (White=GT, Color=Pred)\n" + " | ".join(comp_parts)

                fig = plt.figure(figsize=(24, 18))
                plt.suptitle("Lane 1: Green | Lane 2: Red | Lane 3: Blue | Lane 4: Yellow | Lane 5: Magenta", fontsize=16)
                ax = [plt.subplot2grid((3, 2), (0, 0)), plt.subplot2grid((3, 2), (0, 1)), plt.subplot2grid((3, 2), (1, 0)), plt.subplot2grid((3, 2), (1, 1)), plt.subplot2grid((3, 2), (2, 0), colspan=2)]
                ax[0].set_title('Ground Truth Mask'); ax[0].imshow(cv2.cvtColor(v_gt_m, cv2.COLOR_BGR2RGB))
                ax[1].set_title('Instance Segmentation Prediction'); ax[1].imshow(cv2.cvtColor(v_pr_m, cv2.COLOR_BGR2RGB))
                ax[2].set_title('Poly Fit (Ground Truth)\n' + "\n".join(gt_eqs), fontsize=9); ax[2].imshow(cv2.cvtColor(v_gt_f, cv2.COLOR_BGR2RGB))
                ax[3].set_title('Poly Fit (Prediction)\n' + "\n".join(pr_eqs), fontsize=9); ax[3].imshow(cv2.cvtColor(v_pr_f, cv2.COLOR_BGR2RGB))
                ax[4].set_title(comp_title, fontsize=10); ax[4].imshow(cv2.cvtColor(v_comp, cv2.COLOR_BGR2RGB))
                for a in ax: a.axis('off')
                plt.tight_layout(rect=[0, 0.03, 1, 0.95]); plt.show()

    # Mean IoU Metrics
    print("\n" + "="*40)
    print("GLOBAL MEAN IoU METRICS")
    print(f"Mean GT Mask vs GT Fit IoU:   {np.mean(all_gt_mask_vs_fit):.2f}%")
    print(f"Mean Pred Mask vs Pred Fit IoU: {np.mean(all_pred_mask_vs_fit):.2f}%")
    print(f"Mean GT Fit vs Pred Fit IoU:    {np.mean(all_gt_fit_vs_pred_fit):.2f}%")
    print("="*40 + "\n")

In [None]:
polynomial_degree = 2
# Run
json_files = ['label_data_0313.json', 'label_data_0531.json', 'label_data_0601.json']
evaluate_shapes(EVAL_MODEL_PATH, TRAIN_SET_DIR, PROCESSED_DATA_DIR, json_files, polynomial_degree, method='ransac')
