In [None]:
import os
import sys
import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.nn.functional import ctc_loss, log_softmax
from torchvision.transforms import Compose
import editdistance

# detection
from detection.unet import UNet
from detection.dataset import DetectionDataset
import detection.transform
import detection.routine
# the proper way to do this is relative import, one more nested package and main.py outside the package
# will sort this out
#sys.path.insert(0, os.path.abspath((os.path.dirname(__file__)) + '/../'))

# recognition
from recognition.model import RecognitionModel
from recognition.dataset import RecognitionDataset
import recognition.transform #import Compose, Resize, Pad, Rotate
import recognition.routine
import recognition.common

from utils import get_logger, dice_coeff, dice_loss

In [None]:
torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### Segmentation

In [None]:
# config hyper parametrs
#data_path = "C:\\Users\\Lisen\\Desktop\\CV\\data\\" #path to the data
data_path = "//home//mayer//LocalRepository//JupyterProjects//MADE_2019_cv//02_CarPlatesOCR//data//" 
epochs = 23 #number of epochs
batch_size = 16 #batch size
image_size = 256 #input image size
lr = 0.0001 #learning rate
weight_decay = 5e-4 #weight decay
lr_step = 8 #learning rate step
lr_gamma = 0.5 #learning rate gamma
model = UNet()
weight_bce = 0.5 #weight BCE loss
load = False #load file model
val_split = 0.8 #train/val split
#output_dir = "C:\\Users\\Lisen\\Desktop\\CV\\baseline\\temp\\"#dir to save log and models
output_dir = "//home//mayer//LocalRepository//JupyterProjects//MADE_2019_cv//02_CarPlatesOCR//temp//"
part = 0.01 # config which part of train dataset use

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size
# TODO: img_size=256 is rather mediocre, try to optimize network for at least 512
if load:
    model.load_state_dict(torch.load(load))
model = model.to(device)
# model = nn.DataParallel(model)

In [None]:
os.makedirs(output_dir, exist_ok=True)
logger = get_logger(os.path.join(output_dir, 'segmentation_train.log'))
logger.info('Start training with params:')
logger.info("Argument %s: %r", "data_path", data_path)
logger.info("Argument %s: %r", "epochs", epochs)
logger.info("Argument %s: %r", "batch_size", batch_size)
logger.info("Argument %s: %r", "image_size",image_size )
logger.info("Argument %s: %r", "lr", lr)
logger.info("Argument %s: %r", "weight_decay",weight_decay )
logger.info("Argument %s: %r", "lr_step", lr_step)
logger.info("Argument %s: %r", "lr_gamma",lr_gamma )
logger.info("Argument %s: %r", "weight_bce", weight_bce)
logger.info("Argument %s: %r", "load", load)
logger.info("Argument %s: %r", "val_split", val_split)
logger.info("Argument %s: %r", "output_dir", output_dir)
logger.info('Model type: {}'.format(model.__class__.__name__))

In [None]:
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# TODO: loss experimentation, fight class imbalance, there're many ways you can tackle this challenge
criterion = lambda x, y: (weight_bce * nn.BCELoss()(x, y), (1. - weight_bce) * dice_loss(x, y))
# TODO: you can always try on plateau scheduler as a default option
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=lr_gamma) \
    if lr_step > 0 else None

# dataset
# TODO: to work on transformations a lot, look at albumentations package for inspiration
train_transforms = detection.transform.Compose([
    detection.transform.Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5),
    detection.transform.Flip(p=0.05),
    detection.transform.Pad(max_size=0.6, p=0.25),
    detection.transform.Resize(size=(image_size, image_size), keep_aspect=True)
])
# TODO: don't forget to work class imbalance and data cleansing
val_transforms = detection.transform.Resize(size=(image_size, image_size))

train_dataset = DetectionDataset(data_path, os.path.join(data_path, 'train_segmentation.json'),
                                 transforms=train_transforms, part=part)
val_dataset = DetectionDataset(data_path, None, transforms=val_transforms, part=part)

# split dataset into train/val, don't try to do this at home ;)
train_size = int(len(train_dataset) * val_split)
val_dataset.image_names = train_dataset.image_names[train_size:]
val_dataset.mask_names = train_dataset.mask_names[train_size:]
train_dataset.image_names = train_dataset.image_names[:train_size]
train_dataset.mask_names = train_dataset.mask_names[:train_size]

# TODO: always work with the data: cleaning, sampling
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8,
                              shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4,
                            shuffle=False, drop_last=False)
logger.info('Length of train/val=%d/%d', len(train_dataset), len(val_dataset))
logger.info('Number of batches of train/val=%d/%d', len(train_dataloader), len(val_dataloader))

try:
    detection.routine.train(model, optimizer, criterion, scheduler, epochs, train_dataloader, val_dataloader, saveto=output_dir,
          device=device, logger=logger, show_plots=True)
except KeyboardInterrupt:
    torch.save(model.state_dict(), os.path.join(output_dir, 'INTERRUPTED.pth'))
    logger.info('Saved interrupt')
    sys.exit(0)
    
torch.cuda.empty_cache()
gc.collect()

#### Recognition

In [None]:
# config hyper parametrs
#data_path = "C:\\Users\\Lisen\\Desktop\\CV\\data\\" #path to the data
data_path = "//home//mayer//LocalRepository//JupyterProjects//MADE_2019_cv//02_CarPlatesOCR//data//" 
epochs=40 #number of train epochs
batch_size=128 #batch size
weight_decay=5e-4 #weight_decay
lr=1e-4 #lr
lr_step=None #lr step
lr_gamma=None #lr gamma factor
input_wh='320x32' #model input size
rnn_dropout=0.1 #rnn dropout p
rnn_num_directions=1 #bi
augs=0 #degree of geometric augs
load=None #pretrained weights
val_split=0.8 #train/val split
#output_dir = "C:\\Users\\Lisen\\Desktop\\CV\\baseline\\temp\\"#dir to save log and models
output_dir = "//home//mayer//LocalRepository//JupyterProjects//MADE_2019_cv//02_CarPlatesOCR//temp//"
part = 0.01 # config which part of train dataset use

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = RecognitionModel(rnn_dropout, rnn_num_directions)
if load is not None:
    model.load_state_dict(torch.load(load))
model = model.to(device)

In [None]:
os.makedirs(output_dir, exist_ok=True)

logger = get_logger(os.path.join(output_dir, 'recognition_train.log'))
logger.info('Start training with params:')
logger.info("Argument %s: %r", "data_path", data_path)
logger.info("Argument %s: %r", "epochs", epochs)
logger.info("Argument %s: %r", "batch_size", batch_size)
logger.info("Argument %s: %r", "weight_decay",weight_decay )
logger.info("Argument %s: %r", "lr", lr)
logger.info("Argument %s: %r", "lr_step", lr_step)
logger.info("Argument %s: %r", "lr_gamma",lr_gamma )
logger.info("Argument %s: %r", "input_wh", input_wh)
logger.info("Argument %s: %r", "rnn_dropout", rnn_dropout)
logger.info("Argument %s: %r", "rnn_num_directions", rnn_num_directions)
logger.info("Argument %s: %r", "augs", augs)
logger.info("Argument %s: %r", "load", load)
logger.info("Argument %s: %r", "val_split", val_split)
logger.info("Argument %s: %r", "output_dir", output_dir)
logger.info('Model type: {}'.format(model.__class__.__name__))

In [None]:
criterion = ctc_loss

# TODO: try other optimizers and schedulers
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=lr_gamma) \
    if lr_step is not None else None

# dataset
w, h = list(map(int, input_wh.split('x')))
# TODO: again, augmentations is the key for many tasks
train_transforms = recognition.transform.Compose([
    recognition.transform.Rotate(max_angle=augs * 7.5, p=0.5),  # 5 -> 7.5
    recognition.transform.Pad(max_size=augs / 10, p=0.1),
    recognition.transform.Resize(size=(w, h)),
])
val_transforms = recognition.transform.Resize(size=(w, h))
# TODO: don't forget to work on data cleansing
train_dataset = RecognitionDataset(data_path, os.path.join(data_path, 'train_recognition.json'),
                                   abc=recognition.common.abc, transforms=train_transforms, part=part)
val_dataset = RecognitionDataset(data_path, None, abc=recognition.common.abc, transforms=val_transforms,part=part)
# split dataset into train/val, don't try to do this at home ;)
train_size = int(len(train_dataset) * val_split)
val_dataset.image_names = train_dataset.image_names[train_size:]
val_dataset.texts = train_dataset.texts[train_size:]
train_dataset.image_names = train_dataset.image_names[:train_size]
train_dataset.texts = train_dataset.texts[:train_size]

# TODO: maybe implement batch_sampler for tackling imbalance, which is obviously huge in many respects
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8,
                              collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8,
                            collate_fn=val_dataset.collate_fn)
logger.info('Length of train/val=%d/%d', len(train_dataset), len(val_dataset))
logger.info('Number of batches of train/val=%d/%d', len(train_dataloader), len(val_dataloader))

try:
    recognition.routine.train(model, optimizer, criterion, scheduler, epochs, train_dataloader, val_dataloader, saveto=output_dir,
          device=device, logger=logger, show_plots=True)
except KeyboardInterrupt:
    torch.save(model.state_dict(), os.path.join(output_dir, 'INTERRUPTED.pth'))
    logger.info('Saved interrupt')
    sys.exit(0)
    
torch.cuda.empty_cache()
gc.collect()