In [1]:
import yaml
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import torch.nn as nn
import sys
sys.path.append('/data/ouyuan/yolov5')
from tqdm import tqdm
from utils.dataloaders import create_dataloader
from moe import ModelSuffix, ModelName
from pathlib import Path
import copy
import numpy as np
import random
from utils.general import (
    LOGGER,
    TQDM_BAR_FORMAT,
    Profile,
    check_amp,
    check_dataset,
    check_file,
    check_git_info,
    check_git_status,
    check_img_size,
    check_requirements,
    check_suffix,
    check_yaml,
    colorstr,
    get_latest_run,
    increment_path,
    init_seeds,
    intersect_dicts,
    labels_to_class_weights,
    labels_to_image_weights,
    methods,
    one_cycle,
    print_args,
    print_mutation,
    strip_optimizer,
    yaml_save,
    increment_path,
    non_max_suppression,
    scale_boxes,
    xywh2xyxy,
    xyxy2xywh,
)
from utils.torch_utils import (
    EarlyStopping,
    ModelEMA,
    de_parallel,
    select_device,
    smart_DDP,
    smart_optimizer,
    smart_resume,
    torch_distributed_zero_first,
)
from utils.loss import ComputeLoss
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
from utils.plots import output_to_target, plot_images, plot_val_study

In [2]:
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))
RANK = int(os.getenv("RANK", -1))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
torch.backends.cudnn.enabled = False
data = 'cityscapes.yaml'
imgsz = 640
batch_size = 8
gs = 32
workers = 8
label_smoothing = 0.0
seed = 19260817
single_cls = False
epochs = 10
nc = 80
data = check_yaml(data)
hyp = '../data/hyps/hyp.no-augmentation.yaml'
with open(hyp, errors="ignore") as f:
    hyp = yaml.safe_load(f)
data_dict = check_dataset(data)
train_path, val_path = data_dict["train"], data_dict["val"]
train_loader, dataset = create_dataloader(
        train_path,
        imgsz,
        batch_size,
        gs,
        single_cls,
        hyp=hyp,
        rank=LOCAL_RANK,
        prefix=colorstr("train: "),
        shuffle=True,
        seed=seed,
    )

device = 'cuda' if torch.cuda.is_available() else 'cpu'

[34m[1mtrain: [0mScanning /data/ouyuan/datasets/cityscapes/train.cache... 2975 images, 7 backgrounds, 0 corrupt: 100%|██████████| 2975/2975 [00:00<?, ?it/s]


In [3]:
experts_list = []
optimizer= []
optimizer_type = 'Adam'
for i in range(len(ModelSuffix)):
        model_name = str.replace(ModelName, 'tmp', ModelSuffix[i])
        #model = Model(os.path.join('../models', model_name + '.yaml'), ch=3, nc=80)
        print(f'Adding expert {model_name}')
        model = torch.load(os.path.join('..', model_name + '.pt'), map_location="cpu")["model"].float()
        experts_list.append(model)
        freeze = 10 if i < 5 else 12
        freeze = [f'model.{x}.' for x in range(freeze)]  # layers to freeze
        for k, v in experts_list[i].named_parameters():
                v.requires_grad = True  # train all layers
                if any(x in k for x in freeze):
                        print(f'freezing {model_name}\'s {k}')
                        v.requires_grad = False
                
        optimizer.append(smart_optimizer(experts_list[i], optimizer_type, hyp["lr0"], hyp["momentum"], hyp["weight_decay"]))
        experts_list[i].to(device)
        
nl = de_parallel(experts_list[0]).model[-1].nl  # number of detection layers (to scale hyps)
hyp["box"] *= 3 / nl  # scale to layers
hyp["cls"] *= nc / 80 * 3 / nl  # scale to classes and layers
hyp["obj"] *= (imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
hyp["label_smoothing"] = label_smoothing
experts_list[0].nc = nc  # attach number of classes to model
experts_list[0].hyp = hyp  # attach hyperparameters to model
compute_loss1 = ComputeLoss(experts_list[0])

if len(experts_list) > 5:
        experts_list[5].nc = nc  # attach number of classes to model
        experts_list[5].hyp = hyp  # attach hyperparameters to model
        compute_loss2 = ComputeLoss(experts_list[5])
        
for i in range(len(experts_list)):
        experts_list[i].cpu()

[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 57 weight(decay=0.0), 60 weight(decay=0.0005), 60 bias


Adding expert yolov5n
freezing yolov5n's model.0.conv.weight
freezing yolov5n's model.0.bn.weight
freezing yolov5n's model.0.bn.bias
freezing yolov5n's model.1.conv.weight
freezing yolov5n's model.1.bn.weight
freezing yolov5n's model.1.bn.bias
freezing yolov5n's model.2.cv1.conv.weight
freezing yolov5n's model.2.cv1.bn.weight
freezing yolov5n's model.2.cv1.bn.bias
freezing yolov5n's model.2.cv2.conv.weight
freezing yolov5n's model.2.cv2.bn.weight
freezing yolov5n's model.2.cv2.bn.bias
freezing yolov5n's model.2.cv3.conv.weight
freezing yolov5n's model.2.cv3.bn.weight
freezing yolov5n's model.2.cv3.bn.bias
freezing yolov5n's model.2.m.0.cv1.conv.weight
freezing yolov5n's model.2.m.0.cv1.bn.weight
freezing yolov5n's model.2.m.0.cv1.bn.bias
freezing yolov5n's model.2.m.0.cv2.conv.weight
freezing yolov5n's model.2.m.0.cv2.bn.weight
freezing yolov5n's model.2.m.0.cv2.bn.bias
freezing yolov5n's model.3.conv.weight
freezing yolov5n's model.3.bn.weight
freezing yolov5n's model.3.bn.bias
freezi

[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 57 weight(decay=0.0), 60 weight(decay=0.0005), 60 bias


Adding expert yolov5s
freezing yolov5s's model.0.conv.weight
freezing yolov5s's model.0.bn.weight
freezing yolov5s's model.0.bn.bias
freezing yolov5s's model.1.conv.weight
freezing yolov5s's model.1.bn.weight
freezing yolov5s's model.1.bn.bias
freezing yolov5s's model.2.cv1.conv.weight
freezing yolov5s's model.2.cv1.bn.weight
freezing yolov5s's model.2.cv1.bn.bias
freezing yolov5s's model.2.cv2.conv.weight
freezing yolov5s's model.2.cv2.bn.weight
freezing yolov5s's model.2.cv2.bn.bias
freezing yolov5s's model.2.cv3.conv.weight
freezing yolov5s's model.2.cv3.bn.weight
freezing yolov5s's model.2.cv3.bn.bias
freezing yolov5s's model.2.m.0.cv1.conv.weight
freezing yolov5s's model.2.m.0.cv1.bn.weight
freezing yolov5s's model.2.m.0.cv1.bn.bias
freezing yolov5s's model.2.m.0.cv2.conv.weight
freezing yolov5s's model.2.m.0.cv2.bn.weight
freezing yolov5s's model.2.m.0.cv2.bn.bias
freezing yolov5s's model.3.conv.weight
freezing yolov5s's model.3.bn.weight
freezing yolov5s's model.3.bn.bias
freezi

[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 79 weight(decay=0.0), 82 weight(decay=0.0005), 82 bias
[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 101 weight(decay=0.0), 104 weight(decay=0.0005), 104 bias


freezing yolov5m's model.0.conv.weight
freezing yolov5m's model.0.bn.weight
freezing yolov5m's model.0.bn.bias
freezing yolov5m's model.1.conv.weight
freezing yolov5m's model.1.bn.weight
freezing yolov5m's model.1.bn.bias
freezing yolov5m's model.2.cv1.conv.weight
freezing yolov5m's model.2.cv1.bn.weight
freezing yolov5m's model.2.cv1.bn.bias
freezing yolov5m's model.2.cv2.conv.weight
freezing yolov5m's model.2.cv2.bn.weight
freezing yolov5m's model.2.cv2.bn.bias
freezing yolov5m's model.2.cv3.conv.weight
freezing yolov5m's model.2.cv3.bn.weight
freezing yolov5m's model.2.cv3.bn.bias
freezing yolov5m's model.2.m.0.cv1.conv.weight
freezing yolov5m's model.2.m.0.cv1.bn.weight
freezing yolov5m's model.2.m.0.cv1.bn.bias
freezing yolov5m's model.2.m.0.cv2.conv.weight
freezing yolov5m's model.2.m.0.cv2.bn.weight
freezing yolov5m's model.2.m.0.cv2.bn.bias
freezing yolov5m's model.2.m.1.cv1.conv.weight
freezing yolov5m's model.2.m.1.cv1.bn.weight
freezing yolov5m's model.2.m.1.cv1.bn.bias
free

[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 123 weight(decay=0.0), 126 weight(decay=0.0005), 126 bias


freezing yolov5x's model.0.conv.weight
freezing yolov5x's model.0.bn.weight
freezing yolov5x's model.0.bn.bias
freezing yolov5x's model.1.conv.weight
freezing yolov5x's model.1.bn.weight
freezing yolov5x's model.1.bn.bias
freezing yolov5x's model.2.cv1.conv.weight
freezing yolov5x's model.2.cv1.bn.weight
freezing yolov5x's model.2.cv1.bn.bias
freezing yolov5x's model.2.cv2.conv.weight
freezing yolov5x's model.2.cv2.bn.weight
freezing yolov5x's model.2.cv2.bn.bias
freezing yolov5x's model.2.cv3.conv.weight
freezing yolov5x's model.2.cv3.bn.weight
freezing yolov5x's model.2.cv3.bn.bias
freezing yolov5x's model.2.m.0.cv1.conv.weight
freezing yolov5x's model.2.m.0.cv1.bn.weight
freezing yolov5x's model.2.m.0.cv1.bn.bias
freezing yolov5x's model.2.m.0.cv2.conv.weight
freezing yolov5x's model.2.m.0.cv2.bn.weight
freezing yolov5x's model.2.m.0.cv2.bn.bias
freezing yolov5x's model.2.m.1.cv1.conv.weight
freezing yolov5x's model.2.m.1.cv1.bn.weight
freezing yolov5x's model.2.m.1.cv1.bn.bias
free

[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 75 weight(decay=0.0), 79 weight(decay=0.0005), 79 bias
[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 75 weight(decay=0.0), 79 weight(decay=0.0005), 79 bias


freezing yolov5n6's model.0.conv.weight
freezing yolov5n6's model.0.bn.weight
freezing yolov5n6's model.0.bn.bias
freezing yolov5n6's model.1.conv.weight
freezing yolov5n6's model.1.bn.weight
freezing yolov5n6's model.1.bn.bias
freezing yolov5n6's model.2.cv1.conv.weight
freezing yolov5n6's model.2.cv1.bn.weight
freezing yolov5n6's model.2.cv1.bn.bias
freezing yolov5n6's model.2.cv2.conv.weight
freezing yolov5n6's model.2.cv2.bn.weight
freezing yolov5n6's model.2.cv2.bn.bias
freezing yolov5n6's model.2.cv3.conv.weight
freezing yolov5n6's model.2.cv3.bn.weight
freezing yolov5n6's model.2.cv3.bn.bias
freezing yolov5n6's model.2.m.0.cv1.conv.weight
freezing yolov5n6's model.2.m.0.cv1.bn.weight
freezing yolov5n6's model.2.m.0.cv1.bn.bias
freezing yolov5n6's model.2.m.0.cv2.conv.weight
freezing yolov5n6's model.2.m.0.cv2.bn.weight
freezing yolov5n6's model.2.m.0.cv2.bn.bias
freezing yolov5n6's model.3.conv.weight
freezing yolov5n6's model.3.bn.weight
freezing yolov5n6's model.3.bn.bias
free

[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 103 weight(decay=0.0), 107 weight(decay=0.0005), 107 bias


freezing yolov5m6's model.0.conv.weight
freezing yolov5m6's model.0.bn.weight
freezing yolov5m6's model.0.bn.bias
freezing yolov5m6's model.1.conv.weight
freezing yolov5m6's model.1.bn.weight
freezing yolov5m6's model.1.bn.bias
freezing yolov5m6's model.2.cv1.conv.weight
freezing yolov5m6's model.2.cv1.bn.weight
freezing yolov5m6's model.2.cv1.bn.bias
freezing yolov5m6's model.2.cv2.conv.weight
freezing yolov5m6's model.2.cv2.bn.weight
freezing yolov5m6's model.2.cv2.bn.bias
freezing yolov5m6's model.2.cv3.conv.weight
freezing yolov5m6's model.2.cv3.bn.weight
freezing yolov5m6's model.2.cv3.bn.bias
freezing yolov5m6's model.2.m.0.cv1.conv.weight
freezing yolov5m6's model.2.m.0.cv1.bn.weight
freezing yolov5m6's model.2.m.0.cv1.bn.bias
freezing yolov5m6's model.2.m.0.cv2.conv.weight
freezing yolov5m6's model.2.m.0.cv2.bn.weight
freezing yolov5m6's model.2.m.0.cv2.bn.bias
freezing yolov5m6's model.2.m.1.cv1.conv.weight
freezing yolov5m6's model.2.m.1.cv1.bn.weight
freezing yolov5m6's mode

[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 131 weight(decay=0.0), 135 weight(decay=0.0005), 135 bias


freezing yolov5l6's model.0.conv.weight
freezing yolov5l6's model.0.bn.weight
freezing yolov5l6's model.0.bn.bias
freezing yolov5l6's model.1.conv.weight
freezing yolov5l6's model.1.bn.weight
freezing yolov5l6's model.1.bn.bias
freezing yolov5l6's model.2.cv1.conv.weight
freezing yolov5l6's model.2.cv1.bn.weight
freezing yolov5l6's model.2.cv1.bn.bias
freezing yolov5l6's model.2.cv2.conv.weight
freezing yolov5l6's model.2.cv2.bn.weight
freezing yolov5l6's model.2.cv2.bn.bias
freezing yolov5l6's model.2.cv3.conv.weight
freezing yolov5l6's model.2.cv3.bn.weight
freezing yolov5l6's model.2.cv3.bn.bias
freezing yolov5l6's model.2.m.0.cv1.conv.weight
freezing yolov5l6's model.2.m.0.cv1.bn.weight
freezing yolov5l6's model.2.m.0.cv1.bn.bias
freezing yolov5l6's model.2.m.0.cv2.conv.weight
freezing yolov5l6's model.2.m.0.cv2.bn.weight
freezing yolov5l6's model.2.m.0.cv2.bn.bias
freezing yolov5l6's model.2.m.1.cv1.conv.weight
freezing yolov5l6's model.2.m.1.cv1.bn.weight
freezing yolov5l6's mode

[34m[1moptimizer:[0m Adam(lr=0.001) with parameter groups 159 weight(decay=0.0), 163 weight(decay=0.0005), 163 bias


freezing yolov5x6's model.0.conv.weight
freezing yolov5x6's model.0.bn.weight
freezing yolov5x6's model.0.bn.bias
freezing yolov5x6's model.1.conv.weight
freezing yolov5x6's model.1.bn.weight
freezing yolov5x6's model.1.bn.bias
freezing yolov5x6's model.2.cv1.conv.weight
freezing yolov5x6's model.2.cv1.bn.weight
freezing yolov5x6's model.2.cv1.bn.bias
freezing yolov5x6's model.2.cv2.conv.weight
freezing yolov5x6's model.2.cv2.bn.weight
freezing yolov5x6's model.2.cv2.bn.bias
freezing yolov5x6's model.2.cv3.conv.weight
freezing yolov5x6's model.2.cv3.bn.weight
freezing yolov5x6's model.2.cv3.bn.bias
freezing yolov5x6's model.2.m.0.cv1.conv.weight
freezing yolov5x6's model.2.m.0.cv1.bn.weight
freezing yolov5x6's model.2.m.0.cv1.bn.bias
freezing yolov5x6's model.2.m.0.cv2.conv.weight
freezing yolov5x6's model.2.m.0.cv2.bn.weight
freezing yolov5x6's model.2.m.0.cv2.bn.bias
freezing yolov5x6's model.2.m.1.cv1.conv.weight
freezing yolov5x6's model.2.m.1.cv1.bn.weight
freezing yolov5x6's mode

In [4]:
conf_thres = 0.001
iou_thres = 0.6
max_det = 300
iouv = torch.linspace(0.5, 0.95, 10, device=device)  # iou vector for mAP@0.5:0.95
niou = iouv.numel()
seen = 0
plots = True
save_dir = Path('./exp_tmp')
save_dir.mkdir(exist_ok=True)
confusion_matrix = ConfusionMatrix(nc=nc)
names = experts_list[0].names if hasattr(experts_list[0], "names") else experts_list[0].module.names  # get class names
if isinstance(names, (list, tuple)):  # old format
    names = dict(enumerate(names))
task = 'val'
verbose = False
training = False

dt = Profile(device=device), Profile(device=device), Profile(device=device)  # profiling times

In [5]:
experts = []
experts_id = []
pbar = enumerate(train_loader)
LOGGER.info(("\n" + "%11s" * 7) % ("Epoch", "GPU_mem", "box_loss", "obj_loss", "cls_loss", "Instances", "Size"))
if RANK in {-1, 0}:
    pbar = tqdm(pbar, total=len(train_loader), bar_format=TQDM_BAR_FORMAT)  # progress bar

mloss = torch.zeros(10, 3, device=device)  # mean losses
cnt = torch.zeros(10, device=device)  # mean losses

for batch_i, (imgs, targets, paths, _) in pbar:  # window -------------------------------------------------------------
    id = random.randint(0, len(experts_list) - 1)
    experts_id.append(id)
    experts.append(copy.deepcopy(experts_list[id]).cpu())
    torch.cuda.empty_cache()
    experts_list[id].to(device)
    experts_list[id].train()
    
    if batch_i == len(train_loader) - 1:
        continue
    
    imgs = imgs.to(device, non_blocking=True).float() / 255  # uint8 to float32, 0-255 to 0.0-1.0
    targets = targets.to(device)
    
    optimizer[id].zero_grad()
    for epoch in range(epochs):  # retrain ------------------------------------------------------------------
        # Update mosaic border (optional)
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        
        # if RANK != -1:
        #     train_loader.sampler.set_epoch(epoch)
        
        # Forward
        #with torch.cuda.amp.autocast(amp):
        pred = experts_list[id](imgs)  # forward
        # print(len(pred))
        if id < 5:
            loss, loss_items = compute_loss1(pred, targets)  # loss scaled by batch_size
        else:
            loss, loss_items = compute_loss2(pred, targets)  # loss scaled by batch_size
            
        if RANK != -1:
            loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode
        
        if torch.any(torch.isnan(loss_items)):
            torch.cuda.empty_cache()
            nb, _, height, width = imgs.shape
            experts_list[id].eval()
            preds, train_out = experts_list[id](imgs)
            preds = preds.detach()
            targets[:, 2:] *= torch.tensor((width, height, width, height), device=device)  # to pixels
            lb = [] # for autolabelling
            preds = non_max_suppression(
                preds, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls, max_det=max_det
            )
            plot_images(imgs, targets, paths, save_dir / f"val_batch{batch_i}_id{id}_labels.jpg", names)  # labels
            plot_images(imgs, output_to_target(preds), paths, save_dir / f"val_batch{batch_i}_id{id}_pred.jpg", names)  # pred
            assert False
            
        
        #print(loss, latency)
        # Backward
        #scaler.scale(loss).backward()
        optimizer[id].zero_grad()
        loss.backward()
        optimizer[id].step()
        # Log
        if RANK in {-1, 0}:
            mloss[id] = (mloss[id] * cnt[id] + loss_items) / (cnt[id] + 1)  # update mean losses
            mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G"  # (GB)
            pbar.set_description(
                ("%11s" * 2 + "%11.4g" * 5)
                % (f"{epoch}/{epochs - 1}", mem, *mloss[id], targets.shape[0], imgs.shape[-1])
            )
            cnt[id] += 1
        # end retrain ------------------------------------------------------------------------------------------------
    experts_list[id].cpu()
    imgs, targets = imgs.cpu(), targets.cpu()
    # end window ----------------------------------------------------------------------------------------------------
# end training -----------------------------------------------------------------------------------------------------


      Epoch    GPU_mem   box_loss   obj_loss   cls_loss  Instances       Size
        9/9       3.1G    0.05829    0.03521   0.007979        152        640: 100%|██████████| 372/372 [17:18<00:00,  2.79s/it]


In [6]:
print(len(experts))
for i in range(len(experts_list)):
    experts_list[i].cpu()
torch.cuda.empty_cache()
save_tmp = './exp_tmp'
save_dir = './exp_1'
cnt = 0
while(1):
    cnt += 1
    save_dir = save_tmp.replace('tmp', f'{cnt}')
    # print(f'Testing {save_dir}')
    if not os.path.exists(save_dir):
        print(f'Saving experts to {save_dir}')
        os.makedirs(save_dir)
        break

for i, expert in enumerate(experts):
    expert_dir = os.path.join(save_dir, 'random')
    if not os.path.exists(expert_dir):
        os.makedirs(expert_dir)
    torch.save(expert, os.path.join(expert_dir, 'expert_%d_id_%d.ckpt'%(i, experts_id[i])))

372
Saving experts to ./exp_2


In [7]:
def process_batch(detections, labels, iouv):
    """
    Return correct prediction matrix.

    Arguments:
        detections (array[N, 6]), x1, y1, x2, y2, conf, class
        labels (array[M, 5]), class, x1, y1, x2, y2
    Returns:
        correct (array[N, 10]), for 10 IoU levels
    """
    correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
    iou = box_iou(labels[:, 1:], detections[:, :4])
    correct_class = labels[:, 0:1] == detections[:, 5]
    for i in range(len(iouv)):
        x = torch.where((iou >= iouv[i]) & correct_class)  # IoU > threshold and classes match
        if x[0].shape[0]:
            matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()  # [label, detect, iou]
            if x[0].shape[0] > 1:
                matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                # matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
            correct[matches[:, 1].astype(int), i] = True
    return torch.tensor(correct, dtype=torch.bool, device=iouv.device)

In [8]:
jdict, stats, ap, ap_class = [], [], [], []
save_dir = Path(save_dir)
loss = torch.zeros(3, device=device)
pbar = enumerate(train_loader)
LOGGER.info(("\n" + "%11s" * 7) % ("Epoch", "GPU_mem", "box_loss", "obj_loss", "cls_loss", "Instances", "Size"))
if RANK in {-1, 0}:
    pbar = tqdm(pbar, total=len(train_loader), bar_format=TQDM_BAR_FORMAT)  # progress bar
    
for batch_i, (im, targets, paths, shapes) in pbar:
    experts[batch_i].to(device)
    experts[batch_i].eval()
    with dt[0]:
        im = im.to(device, non_blocking=True)
        targets = targets.to(device)
        im = im.float()  # uint8 to fp16/32
        im /= 255  # 0 - 255 to 0.0 - 1.0
        bs, _, height, width = im.shape  # batch size, channels, height, width

    with dt[1]:
        preds, train_out = experts[batch_i](im)
    preds = preds.detach()

    id = experts_id[batch_i]
    # Loss
    if id < 5:
        loss += compute_loss1(train_out, targets)[1]  # box, obj, cls
    else:
        loss += compute_loss2(train_out, targets)[1]  # box, obj, cls

    # NMS
    targets[:, 2:] *= torch.tensor((width, height, width, height), device=device)  # to pixels
    lb = [] # for autolabelling

    with dt[2]:
        preds = non_max_suppression(
            preds, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls, max_det=max_det
        )

    # Metrics
    for si, pred in enumerate(preds):
        labels = targets[targets[:, 0] == si, 1:]
        nl, npr = labels.shape[0], pred.shape[0]  # number of labels, predictions
        path, shape = Path(paths[si]), shapes[si][0]
        correct = torch.zeros(npr, niou, dtype=torch.bool, device=device)  # init
        seen += 1

        if npr == 0:
            if nl:
                stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))
                if plots:
                    confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
            continue

        # Predictions
        if single_cls:
            pred[:, 5] = 0
        predn = pred.clone()
        scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1])  # native-space pred

        # Evaluate
        if nl:
            tbox = xywh2xyxy(labels[:, 1:5])  # target boxes
            scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1])  # native-space labels
            labelsn = torch.cat((labels[:, 0:1], tbox), 1)  # native-space labels
            correct = process_batch(predn, labelsn, iouv)
            if plots:
                confusion_matrix.process_batch(predn, labelsn)
        stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0]))  # (correct, conf, pcls, tcls)


    # Plot images
    if plots and batch_i < 3:
        plot_images(im, targets, paths, save_dir / f"val_batch{batch_i}_labels.jpg", names)  # labels
        plot_images(im, output_to_target(preds), paths, save_dir / f"val_batch{batch_i}_pred.jpg", names)  # pred
        
    experts[batch_i].cpu()

# Compute metrics
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]  # to numpy
if len(stats) and stats[0].any():
    tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
    ap50, ap = ap[:, 0], ap.mean(1)  # AP@0.5, AP@0.5:0.95
    mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
nt = np.bincount(stats[3].astype(int), minlength=nc)  # number of targets per class

# Print results
pf = "%22s" + "%11i" * 2 + "%11.3g" * 4  # print format
LOGGER.info(pf % ("all", seen, nt.sum(), mp, mr, map50, map))
if nt.sum() == 0:
    LOGGER.warning(f"WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels")

# Print results per class
if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
    for i, c in enumerate(ap_class):
        LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))

# Print speeds
t = tuple(x.t / seen * 1e3 for x in dt)  # speeds per image
if not training:
    shape = (batch_size, 3, imgsz, imgsz)
    LOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}" % t)

# Plots
if plots:
    confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))

# Return results
if not training:
    s = ""
    LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
    maps[c] = ap[i]


      Epoch    GPU_mem   box_loss   obj_loss   cls_loss  Instances       Size
100%|██████████| 372/372 [02:06<00:00,  2.93it/s]
                   all       2975      60488      0.111      0.129     0.0579     0.0214
Speed: 0.1ms pre-process, 15.1ms inference, 1.8ms NMS per image at shape (8, 3, 640, 640)
Results saved to [1mexp_2[0m


In [9]:
print((mp, mr, map50, map, *(loss.cpu() / len(train_loader)).tolist()), maps, t)

(0.11072238504389358, 0.1292027821863751, 0.057901606793976775, 0.021387073665842252, nan, nan, 74299613184.0) [   0.057862   0.0082329    0.089041    0.002533    0.021387    0.003452  0.00043706   0.0011021    0.021387   0.0084368    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387
    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387
    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.021387    0.02138