In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import pickle
import numpy as np
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
import cv2 as cv
import os 
import sys
import pandas as pd
sys.path.append(os.path.dirname(os.path.abspath('.')))
from unet import UNet
from AttnUNet import AttU_Net
from ResUNet import Res_UNet
from PIL import Image


Custom Dataset Class

In [2]:
class Road_Lane_Segmentation(Dataset):
    ROAD_COLOR = np.array([128, 0, 0])
    
    def __init__(self, image_paths , label_paths):
        super().__init__()
        self.image_paths  = image_paths 
        self.label_paths  = label_paths 
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx])
        img = img.resize((512, 512), Image.BILINEAR)
        img = np.array(img)

        mask = Image.open(self.label_paths[idx])
        mask = mask.resize((512, 512), Image.NEAREST)   
        mask = np.array(mask)

        road_mask = (mask == 1).astype(np.float32)

        img = torch.tensor(img, dtype=torch.float).permute(2, 0, 1)       # (3, H, W)
        road_mask = torch.tensor(road_mask, dtype=torch.float).unsqueeze(0)  # (1, H, W)        


        return img, road_mask

In [3]:
train_csv = pd.read_csv("train_file_list.csv")
valid_csv = pd.read_csv("valid_file_list.csv")

root_path_train = r"data\train"
root_path_valid = r"data\valid"
train_images = [os.path.join(root_path_train, fname) for fname in train_csv["image"]]
train_masks = [os.path.join(root_path_train, fname) for fname in train_csv["mask"]]

valid_images = [os.path.join(root_path_valid, fname) for fname in valid_csv["image"]]
valid_masks = [os.path.join(root_path_valid, fname) for fname in valid_csv["mask"]]


train_dataset=Road_Lane_Segmentation(train_images, train_masks)
valid_dataset=Road_Lane_Segmentation(valid_images, valid_masks)

train_loader=DataLoader(train_dataset,batch_size=8,shuffle=True)
val_loader=DataLoader(valid_dataset,batch_size=8,shuffle=False)

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
x = next(iter(train_loader))
print(x[0].shape)
print(x[1].shape)
y = next(iter(val_loader))
print(y[0].shape)
print(y[1].shape)

torch.Size([8, 3, 512, 512])
torch.Size([8, 1, 512, 512])
torch.Size([8, 3, 512, 512])
torch.Size([8, 1, 512, 512])


Model

#### AttUnet Training

In [5]:
AttU_model = AttU_Net(3, 1).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(AttU_model.parameters(),lr=1e-4)

In [None]:
num_epochs=1000
for epoch in range(num_epochs):
    AttU_model.train()
    running_loss=0
    for batch_x,batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device),batch_y.to(device)

        optimizer.zero_grad()
        output=AttU_model(batch_x)
        loss=criterion(output,batch_y)
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
    print(f'Epoch {epoch+1}, loss: {running_loss/len(train_loader)}')

torch.save(AttU_model.state_dict(),'model.pth')

Epoch 1, loss: 0.05897468539248956


#### UNet training

In [7]:
UNet_model = UNet(3,1).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(UNet_model.parameters(),lr=1e-4)

In [None]:
num_epochs=1000
for epoch in range(num_epochs):
    UNet_model.train()
    running_loss=0
    for batch_x,batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device),batch_y.to(device)

        optimizer.zero_grad()
        output=UNet_model(batch_x)
        loss=criterion(output,batch_y)
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
    print(f'Epoch {epoch+1}, loss: {running_loss/len(train_loader)}')

torch.save(UNet_model.state_dict(),'UNet_model.pth')

Epoch 1, loss: 0.06073854167602564


#### ResNet training

In [9]:
ResUNet_model = Res_UNet(3,1).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(ResUNet_model.parameters(),lr=1e-4)

In [None]:
num_epochs=1000
for epoch in range(num_epochs):
    ResUNet_model.train()
    running_loss=0
    for batch_x,batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device),batch_y.to(device)

        optimizer.zero_grad()
        output=ResUNet_model(batch_x)
        loss=criterion(output,batch_y)
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
    print(f'Epoch {epoch+1}, loss: {running_loss/len(train_loader)}')

torch.save(ResUNet_model.state_dict(),'ResUNet_model.pth')

Epoch 1, loss: 0.08006764517018669


In [None]:
UNet_model.load_state_dict(torch.load('model_one.pth',map_location=device))

Test

In [None]:
import cv2 as cv
from moviepy import VideoFileClip
import numpy as np

In [None]:

class Lanes():
    def __nit__(self):
        self.recent_fit = []
        self.avg_fit = []

def road_lines(image,model,lanes):
    small_img = cv.imresize(image,(3,80,160))
    small_img = np.array(small_img)
    small_img = small_img[None,:,:,:]

    prediction = model.eval()
    lanes.recent_fit.append(prediction)

    if len(lanes.recent_fit) >5:
        lanes.recent_fit = lanes.recent_fit[1:]
    
    lanes.avg_fit = np.mean(np.array([i for i in lanes.recent_fit]),axis=0)

    blanks = np.zeros_like(lanes.avg_fit).astype(np.uint8)
    lane_drawn = np.dstack((blanks, lanes.avg_fit, blanks))

    lane_image = cv.imresize(lane_drawn, (720,1280,3))
    result = cv.addWeighted(image,1,lane_image,1,0)

    return result
# lanes= Lanes()
# vid_input = VideoFileClip(r'.mp4')
# vid_output = 'output_test.mp4'

# vid_clip = vid_input.fl_image(road_lines)
# vid_clip.write_videofile(vid_output)

In [None]:
import matplotlib.pyplot as plt

In [None]:
AttU_model.eval() #AttU_model, UNet_model, ResUNet_model
for batch_x,batch_y in val_loader:
    batch_x, batch_y = batch_x.to(device),batch_y.to(device)
    with torch.no_grad():
        output=AttU_model(batch_x)
    break

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8,4))

img = batch_x[100].detach().cpu()                 
img = img / 255.0 if img.dtype.is_floating_point else img
img = img.permute(1,2,0)  
img = img.numpy()                     

axes[0].imshow(img)
axes[0].set_title("Input")
axes[0].axis('off')

axes[1].imshow(output[100].cpu().squeeze(), cmap='gray') 
axes[1].set_title("Prediction")
axes[1].axis('off')

In [None]:
'''UNet_model.eval()
lanes=Lanes()
clip_input=VideoFileClip('input_clip.mp4')
vid_output='output_video.mp4'
def process_frame(frame):
    frame_bgr=cv.cvtColor(frame,cv.COLOR_RGB2BGR)
    output_bgr=road_lines(frame_bgr,UNet_model,lanes)
    return cv.cvtColor(output_bgr,cv.COLOR_BGR2RGB)
vid_clip=clip_input.fl_image(process_frame)
vid_clip.write_videofile(vid_output,audio=False)'''