In [1]:
import cv2 
import torch
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import yaml
from dataset import TrafficLightDataset, collate_fn
from model import trafficLightDetectionModel
from train import train_one_epoch
from os.path import exists
from torchvision.ops import box_iou

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


cuda:0


In [2]:
img_dir = "/home/alexlin/traffic_net/dataset_train_rgb/"
labels_dir = "/home/alexlin/traffic_net/dataset_train_rgb/train.yaml"
classes = ['background', 'GreenLeft', 'RedStraightLeft', 'RedLeft', 'off', 'GreenStraight', 'GreenStraightRight',
             'GreenStraightLeft', 'RedStraight', 'GreenRight', 'Green', 'Yellow', 'RedRight', 'Red']

dataset = TrafficLightDataset(img_dir, labels_dir, classes)
train_data = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=16, 
    collate_fn=collate_fn
)
data_list = list(train_data)

        

In [3]:
print(data_list[2][1][0])

{'boxes': tensor([[244.2500, 140.6000, 246.3000, 143.4500],
        [253.5500, 136.9000, 255.2500, 140.4000],
        [259.8000, 140.1500, 262.0000, 144.3000]]), 'labels': tensor([11, 11, 11])}


In [4]:
model = trafficLightDetectionModel(num_classes=len(classes)).to(device)


In [5]:
start_epoch = 0
epoch = 8
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 15, 20, 25, 30], gamma=0.5)
losses = []
if exists('checkpoints/last_checkpoint.pt'):
    checkpoints = torch.load('checkpoints/last_checkpoint.pt')
    start_epoch = checkpoints['epoch']
    model.load_state_dict(checkpoints['weights'])
    optimizer.load_state_dict(checkpoints['optimizer'])
    lr_scheduler.load_state_dict(checkpoints['lr_scheduler'])

for e in range(start_epoch+1, epoch):
    train_one_epoch(model, device, train_data, optimizer, e, losses)
    lr_scheduler.step()


In [6]:
states = {
            'epoch': e,
            'weights': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'losses' : losses
        }
torch.save(states,'checkpoints/last_checkpoint.pt')

NameError: name 'e' is not defined

In [6]:
checkpoints = torch.load('checkpoints/last_checkpoint.pt')
start_epoch = checkpoints['epoch']
model.load_state_dict(checkpoints['weights'])
optimizer.load_state_dict(checkpoints['optimizer'])
lr_scheduler.load_state_dict(checkpoints['lr_scheduler'])

In [21]:
from predict import Predictor
p1 = Predictor(model, device)
img = p1.read_img('/home/alexlin/traffic_net/dataset_train_rgb/rgb/train/2015-10-05-10-52-01_bag/27860.png')
x = p1.process_img(img).to(device)
predictions = p1.predict(x)
p1.draw_image(img, predictions, classes)
while True:
    cv2.imshow('test', img)
    if cv2.waitKey(1) == 27: 
        break

cv2.destroyAllWindows()

KeyboardInterrupt: 

In [89]:

model.eval()
x = data_list[20][0][0].to(device)
with torch.no_grad():
    predictions = model([x])
predictions = predictions = {k: v.to(device)
                           for k, v in predictions[0].items()}
y= data_list[20][1]
print(y)
print(predictions)
labels = [{k: v.to(device) for k, v in t.items()}
                  for t in y]
print(accurracy(labels[0], predictions))

({'boxes': tensor([[196.6000, 117.0000, 208.6500, 141.2500],
        [253.1000, 114.9000, 258.5500, 126.0000],
        [320.0500,  81.9500, 326.7000,  95.3500],
        [374.9000, 112.3000, 381.4000, 125.7000],
        [415.6500,  99.4000, 422.4500, 114.2000]]), 'labels': tensor([ 3,  3, 13, 13, 13])},)
{'boxes': tensor([[196.9693, 118.3943, 207.8124, 141.6452],
        [375.1334, 113.3391, 381.3925, 126.5752],
        [320.2568,  82.7693, 326.9871,  98.9937],
        [416.8027,  99.9976, 421.9222, 112.7409],
        [253.4191, 115.2037, 258.4132, 126.6794],
        [109.5234, 136.0441, 112.8075, 144.1707],
        [252.7802, 114.9851, 259.0619, 127.1016],
        [196.7881, 117.6098, 209.2438, 144.4737],
        [378.1739, 128.4862, 381.4451, 135.7050],
        [378.1874, 128.9198, 381.4590, 135.5824],
        [109.5551, 136.3233, 112.8839, 144.1271],
        [251.9949, 114.0122, 259.7973, 129.1276],
        [253.8330, 115.7157, 257.6562, 125.6403],
        [249.2918, 113.8223, 260.94

In [81]:
def accurracy(gt, pred):
    prob_threshold = .5
    IOU_threshold = .5
    true_Positive_Counter = 0
    false_Positive_Counter = 0
    false_Negative_Counter = 0
    # pred = pred[0]
    # gt = gt[0]
    for i, box in enumerate(pred["boxes"]):
        if pred["scores"][i] > prob_threshold:
            pred_x_max = box[0]
            pred_x_min = box[1]
            pred_y_max = box[2]
            pred_y_min = box[3]
            for j, box_gt in enumerate(gt["boxes"]):
                gt_x_max = box_gt[0]
                gt_x_min = box_gt[1]
                gt_y_max = box_gt[2]
                gt_y_min = box_gt[3]

                xA = max(pred_x_min, gt_x_min)
                yA = max(pred_y_min, gt_y_min)
                xB = min(pred_x_max, gt_x_max)
                yB = min(pred_y_max, gt_y_max)

                interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)

                predArea =  (pred_x_max - pred_x_min + 1) * (pred_y_max - pred_y_min + 1)
                gtArea = (gt_x_max - gt_x_min + 1) * (gt_y_max - gt_y_min + 1)

                iou = interArea / (predArea + gtArea - interArea)


                if iou > IOU_threshold:
                    if pred["labels"][i] == gt["labels"][j]:
                        true_Positive_Counter += 1
                    else :
                        false_Positive_Counter += 1
    
    tag = False
    precision_score = true_Positive_Counter / (true_Positive_Counter + false_Positive_Counter)
    recall_score = true_Positive_Counter / (true_Positive_Counter + false_Negative_Counter)
    F1_score = precision_score * recall_score / (precision_score + recall_score)

    return precision_score, recall_score, F1_score

In [None]:
exists('checkpoints/last_checkpoint.pt')