In [None]:
import torch
from pathlib import Path
import os
from ultralytics import YOLO
from ultralytics.models.yolo.model import DetectionModel # Ensure this import is correct

In [None]:
# --- 1. Configuration ---
CONFIG_NAME = 'yolo11s-dspan.yaml' # Name of the YAML file for the modified LCBHAM model
# File Paths
current_dir = os.getcwd()
target_yaml_path = os.path.join(current_dir, 'ultralytics', 'cfg', 'models', '11', CONFIG_NAME)
source_weights_path = os.path.join(current_dir, 'yolo11s.pt')           # Pre-trained standard YOLOv11s
output_weights_path = os.path.join(current_dir, 'runs', 'models', f'{CONFIG_NAME.split(".")[0]}.pt') # Output path for transferred weights

# Data Paths
data_path = os.path.join(current_dir, 'datasets', 'AblationDataset', 'wt-flow-taco', 'wt-flow-taco.yaml')
test_data_path = os.path.join(current_dir, 'datasets', 'TestDataset', 'flowtest', 'flowtest.yaml')
test_data_path_wt = os.path.join(current_dir, 'datasets', 'TestDataset', 'watertrashtest', 'watertrashtest.yaml')

# Model Parameters (Ensure these match your dataset and model)
nc = 1 # Number of classes (adjust if your dataset differs from COCO)

# LCBHAM Layer Indices (Confirm these match your target_yaml_path)
lcbham_layers_indices = {17}

In [None]:
# --- 2. Perform Weight Transfer (if necessary) ---

if not os.path.exists(output_weights_path):
    print(f"Transfer weights file '{output_weights_path}' not found. Performing weight transfer...")

    # --- Load Source Weights ---
    print(f"Loading source weights from {source_weights_path}...")
    try:
        source_ckpt = torch.load(source_weights_path, map_location=torch.device('cpu'), weights_only=False)

        if hasattr(source_ckpt.get('model'), 'state_dict'):
            source_state_dict = source_ckpt['model'].float().state_dict()
        elif isinstance(source_ckpt.get('model'), dict):
             source_state_dict = source_ckpt['model']
        elif isinstance(source_ckpt, dict) and not 'model' in source_ckpt:
             source_state_dict = source_ckpt
        else:
            print("Available keys in source_ckpt:", source_ckpt.keys() if isinstance(source_ckpt, dict) else "Not a dict")
            raise ValueError("Could not extract state_dict from source checkpoint.")
        print(f"Source weights loaded and state_dict extracted.")
    except Exception as e:
        print(f"Error loading source weights: {e}")
        exit()

    # --- Build Target Model Structure ---
    print(f"Building target model structure from {target_yaml_path}...")
    try:
        # Ensure your custom LCBHAM module is defined/imported before this line
        target_model = DetectionModel(cfg=str(target_yaml_path), ch=3, nc=nc)
    except Exception as e:
         print(f"Error building target model: {e}")
         print("Ensure LCBHAM is defined, the YAML path is correct, and nc={nc} is appropriate.")
         exit()
    target_state_dict = target_model.state_dict()
    print(f"Target model structure built.")

    # --- Weight Transfer Logic ---
    print("Starting weight transfer...")
    new_state_dict = {}
    transferred_count = 0
    skipped_count = 0
    lcbham_conv_transferred = set()
    lcbham_bn_transferred = set()

    for k_target, v_target in target_state_dict.items():
        k_source = k_target
        layer_index_str = k_target.split('.')[1] # Get layer index string like '17'

        if layer_index_str.isdigit() and int(layer_index_str) in lcbham_layers_indices:
            layer_index = int(layer_index_str)
            if f".{layer_index_str}.conv_block.0." in k_target: # Conv part of LCBHAM
                k_source = k_target.replace("conv_block.0.", "conv.")
                lcbham_conv_transferred.add(layer_index)
            elif f".{layer_index_str}.conv_block.1." in k_target: # BN part of LCBHAM
                k_source = k_target.replace("conv_block.1.", "bn.")
                lcbham_bn_transferred.add(layer_index)

        if k_source in source_state_dict and source_state_dict[k_source].shape == v_target.shape:
            new_state_dict[k_target] = source_state_dict[k_source]
            transferred_count += 1
        else:
            new_state_dict[k_target] = v_target
            skipped_count += 1

    print("\n--- Transfer Summary ---")
    print(f"Total keys in target model: {len(target_state_dict)}")
    print(f"Weights transferred: {transferred_count}")
    print(f"Weights skipped/kept from target: {skipped_count}")
    for idx in lcbham_layers_indices:
        if idx in lcbham_conv_transferred: print(f"Successfully mapped Conv weights for LCBHAM layer {idx}.")
        if idx in lcbham_bn_transferred: print(f"Successfully mapped BN weights for LCBHAM layer {idx}.")

    # --- Load New State Dict and Save Checkpoint ---
    target_model.load_state_dict(new_state_dict, strict=False)
    print("\nLoaded transferred weights into target model structure.")

    output_ckpt = {
        'epoch': -1,
        'best_fitness': None,
        'model': target_model,
        'ema': None,
        'updates': None,
        'train_args': {}, # Use empty dict for compatibility
        'date': None
}

    # Get the directory part of the output path using os.path
    output_directory = os.path.dirname(output_weights_path)

    # Create the directory recursively, ignoring errors if it exists
    os.makedirs(output_directory, exist_ok=True)
    # --- End Replacement Code ---

    # Save the checkpoint (this line remains unchanged)
    torch.save(output_ckpt, output_weights_path)
    print(f"Saved model with transferred weights to {output_weights_path}")
    print("\nWeight transfer complete.")

else:
    print(f"Found existing transfer weights file: '{output_weights_path}'. Skipping transfer step.")

In [None]:
# --- 3. Fine-Tuning ---
print("\n--- Starting Fine-Tuning ---")

# Load the model with transferred weights
model = YOLO(output_weights_path) # Load the result of the transfer
# Train the model
print(f"Training on data: {data_path}")

# --- Train Model ---
results = model.train(
    # ** Essential Paths & Config **
    data=str(data_path),                # Path to your dataset YAML file
    name='wt-flow-taco-uav',                 # Name for the training run directory
    exist_ok=False,                # Error if run name already exists
    save=True,                     # Save checkpoints and final model

    # ** Requested Hyperparameters **
    epochs=300,                    # Number of training epochs
    batch=16,                       # Batch size
    imgsz=640,                     # Input image size (height=width=640)

    # ** Data Augmentation Control (ONLY MOSAIC ENABLED) **
    augment=False,                  # MUST be True to enable the augmentation pipeline for mosaic
)
print("\n--- Training Finished ---")

In [None]:
# --- 4. Validation ---
print("\n--- Starting Validation ---")

print(f"Validating on flowtest: {test_data_path}")
try:
    res_flow = model.val(data=str(test_data_path)) # Use the trained model object
    print("Validation Results (flowtest):")
    # print(res_flow)
except Exception as e:
    print(f"Error during validation on {test_data_path}: {e}")


print(f"\nValidating on watertrashtest: {test_data_path_wt}")
try:
    res_wt = model.val(data=str(test_data_path_wt)) # Use the trained model object
    print("Validation Results (watertrashtest):")
    # print(res_wt)
except Exception as e:
    print(f"Error during validation on {test_data_path_wt}: {e}")