In [None]:
import ast
import math
import os, gc
import time
import random
from tqdm import tqdm
from pathlib import Path

import numpy as np
import pandas as pd

import cv2
from PIL import Image
from skimage.measure import label, regionprops

import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.colors as colors
import matplotlib.patches as patches
from matplotlib.colors import LinearSegmentedColormap, Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable


import torch
import torch.backends.cudnn as cudnn
import torch.utils.data as data
import torch.nn.functional as F

from dataset import DeployDataset, SynthData
from network.textnet import TextNet
from cfglib.config import config as cfg, update_config, print_config
from cfglib.option import BaseOptions
from util.augmentation import BaseTransform, Augmentation
from util.visualize import visualize_detection, visualize_gt
from util.misc import to_device, mkdirs,rescale_result, AverageMeter

import multiprocessing
multiprocessing.set_start_method("spawn", force=True)


jet_cmap = LinearSegmentedColormap.from_list('jet', ['black', 'blue', 'green', 'yellow', 'red', 'white'], N = 256)

In [None]:
args = {
    "exp_name": 'TD500',
    "checkepoch": 1135,
    "test_size": (640, 960),
    "dis_threshold": 0.35,
    "cls_threshold": 0.9,
    "gpu": "0",
    "resume": False,
    "save_dir": './model/',
}
update_config(cfg, args)
print_config(cfg)

In [None]:
def osmkdir(out_dir):
    import shutil
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir)

In [None]:
def _parse_data(inputs):
    input_dict = {}
    inputs = list(map(lambda x: to_device(x), inputs))
    input_dict['img'] = inputs[0]
    input_dict['train_mask'] = inputs[1]
    input_dict['tr_mask'] = inputs[2]
    input_dict['distance_field'] = inputs[3]
    input_dict['direction_field'] = inputs[4]
    input_dict['weight_matrix'] = inputs[5]
    input_dict['gt_points'] = inputs[6]
    input_dict['proposal_points'] = inputs[7]
    input_dict['ignore_tags'] = inputs[8]

    return input_dict

In [None]:
trainset = SynthData(
    data_root = '/home/lkhagvadorj/Temuujin/SynthData/synthetic_data/v2/',
    gt_file_name = 'train_det_v4.txt',
    is_training = True,
    load_memory = cfg.load_memory,
    transform = Augmentation(size = cfg.input_size, 
                             mean = cfg.means, 
                             std = cfg.stds)
)

train_loader = data.DataLoader(trainset, 
                               shuffle = True,
                               drop_last = False,
                               batch_size = 12,
                               num_workers = 4,
                               pin_memory = True)

In [None]:
train_step = 0
batch_time = AverageMeter()
data_time = AverageMeter()
end = time.time()
batch_time_dict = {
    'idx': [],
    'time': [],
}

for i, inputs in tqdm(enumerate(train_loader), total = len(train_loader)):
    data_time.update(time.time() - end)
    train_step += 1
    input_dict = _parse_data(inputs)
    
    batch_time.update(time.time() - end)
    end = time.time()
    
    batch_time_dict['idx'].append(i)
    batch_time_dict['time'].append(batch_time)
    
    if i > 5:
        break