In [1]:
import os
import sys
import time
import argparse
import numpy as np
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
import pandas as pd
from torch.utils.data import DataLoader
from models.efficientdet import EfficientDet
from models.losses import FocalLoss
from datasets import Spine_dataset, get_augumentation, detection_collate
from utils import EFFICIENTDET
from tqdm.notebook import tqdm as tqdm
from sklearn.metrics import average_precision_score

In [16]:
resume = None
network = 'efficientdet-d3'
num_epochs = 50
batch_size = 1
num_worker = 4
num_classes = 1
device = [0]
grad_accumulation_steps = 1
learning_rate = 1e-4
momentum = 0.9
weight_decay = 5e-4
gamma = 0.1
save_folder = 'weights/'
image_root = 'boostnet_labeldata/data/'
csv_root = 'boostnet_labeldata/labels/'

In [17]:
if not os.path.exists(save_folder):
    os.mkdir(save_folder)

In [18]:
def prepare_device(device):
    n_gpu_use = len(device)
    n_gpu = torch.cuda.device_count()
    if n_gpu_use > 0 and n_gpu == 0:
        print("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
        n_gpu_use = 0
    if n_gpu_use > n_gpu:
        print("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(
            n_gpu_use, n_gpu))
        n_gpu_use = n_gpu
    list_ids = device
    device = torch.device('cuda:{}'.format(
        device[0]) if n_gpu_use > 0 else 'cpu')

    return device, list_ids

In [19]:
def get_state_dict(model):
    if type(model) == torch.nn.DataParallel:
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()
    return state_dict

In [20]:
checkpoint = []
if(resume is not None):
    resume_path = str(resume)
    print("Loading checkpoint: {} ...".format(resume_path))
    checkpoint = torch.load(
        resume, map_location=lambda storage, loc: storage)
    num_classes = checkpoint['num_classes']
    network = checkpoint['network']

In [21]:
corner_df_train = pd.read_csv(csv_root+'training/landmarks.csv',header = None)
filename_df_train = pd.read_csv(csv_root+'training/filenames.csv',header = None)
boxes_df_train = pd.read_csv(csv_root+'training/train.csv')
boxes_df_train.label = 0 # All boxes same class??

In [22]:
corner_df_test = pd.read_csv(csv_root+'test/landmarks.csv',header = None)
filename_df_test = pd.read_csv(csv_root+'test/filenames.csv',header = None)
boxes_df_test = pd.read_csv(csv_root+'test/test.csv')
boxes_df_test.label = 0

In [23]:
train_dataset = Spine_dataset.SPINEDetection(image_root,boxes_df_train,filename_df_train,transform=get_augumentation('train'))

In [24]:
test_dataset = Spine_dataset.SPINEDetection(image_root,boxes_df_test,filename_df_test,transform=get_augumentation('test'),image_set='test')

In [25]:
train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=num_worker,
                              shuffle=True,
                              collate_fn=detection_collate,
                              pin_memory=True)

In [26]:
test_dataloader = DataLoader(test_dataset,
                              batch_size=batch_size,
                              num_workers=num_worker,
                              shuffle=False,
                              collate_fn=detection_collate,
                              pin_memory=True)

In [27]:
for idx, (images, annotations) in enumerate(train_dataloader):
    print(idx ,images.shape, annotations.shape)
    break

0 torch.Size([1, 3, 1408, 768]) torch.Size([1, 17, 5])


In [28]:
model = EfficientDet(num_classes=num_classes,
                     network=network,
                     W_bifpn=EFFICIENTDET[network]['W_bifpn'],
                     D_bifpn=EFFICIENTDET[network]['D_bifpn'],
                     D_class=EFFICIENTDET[network]['D_class'],
                     )

Loaded pretrained weights for efficientnet-b3


In [29]:
if(resume is not None):
    model.load_state_dict(checkpoint['state_dict'])
device, device_ids = prepare_device(device)
model = model.to(device)
if(len(device_ids) > 1):
    model = torch.nn.DataParallel(model, device_ids=device_ids)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, patience=3, verbose=True)
criterion = FocalLoss()

In [None]:
model.train()
df = pd.DataFrame(np.zeros((num_epochs,4)),columns = ["train_cls","train_bbox_loss","val_cls","val_bbox_loss"])
for epoch in range(num_epochs):
    print("{} epoch: \t start training....".format(epoch))
    
    start = time.time()
    result = {}
    total_loss = []
    bbox_losses = []
    cls_losses = []
    optimizer.zero_grad()
    total_batches = len(train_dataloader)
    tk0 = tqdm(train_dataloader, total=total_batches)
    for idx, (images, annotations_bboxes) in enumerate(tk0):
        images = images.to(device)
        annotations_bboxes = annotations_bboxes.to(device)
        classification, regression, anchors = model(images)
        classification_loss, regression_loss= criterion(
            classification, regression, anchors, annotations_bboxes)
        classification_loss = classification_loss.mean()
        regression_loss = regression_loss.mean()
        loss = 0.1*classification_loss + regression_loss
        if bool(loss == 0):
            print('loss equal zero(0)')
            continue
        loss.backward()
        if (idx+1) % grad_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            optimizer.zero_grad()
        total_loss.append(loss.item())
        bbox_losses.append(regression_loss.item())
        cls_losses.append(classification_loss.item())
        tk0.set_postfix(loss=(np.mean(total_loss)))
    result = {
        'time': time.time() - start,
        'loss': np.mean(total_loss),
        'bbox_loss': np.mean(bbox_losses),
        'cls_loss': np.mean(cls_losses)
    }
    for key, value in result.items():
        print('    {:15s}: {}'.format(str(key), value))
    df.iloc[epoch,:2] = [np.mean(cls_losses),np.mean(bbox_losses)] 
    torch.cuda.empty_cache()
    with torch.no_grad():
        start = time.time()
        result = {}
        total_loss = []
        bbox_losses = []
        cls_losses = []
        optimizer.zero_grad()
        total_batches = len(test_dataloader)
        tk0 = tqdm(test_dataloader, total=total_batches)
        for idx, (images, annotations_bboxes) in enumerate(tk0):
            images = images.to(device)
            annotations_bboxes = annotations_bboxes.to(device)
            classification, regression, anchors = model(images)
            classification_loss, regression_loss = criterion(
                classification, regression, anchors, annotations_bboxes)
            classification_loss = classification_loss.mean()
            regression_loss = regression_loss.mean()
            loss = 0.1*classification_loss + regression_loss
            if bool(loss == 0):
                print('loss equal zero(0)')
                continue
            total_loss.append(loss.item())
            bbox_losses.append(regression_loss.item())
            cls_losses.append(classification_loss.item())
            tk0.set_postfix(loss=(np.mean(total_loss)))
        result = {
            'time': time.time() - start,
            'loss': np.mean(total_loss),
            'bbox_loss': np.mean(bbox_losses),
            'cls_loss': np.mean(cls_losses)
        }
        for key, value in result.items():
            print('    {:15s}: {}'.format(str(key), value))
        scheduler.step(np.mean(total_loss))
    df.iloc[epoch,2:] = [np.mean(cls_losses),np.mean(bbox_losses)]
    df.to_csv('d3-weight-0.1.csv')
    torch.cuda.empty_cache()
    arch = type(model).__name__
    state = {
        'arch': arch,
        'num_class': num_classes,
        'network': network,
        'state_dict': get_state_dict(model)
    }
    torch.save(
        state, './weights/checkpoint_{}_{}.pth'.format(network, epoch))
state = {
    'arch': arch,
    'num_class': num_classes,
    'network': network,
    'state_dict': get_state_dict(model)
}
torch.save(state, './weights/Final_{}.pth'.format(network))

0 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 492.515177488327
    loss           : 1.6776648554137739
    bbox_loss      : 0.9583265475324683
    cls_loss       : 7.193383050807548


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 41.032285928726196
    loss           : 1.143049256876111
    bbox_loss      : 0.912834987975657
    cls_loss       : 2.3021426498889923
1 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 492.28648352622986
    loss           : 1.041887520752429
    bbox_loss      : 0.8116745372571965
    cls_loss       : 2.302129803477107


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.87297224998474
    loss           : 0.9805522337555885
    bbox_loss      : 0.7503394465893507
    cls_loss       : 2.3021277841180563
2 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 492.15407156944275
    loss           : 0.9120997435833461
    bbox_loss      : 0.6818872206300312
    cls_loss       : 2.3021251937207956


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 41.05888104438782
    loss           : 0.8530002990737557
    bbox_loss      : 0.6227878357749432
    cls_loss       : 2.302124598994851
3 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 492.0605158805847
    loss           : 0.8182770199686474
    bbox_loss      : 0.5880645425552637
    cls_loss       : 2.302124748111019


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 41.10697960853577
    loss           : 0.7949369316920638
    bbox_loss      : 0.5647244756110013
    cls_loss       : 2.3021244946867228
4 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 491.30279088020325
    loss           : 0.7500151549456273
    bbox_loss      : 0.5198027000853525
    cls_loss       : 2.302124508205422


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.92837572097778
    loss           : 0.7380597926676273
    bbox_loss      : 0.5078473414760083
    cls_loss       : 2.302124463021755
5 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 492.3437087535858
    loss           : 0.710500317403036
    bbox_loss      : 0.4802878636580247
    cls_loss       : 2.302124504735713


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 41.02900958061218
    loss           : 0.7570001562125981
    bbox_loss      : 0.5267877038568258
    cls_loss       : 2.3021244574338198
6 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 492.334431886673
    loss           : 0.6733532873235968
    bbox_loss      : 0.4431408162919994
    cls_loss       : 2.302124691108656


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.823899030685425
    loss           : 0.7148118037730455
    bbox_loss      : 0.4845993535127491
    cls_loss       : 2.3021244648844004
7 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 490.5356729030609
    loss           : 0.6507780816475716
    bbox_loss      : 0.420565626291624
    cls_loss       : 2.302124511179458


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.93932867050171
    loss           : 0.6708507174625993
    bbox_loss      : 0.4406382627785206
    cls_loss       : 2.302124524489045
8 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 491.2181808948517
    loss           : 0.630984740296917
    bbox_loss      : 0.400772288100883
    cls_loss       : 2.3021244868914947


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 41.0474956035614
    loss           : 0.6802959553897381
    bbox_loss      : 0.45008350838907063
    cls_loss       : 2.3021244294941425
9 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 489.7923629283905
    loss           : 0.6095069665423054
    bbox_loss      : 0.3792945178159805
    cls_loss       : 2.3021244695429495


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 490.9729423522949
    loss           : 0.5933412700086027
    bbox_loss      : 0.3631288212512982
    cls_loss       : 2.302124460125168


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.83329796791077
    loss           : 0.6441432021092623
    bbox_loss      : 0.41393075638916343
    cls_loss       : 2.3021244294941425
11 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 489.69088315963745
    loss           : 0.5842077744477999
    bbox_loss      : 0.3539953262481273
    cls_loss       : 2.3021244482290224


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.86313223838806
    loss           : 0.6405955841764808
    bbox_loss      : 0.41038313461467624
    cls_loss       : 2.3021244443953037
12 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 490.6182291507721
    loss           : 0.5649061481818836
    bbox_loss      : 0.3346936988669473
    cls_loss       : 2.302124459133822


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.94229054450989
    loss           : 0.6263770745135844
    bbox_loss      : 0.3961646326351911
    cls_loss       : 2.3021244164556265
13 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 489.68210458755493
    loss           : 0.5486440362404885
    bbox_loss      : 0.3184315882266931
    cls_loss       : 2.3021244477333496


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.92071986198425
    loss           : 0.6822777753695846
    bbox_loss      : 0.4520653309300542
    cls_loss       : 2.3021244294941425
14 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 490.3424382209778
    loss           : 0.5415944441564365
    bbox_loss      : 0.3113819970720275
    cls_loss       : 2.3021244422809497


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.99216628074646
    loss           : 0.6158272987231612
    bbox_loss      : 0.3856148550985381
    cls_loss       : 2.302124412730336
15 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 489.75138235092163
    loss           : 0.529839815753909
    bbox_loss      : 0.2996273669346454
    cls_loss       : 2.3021244497160405


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.969804763793945
    loss           : 0.6182255803141743
    bbox_loss      : 0.38801313645672053
    cls_loss       : 2.3021244183182716
16 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 489.53794288635254
    loss           : 0.5068284046377313
    bbox_loss      : 0.2766159584827086
    cls_loss       : 2.3021244393069136


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 40.962238073349
    loss           : 0.5957900711800903
    bbox_loss      : 0.365577622782439
    cls_loss       : 2.302124420180917
18 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))


    time           : 491.02102160453796
    loss           : 0.4974411737150561
    bbox_loss      : 0.26722872511264933
    cls_loss       : 2.3021244422809497


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


    time           : 41.00560927391052
    loss           : 0.6525176470167935
    bbox_loss      : 0.42230519896838814
    cls_loss       : 2.302124420180917
19 epoch: 	 start training....


HBox(children=(FloatProgress(value=0.0, max=481.0), HTML(value='')))