In [5]:

## this file will be used to train road_map using HR Net and save .pt file

import os
import random

from collections import OrderedDict
import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torchvision.models as models


from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box, compute_ts_road_map
from hrnet import get_seg_model, get_config


In [2]:
image_folder = '../data'
annotation_csv = '../data/annotation.csv'

random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# The scenes from 106 - 133 are labeled
# You should devide the labeled_scene_index into two subsets (training and validation)
labeled_scene_index = np.arange(106, 134)

train_index = np.arange(106,110)
val_index = np.arange(128,130)

transform = torchvision.transforms.ToTensor()

labeled_trainset = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=train_index,
    transform=transform,
    extra_info=False
    )

labeled_valset = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=val_index,
    transform=transform,
    extra_info=False
    )

trainloader = torch.utils.data.DataLoader(labeled_trainset, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_fn)
valloader = torch.utils.data.DataLoader(labeled_valset, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_fn)

In [3]:
sample, target, road_image = iter(trainloader).next()
print(torch.stack(sample).shape)
print(road_image)

torch.Size([2, 6, 3, 256, 306])
(tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]]), tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True]]))


In [6]:
model = get_seg_model(get_config()).to(device)

In [7]:
criterion = torch.nn.BCELoss()
#param_list = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    [{'params': filter(lambda p: p.requires_grad, model.parameters()),
    'lr': 0.0001}],
    lr=0.0001,
    momentum=0.9,
    weight_decay=0.0001,
    nesterov=False,
    )
best_val_loss = 100

In [11]:
epochs = 2
for epoch in range(epochs):

    #### train logic ####
    model.train()
    train_losses = []

    for i, (sample, target, road_img) in enumerate(trainloader):


        sample = torch.stack(sample).to(device)
        batch_size = sample.shape[0]
        sample = sample.view(batch_size, -1, 256, 306) # size: ([3, 18, 256, 306])
        road_img = torch.stack(road_img).float().to(device)

        optimizer.zero_grad()
        pred_map = model(sample)
        
        out_map = (pred_map > 0.5).float()
        print('current TS: {}'.format(compute_ts_road_map(out_map, road_img)))

        loss = criterion(pred_map, road_img)
        train_losses.append(loss.item())
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(sample), len(trainloader.dataset),
                100. * i / len(trainloader), loss.item()))
    print("\n Average Train Epoch Loss for epoch {} is {} ", epoch+1, np.mean(train_losses))
    
    model.eval()
    val_losses = []
    threat_score = []
    for i, (sample, target, road_img) in enumerate(valloader):
        sample = torch.stack(sample).to(device)
        batch_size = sample.shape[0]
        sample = sample.view(batch_size, -1, 256, 306) # size: ([3, 18, 256, 306])
        road_img = torch.stack(road_img).float().to(device)

        with torch.no_grad():
            pred_map = model(sample)
            loss = criterion(pred_map, road_img)
            val_losses.append(loss.item())

            out_map = (pred_map > 0.5).float()
            threat_score.append(compute_ts_road_map(out_map, road_img))

        print("Validation Epoch: {}, Average Validation Epoch Loss: {}".format(epoch, np.mean(val_losses)))
        print("Average Threat Score: {} ".format(np.mean(threat_score)))

        if np.mean(val_losses) < best_val_loss:
            best_val_loss = np.mean(val_losses)
            torch.save(model.state_dict(), 'HRNET_RM_labeled_data01.pt')

current TS: 0.21445535123348236
current TS: 0.20742225646972656
current TS: 0.2798864543437958
current TS: 0.25675374269485474
current TS: 0.31889617443084717
current TS: 0.34865251183509827
current TS: 0.2812550365924835
current TS: 0.2801552414894104
current TS: 0.3454919457435608
current TS: 0.32903602719306946
current TS: 0.2590380311012268
current TS: 0.28550106287002563
current TS: 0.3973533511161804
current TS: 0.29465651512145996
current TS: 0.2846889793872833
current TS: 0.41824913024902344
current TS: 0.27187296748161316
current TS: 0.273301362991333


Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 