In [16]:
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Train a YOLOv5 model on a custom dataset

Usage:
    $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
"""
from train_with_bcc import VOL_ID_MAP, SINGLE_VOL_ID_MAP
from train_with_bcc import convert_target_volunteers_yolo2bcc
from train_with_bcc import get_file_volunteers_dict
from label_converter import BACKGROUND_CLASS_ID
from timeit import default_timer as timer
from collections import defaultdict
import pdb
import argparse
from label_converter import qt2yolo_optimized, qt2yolo
from GPUtil import showUtilization as gpu_usage

from label_filter import filter_qt

from PIL.ImageFont import truetype
from label_converter import yolo2bcc_new, find_union_cstargets, targetize, yolo2bcc_newer
from train_with_bcc import convert_yolo2bcc, nn_predict, convert_cs_yolo2bcc
from train_with_bcc import read_crowdsourced_labels, init_bcc_params, \
    init_nn_output, compute_param_confusion_matrices, init_metrics, update_bcc_metrics
from lib.BCCNet.VariationalInference.VB_iteration_yolo import VB_iteration as VBi_yolo
import logging
import math
import os
import random
import sys
import time
from copy import deepcopy
from pathlib import Path

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import yaml
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam, SGD, lr_scheduler
from tqdm import tqdm

# FILE = Path(__file__).absolute()
# sys.path.append(FILE.parents[0].as_posix())  # add yolov5/ to path

import val  # for end-of-epoch mAP
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
    strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
    check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolve
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \
    torch_distributed_zero_first
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
from utils.loggers import Loggers
from utils.callbacks import Callbacks

LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))

In [17]:
def parse_opt(known=False):
    parser = argparse.ArgumentParser()
    parser.add_argument('--bcc_epoch', type=int, default=0, help='start-epoch for BCC+YOLO run; use -1 for no BCC.')
    parser.add_argument('--qtfilter_epoch', type=int, default=-1, help='start-epoch for qt-filter; use -1 for no qt-filter.')
    parser.add_argument('--qt_thres_mode', type=str, default='', help="one of '', 'conf-count', 'entropy', 'conf-val'")
    parser.add_argument('--qt_thres', type=float, default=0.0, help="the threshold value.")
    parser.add_argument('--hybrid_entropy_thres', type=float, default=0.0, help="the entropy threshold value (only to be used when running the hybrid filter).")
    parser.add_argument('--hybrid_conf_thres', type=float, default=0.0, help="the confidence threshold value (only to be used when running the hybrid filter).")
    parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
    parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
    parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
    parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
    parser.add_argument('--rect', action='store_true', help='rectangular training')
    parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    parser.add_argument('--noval', action='store_true', help='only validate final epoch')
    parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
    parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
    parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
    parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
    parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
    parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
    parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
    parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
    parser.add_argument('--project', default='runs/train', help='save to project/name')
    parser.add_argument('--entity', default=None, help='W&B entity')
    parser.add_argument('--name', default='exp', help='save to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    parser.add_argument('--quad', action='store_true', help='quad dataloader')
    parser.add_argument('--linear-lr', action='store_true', help='linear LR')
    parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
    parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
    parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
    parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
    parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
    parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
    parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
    parser.add_argument('--patience', type=int, default=1100, help='EarlyStopping patience (epochs)')
    opt = parser.parse_known_args()[0] if known else parser.parse_args('')
    return opt

## Main

In [18]:
data_name = 'single_toy_bcc'
opt = parse_opt()
opt.data = f'data/{data_name}.yaml' # Does not make any difference!
data_dict_path = f'../../datasets/{data_name}'
opt.exist_ok = False
opt.cache = None
opt.workers = 0
opt.batch_size = 20 # Change this to number of train images
opt.epochs = 10
opt.bcc_epoch = 5

# Checks
set_logging(RANK)
if RANK in [-1, 0]:
    print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
#     check_git_status()
#     check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=['thop'])

# Resume
if opt.resume and not check_wandb_resume(opt) and not opt.evolve:  # resume an interrupted run
    ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run()  # specified or most recent path
    assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
    with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
        opt = argparse.Namespace(**yaml.safe_load(f))  # replace
    opt.cfg, opt.weights, opt.resume = '', ckpt, True  # reinstate
    LOGGER.info(f'Resuming training from {ckpt}')
else:
    opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp)  # check files
    assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
    if opt.evolve:
        opt.project = 'runs/evolve'
        opt.exist_ok = opt.resume
    opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))

# DDP mode
device = select_device(opt.device, batch_size=opt.batch_size)
if LOCAL_RANK != -1:
    from datetime import timedelta
    assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
    assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
    assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
    assert not opt.evolve, '--evolve argument is not compatible with DDP training'
    torch.cuda.set_device(LOCAL_RANK)
    device = torch.device('cuda', LOCAL_RANK)
    dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
hyp = opt.hyp
callbacks=Callbacks()

YOLOv5 🚀 8615cb4 torch 1.9.0 CPU



[34m[1mtrain: [0mbcc_epoch=5, qtfilter_epoch=-1, qt_thres_mode=, qt_thres=0.0, hybrid_entropy_thres=0.0, hybrid_conf_thres=0.0, weights=yolov5s.pt, cfg=, data=data/single_toy_bcc.yaml, hyp=data/hyps/hyp.scratch.yaml, epochs=10, batch_size=20, imgsz=640, rect=False, resume=False, nosave=False, noval=False, noautoanchor=False, evolve=None, bucket=, cache=None, image_weights=False, device=, multi_scale=False, single_cls=False, adam=False, sync_bn=False, workers=0, project=runs/train, entity=None, name=exp, exist_ok=False, quad=False, linear_lr=False, label_smoothing=0.0, upload_dataset=False, bbox_interval=-1, save_period=-1, artifact_alias=latest, local_rank=-1, freeze=0, patience=1100


## Train

In [19]:
torchMode = True
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
    Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
    opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
vol_id_map = SINGLE_VOL_ID_MAP if 'single' in data else VOL_ID_MAP

bcc_epoch = opt.bcc_epoch
qtfilter_epoch = opt.qtfilter_epoch
qt_thres_mode = opt.qt_thres_mode
qt_thres = opt.qt_thres
hybrid_entropy_thres = opt.hybrid_entropy_thres
hybrid_conf_thres = opt.hybrid_conf_thres
# if not opt.bcc:
#     bcc_epoch = -1
# Directories
w = save_dir / 'weights'  # weights dir
w.mkdir(parents=True, exist_ok=True)  # make dir
last, best = w / 'last.pt', w / 'best.pt'

# Hyperparameters
if isinstance(hyp, str):
    with open(hyp) as f:
        hyp = yaml.safe_load(f)  # load hyps dict
LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))

# Save run settings
with open(save_dir / 'hyp.yaml', 'w') as f:
    yaml.safe_dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f:
    yaml.safe_dump(vars(opt), f, sort_keys=False)
data_dict = {
    'path': data_dict_path,  # dataset root dir
    'train': 'images/train',  # train images (relative to 'path') 128 images
    'val': 'images/val',  # val images (relative to 'path') 128 images
    'test': 'images/test', # test images (optional)
    'nc': 2,  # number of classes
    'names': ['bone-loss', 'dental-caries']  # class names
}
for x in ['train', 'val', 'test']:
    data_dict[x] = os.path.join(data_dict['path'], data_dict[x])

# Loggers
if RANK in [-1, 0]:
    loggers = Loggers(save_dir, weights, opt, hyp, LOGGER)  # loggers instance
    if loggers.wandb:
        data_dict = loggers.wandb.data_dict
        if resume:
            weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp

    # Register actions
    for k in methods(loggers):
        callbacks.register_action(k, callback=getattr(loggers, k))

# Config
plots = not evolve  # create plots
cuda = device.type != 'cpu'
init_seeds(1 + RANK)
with torch_distributed_zero_first(RANK):
    data_dict = data_dict or check_dataset(data)  # check if None
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc'])  # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names']  # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}'  # check
is_coco = data.endswith('coco.yaml') and nc == 80  # COCO dataset

# Model
pretrained = weights.endswith('.pt')
if pretrained:
    with torch_distributed_zero_first(RANK):
        weights = attempt_download(weights)  # download if not found locally
    ckpt = torch.load(weights, map_location=device)  # load checkpoint
    model = Model(cfg or ckpt['model'].yaml, ch=1, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
    exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
    csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
    csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # intersect
    model.load_state_dict(csd, strict=False)  # load
    LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}')  # report
else:
    model = Model(cfg, ch=1, nc=nc, anchors=hyp.get('anchors')).to(device)  # create

# Freeze
freeze = [f'model.{x}.' for x in range(freeze)]  # layers to freeze
for k, v in model.named_parameters():
    v.requires_grad = True  # train all layers
    if any(x in k for x in freeze):
        print(f'freezing {k}')
        v.requires_grad = False

# Optimizer
nbs = 64  # nominal batch size
accumulate = max(round(nbs / batch_size), 1)  # accumulate loss before optimizing
hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")

g0, g1, g2 = [], [], []  # optimizer parameter groups
for v in model.modules():
    if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):  # bias
        g2.append(v.bias)
    if isinstance(v, nn.BatchNorm2d):  # weight (no decay)
        g0.append(v.weight)
    elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):  # weight (with decay)
        g1.append(v.weight)

if opt.adam:
    optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))  # adjust beta1 to momentum
else:
    optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']})  # add g1 with weight_decay
optimizer.add_param_group({'params': g2})  # add g2 (biases)
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
            f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
del g0, g1, g2

# Scheduler
if opt.linear_lr:
    lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf']  # linear
else:
    lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)  # plot_lr_scheduler(optimizer, scheduler, epochs)

# EMA
ema = ModelEMA(model) if RANK in [-1, 0] else None

# Resume
start_epoch, best_fitness = 0, 0.0
if pretrained:
    # Optimizer
    if ckpt['optimizer'] is not None:
        optimizer.load_state_dict(ckpt['optimizer'])
        best_fitness = ckpt['best_fitness']

    # EMA
    if ema and ckpt.get('ema'):
        ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
        ema.updates = ckpt['updates']

    # Epochs
    start_epoch = ckpt['epoch'] + 1
    if resume:
        assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
    if epochs < start_epoch:
        LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
        epochs += ckpt['epoch']  # finetune additional epochs

    del ckpt, csd

# Image sizes
gs = max(int(model.stride.max()), 32)  # grid size (max stride)
nl = model.model[-1].nl  # number of detection layers (used for scaling hyp['obj'])
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2)  # verify imgsz is gs-multiple

# DP mode
if cuda and RANK == -1 and torch.cuda.device_count() > 1:
    logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
                    'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
    model = torch.nn.DataParallel(model)

# SyncBatchNorm
if opt.sync_bn and cuda and RANK != -1:
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
    LOGGER.info('Using SyncBatchNorm()')

# Trainloader
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
                                          hyp=hyp, augment=False, cache=opt.cache, rect=opt.rect, rank=RANK,
                                          workers=workers, image_weights=opt.image_weights, quad=opt.quad,
                                          prefix=colorstr('train: '))
if bcc_epoch != -1:
    file_volunteers_dict = get_file_volunteers_dict(data_dict)

n_grid_choices, n_anchor_choices = model.model[-1].nl, model.model[-1].na
grid_ratios = model.model[-1].stride.cpu().detach().numpy() / imgsz
# if bcc_epoch != -1:
#     cstargets_all = read_crowdsourced_labels(data)
#     cstargets_union = find_union_cstargets(cstargets_all['train'])
#     cstargets_all_bcc = convert_cs_yolo2bcc(cstargets_all, n_anchor_choices, nc, grid_ratios)
#     cstargets_bcc = torch.tensor(cstargets_all_bcc['train']) if torchMode else cstargets_all_bcc['train']
#     print("*** GPU Usage after reading crowdsourced labels")
#     gpu_usage()
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max())  # max label class
nb = len(train_loader)  # number of batches
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'

# Process 0
if RANK in [-1, 0]:
    val_loader, val_dataset = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
                                   hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
                                   workers=workers, pad=0.5,
                                   prefix=colorstr('val: '))

    if not resume:
        labels = np.concatenate(dataset.labels, 0)
        if plots:
            plot_labels(labels, names, save_dir)

        # Anchors
        if not opt.noautoanchor:
            check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
        model.half().float()  # pre-reduce anchor precision

    callbacks.on_pretrain_routine_end()

# DDP mode
if cuda and RANK != -1:
    model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

# Model parameters
hyp['box'] *= 3. / nl  # scale to layers
hyp['cls'] *= nc / 80. * 3. / nl  # scale to classes and layers
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl  # scale to image size and layers
hyp['label_smoothing'] = opt.label_smoothing
model.nc = nc  # attach number of classes to model
model.hyp = hyp  # attach hyperparameters to model
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
model.names = names

# Start training
t0 = time.time()
nw = max(round(hyp['warmup_epochs'] * nb), 1000)  # number of warmup iterations, max(3 epochs, 1k iterations)
last_opt_step = -1
maps = np.zeros(nc)  # mAP per class
results = (0, 0, 0, 0, 0, 0, 0)  # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1  # do not move
scaler = amp.GradScaler(enabled=cuda)
stopper = EarlyStopping(patience=opt.patience)
compute_loss = ComputeLoss(model)  # init loss class
if bcc_epoch != -1:
    bcc_params = init_bcc_params(K=len(vol_id_map))
    bcc_params['n_epoch'] = epochs
    batch_pcm = {k: torch.tensor(v).to(device) if torchMode else v for k, v in compute_param_confusion_matrices(bcc_params).items()}
    # pred0_bcc = init_nn_output(dataset.n, grid_ratios, n_anchor_choices, bcc_params)
    bcc_metrics = init_metrics(bcc_params['n_epoch'])
LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
            f'Using {train_loader.num_workers} dataloader workers\n'
            f"Logging results to {colorstr('bold', save_dir)}\n"
            f'Starting training for {epochs} epochs...')

times = defaultdict(float)
epoch_times = {epoch: defaultdict(float) for epoch in range(start_epoch, epochs)}
batch_times = {i: defaultdict(float) for i in range(1 + np.max(dataset.batch))}
epoch_batch_times = {epoch: {i: defaultdict(float) for i in batch_times} for epoch in epoch_times}
old_lb = float('Inf')

[34m[1mhyperparameters: [0mlr0=0.01, lrf=0.2, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=0.05, cls=0.5, cls_pw=10.0, obj=1.0, obj_pw=1.0, iou_t=0.2, anchor_t=4.0, fl_gamma=0, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, mosaic=1.0, mixup=0.0, copy_paste=0.0
[34m[1mTensorBoard: [0mStart with 'tensorboard --logdir runs/train', view at http://localhost:6006/
Overriding model.yaml nc=80 with nc=2

                 from  n    params  module                                  arguments                     
  0                -1  1      3520  models.common.Focus                     [3, 32, 3]                    
  1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]                
  2                -1  1     18816  models.common.C3                        [64, 64, 1]                   
  3                -1  1     7

Plotting labels... 

[34m[1mautoanchor: [0mAnalyzing anchors... anchors/target = 5.96, Best Possible Recall (BPR) = 1.0000


Image sizes 640 train, 640 val
Using 0 dataloader workers
Logging results to [1mruns/train/exp42[0m
Starting training for 10 epochs...


In [20]:
DATA = {}
for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
    epoch_key = f'epoch_{epoch}'
    DATA[epoch_key] = {}
    bcc_flag = False if bcc_epoch == -1 else (epoch - start_epoch >= bcc_epoch)
    qtfilter_flag = False if qtfilter_epoch == -1 else (epoch - start_epoch >= qtfilter_epoch)
    if not bcc_flag and qtfilter_flag:
        qtfilter_flag = False
    LBs = []
    model.train()
    # Update image weights (optional, single-GPU only)
    if opt.image_weights:
        cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc  # class weights
        iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  # image weights
        dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # rand weighted idx

    # Update mosaic border (optional)

    mloss = torch.zeros(3, device=device)  # mean losses
    if RANK != -1:
        train_loader.sampler.set_epoch(epoch)
    pbar = enumerate(train_loader)
    LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
    LOGGER.info('\n' + ('YOLO+BCC' if bcc_flag else 'Only YOLO'))
    if RANK in [-1, 0]:
        pbar = tqdm(pbar, total=nb)  # progress bar
    optimizer.zero_grad()

    for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------
        batch_key = f'batch_{i}'
        DATA[epoch_key][batch_key] = {}
        DATA[epoch_key][batch_key]['imgs_unn'] = imgs.cpu().detach().numpy()
        DATA[epoch_key][batch_key]['targets'] = targets.cpu().detach().numpy()
        if bcc_epoch != -1:
            batch_filenames = [dataset.label_files[x].split(os.sep)[-1] for x in np.where(dataset.batch==i)[0]]
            batch_volunteers_list = [file_volunteers_dict[fn] for fn in batch_filenames]
            batch_volunteers = torch.cat(batch_volunteers_list)
            target_volunteers = torch.cat([targets, batch_volunteers.unsqueeze(-1)], axis=1)
            batch_size = np.where(dataset.batch==i)[0].shape[0]
            target_volunteers_bcc, vigcwh = convert_target_volunteers_yolo2bcc(target_volunteers, n_anchor_choices, nc, grid_ratios, batch_size, vol_id_map)
            # batch_cstargets_bcc = (cstargets_bcc[dataset.batch == i]).to(device)
            DATA[epoch_key][batch_key]['target_volunteers_bcc'] = target_volunteers_bcc.cpu().detach().numpy()
            DATA[epoch_key][batch_key]['batch_volunteers'] = batch_volunteers.cpu().detach().numpy()
            DATA[epoch_key][batch_key]['target_volunteers'] = target_volunteers.cpu().detach().numpy()
        ni = i + nb * epoch  # number integrated batches (since train start)
        DATA[epoch_key][batch_key]['imgs_unn'] = imgs.cpu().detach().numpy()
        DATA[epoch_key][batch_key]['targets'] = targets.cpu().detach().numpy()
        imgs = imgs.to(device, non_blocking=True).float() / 255.0  # uint8 to float32, 0-255 to 0.0-1.0

        # Warmup
        if ni <= nw:
            xi = [0, nw]  # x interp
            accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
            for j, x in enumerate(optimizer.param_groups):
                # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
                if 'momentum' in x:
                    x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

        # Multi-scale
        if opt.multi_scale:
            sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs  # size
            sf = sz / max(imgs.shape[2:])  # scale factor
            if sf != 1:
                ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to gs-multiple)
                imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

        # Forward
        with amp.autocast(enabled=cuda):
            pred = model(imgs)
            DATA[epoch_key][batch_key]['pred'] = [temp.cpu().detach().numpy() for temp in pred]
            if bcc_flag:
                model.eval()
                batch_pred_yolo = nn_predict(model, imgs, imgsz, transform_format_flag=False) # y_hat_yolo
                DATA[epoch_key][batch_key]['batch_pred_yolo'] = batch_pred_yolo.cpu().detach().numpy()
                
                batch_pred_bcc, _, batch_conf = yolo2bcc_newer(batch_pred_yolo, imgsz, silent=False) # y_hat_bcc
                DATA[epoch_key][batch_key]['batch_pred_bcc'] = batch_pred_bcc.cpu().detach().numpy()
                
                batch_qtargets, batch_pcm['variational'], batch_lb = VBi_yolo(target_volunteers_bcc, batch_pred_bcc, batch_pcm['variational'], batch_pcm['prior'], torchMode = torchMode, device=device)
                DATA[epoch_key][batch_key]['batch_qtargets'] = batch_qtargets.cpu().detach().numpy()
                DATA[epoch_key][batch_key]['batch_pcm_var'] = batch_pcm['variational'].cpu().detach().numpy()
                DATA[epoch_key][batch_key]['batch_lb'] = batch_lb.cpu().detach().numpy()

                LBs.append(batch_lb)
                with torch.no_grad():
                    batch_qtargets_yolo = qt2yolo_optimized(batch_qtargets, grid_ratios, n_anchor_choices, vigcwh, torchMode = torchMode, device=device).half().float()
                    DATA[epoch_key][batch_key]['batch_qtargets_yolo_full'] = batch_qtargets_yolo.cpu().detach().numpy()
                    batch_qtargets_yolo = batch_qtargets_yolo[batch_qtargets_yolo[:,1] != BACKGROUND_CLASS_ID, :]
                    DATA[epoch_key][batch_key]['batch_qtargets_yolo'] = batch_qtargets_yolo.cpu().detach().numpy()
                model.train()
                loss, loss_items = compute_loss(pred, batch_qtargets_yolo)
                DATA[epoch_key][batch_key]['loss'] = loss.cpu().detach().numpy()
                DATA[epoch_key][batch_key]['loss_items'] = loss_items.cpu().detach().numpy()
            else:
                loss, loss_items = compute_loss(pred, targets.to(device))
                DATA[epoch_key][batch_key]['loss'] = loss.cpu().detach().numpy()
                DATA[epoch_key][batch_key]['loss_items'] = loss_items.cpu().detach().numpy()
            if RANK != -1:
                loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode
                DATA[epoch_key][batch_key]['loss*WORLD_SIZE'] = loss.cpu().detach().numpy()
            if opt.quad:
                loss *= 4.
                DATA[epoch_key][batch_key]['loss*4'] = loss.cpu().detach().numpy()

        # Backward
        scaler.scale(loss).backward()

        # Optimize
        if ni - last_opt_step >= accumulate:
            scaler.step(optimizer)  # optimizer.step
            scaler.update()
            optimizer.zero_grad()
            if ema:
                ema.update(model)
            last_opt_step = ni

        # Log
        if RANK in [-1, 0]:
            mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
            DATA[epoch_key][batch_key]['mloss'] = mloss.cpu().detach().numpy()
            mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # (GB)
            pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
                f'{epoch}/{epochs - 1}', mem, *mloss, (targets).shape[0], imgs.shape[-1]))
            callbacks.on_train_batch_end(ni, model, imgs, (targets), paths, plots, opt.sync_bn)
        # end batch ------------------------------------------------------------------------------------------------

    # Scheduler
    lr = [x['lr'] for x in optimizer.param_groups]  # for loggers
    
    scheduler.step()

    if RANK in [-1, 0]:
        # mAP
        callbacks.on_train_epoch_end(epoch=epoch)
        ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
        final_epoch = epoch + 1 == epochs
        if not noval or final_epoch:  # Calculate mAP
            results, maps, _ = val.run(data_dict,
                                       batch_size=batch_size // WORLD_SIZE * 2,
                                       imgsz=imgsz,
                                       model=ema.ema,
                                       single_cls=single_cls,
                                       dataloader=val_loader,
                                       save_dir=save_dir,
                                       save_json=is_coco and final_epoch,
                                       verbose=nc < 50 and final_epoch,
                                       plots=plots and final_epoch,
                                       callbacks=callbacks,
                                       compute_loss=compute_loss)
            # yhat_train = pred
            # y_train = dataset.labels
            # yhat_test = results
            # y_test = val_dataset.labels
            # update_bcc_metrics(bcc_metrics, qtargets, yhat_train, y_train, yhat_test, y_test, epoch)
            DATA[epoch_key][batch_key]['results'] = results
        # Update best mAP
        fi = fitness(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
        if fi > best_fitness:
            best_fitness = fi
        log_vals = list(mloss) + list(results) + lr
        callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi)

        # Save model
        if (not nosave) or (final_epoch and not evolve):  # if save
            ckpt = {'epoch': epoch,
                    'best_fitness': best_fitness,
                    'model': deepcopy(de_parallel(model)).half(),
                    'ema': deepcopy(ema.ema).half(),
                    'updates': ema.updates,
                    'optimizer': optimizer.state_dict(),
                    'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}

            # Save last, best and delete
            torch.save(ckpt, last)
            if best_fitness == fi:
                torch.save(ckpt, best)
            del ckpt
            callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)

        # Stop Single-GPU
        if stopper(epoch=epoch, fitness=fi):
            break
    try:
        lb = LBs[-1]
        if epoch > start_epoch and torch.abs((lb - old_lb) / old_lb) < bcc_params['convergence_threshold']:
            print('Convergence reached!')
            break
        old_lb = lb
    except IndexError:
        pass

    try:

        del batch_pred_yolo, batch_pred_bcc, batch_pred_yolo_wh, batch_qtargets, batch_qtargets_yolo
    except (UnboundLocalError, NameError):
        pass


#     torch.cuda.empty_cache()

    # end epoch ----------------------------------------------------------------------------------------------------

# end training -----------------------------------------------------------------------------------------------------


     Epoch   gpu_mem       box       obj       cls    labels  img_size

Only YOLO
       0/9        0G    0.1308   0.03327    0.2097        26       640: 100%|██████████| 1/1 [00:19<00:00, 19.41s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.74s/it]


                 all          2          9    0.00343      0.081     0.0014    0.00014



     Epoch   gpu_mem       box       obj       cls    labels  img_size

Only YOLO
       1/9        0G    0.1306   0.03325    0.2048        26       640: 100%|██████████| 1/1 [00:07<00:00,  7.42s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.38s/it]


                 all          2          9    0.00281      0.125    0.00136   0.000136



     Epoch   gpu_mem       box       obj       cls    labels  img_size

Only YOLO
       2/9        0G    0.1303   0.03325    0.1965        26       640: 100%|██████████| 1/1 [00:07<00:00,  7.43s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.45s/it]


                 all          2          9    0.00568      0.125    0.00341   0.000732



     Epoch   gpu_mem       box       obj       cls    labels  img_size

Only YOLO
       3/9        0G    0.1301   0.03327    0.1862        26       640: 100%|██████████| 1/1 [00:07<00:00,  7.44s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.43s/it]


                 all          2          9    0.00505      0.125    0.00346   0.000755



     Epoch   gpu_mem       box       obj       cls    labels  img_size

Only YOLO
       4/9        0G    0.1298    0.0333    0.1766        26       640: 100%|██████████| 1/1 [00:07<00:00,  7.46s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.40s/it]


                 all          2          9    0.00514      0.225    0.00361   0.000789



     Epoch   gpu_mem       box       obj       cls    labels  img_size

YOLO+BCC
  0%|          | 0/1 [00:00<?, ?it/s]

Minimum probs (c1, c2, bkgd): [3e-05, 2e-05, 0.925257]
Maximum probs (c1, c2, bkgd): [0.034463, 0.040424, 0.999858]


       5/9        0G         0   0.02288         0        26       640: 100%|██████████| 1/1 [00:10<00:00, 10.03s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.36s/it]


                 all          2          9    0.00501      0.225    0.00359   0.000821



     Epoch   gpu_mem       box       obj       cls    labels  img_size

YOLO+BCC
  0%|          | 0/1 [00:00<?, ?it/s]

Minimum probs (c1, c2, bkgd): [3.1e-05, 1.8e-05, 0.926873]
Maximum probs (c1, c2, bkgd): [0.033895, 0.03944, 0.999868]


       6/9        0G         0   0.02265         0        26       640: 100%|██████████| 1/1 [00:10<00:00, 10.07s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.34s/it]


                 all          2          9    0.00521      0.181    0.00365   0.000866



     Epoch   gpu_mem       box       obj       cls    labels  img_size

YOLO+BCC
  0%|          | 0/1 [00:00<?, ?it/s]

Minimum probs (c1, c2, bkgd): [3.2e-05, 1.6e-05, 0.928628]
Maximum probs (c1, c2, bkgd): [0.033494, 0.038802, 0.999876]


       7/9        0G    0.1345   0.02278    0.1918        26       640: 100%|██████████| 1/1 [00:09<00:00,  9.97s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.39s/it]


                 all          2          9    0.00697      0.225    0.00393   0.000963



     Epoch   gpu_mem       box       obj       cls    labels  img_size

YOLO+BCC
  0%|          | 0/1 [00:00<?, ?it/s]

Minimum probs (c1, c2, bkgd): [3.4e-05, 1.5e-05, 0.929678]
Maximum probs (c1, c2, bkgd): [0.033186, 0.038349, 0.99988]


       8/9        0G    0.1286   0.02901    0.2172        26       640: 100%|██████████| 1/1 [00:10<00:00, 10.89s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.30s/it]


                 all          2          9    0.00667      0.225    0.00409    0.00102



     Epoch   gpu_mem       box       obj       cls    labels  img_size

YOLO+BCC
  0%|          | 0/1 [00:00<?, ?it/s]

Minimum probs (c1, c2, bkgd): [3.4e-05, 1.2e-05, 0.932384]
Maximum probs (c1, c2, bkgd): [0.032546, 0.037451, 0.999885]


       9/9        0G    0.1295   0.03476    0.2165        26       640: 100%|██████████| 1/1 [00:10<00:00, 10.30s/it]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 1/1 [00:01<00:00,  1.35s/it]


                 all          2          9     0.0116       0.35    0.00975    0.00143
           bone-loss          2          5     0.0125        0.2    0.00469   0.000937
       dental-caries          2          4     0.0107        0.5     0.0148    0.00192


In [21]:
if bcc_epoch == -1:
    YOLO_DATA = DATA
elif bcc_epoch == 0:
    YOLOBCC_DATA = DATA
else:
    YOLO_PT_BCC_DATA = DATA

In [25]:
len(YOLO_PT_BCC_DATA)

10

In [26]:
import pickle
# pickle.dump((YOLOBCC_DATA, YOLO_DATA, YOLO_PT_BCC_DATA), open('SANITY_CHECK_DATA.pkl', 'wb'))