In [1]:
# https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/object_detection/YOLO

In [2]:
import torch
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.transforms.functional as FT
from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np
import time
import sys
sys.path.append('../')

from model import Yolov1
from dataset import VOCDataset, PapsDataset, train_transforms, val_transforms
from train import Compose, train_fn
from scheduler import CosineAnnealingWarmUpRestarts

from utils import (
    non_max_suppression,
    mean_average_precision,
    intersection_over_union,
    cellboxes_to_boxes,
    get_bboxes,
    plot_image,
    save_checkpoint,
    load_checkpoint,
)
from loss import YoloLoss

In [3]:
partition = np.load('../data/partition.npy', allow_pickle=True, encoding='latin1').item()
label = np.load('../data/labels_info.npy', allow_pickle=True, encoding='latin1').item()


In [4]:
train_dataset = PapsDataset(
    label,
#     partition['train'][0:100],
    partition['train'],
    transform=train_transforms,
)

test_dataset = PapsDataset(
    label,
#     partition['train'][0:100],
    partition['test'],
    transform=val_transforms,
)

In [5]:
# for image, matrix in train_dataset :
# #     print(image.shape)
# #     print(matrix.shape)
# #     print(path)
# #     print(boxes)
#     pass

In [6]:
GPU_NUM = 0 # 원하는 GPU 번호 입력
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
model = Yolov1(split_size=25, num_boxes=2, num_classes=1).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
loss_fn = YoloLoss()
# scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=10, T_mult=2, eta_max=0.0008,  T_up=5, gamma=0.5)

In [7]:
# model.load_state_dict(torch.load('../trained_model/cifar_net.pth'),strict=False)

In [8]:
def collate_fn(batch):
    return tuple(zip(*batch))

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=12,
    num_workers=12,
#     pin_memory=False,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=12,
    num_workers=12,
    shuffle=False,
    drop_last=True,
    collate_fn=collate_fn
)
print(len(train_loader))
print(len(test_loader))

1248
312


In [9]:
# # image, label, bbox = next(iter(test_dataset))
# for images, labels, bboxes in train_loader :
# #     image, label, bbox = data
# #     print(images[0].shape)
#     c, h, w = images[0].shape
#     images = torch.cat(images).view(-1, c, h, w)
#     lc, lh, lw = labels[0].shape
#     labels = torch.cat(labels).view(-1, lc, lh, lw)
# #     print(images.shape)
# #     print((labels.shape))



In [10]:
# saved_dir = '../trained_model/'
# state = torch.load(saved_dir + 'epoch_' + str(12) +'_model.pt')
# epoch = state['epoch']
# model.load_state_dict(state['state_dict'], strict=False)
# optimizer.load_state_dict(state['optimizer'])
# scheduler.load_state_dict(state['scheduler'])

In [None]:
epochs = 120
for epoch in range(0, epochs):
    stime = time.time()
    epoch_loss = train_fn(train_loader, model, optimizer, loss_fn)
#     scheduler.step()    
    mean_avg_prec = 0
    if epoch % 10 == 9 :
        pred_boxes, target_boxes = get_bboxes(
            test_loader, model, iou_threshold=0.4, threshold=0.4
        )

        mean_avg_prec = mean_average_precision(
            pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint"
        )
    print('{} epoch_loss {} mean_avg_prec {} time {} lr {}'.format(
            epoch, epoch_loss, mean_avg_prec, (time.time() - stime)/60, optimizer.param_groups[0]["lr"]))
    

100%|██████████| 1248/1248 [07:49<00:00,  2.66it/s, box_loss=609, cls_loss=81.3, loss=873, noob_loss=101, ob_loss=81.5]       
  0%|          | 0/1248 [00:00<?, ?it/s]

0 epoch_loss (597.5217124865605, 366.9678306762989, 43.44648382908259, 143.04071557216156, 44.066683360399345) mean_avg_prec 0 time 7.832221178213755 lr 0.0005


100%|██████████| 1248/1248 [07:54<00:00,  2.63it/s, box_loss=401, cls_loss=50.4, loss=578, noob_loss=59.2, ob_loss=67.3]  
  0%|          | 0/1248 [00:00<?, ?it/s]

1 epoch_loss (476.4904223466531, 305.69534182548523, 43.032542270727646, 89.00185124079387, 38.76068765307084) mean_avg_prec 0 time 7.90875198841095 lr 0.0005


100%|██████████| 1248/1248 [07:53<00:00,  2.63it/s, box_loss=106, cls_loss=17, loss=197, noob_loss=30.7, ob_loss=43]      
  0%|          | 0/1248 [00:00<?, ?it/s]

2 epoch_loss (291.39267404262836, 183.39739100138345, 48.03803637547371, 38.772906347727165, 21.184340484631367) mean_avg_prec 0 time 7.899644315242767 lr 0.0005


100%|██████████| 1248/1248 [07:54<00:00,  2.63it/s, box_loss=134, cls_loss=17.8, loss=232, noob_loss=13.1, ob_loss=67.3]  
  0%|          | 0/1248 [00:00<?, ?it/s]

3 epoch_loss (192.52915218548898, 108.44653601218492, 51.04620569485884, 17.72496200295595, 15.311448426773914) mean_avg_prec 0 time 7.910334901014964 lr 0.0005


100%|██████████| 1248/1248 [07:57<00:00,  2.61it/s, box_loss=38.1, cls_loss=3.6, loss=70.3, noob_loss=12.9, ob_loss=15.6] 
  0%|          | 0/1248 [00:00<?, ?it/s]

4 epoch_loss (173.776489340342, 95.63639808006776, 51.42402762862352, 14.344730599950521, 12.37133281888106) mean_avg_prec 0 time 7.958641131718953 lr 0.0005


100%|██████████| 1248/1248 [07:59<00:00,  2.60it/s, box_loss=75.2, cls_loss=11.5, loss=151, noob_loss=13.1, ob_loss=50.8] 
  0%|          | 0/1248 [00:00<?, ?it/s]

5 epoch_loss (167.27236590629968, 89.98885441284914, 51.51152601532447, 13.75915631881127, 12.012829242130884) mean_avg_prec 0 time 7.989784598350525 lr 0.0005


100%|██████████| 1248/1248 [07:51<00:00,  2.65it/s, box_loss=81.6, cls_loss=11.8, loss=154, noob_loss=15.2, ob_loss=45.5] 
  0%|          | 0/1248 [00:00<?, ?it/s]

6 epoch_loss (164.73752876428458, 87.96676909465056, 51.44701002652828, 13.492848796722216, 11.830900526868227) mean_avg_prec 0 time 7.859638472398122 lr 0.0005


100%|██████████| 1248/1248 [07:49<00:00,  2.66it/s, box_loss=119, cls_loss=14.8, loss=231, noob_loss=12.6, ob_loss=84.3]  
  0%|          | 0/1248 [00:00<?, ?it/s]

7 epoch_loss (163.5340750553669, 87.08086914282579, 51.254602280182716, 13.533846275164532, 11.664757402470478) mean_avg_prec 0 time 7.829375815391541 lr 0.0005


100%|██████████| 1248/1248 [07:49<00:00,  2.66it/s, box_loss=74.8, cls_loss=13.1, loss=168, noob_loss=18.6, ob_loss=61.8] 
  0%|          | 0/1248 [00:00<?, ?it/s]

8 epoch_loss (161.84828136211786, 85.63498578774623, 51.12658165815549, 13.48886196124248, 11.597851791251928) mean_avg_prec 0 time 7.827183302243551 lr 0.0005


100%|██████████| 1248/1248 [07:44<00:00,  2.68it/s, box_loss=88.3, cls_loss=15, loss=171, noob_loss=16.3, ob_loss=50.9]   
100%|██████████| 312/312 [01:34<00:00,  3.31it/s]
  0%|          | 0/1248 [00:00<?, ?it/s]

9 epoch_loss (160.17675488423077, 84.3844037988247, 50.92154856522878, 13.492712683402575, 11.378089707870132) mean_avg_prec 0.0 time 9.341512338320415 lr 0.0005


100%|██████████| 1248/1248 [07:45<00:00,  2.68it/s, box_loss=55.5, cls_loss=7.92, loss=109, noob_loss=7.77, ob_loss=37.5] 
  0%|          | 0/1248 [00:00<?, ?it/s]

10 epoch_loss (157.6054642383869, 81.92145395584596, 51.2171600766671, 12.822420926430286, 11.644429235599745) mean_avg_prec 0 time 7.762132962544759 lr 0.0005


100%|██████████| 1248/1248 [07:53<00:00,  2.64it/s, box_loss=49.6, cls_loss=5.07, loss=85.6, noob_loss=9.87, ob_loss=21.1]
  0%|          | 0/1248 [00:00<?, ?it/s]

11 epoch_loss (154.18552015377924, 79.82139355861224, 51.28463838956295, 12.31288645703059, 10.766601698377576) mean_avg_prec 0 time 7.888526074091593 lr 0.0005


100%|██████████| 1248/1248 [07:51<00:00,  2.65it/s, box_loss=51.7, cls_loss=8.74, loss=106, noob_loss=11.6, ob_loss=34.1] 
  0%|          | 0/1248 [00:00<?, ?it/s]

12 epoch_loss (152.47197829454373, 79.12124035297296, 51.16593348827117, 12.373282838326235, 9.811521326979765) mean_avg_prec 0 time 7.853991234302521 lr 0.0005


100%|██████████| 1248/1248 [07:52<00:00,  2.64it/s, box_loss=99.8, cls_loss=8.37, loss=198, noob_loss=15.2, ob_loss=74.9] 
  0%|          | 0/1248 [00:00<?, ?it/s]

13 epoch_loss (150.24029055314188, 77.55798246310307, 51.03641716104288, 12.437541815714958, 9.208348898145442) mean_avg_prec 0 time 7.871839785575867 lr 0.0005


100%|██████████| 1248/1248 [07:52<00:00,  2.64it/s, box_loss=62.9, cls_loss=7.48, loss=134, noob_loss=13.9, ob_loss=50.1] 
  0%|          | 0/1248 [00:00<?, ?it/s]

14 epoch_loss (149.18418274475977, 76.85466042390236, 50.992064846631806, 12.386806220962452, 8.950651493186179) mean_avg_prec 0 time 7.873667260011037 lr 0.0005


100%|██████████| 1248/1248 [07:52<00:00,  2.64it/s, box_loss=47.3, cls_loss=4.1, loss=98.1, noob_loss=13.3, ob_loss=33.4] 
  0%|          | 0/1248 [00:00<?, ?it/s]

15 epoch_loss (148.33179912200342, 76.18986822473697, 50.80377803169764, 12.390132849033062, 8.948019925170602) mean_avg_prec 0 time 7.876154498259226 lr 0.0005


100%|██████████| 1248/1248 [07:53<00:00,  2.63it/s, box_loss=67.1, cls_loss=9.67, loss=150, noob_loss=10.5, ob_loss=62.3]  
  0%|          | 0/1248 [00:00<?, ?it/s]

16 epoch_loss (146.57719049086938, 74.79488156544856, 50.68195634316175, 12.247554223124798, 8.852798369163887) mean_avg_prec 0 time 7.89692364136378 lr 0.0005


100%|██████████| 1248/1248 [07:58<00:00,  2.61it/s, box_loss=73.6, cls_loss=13, loss=158, noob_loss=15, ob_loss=56.2]     
  0%|          | 0/1248 [00:00<?, ?it/s]

17 epoch_loss (145.79828945795694, 74.21514783302943, 50.54896894097328, 12.269035262939257, 8.765137382472554) mean_avg_prec 0 time 7.978245468934377 lr 0.0005


100%|██████████| 1248/1248 [07:57<00:00,  2.61it/s, box_loss=55.9, cls_loss=7.24, loss=119, noob_loss=11.3, ob_loss=44.4]  
  0%|          | 0/1248 [00:00<?, ?it/s]

18 epoch_loss (145.8818563345151, 74.43361109800829, 50.41174008907416, 12.289286794952858, 8.747218160555722) mean_avg_prec 0 time 7.964896063009898 lr 0.0005


100%|██████████| 1248/1248 [07:56<00:00,  2.62it/s, box_loss=66.1, cls_loss=8.18, loss=125, noob_loss=11.6, ob_loss=39.5]  
100%|██████████| 312/312 [01:40<00:00,  3.10it/s]
  0%|          | 0/1248 [00:00<?, ?it/s]

19 epoch_loss (144.14023366952554, 73.05648411848607, 50.24991475313138, 12.202234031298222, 8.631601097492071) mean_avg_prec 4.423346763360314e-05 time 9.6990851521492 lr 0.0005


100%|██████████| 1248/1248 [07:58<00:00,  2.61it/s, box_loss=56.5, cls_loss=5.8, loss=98.5, noob_loss=10.5, ob_loss=25.7]  
  0%|          | 0/1248 [00:00<?, ?it/s]

20 epoch_loss (143.59410685147995, 72.64649786628209, 50.067823691245835, 12.318059545296888, 8.56172570416656) mean_avg_prec 0 time 7.970161441961924 lr 0.0005


100%|██████████| 1248/1248 [08:01<00:00,  2.59it/s, box_loss=95.3, cls_loss=10.3, loss=189, noob_loss=12.8, ob_loss=70.4]  
  0%|          | 0/1248 [00:00<?, ?it/s]

21 epoch_loss (142.9047537308473, 72.18304272951224, 49.860507366748955, 12.398043238199675, 8.463160263613249) mean_avg_prec 0 time 8.02374451160431 lr 0.0005


100%|██████████| 1248/1248 [08:44<00:00,  2.38it/s, box_loss=77.8, cls_loss=4.85, loss=130, noob_loss=12.3, ob_loss=35.2] 
  0%|          | 0/1248 [00:00<?, ?it/s]

22 epoch_loss (141.5223791874372, 70.92246416898874, 49.65023581569012, 12.527471843820353, 8.422207277656222) mean_avg_prec 0 time 8.736085891723633 lr 0.0005


100%|██████████| 1248/1248 [07:51<00:00,  2.65it/s, box_loss=31.5, cls_loss=3.76, loss=71.7, noob_loss=11.3, ob_loss=25.1] 
  0%|          | 0/1248 [00:00<?, ?it/s]

23 epoch_loss (140.66797151015356, 70.34130905377559, 49.52220965195925, 12.487898003214445, 8.316554738375812) mean_avg_prec 0 time 7.861367189884186 lr 0.0005


100%|██████████| 1248/1248 [08:01<00:00,  2.59it/s, box_loss=102, cls_loss=9.03, loss=188, noob_loss=10.7, ob_loss=66.7]   
  0%|          | 0/1248 [00:00<?, ?it/s]

24 epoch_loss (138.06248376308343, 68.89382357933583, 49.640011629997154, 11.51972949428436, 8.008919243975424) mean_avg_prec 0 time 8.02362491687139 lr 0.0005


100%|██████████| 1248/1248 [07:56<00:00,  2.62it/s, box_loss=29.7, cls_loss=3, loss=70.1, noob_loss=9.42, ob_loss=28]     
  0%|          | 0/1248 [00:00<?, ?it/s]

25 epoch_loss (130.00330583560162, 63.47784723226841, 49.87298530110946, 9.312569141005858, 7.339904567871529) mean_avg_prec 0 time 7.949328585465749 lr 0.0005


100%|██████████| 1248/1248 [07:51<00:00,  2.64it/s, box_loss=75.1, cls_loss=8.62, loss=139, noob_loss=7.62, ob_loss=47.3] 
  0%|          | 0/1248 [00:00<?, ?it/s]

26 epoch_loss (128.57314718686618, 62.66759807024247, 49.56121297906606, 9.251336033145586, 7.092999975602979) mean_avg_prec 0 time 7.864538208643595 lr 0.0005


100%|██████████| 1248/1248 [07:46<00:00,  2.68it/s, box_loss=23.6, cls_loss=4.76, loss=61.2, noob_loss=5.82, ob_loss=27.1] 
  0%|          | 0/1248 [00:00<?, ?it/s]

27 epoch_loss (127.36907153863173, 61.72578579645891, 49.335895743125526, 9.22580508620311, 7.081584983815749) mean_avg_prec 0 time 7.770960875352224 lr 0.0005


100%|██████████| 1248/1248 [07:44<00:00,  2.69it/s, box_loss=95.5, cls_loss=12.9, loss=188, noob_loss=7.36, ob_loss=72.2]  
  0%|          | 0/1248 [00:00<?, ?it/s]

28 epoch_loss (126.54367381487137, 61.23867972691854, 49.22681651054285, 9.158669314323328, 6.919508167288791) mean_avg_prec 0 time 7.741100645065307 lr 0.0005


100%|██████████| 1248/1248 [07:44<00:00,  2.69it/s, box_loss=53.4, cls_loss=7.29, loss=107, noob_loss=7.6, ob_loss=39]     
100%|██████████| 312/312 [01:35<00:00,  3.26it/s]
  0%|          | 0/1248 [00:00<?, ?it/s]

29 epoch_loss (125.87764277060826, 60.803485468412056, 49.02001112775925, 9.184880582185892, 6.869265432517307) mean_avg_prec 0.0010445857187733054 time 9.42216746409734 lr 0.0005


100%|██████████| 1248/1248 [07:45<00:00,  2.68it/s, box_loss=32.5, cls_loss=1.86, loss=64.6, noob_loss=6.98, ob_loss=23.2] 
  0%|          | 0/1248 [00:00<?, ?it/s]

30 epoch_loss (125.18629345221397, 60.19715871872046, 48.79328944744208, 9.324145801174335, 6.871699609841483) mean_avg_prec 0 time 7.761116262276968 lr 0.0005


100%|██████████| 1248/1248 [07:44<00:00,  2.69it/s, box_loss=76.1, cls_loss=8.92, loss=167, noob_loss=10.9, ob_loss=70.8]  
  0%|          | 0/1248 [00:00<?, ?it/s]

31 epoch_loss (124.4215003099197, 59.706407504203995, 48.58897845638104, 9.389232833416035, 6.7368815037398) mean_avg_prec 0 time 7.7403308590253195 lr 0.0005


100%|██████████| 1248/1248 [07:44<00:00,  2.69it/s, box_loss=55.6, cls_loss=2.53, loss=110, noob_loss=10.6, ob_loss=41.1]  
  0%|          | 0/1248 [00:00<?, ?it/s]

32 epoch_loss (124.25248140249496, 59.555101927274315, 48.39907371080839, 9.454743040677828, 6.843562845737697) mean_avg_prec 0 time 7.738003126780192 lr 0.0005


100%|██████████| 1248/1248 [07:44<00:00,  2.69it/s, box_loss=46.3, cls_loss=6.45, loss=102, noob_loss=10.2, ob_loss=39.5] 
  0%|          | 0/1248 [00:00<?, ?it/s]

33 epoch_loss (123.05506374591437, 58.751827593797294, 48.23064635502986, 9.348391994833946, 6.724197895910877) mean_avg_prec 0 time 7.739375766118368 lr 0.0005


100%|██████████| 1248/1248 [07:44<00:00,  2.69it/s, box_loss=27.8, cls_loss=5.04, loss=66.9, noob_loss=8, ob_loss=26]      
  0%|          | 0/1248 [00:00<?, ?it/s]

34 epoch_loss (122.32635529530354, 58.27854369466122, 48.02908255962225, 9.29637473095686, 6.722354224297958) mean_avg_prec 0 time 7.734206195672353 lr 0.0005


100%|██████████| 1248/1248 [07:46<00:00,  2.68it/s, box_loss=47.8, cls_loss=6.51, loss=92.9, noob_loss=6.51, ob_loss=32.1] 
  0%|          | 0/1248 [00:00<?, ?it/s]

35 epoch_loss (121.62924040892186, 57.70087373944429, 47.845460821420716, 9.443811637086746, 6.639094237596369) mean_avg_prec 0 time 7.771936635176341 lr 0.0005


100%|██████████| 1248/1248 [07:45<00:00,  2.68it/s, box_loss=134, cls_loss=16.3, loss=279, noob_loss=11.7, ob_loss=118]    
  0%|          | 0/1248 [00:00<?, ?it/s]

36 epoch_loss (121.0952576811497, 57.37364632349748, 47.742504445406105, 9.370943679259373, 6.608162949936321) mean_avg_prec 0 time 7.756834896405538 lr 0.0005


100%|██████████| 1248/1248 [07:45<00:00,  2.68it/s, box_loss=69.2, cls_loss=5.93, loss=143, noob_loss=10.1, ob_loss=57.4]  
  0%|          | 0/1248 [00:00<?, ?it/s]

37 epoch_loss (120.8555827446473, 57.25290557130789, 47.51778843616828, 9.537994763790033, 6.546893864655151) mean_avg_prec 0 time 7.753116385142008 lr 0.0005


100%|██████████| 1248/1248 [07:45<00:00,  2.68it/s, box_loss=40, cls_loss=4.08, loss=80.7, noob_loss=8, ob_loss=28.6]      
  0%|          | 0/1248 [00:00<?, ?it/s]

38 epoch_loss (120.47639733094435, 56.92682691911856, 47.36939912881606, 9.575968252924772, 6.604202786000828) mean_avg_prec 0 time 7.766108282407125 lr 0.0005


100%|██████████| 1248/1248 [07:46<00:00,  2.68it/s, box_loss=84.9, cls_loss=8.44, loss=179, noob_loss=10.6, ob_loss=74.6]  
100%|██████████| 312/312 [01:35<00:00,  3.28it/s]
  0%|          | 0/1248 [00:00<?, ?it/s]

39 epoch_loss (119.4828572303821, 56.163681391722115, 47.16777620330835, 9.598176733805584, 6.553223071297487) mean_avg_prec 0.009153717197477818 time 9.48758156299591 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=38.5, cls_loss=5.66, loss=86.4, noob_loss=7.64, ob_loss=34.6] 
  0%|          | 0/1248 [00:00<?, ?it/s]

40 epoch_loss (118.74121714555301, 55.690646149409126, 46.901204984157516, 9.636028468608856, 6.51333753273894) mean_avg_prec 0 time 7.794916109244029 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=12.3, cls_loss=0.149, loss=31.8, noob_loss=6.63, ob_loss=12.7]
  0%|          | 0/1248 [00:00<?, ?it/s]

41 epoch_loss (118.28010684251785, 55.39803075026243, 46.74696898995302, 9.737177635232607, 6.397929266333962) mean_avg_prec 0 time 7.785055454572042 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=62.6, cls_loss=6.65, loss=125, noob_loss=9.69, ob_loss=45.7]  
  0%|          | 0/1248 [00:00<?, ?it/s]

42 epoch_loss (117.73578535593472, 55.17625711285151, 46.64378578005693, 9.588328369152851, 6.327413945500619) mean_avg_prec 0 time 7.783508865038554 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=79.3, cls_loss=6.98, loss=174, noob_loss=12.1, ob_loss=75.9]  
  0%|          | 0/1248 [00:00<?, ?it/s]

43 epoch_loss (113.9325314882474, 52.635855077933044, 46.74985812795468, 8.553842851748833, 5.992975604039832) mean_avg_prec 0 time 7.793253886699676 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=37.4, cls_loss=2.12, loss=76.6, noob_loss=6.42, ob_loss=30.7]  
  0%|          | 0/1248 [00:00<?, ?it/s]

44 epoch_loss (111.00354490677516, 51.00117074908354, 46.77179597509213, 7.980048763064238, 5.250529425451532) mean_avg_prec 0 time 7.797309521834055 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=58.1, cls_loss=5.24, loss=130, noob_loss=8.13, ob_loss=58.8]   
  0%|          | 0/1248 [00:00<?, ?it/s]

45 epoch_loss (109.3041353913454, 49.93654859753755, 46.61958076785772, 7.792039945339545, 4.955966213461346) mean_avg_prec 0 time 7.795338280995687 lr 0.0005


100%|██████████| 1248/1248 [07:48<00:00,  2.67it/s, box_loss=46.6, cls_loss=6.77, loss=100, noob_loss=7.38, ob_loss=39.6]   
  0%|          | 0/1248 [00:00<?, ?it/s]

46 epoch_loss (108.69643403322269, 49.483415856575355, 46.377006338192864, 7.840990105118507, 4.995021468374687) mean_avg_prec 0 time 7.802592265605926 lr 0.0005


100%|██████████| 1248/1248 [07:48<00:00,  2.67it/s, box_loss=54.2, cls_loss=4.21, loss=106, noob_loss=6.79, ob_loss=40.3]   
  0%|          | 0/1248 [00:00<?, ?it/s]

47 epoch_loss (108.1446033349404, 49.239480131711716, 46.22229534234756, 7.795388533327824, 4.887439193269715) mean_avg_prec 0 time 7.801549895604452 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=38.4, cls_loss=3.49, loss=94.3, noob_loss=11.1, ob_loss=41.3] 
  0%|          | 0/1248 [00:00<?, ?it/s]

48 epoch_loss (107.76116807949849, 49.05224074614354, 45.98229381824151, 7.915237629642854, 4.811395988913062) mean_avg_prec 0 time 7.799468839168549 lr 0.0005


100%|██████████| 1248/1248 [07:48<00:00,  2.66it/s, box_loss=41, cls_loss=6.81, loss=108, noob_loss=9.27, ob_loss=51.1]    
100%|██████████| 312/312 [01:36<00:00,  3.23it/s]
  0%|          | 0/1248 [00:00<?, ?it/s]

49 epoch_loss (106.84724494585625, 48.393034521967934, 45.83892460587697, 7.856293211762722, 4.758992731836863) mean_avg_prec 0.02823001518845558 time 9.618205845355988 lr 0.0005


100%|██████████| 1248/1248 [07:49<00:00,  2.66it/s, box_loss=97.5, cls_loss=12.8, loss=239, noob_loss=13, ob_loss=115]     
  0%|          | 0/1248 [00:00<?, ?it/s]

50 epoch_loss (106.50001771939107, 48.16776663905535, 45.742481242387726, 7.812180978365434, 4.777588714857418) mean_avg_prec 0 time 7.8244433999061584 lr 0.0005


100%|██████████| 1248/1248 [07:48<00:00,  2.66it/s, box_loss=38.4, cls_loss=4.04, loss=89.8, noob_loss=7.38, ob_loss=40]   
  0%|          | 0/1248 [00:00<?, ?it/s]

51 epoch_loss (105.55126277758525, 47.58536787980642, 45.44485364586879, 7.817969008133962, 4.703072331439799) mean_avg_prec 0 time 7.806011323134104 lr 0.0005


100%|██████████| 1248/1248 [07:48<00:00,  2.66it/s, box_loss=98.9, cls_loss=13.1, loss=222, noob_loss=13.4, ob_loss=96.1]   
  0%|          | 0/1248 [00:00<?, ?it/s]

52 epoch_loss (105.2084955832897, 47.41796765037072, 45.22938703879332, 7.89663014962123, 4.664510518975126) mean_avg_prec 0 time 7.810119724273681 lr 0.0005


100%|██████████| 1248/1248 [07:48<00:00,  2.66it/s, box_loss=36.6, cls_loss=3.12, loss=78.7, noob_loss=7.61, ob_loss=31.3] 
  0%|          | 0/1248 [00:00<?, ?it/s]

53 epoch_loss (104.84337344383582, 47.17632438509892, 45.13643242609807, 7.83565635043077, 4.6949601346337335) mean_avg_prec 0 time 7.807390451431274 lr 0.0005


100%|██████████| 1248/1248 [07:48<00:00,  2.66it/s, box_loss=19.2, cls_loss=1.92, loss=48.1, noob_loss=6.81, ob_loss=20.2] 
  0%|          | 0/1248 [00:00<?, ?it/s]

54 epoch_loss (104.20634025029646, 46.581480557337784, 45.02330248936629, 7.9184499849111605, 4.683107280327628) mean_avg_prec 0 time 7.811026986440023 lr 0.0005


100%|██████████| 1248/1248 [07:48<00:00,  2.66it/s, box_loss=33.5, cls_loss=3.73, loss=87, noob_loss=7.23, ob_loss=42.5]   
  0%|          | 0/1248 [00:00<?, ?it/s]

55 epoch_loss (103.39696976160391, 46.05488740633695, 44.75677593167011, 7.959318119745988, 4.62598853214429) mean_avg_prec 0 time 7.805975004037221 lr 0.0005


100%|██████████| 1248/1248 [07:48<00:00,  2.66it/s, box_loss=50, cls_loss=5.29, loss=119, noob_loss=7.84, ob_loss=55.4]    
  0%|          | 0/1248 [00:00<?, ?it/s]

56 epoch_loss (102.78833958888666, 45.80674035885395, 44.534171402072296, 7.874584158643698, 4.57284387086148) mean_avg_prec 0 time 7.807083547115326 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=36.4, cls_loss=8.17, loss=95.6, noob_loss=9.03, ob_loss=42]   
  0%|          | 0/1248 [00:00<?, ?it/s]

57 epoch_loss (102.679611768478, 45.66847997903824, 44.34298001115139, 8.011446577998308, 4.656705313314421) mean_avg_prec 0 time 7.799535965919494 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=47.8, cls_loss=6.55, loss=109, noob_loss=7.81, ob_loss=46.7]  
  0%|          | 0/1248 [00:00<?, ?it/s]

58 epoch_loss (101.76932426752188, 45.17189234265914, 44.20336294938357, 7.971110387872427, 4.422958582777005) mean_avg_prec 0 time 7.799850122133891 lr 0.0005


100%|██████████| 1248/1248 [07:47<00:00,  2.67it/s, box_loss=52.6, cls_loss=4.41, loss=113, noob_loss=6.27, ob_loss=49.9]  
100%|██████████| 312/312 [01:35<00:00,  3.25it/s]
  0%|          | 0/1248 [00:00<?, ?it/s]

59 epoch_loss (101.56400781105727, 44.71746767331393, 44.168905927966804, 8.058868008164259, 4.618766251411169) mean_avg_prec 0.041421547532081604 time 9.55809118350347 lr 0.0005


 43%|████▎     | 537/1248 [03:23<04:26,  2.67it/s, box_loss=49, cls_loss=3.77, loss=111, noob_loss=9.63, ob_loss=48.1]     

In [None]:
pred_boxes, target_boxes = get_bboxes(
    test_loader, model, iou_threshold=0.4, threshold=0.4
)

mean_avg_prec = mean_average_precision(
    pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint"
)

In [None]:
print(mean_avg_prec)

In [None]:
# model.eval()

In [None]:
# test_loader = DataLoader(
#     dataset=test_dataset,
#     batch_size=1,
#     num_workers=1,
#     shuffle=False,
#     drop_last=True,
# )
# pred_boxes, target_boxes = get_bboxes(
#     test_loader, model, iou_threshold=0.5, threshold=0.4
# )

# mean_avg_prec = mean_average_precision(
#     pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint"
# )

In [None]:
saved_dir = '../trained_model/'
state = {
    'epoch': epochs,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
#     'scheduler' : scheduler.state_dict()
}

torch.save(state, saved_dir + 'epoch_' + str(90) +'_model.pt')

In [None]:
def convert_cellboxes(predictions, S=23):
    """
    Converts bounding boxes output from Yolo with
    an image split size of S into entire image ratios
    rather than relative to cell ratios. Tried to do this
    vectorized, but this resulted in quite difficult to read
    code... Use as a black box? Or implement a more intuitive,
    using 2 for loops iterating range(S) and convert them one
    by one, resulting in a slower but more readable implementation.
    """

    predictions = predictions.to("cpu")
    batch_size = predictions.shape[0]
    print('batch_size', batch_size)
    predictions = predictions.reshape(batch_size, S, S, 11)
    bboxes1 = predictions[..., 2:6]
    bboxes2 = predictions[..., 7:11]
#     print('box1 {} box2 {}'.format(bboxes1, bboxes2))
    scores = torch.cat(
        (predictions[..., 1].unsqueeze(0), predictions[..., 6].unsqueeze(0)), dim=0
    )
    print('scores', scores.shape)
    best_box = scores.argmax(0).unsqueeze(-1)
#     print(best_box[0][0])
    best_boxes = bboxes1 * (1 - best_box) + best_box * bboxes2
    cell_indices = torch.arange(S).repeat(batch_size, S, 1).unsqueeze(-1)
    x = 1 / S * (best_boxes[..., :1] + cell_indices)
    y = 1 / S * (best_boxes[..., 1:2] + cell_indices.permute(0, 2, 1, 3))
    w_y = 1 / S * best_boxes[..., 2:4]
    converted_bboxes = torch.cat((x, y, w_y), dim=-1)
    predicted_class = predictions[..., :1].argmax(-1).unsqueeze(-1)
    best_confidence = torch.max(predictions[..., 1], predictions[..., 6]).unsqueeze(
        -1
    )
    converted_preds = torch.cat(
        (predicted_class, best_confidence, converted_bboxes), dim=-1
    )

    return converted_preds


def cellboxes_to_boxes(out, S=23):
    print('out', out.shape)
    converted_pred = convert_cellboxes(out).reshape(out.shape[0], S * S, -1)
    print('converted_pred', converted_pred.shape)
    converted_pred[..., 0] = converted_pred[..., 0].long()
    all_bboxes = []

    for ex_idx in range(out.shape[0]):
        bboxes = []

        for bbox_idx in range(S * S):
            bboxes.append([x.item() for x in converted_pred[ex_idx, bbox_idx, :]])
        all_bboxes.append(bboxes)

    return all_bboxes


In [None]:
image, label, bboxes = next(iter(test_dataset))
image = image.unsqueeze(dim=0).to(device)
label = label.unsqueeze(dim=0)
print(image.shape)
model.eval()
with torch.no_grad():
    prediction = model(image)
    print('prediction',prediction.shape)
print('label', label.shape)    
true_bbox = cellboxes_to_boxes(label)
print('true_bbox', true_bbox[0][0])
bbox = cellboxes_to_boxes(prediction)
# print(bboxes)
print(len(true_bbox[0][0]))
print(len(bboxes))

