## Fine-tune sign detector network (in semi-supervised case)

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
from ast import literal_eval
import os.path
from tqdm import tqdm
import copy

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.utils.data as data

from torchvision import transforms as trafos
import torchvision.transforms as transforms

In [None]:
import matplotlib.pyplot as plt

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
relative_path = '../../'
# ensure that parent path is on the python path in order to have all packages available
import sys, os
parent_path = os.path.join(os.getcwd(), relative_path)
parent_path = os.path.realpath(parent_path)  # os.path.abspath(...)
sys.path.insert(0, parent_path)

In [None]:
from lib.datasets.cunei_dataset_ssd import CuneiformSSD

from lib.alignment.LineFragment import plot_boxes
from lib.utils.pytorch_utils import get_tensorboard_writer

In [None]:
from lib.models.mobilenetv2_mod03 import MobileNetV2
from lib.models.mobilenetv2_fpn import MobileNetV2FPN
from lib.models.trained_model_loader import get_fpn_ssd_net
from lib.utils.torchcv.models.net import FPNSSD
from lib.utils.torchcv.loss.ssd_loss import SSDLoss

In [None]:
import time
hh = 0.001
## time.sleep(60*60*hh)
for i in tqdm(range(int(6*60*hh))):
    time.sleep(10)

### Config Basics

In [None]:
model_version = 'v001ft01'

# config pretrained detector
pretrained_model_version = 'v001'  # 'v191'  

# config datasets for training and testing
train_collections = ['train_D'] 
test_collections =  ['testEXT']  # ['test_full']

In [None]:
# config generated data
with_gen_data = False

gen_model_version = 'v001'  

gen_folder = 'results_ssd/{}/'.format(gen_model_version)  
gen_file_path = None

gen_collections = ['saa01', 'saa05', 'saa08', 'saa10', 'saa13', 'saa16']
gen_collections += ['train']

In [None]:
# config backbone architecture
arch_opt = 1
arch_type = 'mobile'
width_mult = 0.625

# config detector
with_64 = False
create_bg_class = False
img_size = 512
num_classes = 240

# config schedule
num_epochs = 11 
lr_milestones = [60]

In [None]:
# set log file name
if with_gen_data:
    version_remark = '{}_fpnssd_mobilenetv2_{}_gen_{}'
    version_remark = version_remark.format("_".join(train_collections), pretrained_model_version, gen_model_version)
else:
    version_remark = '{}_fpnssd_mobilenetv2_{}'
    version_remark = version_remark.format("_".join(train_collections), pretrained_model_version)

### Preparing Datasets

In [None]:
if with_gen_data:
    from lib.utils.torchcv.box_coder_retina_lm import RetinaBoxCoder
    from lib.utils.torchcv.transforms_lm.resize import resize_lm
    from lib.utils.torchcv.transforms_lm.random_crop_tile import random_crop_tile_lm
    from lib.utils.torchcv.transforms_lm.pad_gs import pad_lm
else:
    from lib.utils.torchcv.box_coder_retina import RetinaBoxCoder
    from lib.utils.torchcv.transforms.resize import resize
    from lib.utils.torchcv.transforms.random_crop_tile import random_crop_tile
    from lib.utils.torchcv.transforms.pad_gs import pad

In [None]:
box_coder = RetinaBoxCoder(create_bg_class=create_bg_class)
print('num_anchors', len(box_coder.anchor_boxes))
print('anchor areas', np.sqrt(box_coder.anchor_areas))

In [None]:
if with_gen_data:    
    def transform_train(img, boxes, labels, linemap):
        # img = transforms.ColorJitter(0.3,0.3,0,0)(img)
        img = transforms.RandomChoice([transforms.ColorJitter(0.5,0.5,0,0), 
                                       transforms.Lambda(lambda x: x)  # identity
                                      ])(img)  
        img, linemap = pad_lm(img, linemap, (600, 600))
        img, boxes, labels, linemap = random_crop_tile_lm(img, boxes, labels, linemap, scale_range=[0.65, 1], max_aspect_ratio=1.35)
        img, boxes, linemap = resize_lm(img, boxes, linemap, size=(img_size, img_size), random_interpolation=True)
        img = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[1.0])
        ])(img)
        boxes, labels = box_coder.encode(boxes, labels, linemap)

        return img, boxes, labels, transforms.ToTensor()(linemap)
else:
    def transform_train(img, boxes, labels):
        # img = transforms.ColorJitter(0.3,0.3,0,0)(img)
        img = transforms.RandomChoice([transforms.ColorJitter(0.5,0.5,0,0), 
                                       transforms.Lambda(lambda x: x)  # identity
                                      ])(img)  
        img = pad(img, (600, 600))
        img, boxes, labels = random_crop_tile(img, boxes, labels, scale_range=[0.65, 1], max_aspect_ratio=1.35)
        img, boxes = resize(img, boxes, size=(img_size, img_size), random_interpolation=True)
        img = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[1.0])
        ])(img)
        boxes, labels = box_coder.encode(boxes, labels)
        return img, boxes, labels

In [None]:
if with_gen_data:
    trainset = CuneiformSSD(collections=train_collections, transform=transform_train, 
                            gen_file_path=gen_file_path, gen_collections=gen_collections, gen_folder=gen_folder, 
                            relative_path=relative_path, use_balanced_idx=False, use_linemaps=True, 
                            remove_empty_tiles=False, min_align_ratio=0.2)
else:
    trainset = CuneiformSSD(collections=train_collections, transform=transform_train,
                            gen_file_path=gen_file_path, relative_path=relative_path, use_linemaps=False)

In [None]:
if with_gen_data:
    def transform_test(img, boxes, labels, linemap):
        img, boxes, labels, linemap = random_crop_tile_lm(img, boxes, labels, linemap, scale_range=[0.85, 0.86], max_aspect_ratio=1.001)
        img, boxes, linemap = resize_lm(img, boxes, linemap, size=(img_size, img_size), random_interpolation=True)
        img = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5],std=[1.0])
        ])(img)
        boxes, labels = box_coder.encode(boxes, labels, linemap)
        return img, boxes, labels, transforms.ToTensor()(linemap)
else:
    def transform_test(img, boxes, labels):
        img, boxes, labels = random_crop_tile(img, boxes, labels, scale_range=[0.85, 0.86], max_aspect_ratio=1.001)
        img, boxes = resize(img, boxes, size=(img_size, img_size), random_interpolation=True)
        img = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5],std=[1.0])
        ])(img)
        boxes, labels = box_coder.encode(boxes, labels)
        return img, boxes, labels

In [None]:
if with_gen_data:
    testset = CuneiformSSD(collections=test_collections, transform=transform_test,
                           gen_file_path=None, relative_path=relative_path, use_linemaps=True)
else:
    testset = CuneiformSSD(collections=test_collections, transform=transform_test,
                           gen_file_path=None, relative_path=relative_path, use_linemaps=False)

In [None]:
trainloader = data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=3)
testloader = data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=3)

### Building Model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# load FPN model from pretrained detector model
fpnssd_net = get_fpn_ssd_net(pretrained_model_version, device, arch_type, with_64, arch_opt, width_mult, 
                             relative_path, num_classes, num_c=1)
fpnssd_net.train()

# print model
print(fpnssd_net)

In [None]:
### Test net
loc_preds, cls_preds = fpnssd_net(torch.randn(1, 1, img_size, img_size).to(device))
print(loc_preds.size(), cls_preds.size())

### Optimization

In [None]:
criterion = SSDLoss(num_classes=num_classes)
#criterion = FocalLoss(num_classes=num_classes)
optimizer = optim.SGD(fpnssd_net.parameters(), lr=0.0001, momentum=0.9, weight_decay=1e-4)

# lr policy
# scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.97)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=0.1)

In [None]:
# init logger
if version_remark == '':
    comment_str = '_{}'.format(model_version)
else:
    comment_str = '_{}_{}'.format(model_version, version_remark)
writer = get_tensorboard_writer(logs_folder='{}results/run_logs/detector'.format(relative_path), comment=comment_str)

In [None]:
# Training
best_loss = float('inf')  # best test loss
best_epoch = 0
best_model_wts = copy.deepcopy(fpnssd_net.state_dict())


def train(epoch):
    fpnssd_net.train()
    train_loss = 0

    scheduler.step()

    if with_gen_data:
        for batch_idx, (inputs, loc_targets, cls_targets, linemap) in enumerate(trainloader):
            inputs = inputs.to(device)
            loc_targets = loc_targets.to(device)
            cls_targets = cls_targets.to(device)

            optimizer.zero_grad()
            loc_preds, cls_preds = fpnssd_net(inputs)
            loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            print('train_loss: %.3f | avg_loss: %.3f [%d/%d]'
                  % (loss.item(), train_loss/(batch_idx+1), batch_idx+1, len(trainloader)))
    else:
        for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(trainloader):
            inputs = inputs.to(device)
            loc_targets = loc_targets.to(device)
            cls_targets = cls_targets.to(device)

            optimizer.zero_grad()
            loc_preds, cls_preds = fpnssd_net(inputs)
            loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            print('train_loss: %.3f | avg_loss: %.3f [%d/%d]'
                  % (loss.item(), train_loss/(batch_idx+1), batch_idx+1, len(trainloader)))

    # write to logger
    phase = 'train'
    writer.add_scalar('data/{}/loss'.format(phase), train_loss / len(trainloader), epoch)

def test(epoch):
    fpnssd_net.eval()
    test_loss = 0
    with torch.no_grad():

        if with_gen_data:
            for batch_idx, (inputs, loc_targets, cls_targets, linemap) in enumerate(testloader):
                inputs = inputs.to(device)
                loc_targets = loc_targets.to(device)
                cls_targets = cls_targets.to(device)

                loc_preds, cls_preds = fpnssd_net(inputs)
                loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
                test_loss += loss.item()
                print('test_loss: %.3f | avg_loss: %.3f [%d/%d]'
                      % (loss.item(), test_loss/(batch_idx+1), batch_idx+1, len(testloader)))
        else:
            for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(testloader):
                inputs = inputs.to(device)
                loc_targets = loc_targets.to(device)
                cls_targets = cls_targets.to(device)

                loc_preds, cls_preds = fpnssd_net(inputs)
                loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
                test_loss += loss.item()
                print('test_loss: %.3f | avg_loss: %.3f [%d/%d]'
                      % (loss.item(), test_loss/(batch_idx+1), batch_idx+1, len(testloader)))

    # write to logger
    phase = 'test'
    writer.add_scalar('data/{}/loss'.format(phase), test_loss / len(testloader), epoch)

    # deep copy the model
    global best_loss
    global best_epoch
    test_loss /= len(testloader)
    if test_loss < best_loss and epoch > 5:
        # best_model_wts = copy.deepcopy(fpnssd_net.state_dict())
        weights_path = '{}results/weights/fpn_net_{}_best.pth'.format(relative_path, model_version)
        torch.save(fpnssd_net.state_dict(), weights_path)
        best_epoch = epoch
        best_loss = test_loss

In [None]:
for epoch in tqdm(range(num_epochs)):
    print('\nEpoch: %d' % epoch)
    train(epoch)
    if epoch % 2 == 0:
        print('\nTest')
        test(epoch)

In [None]:
print('Best val Loss: {:4f} at {}'.format(best_loss, best_epoch))

In [None]:
# choose model filename
weights_path = '{}results/weights/fpn_net_{}.pth'.format(relative_path, model_version)
# Save only the model parameters
torch.save(fpnssd_net.state_dict(), weights_path)