In [1]:
import os
import random

from collections import OrderedDict
import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torchvision.models as models


from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box, compute_ts_road_map
from hrnet import get_seg_model, get_config

In [7]:
def extract_boxes(gt_bboxes):
    """ Compute new box format given bbox
    """
    new_boxes = []
    for corners in gt_bboxes:
        point_squence = torch.stack([corners[:, 0], corners[:, 1], corners[:, 3], corners[:, 2], corners[:, 0]])
        x_ = point_squence.transpose(0,1)[0] * 10 + 400
        y_ = -point_squence.transpose(0,1)[1] * 10 + 400

        xmin = min(x_)
        xmax = max(x_)
        ymin = min(y_)
        ymax = max(y_)
        
        coors = [xmin, ymin, xmax, ymax]
        new_boxes.append(coors)
    return new_boxes



def generate_mask(gt_bboxes, h=800, w=800):
    """Compute masks given boxes
    """
    boxes = extract_boxes(gt_bboxes)
    masks = np.zeros([h, w, len(boxes)], dtype='uint8') # [800, 800, number of bbox]
    # create masks
    for i in range(len(boxes)):
        box = boxes[i]
        row_s, row_e = int(box[1]), int(box[3])
        col_s, col_e = int(box[0]), int(box[2])
        masks[row_s:row_e, col_s:col_e, i] = 1

    return masks


def extract_bboxes(mask):
    """Compute bounding boxes from masks.
    mask: [height, width, num_instances]. Mask pixels are either 1 or 0.
    Returns: bbox array [num_instances, (y1, x1, y2, x2)].
    """
    boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32)
    for i in range(mask.shape[-1]):
        m = mask[:, :, i]
        # Bounding box.
        horizontal_indicies = np.where(np.any(m, axis=0))[0]
        vertical_indicies = np.where(np.any(m, axis=1))[0]
        if horizontal_indicies.shape[0]:
            x1, x2 = horizontal_indicies[[0, -1]]
            y1, y2 = vertical_indicies[[0, -1]]
            # x2 and y2 should not be part of the box. Increment by 1.
            x2 += 1
            y2 += 1
        else:
            # No mask for this instance. Might happen due to
            # resizing or cropping. Set bbox to zeros
            x1, x2, y1, y2 = 0, 0, 0, 0
        boxes[i] = np.array([y1, x1, y2, x2])
    return boxes.astype(np.int32)



def get_coor(boxes):
    """ convert boxes to normal coordinates.
    boxes: [num_instances, (y1, x1, y2, x2)]
    
    returns: [['fl_x', 'fr_x', 'bl_x', 'br_x'], ['fl_y', 'fr_y','bl_y', 'br_y']] 
                (num_instances, 2, 4) format
    """
    coor_list = []
    for box in boxes:
        xmin = box[1]
        ymin = box[0]
        xmax = box[3]
        ymax = box[2]
        
        coor = [[xmax, xmax, xmin, xmin], [ymin, ymax, ymin, ymax]]
        coor_list.append(coor)
        
    return torch.as_tensor(coor_list)

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

cpu


In [3]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

In [4]:
image_folder = '../data'
annotation_csv = '../data/annotation.csv'

In [5]:
labeled_scene_index = np.arange(106, 134)

train_index = np.arange(106,108)
val_index = np.arange(128,130)

transform = torchvision.transforms.ToTensor()

labeled_trainset = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=train_index,
    transform=transform,
    extra_info=False
    )

labeled_valset = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=val_index,
    transform=transform,
    extra_info=False
    )

trainloader = torch.utils.data.DataLoader(labeled_trainset, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_fn)
valloader = torch.utils.data.DataLoader(labeled_valset, batch_size=1, shuffle=True, num_workers=2, collate_fn=collate_fn)

model = get_seg_model(get_config()).to(device)

In [6]:
criterion = torch.nn.BCELoss()
#param_list = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    [{'params': filter(lambda p: p.requires_grad, model.parameters()),
    'lr': 0.0001}],
    lr=0.0001,
    momentum=0.9,
    weight_decay=0.0001,
    nesterov=False,
    )
best_val_loss = 100

In [None]:
epochs = 10
for epoch in range(epochs):
    #### train logic ####
    model.train()
    train_losses = []

    for i, (sample, target, road_img) in enumerate(trainloader):


        sample = torch.stack(sample).to(device)
        batch_size = sample.shape[0]
        sample = sample.view(batch_size, -1, 256, 306) # size: ([3, 18, 256, 306])
        
        # generate mask for training
        
        
        
        optimizer.zero_grad()
        pred_map = model(sample)
        
        loss = criterion(pred_map, road_img)
        train_losses.append(loss.item())
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(sample), len(trainloader.dataset),
                100. * i / len(trainloader), loss.item()))
    print("\n Average Train Epoch Loss for epoch {} is {} ", epoch+1, np.mean(train_losses))