In [1]:
import os
import sys
import tqdm
import gc

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.nn.functional import ctc_loss, log_softmax
from torchvision.transforms import Compose
import editdistance

# segmentation
from segmentation.unet import UNet
from segmentation.maskrcnn import maskrcnn_resnet50_fpn
from segmentation.dataset import DetectionDataset
import segmentation.transform
import segmentation.routine
import segmentation_models_pytorch as smp
# the proper way to do this is relative import, one more nested package and main.py outside the package
# will sort this out
#sys.path.insert(0, os.path.abspath((os.path.dirname(__file__)) + '/../'))


from utils import get_logger, dice_coeff, dice_loss

In [2]:
torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

device(type='cpu')

#### Segmentation

In [3]:
# config hyper parametrs
#data_path = "C:\\Users\\Lisen\\Desktop\\CV\\data\\" #path to the data
#data_path = "//home//mayer//LocalRepository//JupyterProjects//MADE_2019_cv//02_CarPlatesOCR//data//" 
epochs = 13 #number of epochs
batch_size = 16 #batch size
image_size = 256 #input image size
lr = 1e-3 #learning rate
weight_decay = 5e-4 #weight decay
lr_step = 3 #learning rate step
lr_gamma = 0.3 #learning rate gamma
model = UNet()
#model = smp.Unet('resnext50_32x4d', encoder_weights='imagenet',classes=13)
#model = smp.FPN(encoder_name='resnext50_32x4d', encoder_weights='imagenet',classes=2)
#model = maskrcnn_resnet50_fpn()
weight_bce = 0.5 #weight BCE loss
load = False #load file model
val_split = 0.8 #train/val split
output_dir = "temp\\"#dir to save log and models
#output_dir = "//home//mayer//LocalRepository//JupyterProjects//MADE_2019_cv//02_CarPlatesOCR//temp//"
part = 1 # config which part of train dataset use
#segmentationFile =  'C:\\Users\\Lisen\\Desktop\\CV\\dataset\\segmentation.json'
segmentationFile = "/home/mayer/LocalRepository/JupyterProjects/DeepFashion2/dataset/segmentation.json"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size
# TODO: img_size=256 is rather mediocre, try to optimize network for at least 512
if load:
    model.load_state_dict(torch.load(load))
model = model.to(device)
# model = nn.DataParallel(model)

In [4]:
os.makedirs(output_dir, exist_ok=True)
logger = get_logger(os.path.join(output_dir, 'segmentation_train.log'))
logger.info('Start training with params:')
logger.info("Argument %s: %r", "epochs", epochs)
logger.info("Argument %s: %r", "batch_size", batch_size)
logger.info("Argument %s: %r", "image_size",image_size )
logger.info("Argument %s: %r", "lr", lr)
logger.info("Argument %s: %r", "weight_decay",weight_decay )
logger.info("Argument %s: %r", "lr_step", lr_step)
logger.info("Argument %s: %r", "lr_gamma",lr_gamma )
logger.info("Argument %s: %r", "weight_bce", weight_bce)
logger.info("Argument %s: %r", "load", load)
logger.info("Argument %s: %r", "val_split", val_split)
logger.info("Argument %s: %r", "output_dir", output_dir)
logger.info("Argument %s: %r", "segmentationFile", segmentationFile)
logger.info('Model type: {}'.format(model.__class__.__name__))

2020-07-25 18:06:13 Start training with params:
2020-07-25 18:06:13 Argument epochs: 13
2020-07-25 18:06:13 Argument batch_size: 16
2020-07-25 18:06:13 Argument image_size: 256
2020-07-25 18:06:13 Argument lr: 0.001
2020-07-25 18:06:13 Argument weight_decay: 0.0005
2020-07-25 18:06:13 Argument lr_step: 3
2020-07-25 18:06:13 Argument lr_gamma: 0.3
2020-07-25 18:06:13 Argument weight_bce: 0.5
2020-07-25 18:06:13 Argument load: False
2020-07-25 18:06:13 Argument val_split: 0.8
2020-07-25 18:06:13 Argument output_dir: 'temp\\'
2020-07-25 18:06:13 Argument segmentationFile: '/home/mayer/LocalRepository/JupyterProjects/DeepFashion2/dataset/segmentation.json'
2020-07-25 18:06:13 Model type: UNet


In [5]:
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# TODO: loss experimentation, fight class imbalance, there're many ways you can tackle this challenge
criterion = lambda x, y: (weight_bce * nn.BCELoss()(x, y), (1. - weight_bce) * dice_loss(x, y))
# TODO: you can always try on plateau scheduler as a default option
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=lr_gamma) \
    if lr_step > 0 else None

# dataset
# TODO: to work on transformations a lot, look at albumentations package for inspiration
train_transforms = segmentation.transform.Compose([
    segmentation.transform.Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5),
    segmentation.transform.Flip(p=0.05),
    segmentation.transform.Pad(max_size=0.6, p=0.25),
    segmentation.transform.Resize(size=(image_size, image_size), keep_aspect=True)
])
# TODO: don't forget to work class imbalance and data cleansing
val_transforms = segmentation.transform.Resize(size=(image_size, image_size))

train_dataset = DetectionDataset(segmentationFile, transforms=train_transforms, part=part)
val_dataset = DetectionDataset(None, transforms=val_transforms, part=part)

# split dataset into train/val, don't try to do this at home ;)
train_size = int(len(train_dataset) * val_split)
val_dataset.image_names = train_dataset.image_names[train_size:]
val_dataset.mask_names = train_dataset.mask_names[train_size:]
train_dataset.image_names = train_dataset.image_names[:train_size]
train_dataset.mask_names = train_dataset.mask_names[:train_size]

# TODO: always work with the data: cleaning, sampling
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8,
                              shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4,
                            shuffle=False, drop_last=False)
logger.info('Length of train/val=%d/%d', len(train_dataset), len(val_dataset))
logger.info('Number of batches of train/val=%d/%d', len(train_dataloader), len(val_dataloader))

try:
    segmentation.routine.train(model, optimizer, criterion, scheduler, epochs, train_dataloader, val_dataloader, saveto=output_dir,
          device=device, logger=logger, show_plots=True)

except KeyboardInterrupt:
    logger.info('Saved interrupt')
    sys.exit(0)
    
torch.cuda.empty_cache()
gc.collect()

2020-07-25 18:06:14 Length of train/val=6836/1709
2020-07-25 18:06:14 Number of batches of train/val=427/107
2020-07-25 18:06:14 Starting epoch 1/13.
mean loss: 1.3538:   1%|          | 5/427 [01:48<2:28:40, 21.14s/it]2020-07-25 18:08:14 Saved interrupt


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
