In [None]:
import os
import sys
import time
import argparse
import numpy as np
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
import pandas as pd
from torch.utils.data import DataLoader

from models.efficientdet import EfficientDet
from models.losses import FocalLoss
from datasets import Spine_dataset, get_augumentation, detection_collate
from utils import EFFICIENTDET

In [None]:
resume = None
network = 'efficientdet-d0'
num_epochs = 100
batch_size = 32
num_worker = 4
num_classes = 1
device = [0]
grad_accumalation_steps = 1
learning_rate = 1e-4
momentum = 0.9
weight_decay = 5e-4
gamma = 0.1
save_folder = 'weights/'
image_root = '../boostnet_labeldata/data/'
csv_root = '../boostnet_labeldata/labels/training/'

In [None]:
if not os.path.exists(save_folder):
    os.mkdir(save_folder)

In [None]:
def prepare_device(device):
    n_gpu_use = len(device)
    n_gpu = torch.cuda.device_count()
    if n_gpu_use > 0 and n_gpu == 0:
        print("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
        n_gpu_use = 0
    if n_gpu_use > n_gpu:
        print("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(
            n_gpu_use, n_gpu))
        n_gpu_use = n_gpu
    list_ids = device
    device = torch.device('cuda:{}'.format(
        device[0]) if n_gpu_use > 0 else 'cpu')

    return device, list_ids

In [None]:
def get_state_dict(model):
    if type(model) == torch.nn.DataParallel:
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()
    return state_dict

In [None]:
checkpoint = []
if(resume is not None):
    resume_path = str(resume)
    print("Loading checkpoint: {} ...".format(resume_path))
    checkpoint = torch.load(
        resume, map_location=lambda storage, loc: storage)
    num_classes = checkpoint['num_classes']
    network = checkpoint['network']

In [None]:
# boxes_df = pd.read_csv(csv_root+'')
corner_df = pd.read_csv(csv_root+'landmarks.csv',header = None)
filename_df = pd.read_csv(csv_root+'filenames.csv',header = None)
boxes_df = pd.read_csv(csv_root+'train.csv')

In [None]:
for img_id in boxes_df[(boxes_df['xmin']<0) | (boxes_df['ymin']<0)].image_id.unique():
    boxes_df = boxes_df[boxes_df!=img_id]

In [None]:
train_dataset = Spine_dataset.SPINEDetection(image_root,boxes_df,corner_df,filename_df,transform=get_augumentation('train'))

In [None]:
train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=num_worker,
                              shuffle=True,
                              collate_fn=detection_collate,
                              pin_memory=True)

In [None]:
for idx, (images, annotations, corners) in enumerate(train_dataloader):
    print(idx ,images.shape, annotations.shape, corners.shape)
    break

In [None]:
model = EfficientDet(num_classes=num_classes,
                     network=network,
                     W_bifpn=EFFICIENTDET[network]['W_bifpn'],
                     D_bifpn=EFFICIENTDET[network]['D_bifpn'],
                     D_class=EFFICIENTDET[network]['D_class'],
                     )

In [None]:
if(resume is not None):
    model.load_state_dict(checkpoint['state_dict'])
device, device_ids = prepare_device(device)
model = model.to(device)
if(len(device_ids) > 1):
    model = torch.nn.DataParallel(model, device_ids=device_ids)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, patience=3, verbose=True)
criterion = FocalLoss()

In [None]:
model.train()
iteration = 1
for epoch in range(num_epochs):
    print("{} epoch: \t start training....".format(epoch))
    start = time.time()
    result = {}
    total_loss = []
    optimizer.zero_grad()
    for idx, (images, annotations_bboxes, annotations_corners) in enumerate(train_dataloader):
        images = images.to(device)
        annotations_bboxes = annotations_bboxes.to(device)
        annotations_corners = annotations_corners.to(device)
        classification, regression, corners, anchors = model(images)
        classification_loss, regression_loss, corner_loss= criterion(
            classification, regression, corners, anchors, annotations_bboxes, annotations_corners)
        classification_loss = classification_loss.mean()
        regression_loss = regression_loss.mean()
        corner_loss = corner_loss.mean()
        loss = classification_loss + regression_loss + corner_loss
        if bool(loss == 0):
            print('loss equal zero(0)')
            continue
        loss.backward()
        if (idx+1) % grad_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            optimizer.zero_grad()

        total_loss.append(loss.item())
        if(iteration % 100 == 0):
            print('{} iteration: training ...'.format(iteration))
            ans = {
                'epoch': epoch,
                'iteration': iteration,
                'cls_loss': classification_loss.item(),
                'reg_loss': regression_loss.item(),
                'creg_loss': corner_losss.item(),
                'mean_loss': np.mean(total_loss)
            }
            for key, value in ans.items():
                print('    {:15s}: {}'.format(str(key), value))

        iteration += 1
    scheduler.step(np.mean(total_loss))
    result = {
        'time': time.time() - start,
        'loss': np.mean(total_loss)
    }
    for key, value in result.items():
        print('    {:15s}: {}'.format(str(key), value))
    arch = type(model).__name__
    state = {
        'arch': arch,
        'num_class': num_classes,
        'network': network,
        'state_dict': get_state_dict(model)
    }
    torch.save(
        state, './weights/checkpoint_{}_{}.pth'.format(network, epoch))
state = {
    'arch': arch,
    'num_class': num_class,
    'network': network,
    'state_dict': get_state_dict(model)
}
torch.save(state, './weights/Final_{}.pth'.format(network))
