-
Notifications
You must be signed in to change notification settings - Fork 0
Training
Gaurav14cs17 edited this page Jun 21, 2026
·
1 revision
FlashFusion trains fusion layers (heads, necks, learned weights) while keeping base model backbones frozen.
flashfusion train --config configs/flashfusion_det_cls_320.yaml# configs/flashfusion_det_cls_320.yaml
fusion:
strategy: weighted_box_fusion
input_size: [320, 320]
models:
- type: detection
source: flashdet
model_size: m
frozen: true
- type: classification
source: flashcls
model_size: m
frozen: true
train:
epochs: 50
batch_size: 16
learning_rate: 0.0005
scheduler: cosine
warmup_epochs: 3
save_dir: workspace/flashfusion_det_clsfrom flashfusion import Trainer
from flashfusion.cfg import get_config
config = get_config("configs/flashfusion_det_cls_320.yaml")
trainer = Trainer(config, device="cuda")
results = trainer.train()flashfusion train --config configs/flashfusion_det_cls_320.yaml --resume workspace/flashfusion_det_cls/last.ptFlashFusion expects the following directory structure:
dataset_root/
├── images/
│ ├── train/
│ │ ├── img_001.jpg
│ │ └── ...
│ └── val/
│ ├── img_100.jpg
│ └── ...
├── annotations/
│ ├── train/
│ │ ├── img_001.json
│ │ └── ...
│ └── val/
│ ├── img_100.json
│ └── ...
└── masks/ (optional)
├── train/
└── val/
Each annotation JSON:
{
"boxes": [[x1, y1, x2, y2], ...],
"labels": [0, 1, ...],
"class_label": 3
}FlashFusion uses a multi-task loss:
from flashfusion.losses import FusionLoss
criterion = FusionLoss(
det_weight=1.0, # Detection loss weight
cls_weight=0.5, # Classification loss weight
consistency_weight=0.1, # Inter-model consistency
)Checkpoints are saved automatically:
-
workspace/<name>/last.pt— Latest checkpoint -
workspace/<name>/best.pt— Best validation metric
For efficient training with minimal parameters:
from flashfusion.models.lora import apply_lora
model = apply_lora(model, rank=8, alpha=16)
# Only LoRA parameters are trainedUse the logger and metrics utilities:
from flashfusion.utils import setup_logger, AverageMeter
logger = setup_logger("train", log_file="train.log")
loss_meter = AverageMeter("loss")FlashFusion — Multi-model vision fusion | PyPI | MIT License