In [1]:
import torch
# from torch.utils.data import DataLoader

import sys, os
rootdir = '../..'
sys.path.append(rootdir)

# import config_dataset

# from get_parameters import bdd_parameters, kitti_parameters, net_config, print_parameters, get_device
# from modules.dataset_utils.bdd_dataset_utils.remapped_bdd_utils import load_ground_truths
# from modules.dataset_utils.bdd_dataset_and_dataloader import BerkeleyDeepDriveDataset
# from modules.plot.viz_annotation import vizualize_bbox_resized, draw_bbox_on_img_data
# from modules.dataset_utils.bdd_dataset_and_dataloader import inverse_norm
# from modules.proposal.box_association import greedy_association, identify_prominent_objects
# from modules.second_stage.generate_gt import gen_training_gt

from get_parameters import print_parameters, get_device, reset_seed
from modules.second_stage.get_param import net_config_stage2, bdd_parameters_stage2, kitti_parameters_stage2
from get_datasets import BDD_dataset, KITTI_dataset, DATSET_Selector

from modules.neural_net.backbone.backbone_v2 import net_backbone
from modules.neural_net.bifpn.bifpn_nblks_v2 import BiFPN
from modules.neural_net.head.shared_head_v5 import SharedNet
from modules.neural_net.detector.detector_v1 import FCOS
from modules.neural_net.fcos import FCOS_train
from modules.loss.fcos_loss import FCOS_Loss

from modules.second_stage.proposal_extraction import proposal_extractor
from modules.second_stage.roi_embedding import query_embedding, featmap_embedding
from modules.second_stage.attention import attention_network
from modules.second_stage.detector import second_stage_detector, second_stage_detector_train
from modules.second_stage.second_stage_loss import second_stage_loss

In [2]:
reset_seed(0)
BATCH_SIZE = 4
DEVICE = get_device()

net_config_obj = net_config_stage2()
bdd_param_obj = bdd_parameters_stage2()
kitti_param_obj = kitti_parameters_stage2()
print_parameters(net_config_obj, bdd_param_obj, kitti_param_obj, DEVICE)

bdd_dataloader = BDD_dataset(
    batch_size = BATCH_SIZE,
    num_samples_val = 500, 
    bdd_param_obj = bdd_param_obj,
    device = DEVICE,
    shuffle_dataset = False,
    perform_augmentation_train = True,
    augmentation_prob_train = 0.999)

kitti_dataloader = KITTI_dataset(
    batch_size = BATCH_SIZE,
    num_samples_val = 500, 
    kitti_param_obj = kitti_param_obj,
    device = DEVICE,
    shuffle_dataset = True,
    perform_augmentation_train = True,
    augmentation_prob_train = 0.999)

dataloader_selector = DATSET_Selector(
    bdd_dataset_obj = bdd_dataloader,
    kitti_dataset_obj = kitti_dataloader,
    max_training_iter = 1000,
    bdd_dataset_weight = 0.8)

GPU is available. Good to go!
printing model config parameters
----------------------------------------------------------------------------------------------------
backbone                        : efficientnet_b4
num_backbone_nodes              : 4
num_extra_blocks                : 1
num_levels                      : 5
extra_blocks_feat_dim           : 512
num_fpn_blocks                  : 2
fpn_feat_dim                    : 128
prediction head stem_channels   : [128, 128, 128, 128]
activation                      : swish
image dimension BDD (H, W, D)   : (360, 640, 3)
image dimension KITTI (H, W, D) : (263, 873, 3)
num_classes                     : 2
DEVICE                          : cuda
****************************************************************************************************
 
Load JSON file .. please wait
annotations from 69863/69863 aggregated : Aggregation COMPLETE
Load JSON file .. please wait
annotations from 10000/10000 aggregated : Aggregation COMPLETE
Loading JSO

In [3]:
# ================================================> SAVED MODEL WEIGHTS <========================================================
weights_path = 'model_weights/1705990924432/anchor_free_detector.pt'
# weights_path = os.path.join(rootdir, weights_path)

# ===============================================> INIT NETWORK STRUCTURE <======================================================
backbone = net_backbone(net_config_obj)
bifpn = BiFPN(net_config_obj, bdd_param_obj.feat_pyr_shapes)
shared_head = SharedNet(net_config_obj, bdd_param_obj.out_feat_shape) 
fcos = FCOS(backbone, bifpn, shared_head)

loss = FCOS_Loss(net_config_obj, DEVICE)
detector = FCOS_train(fcos, loss, bdd_param_obj, DEVICE)
detector.load_state_dict(torch.load(weights_path, map_location="cpu"))
detector = detector.to(DEVICE)

# ==========================================> LOAD FEATURE EXTRACTOR NTEWORK <==================================================
prop_extractor = proposal_extractor(
    backbone = detector.detector.backbone, 
    feataggregator = detector.detector.feataggregator, 
    sharednet = detector.detector.sharednet,
    netconfig_obj = net_config_obj,
    param_obj = bdd_param_obj,
    device = DEVICE)

feat_embedding_net = featmap_embedding(net_config_obj)
query_embedding_net = query_embedding(net_config_obj)
attention_net = attention_network(net_config_obj)

detector_second_stage = second_stage_detector(
    feat_embedding_net = feat_embedding_net, 
    query_embedding_net = query_embedding_net,
    attention_net = attention_net )

loss_second_stage = second_stage_loss(
    net_config = net_config_obj,
    device = DEVICE)

detector_second_stage_train = second_stage_detector_train(
    detector = detector_second_stage, 
    loss_obj = loss_second_stage,
    param_obj = bdd_param_obj,
    device = DEVICE)
detector_second_stage_train = detector_second_stage_train.to(DEVICE)

# ============================================> SET OPTIMIZATION PARAMETERS <==================================================
learning_rate = 8e-3
weight_decay = 1e-4
max_iters = 400000

params = [p for p in detector_second_stage_train.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, momentum=0.9, lr=learning_rate, weight_decay=weight_decay)

# in case we have to abruptly stop training and resume the training at a later time
init_start = 0 # ==> start from this iteration  
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 
    optimizer, 
    gamma=0.1,
    milestones=[int(0.65 * max_iters - init_start), 
                int(0.95 * max_iters - init_start)])

In [4]:
from modules.second_stage.generate_gt import gen_training_gt
from modules.second_stage.prop_functions import compute_bbox_from_offsets_normalized
from modules.dataset_utils.bdd_dataset_and_dataloader import inverse_norm
from modules.plot.viz_annotation import draw_bbox_on_img_data
from config_neuralnet_stage2 import CLS_LOSS_WT, BOX_LOSS_WT, OBJ_LOSS_WT

iter_start_offset = 0
num_samples = 1000
max_iters = iter_start_offset + num_samples
prop_extractor.eval()

for iter_train in range(iter_start_offset, max_iters):
    
    detector_second_stage_train.train()
    images, labels, param_obj = dataloader_selector.get_training_sample(iter_train)

    prop_extractor.reinit_const_parameters(param_obj)
    detector_second_stage_train.reinit_const_parameters(param_obj)

    img_path = labels['img_path']
    bboxes = labels['bbox_batch']
    clslabels = labels['obj_class_label']

    roi_features = prop_extractor(images)
    features = roi_features['features']
    queries = roi_features['queries']
    pred_boxes = roi_features['pred_boxes']
    pred_clsidx = roi_features['pred_clsidx']

    # predictions = detector_train.detector((features, queries))

    losses = detector_second_stage_train(
        x = (features, queries),
        gtbboxes = bboxes,
        gtclslabels = clslabels,
        proposals = pred_boxes)
    
    total_loss = \
        CLS_LOSS_WT * losses['loss_cls'] + \
        BOX_LOSS_WT * losses['loss_box'] + \
        OBJ_LOSS_WT * losses['loss_obj']
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    lr_scheduler.step() 
    
    for key, loss_item in losses.items():
        print(f"{key} ==> {loss_item}")
    print('-' * 100)


    # groundtruths = gen_training_gt(
    #     gt_boxes = bboxes, 
    #     gt_class = clslabels,
    #     pred_boxes = pred_boxes,
    #     deltas_mean = torch.tensor(param_obj.deltas_mean_stage2, dtype=torch.float32, device=DEVICE), 
    #     deltas_std = torch.tensor(param_obj.deltas_std_stage2, dtype=torch.float32, device=DEVICE), 
    #     iou_threshold = param_obj.iou_threshold_stage2,
    #     ignored_classId = param_obj.ignored_classId_stage2)
    
    # matched_gt_class = groundtruths.class_logits.clone()
    # matched_gt_deltas = groundtruths.boxreg_deltas
    # matched_gt_objness = groundtruths.objness_logits
    # # matched_gt_boxes = groundtruths.bbox

    # flag = [matched_gt_class >= 0]
    # matched_gt_boxes = compute_bbox_from_offsets_normalized(
    #     (torch.concat(pred_boxes, dim=0))[flag],
    #     matched_gt_deltas[flag],
    #     torch.tensor(param_obj.deltas_mean_stage2, dtype=torch.float32, device=DEVICE),
    #     torch.tensor(param_obj.deltas_std_stage2, dtype=torch.float32, device=DEVICE))
    
    # for b in range(BATCH_SIZE):
    #     # boxes = matched_gt_boxes[matched_gt_class >= 0].cpu().numpy()
    #     # boxes = matched_gt_boxes.cpu().numpy()
    #     boxes = pred_boxes[b].cpu().numpy()
    #     img_inv = inverse_norm(images[b])
    #     img_inv = (img_inv.permute(1,2,0).cpu().numpy() * 255).astype('uint8')
    #     draw_bbox_on_img_data(img_inv, boxes, figsize=(10,8))

    

loss_cls ==> 0.5009363293647766
loss_box ==> 0.32726627588272095
loss_obj ==> 1.8360141515731812
----------------------------------------------------------------------------------------------------
loss_cls ==> 0.3078814446926117
loss_box ==> 0.5104168653488159
loss_obj ==> 2.187551975250244
----------------------------------------------------------------------------------------------------
loss_cls ==> 0.636344313621521
loss_box ==> 0.2733536660671234
loss_obj ==> 2.192546844482422
----------------------------------------------------------------------------------------------------
loss_cls ==> 0.07204307615756989
loss_box ==> 0.2330779731273651
loss_obj ==> 1.4127815961837769
----------------------------------------------------------------------------------------------------
loss_cls ==> 0.755911111831665
loss_box ==> 0.2671700119972229
loss_obj ==> 1.842073678970337
----------------------------------------------------------------------------------------------------
loss_cls ==> 0.280