In [1]:
from multiprocessing import reduction
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from WsdnnPIXOR import WSDDNPIXOR
from dataset import KITTIBEV
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.ops import nms
from post_processing import calculate_ap
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_pretrained(model, filename='40epoch'):
    own_state = model.state_dict()
    state_dict = torch.load(filename)
    for name, param in state_dict.items():
        if name not in own_state:
                continue
        if isinstance(param, nn.Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        own_state[name].copy_(param)

In [3]:
def train(train_loader, model, loss_fn, optimizer, test_loader):
    loss_total = 0.0
    data_count = 0.0
    for iter, data in tqdm(enumerate(train_loader),
                           total=len(train_loader),
                           leave=False):
        model = model.train()
        bev = data['bev'].cuda()
        labels = data['labels'].cuda()
        #gt_boxes = data['gt_boxes'].cuda()
        proposals = data['proposals'].squeeze().float().cuda()
        proposals = torch.cuda.FloatTensor(proposals)
        #gt_class_list = data['gt_class_list'].cuda()
        preds = model(bev, proposals)
        preds_class = preds.sum(dim=0).reshape(1, -1)
        loss = loss_fn(preds_class, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_total += loss.item() * bev.shape[0]
        data_count += bev.shape[0]
        if iter%100 == 0:
            wandb.log({"Loss":loss_total / data_count})
            print("Loss: ", loss_total / data_count)
        if iter%1000 == 0 and iter != 0:
            model.eval()
            validate(test_loader, model, loss_fn)
    return loss_total / data_count

In [4]:
def validate(test_loader, 
             model, 
             loss_fn, 
             score_threshold=0.05,
             nms_iou_threshold=0.5,
             iou_list = [0.05, 0.1, 0.2, 0.3, 0.4]):
    num_classes = 8
    loss_total = 0.0
    data_count = 0.0
    all_gt_boxes = torch.zeros((0, 6))
    all_pred_boxes = torch.zeros((0, 7))
    with torch.no_grad():
        for iter, data in tqdm(enumerate(test_loader),
                            total=len(test_loader),
                            leave=False):
            bev = data['bev'].cuda()
            labels = data['labels'].cuda()
            gt_boxes = data['gt_boxes'].reshape(-1, 4) #.cuda()
            proposals = data['proposals'].squeeze().float().cuda()
            gt_class_list = data['gt_class_list'].reshape(-1) #.cuda()

            cls_probs = model(bev, proposals)
            preds_class = cls_probs.sum(dim=0).reshape(1, -1)
            loss = loss_fn(preds_class, labels)
            loss_total += loss.item()
            data_count += bev.shape[0]

            for i in range(gt_boxes.shape[0]):
                modified_boxes = torch.cat([torch.tensor([iter, gt_class_list[i]]), gt_boxes[i]]).reshape(1, -1)
                all_gt_boxes = torch.cat([all_gt_boxes, modified_boxes], dim=0)

            for class_num in range(num_classes):
                curr_class_scores = cls_probs[:, class_num]
                valid_score_idx = torch.where(curr_class_scores >= score_threshold)
                valid_scores = curr_class_scores[valid_score_idx]
                valid_proposals = proposals[valid_score_idx]
                retained_idx = nms(valid_proposals, valid_scores, nms_iou_threshold)
                retained_scores = valid_scores[retained_idx]
                retained_proposals = valid_proposals[retained_idx]

                for i in range(retained_proposals.shape[0]):
                    modified_pred_boxes = torch.cat([torch.tensor([iter, class_num, retained_scores[i]]), 
                                                                retained_proposals[i].detach().cpu()]).reshape(1, -1)
                    all_pred_boxes = torch.cat([all_pred_boxes, modified_pred_boxes], dim=0)
    
    for iou in iou_list:
        #print(all_gt_boxes.shape, all_gt_boxes.shape)
        AP = calculate_ap(all_pred_boxes, all_gt_boxes, iou)
        mAP = 0 if len(AP) == 0 else sum(AP) / len(AP)
        #return mAP.item(), AP
        wandb.log({"map@ " + str(iou): mAP})
        print("Iou ", iou, " mAP ", mAP)
    return mAP

In [5]:
valid_data_list_filename = "./data/valid_data_list_after_threshold.txt"
lidar_folder_name = "./data"
dataset = KITTIBEV(valid_data_list_filename=valid_data_list_filename, 
                        lidar_folder_name=lidar_folder_name)

Preloading Data


                                                    

In [6]:
if __name__ == '__main__':
    wandb.init("WSDNNPIXOR")
    epochs = 10
    model = WSDDNPIXOR()
    load_pretrained(model)

    for params in model.backbone.parameters():
        params.requires_grad = False

    train_dataset_length = int(0.7 * len(dataset))
    train_dataset, test_dataset = random_split(dataset, [train_dataset_length,
                                                        len(dataset) - train_dataset_length],
                                                        generator=torch.Generator().manual_seed(42))
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    print(len(train_dataset), len(test_dataset))

    loss_fn = nn.BCEWithLogitsLoss(reduction='sum')
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for i in range(epochs):
        model = model.train()
        loss = train(train_loader, model, loss_fn, optimizer, test_loader)
        print("Epoch average Loss: ", loss)
        torch.save(model.state_dict(), "model.pth")
        torch.save(optimizer.state_dict(), "opt.pth")
        if i%1 == 0:
            model = model.eval()
            mAP = validate(test_loader, model, loss_fn)

        

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33makshayantony12[0m (use `wandb login --relogin` to force relogin)


4626 1983


  0%|          | 1/4626 [00:00<42:59,  1.79it/s]

Loss:  6.3397697121836245


  2%|▏         | 101/4626 [00:34<25:22,  2.97it/s]

Loss:  6.180341551112696


  4%|▍         | 201/4626 [01:07<24:40,  2.99it/s]

Loss:  6.149544719321601


  7%|▋         | 301/4626 [01:41<24:03,  3.00it/s]

Loss:  6.1524955854418115


  9%|▊         | 401/4626 [02:15<23:58,  2.94it/s]

Loss:  6.156470541511504


 11%|█         | 501/4626 [02:50<23:52,  2.88it/s]

Loss:  6.15087785625711


 13%|█▎        | 601/4626 [03:24<22:50,  2.94it/s]

Loss:  6.145482327606368


 15%|█▌        | 701/4626 [03:58<22:04,  2.96it/s]

Loss:  6.154460011568965


 17%|█▋        | 801/4626 [04:32<21:37,  2.95it/s]

Loss:  6.14622204923626


 19%|█▉        | 901/4626 [05:06<21:10,  2.93it/s]

Loss:  6.144248480941279


 22%|██▏       | 1000/4626 [05:39<20:43,  2.92it/s]

Loss:  6.142669240883828


 22%|██▏       | 1001/4626 [10:35<89:40:56, 89.06s/it]

Iou  0.05  mAP  tensor(0.)
Iou  0.1  mAP  tensor(0.)
Iou  0.2  mAP  tensor(0.)
Iou  0.3  mAP  tensor(0.)
Iou  0.4  mAP  tensor(0.)


 24%|██▍       | 1101/4626 [11:10<20:11,  2.91it/s]   

Loss:  6.145917350216537


 26%|██▌       | 1201/4626 [11:44<19:38,  2.91it/s]

Loss:  6.160291767737494


 28%|██▊       | 1301/4626 [12:18<19:00,  2.92it/s]

Loss:  6.1647611175884


 30%|███       | 1401/4626 [12:52<18:17,  2.94it/s]

Loss:  6.165737371086585


 32%|███▏      | 1501/4626 [13:26<17:48,  2.92it/s]

Loss:  6.165917328724143


 35%|███▍      | 1601/4626 [14:00<17:31,  2.88it/s]

Loss:  6.161702607301105


 37%|███▋      | 1701/4626 [14:34<16:48,  2.90it/s]

Loss:  6.157395560343221


 39%|███▉      | 1801/4626 [15:08<15:52,  2.96it/s]

Loss:  6.160229674065281


 41%|████      | 1901/4626 [15:43<15:31,  2.93it/s]

Loss:  6.158031339143561


 43%|████▎     | 2000/4626 [16:16<14:57,  2.93it/s]

Loss:  6.151555042269742


 43%|████▎     | 2001/4626 [21:13<65:00:17, 89.15s/it]

Iou  0.05  mAP  tensor(0.)
Iou  0.1  mAP  tensor(0.)
Iou  0.2  mAP  tensor(0.)
Iou  0.3  mAP  tensor(0.)
Iou  0.4  mAP  tensor(0.)


 45%|████▌     | 2101/4626 [21:47<14:32,  2.89it/s]   

Loss:  6.150930768944397


 48%|████▊     | 2201/4626 [22:21<13:46,  2.94it/s]

Loss:  6.149000222805704


 50%|████▉     | 2301/4626 [22:56<13:13,  2.93it/s]

Loss:  6.145499119833938


 52%|█████▏    | 2401/4626 [23:30<12:49,  2.89it/s]

Loss:  6.146871025927082


 54%|█████▍    | 2501/4626 [24:04<12:16,  2.89it/s]

Loss:  6.145334374856499


 56%|█████▌    | 2601/4626 [24:39<11:45,  2.87it/s]

Loss:  6.146222660385589


 58%|█████▊    | 2701/4626 [25:13<11:04,  2.90it/s]

Loss:  6.148526079813276


 61%|██████    | 2801/4626 [25:48<10:43,  2.84it/s]

Loss:  6.147808947820666


 63%|██████▎   | 2901/4626 [26:22<09:56,  2.89it/s]

Loss:  6.1488647660564055


 65%|██████▍   | 3000/4626 [26:56<09:26,  2.87it/s]

Loss:  6.147517710778363


 65%|██████▍   | 3001/4626 [31:55<40:35:21, 89.92s/it]

Iou  0.05  mAP  tensor(0.)
Iou  0.1  mAP  tensor(0.)
Iou  0.2  mAP  tensor(0.)
Iou  0.3  mAP  tensor(0.)
Iou  0.4  mAP  tensor(0.)


 67%|██████▋   | 3101/4626 [32:30<08:48,  2.89it/s]   

Loss:  6.149482225782738


 69%|██████▉   | 3201/4626 [33:05<08:14,  2.88it/s]

Loss:  6.150699176468995


 71%|███████▏  | 3301/4626 [33:39<07:38,  2.89it/s]

Loss:  6.148813152982704


 74%|███████▎  | 3401/4626 [34:14<07:07,  2.86it/s]

Loss:  6.149096173158704


 76%|███████▌  | 3501/4626 [34:49<06:35,  2.84it/s]

Loss:  6.147078026120124


 78%|███████▊  | 3601/4626 [35:23<05:57,  2.86it/s]

Loss:  6.147948894274303


 80%|████████  | 3701/4626 [35:58<05:20,  2.89it/s]

Loss:  6.1485024952343235


 82%|████████▏ | 3801/4626 [36:32<04:43,  2.91it/s]

Loss:  6.149026964867713


 84%|████████▍ | 3901/4626 [37:07<04:11,  2.88it/s]

Loss:  6.149524512527702


 86%|████████▋ | 4000/4626 [37:41<03:36,  2.90it/s]

Loss:  6.151182403342015


 86%|████████▋ | 4001/4626 [42:40<15:35:14, 89.78s/it]

Iou  0.05  mAP  tensor(0.)
Iou  0.1  mAP  tensor(0.)
Iou  0.2  mAP  tensor(0.)
Iou  0.3  mAP  tensor(0.)
Iou  0.4  mAP  tensor(0.)


 89%|████████▊ | 4101/4626 [43:14<03:02,  2.88it/s]   

Loss:  6.150627824807404


 91%|█████████ | 4201/4626 [43:49<02:26,  2.89it/s]

Loss:  6.150575725442435


 93%|█████████▎| 4301/4626 [44:23<01:51,  2.91it/s]

Loss:  6.150991057203647


 95%|█████████▌| 4401/4626 [44:58<01:17,  2.89it/s]

Loss:  6.151841956405589


 97%|█████████▋| 4501/4626 [45:33<00:42,  2.92it/s]

Loss:  6.151099836413272


 99%|█████████▉| 4601/4626 [46:07<00:08,  2.91it/s]

Loss:  6.15234607189186


                                                   

Epoch average Loss:  6.152703254556953


                                                   

Iou  0.05  mAP  tensor(0.)
Iou  0.1  mAP  tensor(0.)
Iou  0.2  mAP  tensor(0.)
Iou  0.3  mAP  tensor(0.)
Iou  0.4  mAP  tensor(0.)


  0%|          | 1/4626 [00:00<27:18,  2.82it/s]

Loss:  5.858439117670059


  2%|▏         | 101/4626 [00:35<26:03,  2.89it/s]

Loss:  6.145567847950624


  4%|▍         | 201/4626 [01:09<25:20,  2.91it/s]

Loss:  6.201722719034746


  7%|▋         | 301/4626 [01:44<25:12,  2.86it/s]

Loss:  6.187342791858305


  9%|▊         | 401/4626 [02:19<24:53,  2.83it/s]

Loss:  6.170159835842185


 11%|█         | 501/4626 [02:53<23:52,  2.88it/s]

Loss:  6.165824366722278


 13%|█▎        | 601/4626 [03:28<23:15,  2.88it/s]

Loss:  6.169587223531799


 15%|█▌        | 701/4626 [04:03<22:47,  2.87it/s]

Loss:  6.165143843469198


 17%|█▋        | 801/4626 [04:37<22:01,  2.90it/s]

Loss:  6.160561483376631


 19%|█▉        | 901/4626 [05:12<21:22,  2.91it/s]

Loss:  6.163655562452418


 22%|██▏       | 1000/4626 [05:46<21:04,  2.87it/s]

Loss:  6.172125449487856


 22%|██▏       | 1001/4626 [10:44<90:06:05, 89.48s/it]

Iou  0.05  mAP  tensor(0.)
Iou  0.1  mAP  tensor(0.)
Iou  0.2  mAP  tensor(0.)
Iou  0.3  mAP  tensor(0.)
Iou  0.4  mAP  tensor(0.)


 24%|██▍       | 1101/4626 [11:18<20:17,  2.89it/s]   

Loss:  6.171790634714182


                                                   

KeyboardInterrupt: 