In [None]:
from __future__ import division

import os
import random
import numpy as np
import argparse
from copy import deepcopy

# ----------------- Torch Components -----------------
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# ----------------- Extra Components -----------------
from utils import distributed_utils
from utils.misc import compute_flops

# ----------------- Config Components -----------------
from config import build_dataset_config, build_model_config, build_trans_config

# ----------------- Model Components -----------------
from models.detectors import build_model

# ----------------- Train Components -----------------
from engine import build_trainer

In [None]:
class Args():
    def __init__(self):
        self.seed = 42
        self.cuda = False

        self.img_size = 640
        self.eval_first = False

        self.tfboard = False
        self.save_folder = 'weights/'
        self.vis_tgt = False
        self.vis_aux_loss = False
        self.fp16 = False
        self.batch_size = 16

        self.max_epoch = 150
        self.wp_epoch = 1
        self.eval_epoch = 10
        self.no_aug_epoch = 20

        self.model = 'yolov8_n'
        self.conf_thresh = 0.001
        self.nms_thresh = 0.7
        self.topk = 1000
        self.pretrained = None
        self.resume = None
        self.no_multi_labels = False
        self.nms_class_agnostic = False

        self.root = 'D:\\Number Plate Region\\Demo'
        self.dataset = 'character'
        self.load_cache = False
        self.num_workers = 1
        
        self.multi_scale = False
        self.ema = False
        self.min_box_size = 8.0
        self.mosaic = None
        self.mixup = None
        self.grad_accumulate = 1

        self.distributed = False
        self.dist_url = ""
        self.world_size = 1
        self.sybn = False
        self.find_unused_parameters = False
        self.debug = False

In [None]:
args = Args()
print("Setting Arguments.. : ", args)
local_rank = local_process_rank = -1
if args.distributed:
    distributed_utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(distributed_utils.get_sha()))
    try:
        # Multiple Mechine & Multiple GPUs (world size > 8)
        local_rank = torch.distributed.get_rank()
        local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
    except:
        # Single Mechine & Multiple GPUs (world size <= 8)
        local_rank = local_process_rank = torch.distributed.get_rank()
world_size = distributed_utils.get_world_size()
print("LOCAL RANK: ", local_rank)
print("LOCAL_PROCESS_RANL: ", local_process_rank)
print('WORLD SIZE: {}'.format(world_size))

In [None]:
if args.cuda and torch.cuda.is_available():
    print('use cuda')
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
def fix_random_seed(args):
    seed = args.seed + distributed_utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

fix_random_seed(args)

In [None]:
data_cfg = build_dataset_config(args)
model_cfg = build_model_config(args)
trans_cfg = build_trans_config(model_cfg['trans_type'])

In [None]:
model, criterion = build_model(args, model_cfg, device, data_cfg['num_classes'], True)
model = model.to(device).train()
model_without_ddp = model

In [None]:
if args.distributed:
    model = DDP(model, device_ids=[args.gpu], find_unused_parameters=args.find_unused_parameters)
    if args.sybn:
        print('use SyncBatchNorm ...')
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model_without_ddp = model.module

In [None]:
## Calcute Params & GFLOPs
if distributed_utils.is_main_process:
    model_copy = deepcopy(model_without_ddp)
    model_copy.trainable = False
    model_copy.eval()
    compute_flops(model=model_copy,
                    img_size=args.img_size,
                    device=device)
    del model_copy
if args.distributed:
    dist.barrier()

In [None]:
trainer = build_trainer(args, 
                        data_cfg, 
                        model_cfg, 
                        trans_cfg, 
                        device, 
                        model_without_ddp, 
                        criterion, 
                        world_size)

In [None]:
## Eval before training
if args.eval_first and distributed_utils.is_main_process():
    # to check whether the evaluator can work
    model_eval = model_without_ddp
    trainer.eval(model_eval)
    # return

## Satrt Training
trainer.train(model)