In [1]:
import torch, os
import numpy as np
import time
import sys
sys.path.append('model')

from model.parsingnet import parsingNet
from model.dataloader import get_train_loader
from model.loss import SoftmaxFocalLoss, ParsingRelationLoss, ParsingRelationDis
from model.metrics import MultiLabelAcc, AccTopk, Metric_mIoU, update_metrics, reset_metrics

In [2]:
def inference(net, data_label, use_aux ,device):
    # 用于将参数送入网络并计算结果，并返回一个字典供计算损失
    if use_aux:
        img, cls_label, seg_label = data_label
        cls_label = cls_label.long()
        seg_label = seg_label.long()
        if device==torch.device('cuda:0'):
            img, cls_label, seg_label = img.cuda(), cls_label.cuda(), seg_label.cuda()
        cls_out, seg_out = net(img)
        return {'cls_out': cls_out, 'cls_label': cls_label, 'seg_out':seg_out, 'seg_label': seg_label}
    else:
        img, cls_label = data_label
        if device == torch.device('cuda:0'):
            img, cls_label = img.cuda(), cls_label.cuda()
        cls_out = net(img)
        return {'cls_out': cls_out, 'cls_label': cls_label}

In [3]:
def resolve_val_data(results, use_aux):
    # 取出结果的最大值位置，供计算损失用
    results['cls_out'] = torch.argmax(results['cls_out'], dim=1)
    if use_aux:
        results['seg_out'] = torch.argmax(results['seg_out'], dim=1)
    return results

In [4]:
def get_loss_dict(use_aux = True):
    # 生成一个用于计算损失的字典
    if use_aux:
        loss_dict = {
            'name': ['cls_loss', 'relation_loss', 'aux_loss', 'relation_dis'],
            'op': [SoftmaxFocalLoss(2), ParsingRelationLoss(), torch.nn.CrossEntropyLoss(), ParsingRelationDis()],
            'weight': [1.0, 1.0, 1.0, 0.0],
            'data_src': [('cls_out', 'cls_label'), ('cls_out',), ('seg_out', 'seg_label'), ('cls_out',)]
        }

    else:
        loss_dict = {
            'name': ['cls_loss', 'relation_loss', 'relation_dis'],
            'op': [SoftmaxFocalLoss(2), ParsingRelationLoss(), ParsingRelationDis()],
            'weight': [1.0, 1.0, 0.0],
            'data_src': [('cls_out', 'cls_label'), ('cls_out',), ('cls_out',)]
        }

    return loss_dict

In [5]:
def calc_loss(loss_dict, results):
    # 计算损失
    loss = 0
    for i in range(len(loss_dict['name'])):
        data_src = loss_dict['data_src'][i]
        datas = [results[src] for src in data_src]
        loss_cur = loss_dict['op'][i](*datas)
        loss += loss_cur * loss_dict['weight'][i]
    return loss


def get_metric_dict(use_aux = True):
    # 生成计算指标的字典
    if use_aux:
        metric_dict = {
            'name': ['top1', 'top2', 'top3', 'iou'],
            'op': [MultiLabelAcc(), AccTopk(100, 2), AccTopk(100, 3), Metric_mIoU(5)],
            'data_src': [('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('cls_out', 'cls_label'),
                         ('seg_out', 'seg_label')]
        }
    else:
        metric_dict = {
            'name': ['top1', 'top2', 'top3'],
            'op': [MultiLabelAcc(), AccTopk(100, 2), AccTopk(100, 3)],
            'data_src': [('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('cls_out', 'cls_label')]
        }

    return metric_dict

In [6]:
def train(epoches, net, loss_dict, metric_dict, train_loader, optimizer, use_aux,device,save_path):
    train_loader_size = len(train_loader)
    for epoch in range(epoches):
        print(f"Epoch {epoch + 1}")
        step = 0
        reset_metrics(metric_dict)
        for batch_idx, data_label in enumerate(train_loader):
            start = time.time()
            step += 1
            results = inference(net, data_label, use_aux, device)
            loss = calc_loss(loss_dict, results)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            end = time.time()
            batch_time = end - start
            results = resolve_val_data(results, use_aux)
            update_metrics(metric_dict, results)

            if step % 10 == 0:
                print(
                    "Epoch {ep} Step {st} |({batch}/{size})| ETA: {et:.2f}|MultiLabelAcc:{MultiLabelAcc:.5f}|AccTop2:{AccTop2:.5f}|AccTop3:{AccTop3:.5f}|Metric_mIoU:{Metric_mIoU:.5f}".format(
                        ep=epoch + 1,
                        st=step,
                        batch=batch_idx + 1,
                        size=train_loader_size,
                        et=batch_time,
                        MultiLabelAcc=metric_dict['op'][0].get(),
                        AccTop2=metric_dict['op'][1].get(),
                        AccTop3=metric_dict['op'][2].get(),
                        Metric_mIoU=metric_dict['op'][3].get(),
                    ))
        if (epoch + 1) % 5 == 0:
            print("should save model")
            torch.save(net.state_dict(), os.path.join(save_path, ('parsenet_'+str(epoch+1)+'.pth')))
    return metric_dict

In [7]:
torch.cuda.empty_cache()
epoches = 100
batch_size = 16
data_root = r'./tusimple_0531'
griding_num = 100
dataset = 'Tusimple'
use_aux = True
distributed = False
backbone = '50'
num_lanes = 4
learning_rate = 0.001
save_path = './save_pth'
if not os.path.isdir(save_path):
    os.mkdir(save_path)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
train_loader, cls_num_per_lane = get_train_loader(batch_size, data_root, griding_num, dataset, use_aux,distributed)
net = parsingNet(pretrained=True, backbone=backbone,cls_dim = (griding_num+1,cls_num_per_lane, num_lanes),use_aux=use_aux)
net = net.to(device)
optimizer = torch.optim.Adam(net.parameters(), learning_rate)
loss_dict = get_loss_dict(use_aux)
metric_dict = get_metric_dict(use_aux)

In [8]:
train(epoches, net, loss_dict, metric_dict, train_loader, optimizer, use_aux,device,save_path)

Epoch 1
Epoch 1 Step 10 |(10/23)| ETA: 1.99|MultiLabelAcc:0.44565|AccTop2:0.46099|AccTop3:0.47497|Metric_mIoU:0.17521
Epoch 1 Step 20 |(20/23)| ETA: 1.99|MultiLabelAcc:0.48126|AccTop2:0.49513|AccTop3:0.50878|Metric_mIoU:0.19384
Epoch 2
Epoch 2 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.51596|AccTop2:0.52896|AccTop3:0.54166|Metric_mIoU:0.23854
Epoch 2 Step 20 |(20/23)| ETA: 1.99|MultiLabelAcc:0.51621|AccTop2:0.53006|AccTop3:0.54375|Metric_mIoU:0.25549
Epoch 3
Epoch 3 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.52277|AccTop2:0.53979|AccTop3:0.55759|Metric_mIoU:0.34253
Epoch 3 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.51996|AccTop2:0.53962|AccTop3:0.55963|Metric_mIoU:0.36997
Epoch 4
Epoch 4 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.53030|AccTop2:0.55650|AccTop3:0.58491|Metric_mIoU:0.45907
Epoch 4 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.52383|AccTop2:0.55172|AccTop3:0.58015|Metric_mIoU:0.47376
Epoch 5
Epoch 5 Step 10 |(10/23)| ETA: 1.99|MultiLabelAcc:0.52871|AccTop2:0.56362|AccTop

Epoch 36 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.64983|AccTop2:0.81373|AccTop3:0.86526|Metric_mIoU:0.76378
Epoch 36 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.64674|AccTop2:0.80900|AccTop3:0.86334|Metric_mIoU:0.76001
Epoch 37
Epoch 37 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.64874|AccTop2:0.81217|AccTop3:0.86738|Metric_mIoU:0.77059
Epoch 37 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.65372|AccTop2:0.81896|AccTop3:0.87187|Metric_mIoU:0.76278
Epoch 38
Epoch 38 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.64699|AccTop2:0.81303|AccTop3:0.86691|Metric_mIoU:0.75918
Epoch 38 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.64886|AccTop2:0.81504|AccTop3:0.86840|Metric_mIoU:0.75672
Epoch 39
Epoch 39 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.64994|AccTop2:0.81607|AccTop3:0.86961|Metric_mIoU:0.74879
Epoch 39 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.65460|AccTop2:0.81939|AccTop3:0.87217|Metric_mIoU:0.75227
Epoch 40
Epoch 40 Step 10 |(10/23)| ETA: 1.99|MultiLabelAcc:0.65363|AccTop2:0.81878|A

Epoch 71
Epoch 71 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.69618|AccTop2:0.86858|AccTop3:0.91264|Metric_mIoU:0.83627
Epoch 71 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.69282|AccTop2:0.86484|AccTop3:0.90932|Metric_mIoU:0.82672
Epoch 72
Epoch 72 Step 10 |(10/23)| ETA: 1.99|MultiLabelAcc:0.69612|AccTop2:0.86724|AccTop3:0.90915|Metric_mIoU:0.83561
Epoch 72 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.69523|AccTop2:0.86825|AccTop3:0.91041|Metric_mIoU:0.82891
Epoch 73
Epoch 73 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.69576|AccTop2:0.87215|AccTop3:0.91459|Metric_mIoU:0.83078
Epoch 73 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.69618|AccTop2:0.87235|AccTop3:0.91381|Metric_mIoU:0.82656
Epoch 74
Epoch 74 Step 10 |(10/23)| ETA: 1.99|MultiLabelAcc:0.70519|AccTop2:0.87215|AccTop3:0.91297|Metric_mIoU:0.82913
Epoch 74 Step 20 |(20/23)| ETA: 2.00|MultiLabelAcc:0.70276|AccTop2:0.87193|AccTop3:0.91189|Metric_mIoU:0.83251
Epoch 75
Epoch 75 Step 10 |(10/23)| ETA: 2.00|MultiLabelAcc:0.70248|AccTop2:

{'data_src': [('cls_out', 'cls_label'),
  ('cls_out', 'cls_label'),
  ('cls_out', 'cls_label'),
  ('seg_out', 'seg_label')],
 'name': ['top1', 'top2', 'top3', 'iou'],
 'op': [<model.metrics.MultiLabelAcc at 0x7f5bc0285208>,
  <model.metrics.AccTopk at 0x7f5bc0285358>,
  <model.metrics.AccTopk at 0x7f5bc0285390>,
  <model.metrics.Metric_mIoU at 0x7f5bc02853c8>]}