In [9]:
import cv2 
import torch
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import yaml
from dataset import TrafficLightDataset, collate_fn
from model import trafficLightDetectionModel
from train import train_one_epoch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


cuda:0


In [13]:
img_dir = "/home/alexlin/traffic_net/dataset_train_rgb/"
labels_dir = "/home/alexlin/traffic_net/dataset_train_rgb/train.yaml"
classes = ['background', 'GreenLeft', 'RedStraightLeft', 'RedLeft', 'off', 'GreenStraight', 'GreenStraightRight',
             'GreenStraightLeft', 'RedStraight', 'GreenRight', 'Green', 'Yellow', 'RedRight', 'Red']

dataset = TrafficLightDataset(img_dir, labels_dir, classes)
train_data = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=16, 
    collate_fn=collate_fn
)
data_list = list(train_data)

        

In [23]:
print(data_list[2][0][0].shape)

torch.Size([3, 288, 512])


In [4]:
model = trafficLightDetectionModel(num_classes=len(classes)).to(device)


In [5]:
start_epoch = 0
epoch = 10
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 15, 20, 25, 30], gamma=0.5)
losses = []
for e in range(start_epoch+1, epoch):
    train_one_epoch(model, device, train_data, optimizer, e, losses)
    lr_scheduler.step()


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[Train]Epoch: [0][10/5093]	Loss_sum:  1.157192( 3.612841)	Cls:  0.086427( 1.502535)	Box:  0.006731( 0.001734)	Obj:  0.608177( 0.677314)	RPN  0.455856( 1.431258)
[Train]Epoch: [0][20/5093]	Loss_sum:  0.993073( 2.549130)	Cls:  0.214922( 0.841104)	Box:  0.015407( 0.005028)	Obj:  0.699233( 0.634882)	RPN  0.063512( 1.068116)
[Train]Epoch: [0][30/5093]	Loss_sum:  0.789345( 1.978990)	Cls:  0.100664( 0.585539)	Box:  0.031837( 0.006146)	Obj:  0.581751( 0.649352)	RPN  0.075093( 0.737953)
[Train]Epoch: [0][40/5093]	Loss_sum:  2.601005( 1.790591)	Cls:  0.003736( 0.458502)	Box:  0.000000( 0.007925)	Obj:  0.621233( 0.615667)	RPN  1.976035( 0.708497)
[Train]Epoch: [0][50/5093]	Loss_sum:  1.040406( 1.778031)	Cls:  0.002560( 0.367213)	Box:  0.000000( 0.006340)	Obj:  0.541960( 0.617752)	RPN  0.495886( 0.786726)
[Train]Epoch: [0][60/5093]	Loss_sum:  0.796837( 1.643425)	Cls:  0.002155( 0.306102)	Box:  0.000000( 0.005283)	Obj:  0.531878( 0.607899)	RPN  0.262804( 0.724141)
[Train]Epoch: [0][70/5093]	Loss_su

In [6]:
states = {
            'epoch': e,
            'weights': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict()
        }
torch.save(states,'checkpoints/last_checkpoint.pt')

In [19]:
checkpoints = torch.load('checkpoints/last_checkpoint.pt')
start_epoch = checkpoints['epoch']
model.load_state_dict(checkpoints['weights'])
optimizer.load_state_dict(checkpoints['optimizer'])
lr_scheduler.load_state_dict(checkpoints['lr_scheduler'])

In [35]:
model.eval()
x = data_list[6][0][0].to(device)
with torch.no_grad():
    predictions = model([x])
print(data_list[6][1][0])
print(predictions)

{'boxes': tensor([[244.6000, 140.7000, 246.5500, 144.4500],
        [253.2500, 136.1000, 255.9000, 140.1000],
        [262.0500, 140.6500, 264.2000, 144.2500]]), 'labels': tensor([ 3, 13, 13])}
[{'boxes': tensor([[261.3960, 140.5317, 264.2551, 144.8651],
        [253.6525, 136.7338, 256.2070, 141.7906],
        [253.3901, 136.5048, 255.5030, 140.5381],
        [253.9260, 136.1562, 256.0099, 139.8969],
        [261.5974, 139.7525, 264.0901, 144.0281],
        [253.6632, 136.0107, 256.0873, 140.7885],
        [261.4073, 140.7559, 264.2865, 145.1864],
        [254.2957, 136.1245, 256.9283, 141.3715],
        [253.6626, 137.0677, 256.1863, 142.1774]], device='cuda:0'), 'labels': tensor([13, 13, 13, 13, 11, 11,  3, 13,  3], device='cuda:0'), 'scores': tensor([0.6349, 0.5779, 0.3207, 0.2503, 0.1234, 0.1016, 0.0995, 0.0939, 0.0862],
       device='cuda:0')}]
