In [9]:
%load_ext autoreload
%autoreload 2
from multiprocessing import reduction
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from WsdnnPIXOR import WSDDNPIXOR
from dataset import KITTIBEV
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.ops import nms
from post_processing import calculate_ap
import wandb
import math
import sklearn
import sklearn.metrics
from visualize_dataset_new import plot_bev
from loss import FocalLoss
import torch.nn.functional as F

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def load_pretrained(model, filename='40epoch'):
    own_state = model.state_dict()
    state_dict = torch.load(filename)
    for name, param in state_dict.items():
        if name not in own_state:
                continue
        if isinstance(param, nn.Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        own_state[name].copy_(param)

In [12]:
def train(train_loader, 
          model, 
          loss_fn,
          optimizer, 
          test_loader,
          num_classes=2):
    loss_bce_total = 0.0
    loss_total = 0.0
    data_count = 0.0
    total_target = torch.zeros((0, num_classes)).cuda()
    total_preds = torch.zeros((0, num_classes)).cuda()
    for iter, data in tqdm(enumerate(train_loader),
                           total=len(train_loader),
                           leave=False):
        model = model.train()
        bev = data['bev'].cuda()
        labels = data['labels'].cuda()
        #gt_boxes = data['gt_boxes'].cuda()
        proposals = data['proposals'].squeeze().float().cuda()
        proposals = torch.cuda.FloatTensor(proposals)
        #gt_class_list = data['gt_class_list'].cuda()
        #with torch.cuda.amp.autocast():
        preds = model(bev, proposals)
        preds_class = preds.sum(dim=0).reshape(1, -1)
        preds_class_sigmoid = torch.sigmoid(preds_class)
        total_preds = torch.cat([total_preds, preds_class_sigmoid], dim=0)
        total_target = torch.cat([total_target, labels], dim=0)
        preds_class = torch.clamp(preds_class, 0, 1)
        loss = loss_fn(preds_class, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # scaler.scale(loss).backward()
        # scaler.step(optimizer)
        # scaler.update()
        loss_bce_total += F.binary_cross_entropy_with_logits(preds_class, labels, reduction='sum').item()
        loss_total += loss.item() * bev.shape[0]
        data_count += bev.shape[0]
        if iter%500 == 0 and iter != 0:
            map_class = map_classification(total_preds, total_target)
            #wandb.log({"Loss":loss_total / data_count})
            print("Focal Loss: ", loss_total / data_count, " BCE loss: ", loss_bce_total / data_count,  " mAP: ", map_class)
        # if iter%5000 == 0 and iter != 0:
        #     model.eval()
        #     validate(test_loader, model, loss_fn)
    return loss_total / data_count

In [4]:
def validate(test_loader, 
             model, 
             loss_fn, 
             score_threshold=0.005,
             nms_iou_threshold=0.5,
             iou_list = [0.05, 0.1, 0.2, 0.3, 0.4],
             inv_class=None,
             direct_class=None):
    np.random.seed(2)
    num_classes = 2
    loss_total = 0.0
    data_count = 0.0
    all_gt_boxes = torch.zeros((0, 6))
    all_pred_boxes = torch.zeros((0, 7))
    plotting_idxs = np.random.randint(0, 500, (50))

    with torch.no_grad():
        for iter, data in tqdm(enumerate(test_loader),
                            total=len(test_loader),
                            leave=False):
            plotting_proposals = torch.zeros((0, 5))
            plotting_gts = torch.zeros((0, 5))
            bev = data['bev'].cuda()
            labels = data['labels'].cuda()
            gt_boxes = data['gt_boxes'].reshape(-1, 4) #.cuda()
            proposals = data['proposals'].squeeze().float().cuda()
            gt_class_list = data['gt_class_list'].reshape(-1) #.cuda()

            cls_probs = model(bev, proposals)
            preds_class = cls_probs.sum(dim=0).reshape(1, -1)
            loss = loss_fn(preds_class, labels)
            loss_total += loss.item()
            data_count += bev.shape[0]

            for i in range(gt_boxes.shape[0]):
                modified_boxes = torch.cat([torch.tensor([iter, gt_class_list[i]]), gt_boxes[i]]).reshape(1, -1)
                all_gt_boxes = torch.cat([all_gt_boxes, modified_boxes], dim=0)
                plotting_gts = torch.cat([plotting_gts,
                                          modified_boxes[0, 1:].reshape(1, -1)], dim=0)

            for class_num in range(num_classes):
                curr_class_scores = cls_probs[:, class_num]
                valid_score_idx = torch.where(curr_class_scores >= score_threshold)
                valid_scores = curr_class_scores[valid_score_idx]
                valid_proposals = proposals[valid_score_idx]
                retained_idx = nms(valid_proposals, valid_scores, nms_iou_threshold)
                retained_scores = valid_scores[retained_idx]
                retained_proposals = valid_proposals[retained_idx]

                class_num_for_plotting = torch.ones((retained_proposals.shape[0], 1)) * class_num
                plotting_proposals = torch.cat([plotting_proposals,
                                                torch.cat([retained_proposals.detach().cpu(), 
                                                           class_num_for_plotting], dim=1)], dim=0)

                for i in range(retained_proposals.shape[0]):
                    modified_pred_boxes = torch.cat([torch.tensor([iter, class_num, retained_scores[i]]), 
                                                                retained_proposals[i].detach().cpu()]).reshape(1, -1)
                    all_pred_boxes = torch.cat([all_pred_boxes, modified_pred_boxes], dim=0)

            if iter in plotting_idxs:
                all_boxes = []
                all_gt_plotting_boxes = []
                raw_image = plot_bev(bev[0].detach().cpu())

                for idx in range(plotting_proposals.shape[0]):
                    box_data = {"position": {
                        "minX": plotting_proposals[idx, 1].item() / 350,
                        "minY": plotting_proposals[idx, 0].item() / 400,
                        "maxX": plotting_proposals[idx, 3].item() / 350,
                        "maxY": plotting_proposals[idx, 2].item() / 400},
                        "class_id": int(plotting_proposals[idx, 4].item()),
                        "box_caption": inv_class[int(plotting_proposals[idx][4])],
                        }
                    all_boxes.append(box_data)
                

                for idx in range(plotting_gts.shape[0]):
                    box_data_new = {"position": {
                        "minX": plotting_gts[idx, 2].item() / 350,
                        "minY": plotting_gts[idx, 1].item() / 400,
                        "maxX": plotting_gts[idx, 4].item() / 350,
                        "maxY": plotting_gts[idx, 3].item() / 400},
                        "class_id": int(plotting_gts[idx, 0].item()),
                        "box_caption": inv_class[int(plotting_gts[idx][0])],
                        }
                    all_gt_plotting_boxes.append(box_data_new)
                    
                box_image = wandb.Image(raw_image, 
                                        boxes={"predictions":
                                        {"box_data": all_boxes,
                                        "class_labels": inv_class},
                                             "ground_truth":
                                        {"box_data": all_gt_plotting_boxes,
                                        "class_labels": inv_class}
                                        })
                wandb.log({"Image proposals " + str(iter): box_image})
                box_image = wandb.Image(raw_image, 
                                        boxes= {"predictions":
                                        {"box_data": all_gt_plotting_boxes,
                                        "class_labels": inv_class}
                                        })
                wandb.log({"Image gt " + str(iter): box_image})
                
    for iou in iou_list:
        #print(all_gt_boxes.shape, all_gt_boxes.shape)
        AP = calculate_ap(all_pred_boxes, all_gt_boxes, iou, inv_class=inv_class, total_cls_num=num_classes)
        mAP = 0 if len(AP) == 0 else sum(AP) / len(AP)
        #return mAP.item(), AP
        wandb.log({"map@ " + str(iou): mAP})
        print("Iou ", iou, " mAP ", mAP)
    return mAP

In [5]:
def map_classification(output, target):
    target = target.detach().cpu().numpy()
    output = output.detach().cpu().numpy()
    num_classes = target.shape[1]
    ap = []
    for class_id in range(num_classes):
        output_req = output[:, class_id].astype('float32')
        target_req = target[:, class_id].astype('float32')
        output_req = output_req - 1e-5*target_req
        if np.sum(target_req) == 0:
            #ap.append(0)    
            continue
        curr_ap = sklearn.metrics.average_precision_score(target_req, output_req, average=None)
        if not math.isnan(curr_ap):
            ap.append(curr_ap)
    return sum(ap) / (len(ap) if len(ap) > 0 else 1)

In [7]:
valid_data_list_filename = "./valid_data_list_after_threshold.txt"
lidar_folder_name =  "/media/akshay/Data/KITTI/"
dataset = KITTIBEV(valid_data_list_filename=valid_data_list_filename, 
                        lidar_folder_name=lidar_folder_name)

Preloading Data


                                                   

336 74




In [13]:
if __name__ == '__main__':
    #wandb.init(project="WSDNNPIXOR")
    epochs = 10
    model = WSDDNPIXOR()
    load_pretrained(model)

    for params in model.backbone.parameters():
        params.requires_grad = False

    train_dataset_length = int(0.70 * len(dataset))
    train_dataset, test_dataset = random_split(dataset, [train_dataset_length,
                                                        len(dataset) - train_dataset_length],
                                                        generator=torch.Generator().manual_seed(10))
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    print(len(train_dataset), len(test_dataset))

    #scaler = torch.cuda.amp.GradScaler()
    #loss_fn = nn.BCEWithLogitsLoss(reduction='sum')
    loss_fn = FocalLoss(alpha=0.25, gamma=2)
    model = model.cuda()
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
    for i in range(epochs):
        # if i%1 == 0:
        #     model = model.eval()
        #     mAP = validate(test_loader, 
        #                   model, 
        #                   loss_fn, 
        #                   inv_class=dataset.inv_class, 
        #                   direct_class=dataset.class_to_int)
        model = model.train()
        loss = train(train_loader, model, loss_fn, optimizer, test_loader)
        print("Epoch average Loss: ", loss)
        torch.save(model.state_dict(), "model.pth")
        torch.save(optimizer.state_dict(), "opt.pth")
        if i%1 == 0:
            model = model.eval()
            mAP = validate(test_loader, 
                          model, 
                          loss_fn, 
                          inv_class=dataset.inv_class, 
                          direct_class=dataset.class_to_int)

        

4083 1751


  0%|          | 0/4083 [00:00<?, ?it/s]

003767 (3, 4) (3,) (186, 4)


  0%|          | 1/4083 [00:14<16:06:08, 14.20s/it]

004855 (4, 4) (4,) (270, 4)


  0%|          | 2/4083 [00:14<6:54:39,  6.10s/it] 

007131 (7, 4) (7,) (240, 4)


  0%|          | 3/4083 [00:15<3:58:26,  3.51s/it]

000339 (10, 4) (10,) (244, 4)


  0%|          | 4/4083 [00:15<2:35:09,  2.28s/it]

005226 (16, 4) (16,) (282, 4)


  0%|          | 5/4083 [00:15<1:49:09,  1.61s/it]

003113 (12, 4) (12,) (236, 4)


  0%|          | 6/4083 [00:16<1:21:32,  1.20s/it]

003802 (7, 4) (7,) (254, 4)


  0%|          | 7/4083 [00:16<1:01:59,  1.10it/s]

004812 (3, 4) (3,) (252, 4)


  0%|          | 8/4083 [00:16<49:58,  1.36it/s]  

005037 (4, 4) (4,) (248, 4)


  0%|          | 9/4083 [00:17<40:58,  1.66it/s]

003716 (4, 4) (4,) (240, 4)


  0%|          | 10/4083 [00:17<35:12,  1.93it/s]

006713 (2, 4) (2,) (150, 4)


  0%|          | 12/4083 [00:18<27:28,  2.47it/s]

007259 (5, 4) (5,) (232, 4)


  0%|          | 13/4083 [00:18<24:06,  2.81it/s]

003270 (3, 4) (3,) (224, 4)


  0%|          | 14/4083 [00:18<20:56,  3.24it/s]

007260 (8, 4) (8,) (256, 4)


  0%|          | 15/4083 [00:18<19:11,  3.53it/s]

006204 (2, 4) (2,) (198, 4)
006333 (5, 4) (5,) (228, 4)


  0%|          | 16/4083 [00:19<21:12,  3.20it/s]

002390 (1, 4) (1,) (190, 4)


  0%|          | 17/4083 [00:19<28:03,  2.42it/s]

003380 (9, 4) (9,) (246, 4)


  0%|          | 18/4083 [00:20<31:06,  2.18it/s]

007145 (4, 4) (4,) (244, 4)


  0%|          | 19/4083 [00:21<44:50,  1.51it/s]

006932 (14, 4) (14,) (272, 4)


  0%|          | 20/4083 [00:22<58:40,  1.15it/s]

001029 (1, 4) (1,) (252, 4)


  1%|          | 21/4083 [00:48<9:26:06,  8.36s/it]

004450 (4, 4) (4,) (216, 4)


  1%|          | 22/4083 [01:23<18:14:36, 16.17s/it]

006627 (7, 4) (7,) (224, 4)


  1%|          | 24/4083 [01:33<11:21:02, 10.07s/it]

006356 (6, 4) (6,) (252, 4)


  1%|          | 25/4083 [01:33<8:01:58,  7.13s/it] 

002647 (4, 4) (4,) (224, 4)


  1%|          | 26/4083 [01:33<5:41:04,  5.04s/it]

005158 (1, 4) (1,) (184, 4)
007176 (2, 4) (2,) (238, 4)


  1%|          | 28/4083 [01:34<2:54:50,  2.59s/it]

000761 (4, 4) (4,) (220, 4)


  1%|          | 29/4083 [01:34<2:06:41,  1.88s/it]

001500 (4, 4) (4,) (214, 4)


  1%|          | 30/4083 [01:34<1:32:25,  1.37s/it]

007283 (7, 4) (7,) (248, 4)
002587 (1, 4) (1,) (184, 4)


  1%|          | 32/4083 [01:34<52:20,  1.29it/s]  

000539 (1, 4) (1,) (194, 4)
006781 (8, 4) (8,) (266, 4)


  1%|          | 34/4083 [01:35<32:24,  2.08it/s]

005664 (5, 4) (5,) (236, 4)


  1%|          | 35/4083 [01:35<27:07,  2.49it/s]

006258 (1, 4) (1,) (252, 4)


  1%|          | 36/4083 [01:35<23:41,  2.85it/s]

004021 (1, 4) (1,) (252, 4)


  1%|          | 37/4083 [01:36<20:42,  3.26it/s]

006360 (3, 4) (3,) (244, 4)
001823 (2, 4) (2,) (228, 4)


  1%|          | 38/4083 [01:36<28:58,  2.33it/s]

001966 (7, 4) (7,) (248, 4)


  1%|          | 39/4083 [01:37<43:50,  1.54it/s]

001034 (3, 4) (3,) (202, 4)


  1%|          | 40/4083 [01:38<42:44,  1.58it/s]

003584 (1, 4) (1,) (228, 4)


  1%|          | 41/4083 [01:44<2:31:55,  2.26s/it]

006342 (10, 4) (10,) (248, 4)


  1%|          | 42/4083 [01:49<3:17:46,  2.94s/it]

006597 (6, 4) (6,) (246, 4)


  1%|          | 43/4083 [01:56<4:41:58,  4.19s/it]

000233 (1, 4) (1,) (232, 4)


  1%|          | 44/4083 [02:04<6:03:07,  5.39s/it]

006835 (6, 4) (6,) (248, 4)


  1%|          | 45/4083 [02:34<14:23:02, 12.82s/it]

003837 (2, 4) (2,) (232, 4)


  1%|          | 46/4083 [02:35<10:15:37,  9.15s/it]

004154 (8, 4) (8,) (248, 4)


  1%|          | 47/4083 [02:35<7:18:15,  6.52s/it] 

005591 (6, 4) (6,) (208, 4)


  1%|          | 48/4083 [02:35<5:13:20,  4.66s/it]

004973 (1, 4) (1,) (236, 4)


  1%|          | 49/4083 [02:36<3:48:22,  3.40s/it]

005764 (10, 4) (10,) (200, 4)


  1%|          | 50/4083 [02:36<2:45:44,  2.47s/it]

002187 (10, 4) (10,) (234, 4)


  1%|          | 51/4083 [02:36<2:02:54,  1.83s/it]

000262 (6, 4) (6,) (236, 4)


  1%|▏         | 52/4083 [02:37<1:33:42,  1.39s/it]

001686 (2, 4) (2,) (186, 4)


  1%|▏         | 53/4083 [02:37<1:12:53,  1.09s/it]

005775 (3, 4) (3,) (236, 4)


  1%|▏         | 54/4083 [02:38<58:25,  1.15it/s]  

003848 (6, 4) (6,) (220, 4)


  1%|▏         | 55/4083 [02:38<48:57,  1.37it/s]

000442 (7, 4) (7,) (228, 4)


  1%|▏         | 56/4083 [02:38<41:33,  1.61it/s]

003315 (1, 4) (1,) (248, 4)


  1%|▏         | 57/4083 [02:39<36:20,  1.85it/s]

005345 (5, 4) (5,) (240, 4)


  1%|▏         | 58/4083 [02:39<33:01,  2.03it/s]

007378 (3, 4) (3,) (220, 4)


  1%|▏         | 59/4083 [02:40<33:24,  2.01it/s]

006926 (14, 4) (14,) (248, 4)


  1%|▏         | 60/4083 [02:40<39:16,  1.71it/s]

002332 (7, 4) (7,) (176, 4)


  1%|▏         | 61/4083 [02:45<2:05:47,  1.88s/it]

004716 (6, 4) (6,) (244, 4)


  2%|▏         | 62/4083 [02:47<2:12:01,  1.97s/it]

004031 (7, 4) (7,) (228, 4)


  2%|▏         | 63/4083 [02:55<4:15:00,  3.81s/it]

000435 (5, 4) (5,) (268, 4)


  2%|▏         | 64/4083 [03:07<6:45:25,  6.05s/it]

Error: Canceled future for execute_request message before replies were done