In [1]:
import os
import numpy as np

import torch
import torch.optim as optim

from datasets.datagen import load_data
from models.retinanet import RetinaNet
from models.fpn import FPN50, FPN101
from models.focal_loss import FocalLoss

from PIL import Image, ImageDraw

# Model init

In [2]:
OPTIMIZER = 'sgd'
OPTIM_BASE_LR = 0.001
OPTIM_MOMENTUM = 0.9
OPTIM_ALPHA = 0.5
OPTIM_EPS = 1e-8
OPTIM_WEIGHT_DECAY = 1e-4
OPTIM_BETA = (0.9, 0.999)

In [3]:
def create(conv_body, num_classes):
    networks_map = {
        'ResNet50_FPN': FPN50,
        'ResNet101_FPN': FPN101
    }

    model = RetinaNet(networks_map[conv_body], num_classes)
    return model

def config_optimizer(param):
    print(f"using {OPTIMIZER}: base_learning_rate = {OPTIM_BASE_LR}, momentum = {OPTIM_MOMENTUM}, weight_decay = {OPTIM_WEIGHT_DECAY}")
    if OPTIMIZER == 'sgd':
        optimizer = optim.SGD(param, lr=OPTIM_BASE_LR, momentum=OPTIM_MOMENTUM, weight_decay=OPTIM_WEIGHT_DECAY)
    elif OPTIMIZER == 'rmsprop':
        optimizer = optim.RMSprop(param, lr=OPTIM_BASE_LR, momentum=OPTIM_MOMENTUM, alpha=OPTIM_ALPHA, eps=OPTIM_EPS, weight_decay=OPTIM_WEIGHT_DECAY)
    elif OPTIMIZER == 'adam':
        optimizer = optim.Adam(param, lr=OPTIM_BASE_LR, betas=OPTIM_BETA, eps=OPTIM_EPS, weight_decay=OPTIM_WEIGHT_DECAY)
    else:
        AssertionError('optimizer can not be recognized.')
    return optimizer

# Train

In [4]:
MODEL_CONV_BODY = 'ResNet50_FPN'
MODEL_NUM_CLASSES = 1
MODEL_CHECKPOINT_DIR = 'D:/ForME/3_Data/shrimp/checkpoint'

TRAIN_AUTO_RESUME = False
TRAIN_RESUME_FILE = ''
TRAIN_DATASET = ('D:/ForME/3_Data/shrimp/train', (600, 600), 1) # dir_path, image_size(800, ), batch_size(8)
TRAIN_MAX_ITER = 100

VALID_DATASET = ('D:/ForME/3_Data/shrimp/val', (600, 600), 1)

In [5]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint', loss=None):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = loss
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
    
    def __call__(self, val_loss, model, epoch):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch):
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...") 
        state = {
            'net': model.state_dict(),
            'loss': self.best_score,
            'epoch': epoch
        }
        torch.save(state, os.path.join(self.path, f"0_fpn50_b1_600_{epoch}_{val_loss:.3f}.pkl"))
        self.val_loss_min = val_loss

In [6]:
def create_train_model():
    start_iter = 0
    min_loss = None
    if os.path.exists(os.path.join(MODEL_CHECKPOINT_DIR, f"FPN50.pkl")):
        if TRAIN_AUTO_RESUME:
            checkpoints = torch.load(os.path.join(MODEL_CHECKPOINT_DIR, TRAIN_RESUME_FILE))
            start_iter = checkpoints['epoch']
            if start_iter > 0:
                weights = checkpoints['net']
                min_loss = checkpoints['loss']
    else:
        weights = None
    model = create(MODEL_CONV_BODY, MODEL_NUM_CLASSES)
    return model, weights, start_iter, min_loss

def setup_train_model(model, weights, train=False):
    if weights is not None:
        model.load_state_dict(weights)
    else:
        # init_weight(model)
        print('weight init')
    if not torch.cuda.is_available():
        raise print(f"You could use GPU for train model")
    model.cuda()

In [7]:
def train_model():
    model, weight, start_iter, min_loss = create_train_model()
    setup_train_model(model, weight, train=True)

    trainloader = load_data(TRAIN_DATASET)
    validloader = load_data(VALID_DATASET)
    optimizer = config_optimizer(param=model.parameters())
    criterion = FocalLoss(num_classes=MODEL_NUM_CLASSES)

    train_losses = []
    valid_losses = []
    avg_train_losses = []
    avg_valid_losses = []

    early_stopping = EarlyStopping(patience=10, verbose=True, path=MODEL_CHECKPOINT_DIR, loss=min_loss)

    for cur_iter in range(start_iter, TRAIN_MAX_ITER):
        print(f"Epoch: {cur_iter}")

        ###################
        # train the model #
        ###################
        model.train()
        model.freeze_bn()
        for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(trainloader):
            inputs = torch.autograd.Variable(inputs.cuda())
            loc_targets = torch.autograd.Variable(loc_targets.cuda())
            cls_targets = torch.autograd.Variable(cls_targets.cuda())

            # print(loc_targets.shape)
            optimizer.zero_grad()
            loc_preds, cls_preds = model(inputs)
            # print(loc_preds.shape)
            loc_loss, cls_loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
            loss = loc_loss + cls_loss
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
        
            # show status
            if batch_idx % 10 == 0:
                print(f"batch idx: {batch_idx} => loc_loss: {loc_loss.item()} || cls_loss: {cls_loss.item()}  || train_loss: {loss.item()}")
        
        
        ######################    
        # validate the model #
        ######################
        model.eval()
        for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(validloader):
            inputs = torch.autograd.Variable(inputs.cuda())
            loc_target = torch.autograd.Variable(loc_target.cuda())
            cls_target = torch.autograd.Variable(cls_target.cuda())

            loc_preds, cls_preds = model(inputs)
            loc_loss, cls_loss = criterion(loc_preds, loc_target, cls_preds, cls_target)
            loss = loc_loss + cls_loss
            valid_losses.append(loss.item())
        
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        print(f"[epoch: {cur_iter} / {TRAIN_MAX_ITER} || train_loss: {train_loss:.5f} || valid_loss: {valid_loss:.5f}]")

        train_losses = []
        valid_losses = []

        early_stopping(valid_loss, model, cur_iter)

        if early_stopping.early_stop:
            print('Early stopping')
            break
    
    return  model, avg_train_losses, avg_valid_losses

In [8]:
model, train_loss, valid_loss = train_model()

weight init
using sgd: base_learning_rate = 0.001, momentum = 0.9, weight_decay = 0.0001
Epoch: 0




batch idx: 0 => loc_loss: 1.0443321466445923 || cls_loss: 140.2132568359375  || train_loss: 141.25758361816406
batch idx: 10 => loc_loss: 1.1017905473709106 || cls_loss: 0.9683222770690918  || train_loss: 2.070112705230713
batch idx: 20 => loc_loss: 1.0324469804763794 || cls_loss: 4.159339427947998  || train_loss: 5.191786289215088
batch idx: 30 => loc_loss: 1.0849738121032715 || cls_loss: 3.2628188133239746  || train_loss: 4.347792625427246
batch idx: 40 => loc_loss: 1.112111210823059 || cls_loss: 2.0044779777526855  || train_loss: 3.116589069366455
batch idx: 50 => loc_loss: 1.1228498220443726 || cls_loss: 1.53373384475708  || train_loss: 2.656583786010742
batch idx: 60 => loc_loss: 1.0883634090423584 || cls_loss: 1.0899275541305542  || train_loss: 2.178290843963623
batch idx: 70 => loc_loss: 1.0747389793395996 || cls_loss: 0.7310431003570557  || train_loss: 1.8057820796966553
batch idx: 80 => loc_loss: 1.0168358087539673 || cls_loss: 0.7084653973579407  || train_loss: 1.725301265716

UnboundLocalError: local variable 'loc_target' referenced before assignment