In [1]:
import json
import os
from glob import glob

import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms


class LaneDataset(Dataset):
    def __init__(self, root, split="train"):
        if split == "train":
            self.I_H = 800
            self.I_W = 1333
        else:
            self.I_H = 800
            self.I_W = 1333

        self.img_list = glob(os.path.join(root, split, "IMAGE/*.jpg"))
        self.label_path = [
            i.replace("IMAGE", "ANNOTATION").replace("jpg", "json")
            for i in self.img_list
        ]

        self.len = len(self.img_list)

    def __getitem__(self, index):
        while True:
            img_path = self.img_list[index]
            try:
                img = Image.open(img_path)
                break
            except:
                with open("error_files.txt", "a") as errlog:
                    errlog.write(str(index) + ": " + img_path + "\n")
                    index = index + 1

        w, h = img.size
        label_path = self.label_path[index]
        with open(label_path, "r") as f:
            json_data = json.load(f)
        img_tensor = transforms.functional.to_tensor(
            transforms.functional.resized_crop(
                img, h - w // 2, 0, w // 2, w, (self.I_H, self.I_W)
            )
        )
        target_map = self.make_gt_map(json_data, w, h)

        return img_tensor, torch.LongTensor(target_map), img_path

    def __len__(self):
        return self.len

    def make_gt_map(self, json_data, original_w, original_h):

        target_map = np.zeros((self.I_H, self.I_W), dtype=np.int32)
        annotation = json_data["data_set_info"]["data"]
        y_offset = original_h - original_w // 2

        for item in annotation:
            label = item["value"]["object_Label"]
            if "lane_type" in label.keys():
                obj_class = label["lane_type"]
                obj_lab_att = label["lane_attribute"]
            else:
                continue
            if obj_class[5:] == "white":
                pos = item["value"]["points"]
                poly_points = np.array(
                    [
                        (
                            [
                                pt["x"] * self.I_W / original_w,
                                (pt["y"] - y_offset)
                                * self.I_H
                                / (original_h - y_offset),
                            ]
                        )
                        for pt in pos
                    ]
                ).astype(np.int32)
                if obj_lab_att == "single_solid":
                    cv2.fillPoly(target_map, [poly_points], 1)
                elif obj_lab_att == "single_dashed":
                    cv2.fillPoly(target_map, [poly_points], 1)
                elif obj_lab_att == "double_solid":
                    cv2.fillPoly(target_map, [poly_points], 1)
                elif obj_lab_att == "left_dashed_double":
                    cv2.fillPoly(target_map, [poly_points], 1)
                elif obj_lab_att == "right_dashed_double":
                    cv2.fillPoly(target_map, [poly_points], 1)

            elif obj_class[5:] == "blue":
                pos = item["value"]["points"]
                poly_points = np.array(
                    [
                        (
                            [
                                pt["x"] * self.I_W / original_w,
                                (pt["y"] - y_offset)
                                * self.I_H
                                / (original_h - y_offset),
                            ]
                        )
                        for pt in pos
                    ]
                ).astype(np.int32)
                if obj_lab_att == "single_solid":
                    cv2.fillPoly(target_map, [poly_points], 2)
                elif obj_lab_att == "single_dashed":
                    cv2.fillPoly(target_map, [poly_points], 2)
                elif obj_lab_att == "double_solid":
                    cv2.fillPoly(target_map, [poly_points], 2)
                elif obj_lab_att == "left_dashed_double":
                    cv2.fillPoly(target_map, [poly_points], 2)
                elif obj_lab_att == "right_dashed_double":
                    cv2.fillPoly(target_map, [poly_points], 2)

            elif obj_class[5:] == "yellow":
                pos = item["value"]["points"]
                poly_points = np.array(
                    [
                        (
                            [
                                pt["x"] * self.I_W / original_w,
                                (pt["y"] - y_offset)
                                * self.I_H
                                / (original_h - y_offset),
                            ]
                        )
                        for pt in pos
                    ]
                ).astype(np.int32)
                if obj_lab_att == "single_solid":
                    cv2.fillPoly(target_map, [poly_points], 3)
                elif obj_lab_att == "single_dashed":
                    cv2.fillPoly(target_map, [poly_points], 3)
                elif obj_lab_att == "double_solid":
                    cv2.fillPoly(target_map, [poly_points], 3)
                elif obj_lab_att == "left_dashed_double":
                    cv2.fillPoly(target_map, [poly_points], 3)
                elif obj_lab_att == "right_dashed_double":
                    cv2.fillPoly(target_map, [poly_points], 3)

            elif obj_class[5:] == "shoulder":
                pos = item["value"]["points"]
                poly_points = np.array(
                    [
                        (
                            [
                                pt["x"] * self.I_W / original_w,
                                (pt["y"] - y_offset)
                                * self.I_H
                                / (original_h - y_offset),
                            ]
                        )
                        for pt in pos
                    ]
                ).astype(np.int32)
                if obj_lab_att == "single_solid":
                    cv2.fillPoly(target_map, [poly_points], 4)
                elif obj_lab_att == "single_dashed":
                    cv2.fillPoly(target_map, [poly_points], 4)
                elif obj_lab_att == "double_solid":
                    cv2.fillPoly(target_map, [poly_points], 4)
                elif obj_lab_att == "left_dashed_double":
                    cv2.fillPoly(target_map, [poly_points], 4)
                elif obj_lab_att == "right_dashed_double":
                    cv2.fillPoly(target_map, [poly_points], 4)

        return target_map

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# class FocalLoss(nn.Module):
#     def __init__(self, gamma=0, alpha=None, size_average=True):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         self.alpha = alpha
#         if isinstance(alpha, (float, int)):
#             self.alpha = torch.Tensor([alpha, 1 - alpha])
#         if isinstance(alpha, list):
#             self.alpha = torch.Tensor(alpha)
#         self.size_average = size_average

#     def forward(self, inp, target):
#         if inp.dim() > 2:
#             inp = inp.view(inp.size(0), inp.size(1), -1)  # N,C,H,W => N,C,H*W
#             inp = inp.transpose(1, 2)  # N,C,H*W => N,H*W,C
#             inp = inp.contiguous().view(-1, inp.size(2))  # N,H*W,C => N*H*W,C
#         target = target.view(-1, 1)

#         logpt = F.log_softmax(inp)
#         logpt = logpt.gather(1, target)
#         logpt = logpt.view(-1)
#         pt = Variable(logpt.data.exp())

#         if self.alpha is not None:
#             if self.alpha.type() != inp.data.type():
#                 self.alpha = self.alpha.type_as(inp.data)
#             at = self.alpha.gather(0, target.data.view(-1))
#             logpt = logpt * Variable(at)

#         loss = -1 * (1 - pt) ** self.gamma * logpt
#         if self.size_average:
#             return loss.mean()
#         else:
#             return loss.sum()


class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2, reduction="mean"):
        super(FocalLoss, self).__init__(weight, reduction=reduction)
        self.gamma = gamma
        # weight parameter will act as the alpha parameter to balance class weights
        self.weight = weight

    def forward(self, input, target):

        ce_loss = F.cross_entropy(
            input, target, reduction=self.reduction, weight=self.weight
        )
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

In [3]:
from torch import nn
from torchvision.models.segmentation import fcn_resnet50


class LaneSegModel(nn.Module):
    def __init__(self, num_classes=21):
        super(LaneSegModel, self).__init__()
        self.fcn = fcn_resnet50(pretrained=True)
        in_channels = 2048
        inter_channels = in_channels // 4
        channels = num_classes
        self.num_lanes = num_classes
        self.fcn.classifier = nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, channels, 1),  # yellow, white, blue, shoulder
        )
        self.f1 = 0
        self.f1cnt = 0

    def forward(self, x):
        out = self.fcn(x)
        return out

In [4]:
import os

import matplotlib.pyplot as plt
import torch


def test_step(
    batch, model, device, batch_idx, batch_size, show=False, file_output=False
):
    x, y, img_path = batch
    x = x.to(device)
    y = y.to(device)
    out = torch.sigmoid(model(x)["out"])
    confusion_mat = torch.zeros(
        (model.num_lanes, model.num_lanes), device=device, dtype=torch.long
    )
    f1_sum = 0
    f1_cnt = 0
    acc = torch.tensor(0.0, device=device)
    imshow = show
    if imshow:
        for i, output in enumerate(out):
            final_out = torch.argmax(output, 0)
            img = x[i].cpu().permute((1, 2, 0)).numpy()
            # img = img[:,:,::-1]
            plt.imsave("input.png", img)
            plt.imsave(
                "output.png", (final_out.cpu()).int(), vmin=0, vmax=model.num_lanes - 1
            )
            plt.imsave(
                "target.png", (y[i].cpu()).int(), vmin=0, vmax=model.num_lanes - 1
            )
            input()
    else:
        for i, output in enumerate(out):
            #             print(output.shape)
            final_out = torch.argmax(output, 0)
            #             print(final_out.shape)

            acc += torch.sum((final_out == y[i])) / (512 * 1024.0)

            for xx in torch.arange(model.num_lanes, device=device):
                for yy in torch.arange(model.num_lanes, device=device):
                    confusion_mat[xx, yy] += torch.sum((final_out == xx) * (y[i] == yy))

            aa, bb, cnt = 0, 0, 0
            for ii in range(model.num_lanes):
                if (
                    torch.sum(confusion_mat[ii, :]) != 0
                    and torch.sum(confusion_mat[:, ii]) != 0
                ):
                    aa += (
                        confusion_mat[ii, ii] / torch.sum(confusion_mat[ii, :]).float()
                    )
                    bb += (
                        confusion_mat[ii, ii] / torch.sum(confusion_mat[:, ii]).float()
                    )
                    cnt += 1
            aa /= cnt
            bb /= cnt
            # self.f1 += (2*aa*bb/(aa+bb))
            # self.f1cnt += 1
            f1 = (2 * aa * bb / (aa + bb)).item()
            f1_sum += f1
            f1_cnt += 1
            #             print(img_path[i], "F1 measure :", f1)

            #             file_output = False
            if file_output:
                if not os.path.exists("./outputs/"):
                    os.mkdir("./outputs/")
                img = x[i].cpu().permute((1, 2, 0)).numpy()
                folder_path = "./outputs/" + str(batch_idx * batch_size + i)
                #                 print(folder_path)
                if not os.path.exists(folder_path):
                    os.mkdir(folder_path)
                plt.imsave(folder_path + "/input.png", img)
                plt.imsave(
                    folder_path + "/output.png",
                    (final_out.cpu()).int(),
                    vmin=0,
                    vmax=model.num_lanes - 1,
                )
                plt.imsave(
                    folder_path + "/target.png",
                    (y[i].cpu()).int(),
                    vmin=0,
                    vmax=model.num_lanes - 1,
                )
        acc /= len(out)

        return confusion_mat.cpu().numpy(), f1_sum, f1_cnt, img_path


def test_epoch_end(outputs):
    sum_confusion_mat = 0
    total_f1 = 0
    total_f1_cnt = 0
    for confusion_mat, f1_sum, f1_cnt, _ in outputs:
        sum_confusion_mat += confusion_mat
        total_f1 += f1_sum
        total_f1_cnt += f1_cnt

    #     print("total_f1_cnt",total_f1_cnt)
    #     print("average F1 measure", total_f1/total_f1_cnt)
    #     print("total confusion matrix:\n", sum_confusion_mat.cpu().numpy())
    return total_f1 / total_f1_cnt, sum_confusion_mat