In [1]:
%load_ext autoreload
%autoreload 2
from multiprocessing import reduction
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from WsdnnPIXOR import WSDDNPIXOR
from dataset import KITTIBEV
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.ops import nms
from post_processing import calculate_ap
import wandb
import math
import sklearn
import sklearn.metrics
from visualize_dataset_new import plot_bev

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_pretrained(model, filename='40epoch'):
    own_state = model.state_dict()
    state_dict = torch.load(filename)
    for name, param in state_dict.items():
        if name not in own_state:
                continue
        if isinstance(param, nn.Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        own_state[name].copy_(param)

In [3]:
def train(train_loader, 
          model, 
          loss_fn,
          optimizer, 
          test_loader):
    loss_total = 0.0
    data_count = 0.0
    total_target = torch.zeros((0, 8)).cuda()
    total_preds = torch.zeros((0, 8)).cuda()
    for iter, data in tqdm(enumerate(train_loader),
                           total=len(train_loader),
                           leave=False):
        model = model.train()
        bev = data['bev'].cuda()
        labels = data['labels'].cuda()
        #gt_boxes = data['gt_boxes'].cuda()
        proposals = data['proposals'].squeeze().float().cuda()
        proposals = torch.cuda.FloatTensor(proposals)
        #gt_class_list = data['gt_class_list'].cuda()
        #with torch.cuda.amp.autocast():
        preds = model(bev, proposals)
        preds_class = preds.sum(dim=0).reshape(1, -1)
        preds_class_sigmoid = torch.sigmoid(preds_class)
        total_preds = torch.cat([total_preds, preds_class_sigmoid], dim=0)
        total_target = torch.cat([total_target, labels], dim=0)
        preds_class = torch.clamp(preds_class, 0, 1)
        loss = loss_fn(preds_class, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # scaler.scale(loss).backward()
        # scaler.step(optimizer)
        # scaler.update()

        loss_total += loss.item() * bev.shape[0]
        data_count += bev.shape[0]
        if iter%500 == 0 and iter != 0:
            map_class = map_classification(total_preds, total_target)
            wandb.log({"Loss":loss_total / data_count})
            print("Loss: ", loss_total / data_count, " mAP: ", map_class)
        # if iter%5000 == 0 and iter != 0:
        #     model.eval()
        #     validate(test_loader, model, loss_fn)
    return loss_total / data_count

In [4]:
def validate(test_loader, 
             model, 
             loss_fn, 
             score_threshold=0.005,
             nms_iou_threshold=0.5,
             iou_list = [0.05, 0.1, 0.2, 0.3, 0.4],
             inv_class=None,
             direct_class=None):
    np.random.seed(2)
    num_classes = 8
    loss_total = 0.0
    data_count = 0.0
    all_gt_boxes = torch.zeros((0, 6))
    all_pred_boxes = torch.zeros((0, 7))
    plotting_idxs = np.random.randint(0, 10, (5))

    with torch.no_grad():
        for iter, data in tqdm(enumerate(test_loader),
                            total=len(test_loader),
                            leave=False):
            plotting_proposals = torch.zeros((0, 5))
            plotting_gts = torch.zeros((0, 5))
            bev = data['bev'].cuda()
            labels = data['labels'].cuda()
            gt_boxes = data['gt_boxes'].reshape(-1, 4) #.cuda()
            proposals = data['proposals'].squeeze().float().cuda()
            gt_class_list = data['gt_class_list'].reshape(-1) #.cuda()

            cls_probs = model(bev, proposals)
            preds_class = cls_probs.sum(dim=0).reshape(1, -1)
            loss = loss_fn(preds_class, labels)
            loss_total += loss.item()
            data_count += bev.shape[0]

            for i in range(gt_boxes.shape[0]):
                modified_boxes = torch.cat([torch.tensor([iter, gt_class_list[i]]), gt_boxes[i]]).reshape(1, -1)
                all_gt_boxes = torch.cat([all_gt_boxes, modified_boxes], dim=0)
                plotting_gts = torch.cat([plotting_gts,
                                          modified_boxes[1:].reshape(1, -1)], dim=0)

            for class_num in range(num_classes):
                curr_class_scores = cls_probs[:, class_num]
                valid_score_idx = torch.where(curr_class_scores >= score_threshold)
                valid_scores = curr_class_scores[valid_score_idx]
                valid_proposals = proposals[valid_score_idx]
                retained_idx = nms(valid_proposals, valid_scores, nms_iou_threshold)
                retained_scores = valid_scores[retained_idx]
                retained_proposals = valid_proposals[retained_idx]

                class_num_for_plotting = torch.ones((retained_proposals.shape[0], 1)) * class_num
                plotting_proposals = torch.cat([plotting_proposals,
                                                torch.cat([retained_proposals, 
                                                           class_num_for_plotting], dim=1)], dim=0)

                for i in range(retained_proposals.shape[0]):
                    modified_pred_boxes = torch.cat([torch.tensor([iter, class_num, retained_scores[i]]), 
                                                                retained_proposals[i].detach().cpu()]).reshape(1, -1)
                    all_pred_boxes = torch.cat([all_pred_boxes, modified_pred_boxes], dim=0)

            if iter in plotting_idxs:
                all_boxes = []
                gt_boxes = []
                raw_image = plot_bev(bev.detach().cpu())
                for idx in range(plotting_proposals.shape[0]):
                    box_data = {"position": {
                        "minX": plotting_proposals[idx][0],
                        "minY": plotting_proposals[idx][1],
                        "maxX": plotting_proposals[idx][2],
                        "maxY": plotting_proposals[idx][3]},
                        "class_id": plotting_proposals[idx][4],
                        "box_caption": inv_class[plotting_proposals[idx][4]],
                        }
                    all_boxes.append(box_data)
                

                for idx in range(plotting_gts.shape[0]):
                    box_data = {"position": {
                        "minX": plotting_gts[idx][0],
                        "minY": plotting_gts[idx][1],
                        "maxX": plotting_gts[idx][2],
                        "maxY": plotting_gts[idx][3]},
                        "class_id": plotting_gts[idx][4],
                        "box_caption": inv_class[plotting_gts[idx][4]],
                        }
                    gt_boxes.append(box_data)

                box_image = wandb.Image(raw_image, 
                                        boxes={"predictions":
                                        {"box_data": all_boxes,
                                        "class_labels": inv_class},
                                             "ground_truth":
                                        {"box_data": gt_boxes,
                                        "class_labels": inv_class}})
                
    for iou in iou_list:
        #print(all_gt_boxes.shape, all_gt_boxes.shape)
        AP = calculate_ap(all_pred_boxes, all_gt_boxes, iou, inv_class=inv_class)
        mAP = 0 if len(AP) == 0 else sum(AP) / len(AP)
        #return mAP.item(), AP
        wandb.log({"map@ " + str(iou): mAP})
        print("Iou ", iou, " mAP ", mAP)
    return mAP

In [5]:
def map_classification(output, target):
    target = target.detach().cpu().numpy()
    output = output.detach().cpu().numpy()
    num_classes = target.shape[1]
    ap = []
    for class_id in range(num_classes):
        output_req = output[:, class_id].astype('float32')
        target_req = target[:, class_id].astype('float32')
        output_req = output_req - 1e-5*target_req
        if np.sum(target_req) == 0:
            #ap.append(0)    
            continue
        curr_ap = sklearn.metrics.average_precision_score(target_req, output_req, average=None)
        if not math.isnan(curr_ap):
            ap.append(curr_ap)
    return sum(ap) / (len(ap) if len(ap) > 0 else 1)

In [7]:
valid_data_list_filename = "./valid_data_list_after_threshold.txt"
lidar_folder_name = "./data"
dataset = KITTIBEV(valid_data_list_filename=valid_data_list_filename, 
                        lidar_folder_name=lidar_folder_name)

Preloading Data


                                                    

In [8]:
if __name__ == '__main__':
    wandb.init("WSDNNPIXOR")
    epochs = 10
    model = WSDDNPIXOR()
    #load_pretrained(model)

    # for params in model.backbone.parameters():
    #     params.requires_grad = False

    train_dataset_length = int(0.70 * len(dataset))
    train_dataset, test_dataset = random_split(dataset, [train_dataset_length,
                                                        len(dataset) - train_dataset_length],
                                                        generator=torch.Generator().manual_seed(10))
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    print(len(train_dataset), len(test_dataset))

    #scaler = torch.cuda.amp.GradScaler()
    loss_fn = nn.BCEWithLogitsLoss(reduction='sum')
    model = model.cuda()
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    for i in range(epochs):
        
        model = model.train()
        loss = train(train_loader, model, loss_fn, optimizer, test_loader)
        print("Epoch average Loss: ", loss)
        torch.save(model.state_dict(), "model.pth")
        torch.save(optimizer.state_dict(), "opt.pth")
        if i%1 == 0:
            model = model.eval()
            mAP = validate(test_loader, 
                          model, 
                          loss_fn, 
                          inv_class=dataset.inv_class, 
                          direct_class=dataset.class_to_int)

        

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33makshayantony12[0m (use `wandb login --relogin` to force relogin)


4087 1753


 12%|█▏        | 501/4087 [02:48<20:24,  2.93it/s]

Loss:  5.325825215838698  mAP:  0.22669957982641573


 24%|██▍       | 1001/4087 [05:39<17:38,  2.91it/s]

Loss:  5.312571852585266  mAP:  0.22789770326149725


 37%|███▋      | 1501/4087 [08:30<14:45,  2.92it/s]

Loss:  5.3034846243891405  mAP:  0.23093919114897032


 49%|████▉     | 2001/4087 [11:22<12:11,  2.85it/s]

Loss:  5.303936238258655  mAP:  0.22994695160739279


 61%|██████    | 2501/4087 [14:14<09:07,  2.90it/s]

Loss:  5.299009358073982  mAP:  0.23136117098108483


 73%|███████▎  | 3001/4087 [17:06<06:20,  2.85it/s]

Loss:  5.303721556938808  mAP:  0.23178087167884623


 86%|████████▌ | 3501/4087 [19:59<03:24,  2.86it/s]

Loss:  5.305659636937453  mAP:  0.23201140152000904


 98%|█████████▊| 4001/4087 [22:52<00:29,  2.88it/s]

Loss:  5.305363758245663  mAP:  0.23175081023397476


                                                   

Epoch average Loss:  5.30510778257067


                                                   

Iou  0.05  mAP  tensor(0.0048)
Iou  0.1  mAP  tensor(0.0032)
Iou  0.2  mAP  tensor(0.0017)
Iou  0.3  mAP  tensor(0.0011)
Iou  0.4  mAP  tensor(0.0006)


 12%|█▏        | 501/4087 [03:05<22:30,  2.66it/s]

Loss:  5.291040490080661  mAP:  0.2292361879960516


 24%|██▍       | 1001/4087 [06:12<19:05,  2.69it/s]

Loss:  5.294163115161537  mAP:  0.23206351023653607


 37%|███▋      | 1501/4087 [09:19<16:20,  2.64it/s]

Loss:  5.2945391502793  mAP:  0.23319444701285097


 49%|████▉     | 2001/4087 [12:28<13:22,  2.60it/s]

Loss:  5.29522701559351  mAP:  0.2329242457855264


 61%|██████    | 2501/4087 [15:39<10:08,  2.61it/s]

Loss:  5.300038084744701  mAP:  0.23194196200741202


 73%|███████▎  | 3001/4087 [18:50<07:03,  2.56it/s]

Loss:  5.302579556112078  mAP:  0.2324184415509369


 86%|████████▌ | 3501/4087 [22:05<03:51,  2.53it/s]

Loss:  5.301253142809168  mAP:  0.23162401970726393


 98%|█████████▊| 4001/4087 [25:00<00:29,  2.88it/s]

Loss:  5.3015079345398854  mAP:  0.23196739794074384


                                                   

Epoch average Loss:  5.3018224507459895


                                                   

Iou  0.05  mAP  tensor(0.0052)
Iou  0.1  mAP  tensor(0.0035)
Iou  0.2  mAP  tensor(0.0018)
Iou  0.3  mAP  tensor(0.0012)
Iou  0.4  mAP  tensor(0.0007)


 12%|█▏        | 501/4087 [02:51<20:26,  2.92it/s]

Loss:  5.299024526140128  mAP:  0.2346344457843875


 24%|██▍       | 1001/4087 [05:44<17:50,  2.88it/s]

Loss:  5.291166130041765  mAP:  0.23312971650966297


 37%|███▋      | 1501/4087 [08:37<15:00,  2.87it/s]

Loss:  5.290541830523754  mAP:  0.2342648189422035


 49%|████▉     | 2001/4087 [11:30<12:06,  2.87it/s]

Loss:  5.299225026288636  mAP:  0.23262151087657393


 61%|██████    | 2501/4087 [14:23<09:03,  2.92it/s]

Loss:  5.300437932467699  mAP:  0.2334759110460061


 73%|███████▎  | 3001/4087 [17:17<06:18,  2.87it/s]

Loss:  5.302912781460503  mAP:  0.23178445276293974


 86%|████████▌ | 3501/4087 [20:11<03:27,  2.83it/s]

Loss:  5.303252570853221  mAP:  0.2330200739443815


 98%|█████████▊| 4001/4087 [23:05<00:30,  2.87it/s]

Loss:  5.300758120272656  mAP:  0.23257510779354407


                                                   

Epoch average Loss:  5.3018224493745665


                                                   

Iou  0.05  mAP  tensor(0.0049)
Iou  0.1  mAP  tensor(0.0033)
Iou  0.2  mAP  tensor(0.0017)
Iou  0.3  mAP  tensor(0.0011)
Iou  0.4  mAP  tensor(0.0006)


 12%|█▏        | 501/4087 [02:52<20:27,  2.92it/s]

Loss:  5.293036506250171  mAP:  0.23631911465730915


 24%|██▍       | 1001/4087 [05:45<17:59,  2.86it/s]

Loss:  5.302155119235926  mAP:  0.2353072410994453


 37%|███▋      | 1501/4087 [08:39<14:59,  2.88it/s]

Loss:  5.307863600210987  mAP:  0.233763837061354


 49%|████▉     | 2001/4087 [11:33<12:10,  2.85it/s]

Loss:  5.300224514881074  mAP:  0.23367129022296496


 61%|██████    | 2501/4087 [14:28<09:20,  2.83it/s]

Loss:  5.302437124013689  mAP:  0.2324594325798901


 73%|███████▎  | 3001/4087 [17:23<06:21,  2.85it/s]

Loss:  5.302912776159432  mAP:  0.2322491445705285


 86%|████████▌ | 3501/4087 [20:18<03:27,  2.83it/s]

Loss:  5.300967502987018  mAP:  0.2329863407108694


 98%|█████████▊| 4001/4087 [23:14<00:30,  2.80it/s]

Loss:  5.300758118949598  mAP:  0.2329969885920553


                                                   

Epoch average Loss:  5.301822447637335


                                                   

Iou  0.05  mAP  tensor(0.0049)
Iou  0.1  mAP  tensor(0.0033)
Iou  0.2  mAP  tensor(0.0018)
Iou  0.3  mAP  tensor(0.0011)
Iou  0.4  mAP  tensor(0.0007)


 12%|█▏        | 501/4087 [02:51<20:32,  2.91it/s]

Loss:  5.314992603803881  mAP:  0.23287882172338226


 24%|██▍       | 1001/4087 [05:43<17:57,  2.86it/s]

Loss:  5.310147136823429  mAP:  0.2337748438211644


 37%|███▋      | 1501/4087 [08:35<14:56,  2.89it/s]

Loss:  5.311194719469741  mAP:  0.23283528542781598


 49%|████▉     | 2001/4087 [11:28<12:07,  2.87it/s]

Loss:  5.3082205219888055  mAP:  0.23606285401757052


 61%|██████    | 2501/4087 [14:21<09:15,  2.85it/s]

Loss:  5.304436328174569  mAP:  0.23564090264714804


 73%|███████▎  | 3001/4087 [17:13<06:21,  2.85it/s]

Loss:  5.306245002531613  mAP:  0.23517361466018896


 86%|████████▌ | 3501/4087 [20:08<03:25,  2.85it/s]

Loss:  5.306680160207245  mAP:  0.23341261319417986


 98%|█████████▊| 4001/4087 [23:03<00:30,  2.80it/s]

Loss:  5.301757867939507  mAP:  0.23319840241018627


                                                   

Epoch average Loss:  5.301822446437536


                                                   

Iou  0.05  mAP  tensor(0.0054)
Iou  0.1  mAP  tensor(0.0036)
Iou  0.2  mAP  tensor(0.0019)
Iou  0.3  mAP  tensor(0.0012)
Iou  0.4  mAP  tensor(0.0007)


 12%|█▏        | 501/4087 [02:53<20:52,  2.86it/s]

Loss:  5.275072419397512  mAP:  0.23503405081415352


 24%|██▍       | 1001/4087 [05:45<18:06,  2.84it/s]

Loss:  5.296161105882663  mAP:  0.2336793671849838


 37%|███▋      | 1501/4087 [08:38<15:13,  2.83it/s]

Loss:  5.304532481745495  mAP:  0.23371154956348916


 49%|████▉     | 2001/4087 [11:32<12:07,  2.87it/s]

Loss:  5.299724758917723  mAP:  0.23504870540916056


 61%|██████    | 2501/4087 [14:25<09:24,  2.81it/s]

Loss:  5.298838559449758  mAP:  0.23597869807379532


 73%|███████▎  | 3001/4087 [17:19<06:15,  2.89it/s]

Loss:  5.298580882573869  mAP:  0.2350607429253045


 86%|████████▌ | 3501/4087 [20:13<03:29,  2.80it/s]

Loss:  5.301824397886711  mAP:  0.23397994642218933


 98%|█████████▊| 4001/4087 [23:09<00:30,  2.79it/s]

Loss:  5.302257740336443  mAP:  0.23316785129488335


                                                   

Epoch average Loss:  5.301822445107984


                                                   

Iou  0.05  mAP  tensor(0.0049)
Iou  0.1  mAP  tensor(0.0033)
Iou  0.2  mAP  tensor(0.0017)
Iou  0.3  mAP  tensor(0.0011)
Iou  0.4  mAP  tensor(0.0006)


 12%|█▏        | 501/4087 [02:53<20:37,  2.90it/s]

Loss:  5.314992585507362  mAP:  0.2324406107776119


 24%|██▍       | 1001/4087 [05:46<18:09,  2.83it/s]

Loss:  5.297160117755467  mAP:  0.23535719574986388


 37%|███▋      | 1501/4087 [08:40<14:58,  2.88it/s]

Loss:  5.305864934289778  mAP:  0.2371260475051214


 49%|████▉     | 2001/4087 [11:34<12:03,  2.88it/s]

Loss:  5.305222015531671  mAP:  0.23504584713607984


 61%|██████    | 2501/4087 [14:27<09:12,  2.87it/s]

Loss:  5.307635047495056  mAP:  0.23459555931103543


 73%|███████▎  | 3001/4087 [17:21<06:21,  2.84it/s]

Loss:  5.306244999065117  mAP:  0.23384102165594578


 86%|████████▌ | 3501/4087 [20:16<03:31,  2.78it/s]

Loss:  5.306108892988408  mAP:  0.2338083508977935


 98%|█████████▊| 4001/4087 [23:11<00:30,  2.81it/s]

Loss:  5.3032574908442625  mAP:  0.23308853376505412


                                                   

Epoch average Loss:  5.301822444300379


 12%|█▏        | 206/1753 [00:30<05:02,  5.11it/s]