In [2]:
import os
from datetime import datetime
import traceback

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import torchvision
import numpy as np

from tqdm import tqdm
from tensorboardX import SummaryWriter

from dataset import voc
from retinanet import model, val
from retinanet import transforms as aug


In [3]:
# consts
tag = 'debug'
split_name = 'voc-1'
root_path = '/home/voyager/data/root/voc/'

device_name = 'cpu'
batch_size = 2
depth = 50

epochs = 500
lr = 1e-5
patience = 3

image_size = 512
num_classes = 2
num_workers = 8

# info and deps
now = datetime.now()

if not os.path.exists('./result'):
    os.mkdir('./result')

result_path = './result/{}_{}_{}_{}'.format(
    tag,
    depth,
    split_name,
    now.strftime('%Y%m%d_%H%M%S')
)

summary_writer = SummaryWriter(result_path)


In [4]:
# data loader

# train

# transforms for train
train_trans = aug.Compose([
    aug.RandomCropAndPad(),
    aug.Pad(), # pad to square image
    aug.Resize(image_size, image_size),
    aug.RandomFlipLeftRight(0.5),
    aug.RandomFlipUpDown(0.5),
    aug.RandomRotate(5),
    aug.RandomTranslatePc(50, 50),
    aug.AutoLevel(min_level_rate=1, max_level_rate=1),
    aug.AutoContrast(),
    aug.RandomContrast(0.5),
    aug.Contrast(1.25),
    aug.RandomChoice([
        aug.RandomSaltPepperNoise(0.9, 0.5),
        aug.RandomSaltPepperNoise(0.95, 0.5),
        aug.RandomSaltPepperNoise(0.99, 0.5)
    ]),
    aug.ToTensor(),
    # aug.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) # mean and std of pretrained model
])

train_set = voc.VOCDetection(
    root_path,
    image_set="{}_train".format(split_name),
    transforms=train_trans
)

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=voc.collate,
    num_workers=num_workers
)


In [5]:
# model, loss and optimizer

device = torch.device(device_name)

if depth == 34:
    net = model.resnet34(num_classes, pretrained=True)
elif depth == 50:
    net = model.resnet50(num_classes, pretrained=True)
elif depth == 101:
    net = model.resnet101(num_classes, pretrained=True)
elif depth == 152:
    net = model.resnet152(num_classes, pretrained=True)
    
net = net.to(device)
# net = torch.nn.DataParallel(net).to(device)
net.training = True

optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    patience=patience,
    verbose=True
)

net.train()
net.freeze_bn()


In [None]:
# train-val loop

for epoch in range(epochs):
    print('training epoch {}:'.format(epoch))
    
    # train
    net.train()
    net.freeze_bn()
    
    epoch_loss = []
    
    with tqdm(total=len(train_loader)) as pbar:
        for i, data in enumerate(train_loader):
            try:
                optimizer.zero_grad()

                # convert annos to focal loss format - [x1, y1, x2, y2, cls]
                batch_annos = []
                max_anno_count = 0

                for batch in range(len(data[1])):
                    bboxes = data[1][batch][1]
                    bboxes = bboxes.to(dtype=torch.long)
                    labels = data[1][batch][0]
                    labels = labels.to(dtype=torch.long)

                    # bboxes.shape = [4, n]
                    bboxes = torch.t(bboxes)
                    # labels.shape = [1, n]
                    labels = torch.unsqueeze(labels, 0)

                    # annos.shape = [n, 5]
                    annos = torch.cat((bboxes, labels), 0)
                    annos = torch.t(annos)

                    # record max anno count
                    anno_count = annos.shape[0]

                    if anno_count > max_anno_count:
                        max_anno_count = anno_count

                    batch_annos.append(annos)

                dummy_anno = torch.tensor([[0, 0, 0, 0, -1]])
                padded_batch_annos = []

                for anno in batch_annos:
                    if anno.shape[0] < max_anno_count:
                        dummy_count = max_anno_count - anno.shape[0]

                        for i in range(dummy_count):
                            anno = torch.cat((anno, dummy_anno), 0)

                    padded_batch_annos.append(anno)

                padded_batch_annos = torch.stack(padded_batch_annos)
                padded_batch_annos = padded_batch_annos.to(dtype=torch.float32)

                # forward
                classification_loss, regression_loss = net([
                    data[0].to(device),
                    padded_batch_annos
                ])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()

                loss = classification_loss + regression_loss
                epoch_loss.append(loss.item())

                if loss == 0:
                    continue

                loss.backward()

                torch.nn.utils.clip_grad_norm_(net.parameters(), 0.1)

                optimizer.step()

                del classification_loss
                del regression_loss
                
                pbar.update(1)
            except Exception as e:
                traceback.print_exc()
                break

    # epoch-wise work and record
    mean_loss = np.mean(epoch_loss)
    print('epoch avg loss: {}'.format(mean_loss))

    scheduler.step(mean_loss)

    # save checkpoint
    torch.save(net, os.path.join(result_path, '{:0>3}_{:1.4f}.pth'.format(
        epoch,
        mean_loss
    )))

    # write summary for tensorboardX
    summary_writer.add_scalar(
        '/train/loss',
        mean_loss,
        epoch
    )
    

  0%|          | 0/634 [00:00<?, ?it/s]

training epoch 0:


  1%|          | 4/634 [00:23<1:02:49,  5.98s/it]