In [1]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from models.UW_EffSAM_tracker import UW_EffSSAM
from PIL import Image
import contextlib
import io
import time


def predict_mask(pred_masks, ignored_masks):
    pred_masks = torch.argmax(pred_masks[1:, ...], dim=0) + 1  
    if ignored_masks is not None:
        pred_masks[pred_masks == 0] = 0
    return pred_masks


def process_frame(frame, model, device):

    with torch.no_grad():
        output = model([frame])
    image_result = output[0]
    image = image_result['images']
    masks = np.array(image_result['masks'])
    boxes = np.array(image_result['boxes'])
    labels = image_result['labels']
    return frame, masks, boxes, labels

def overlay_mask_on_frame(frame, mask_cls_pred, num_classes, alpha=0.5):
    cmap = plt.get_cmap('tab10', num_classes)
    norm = mcolors.Normalize(vmin=0, vmax=num_classes - 1)
    colored_mask = cmap(norm(mask_cls_pred))[..., :3]
    colored_mask = (colored_mask * 255).astype(np.uint8)
    overlay = cv2.addWeighted(frame, 1 - alpha, colored_mask, alpha, 0)
    return overlay

def draw_boxes_and_labels_on_frame(frame, boxes, labels):
    for box, label in zip(boxes, labels):
        x1, y1, x2, y2 = map(int, box)
        rect = cv2.rectangle(frame, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
        cv2.putText(frame, str(label), (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
    return frame


def video_processing(input_video_path, output_video_path, model, device):
    cap = cv2.VideoCapture(input_video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (512, 512))

    frame_count = 0
    total_inference_time = 0.0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        start_time = time.time()
        frame_resized, masks, boxes, labels = process_frame(frame, model, device)
        num_classes = int(masks.max() + 1)

        # 叠加掩码到帧上
        result_frame = overlay_mask_on_frame(frame_resized, masks, num_classes)
        # 叠加边界框和标签到帧上
        result_frame = draw_boxes_and_labels_on_frame(result_frame, boxes, labels)

        end_time = time.time()

        inference_time = end_time - start_time
        total_inference_time += inference_time
        frame_count += 1

        out.write(result_frame)

    cap.release()
    out.release()

    avg_inference_time = total_inference_time / frame_count
    avg_fps = 1.0 / avg_inference_time

    print(f"Average Inference Time per Frame: {avg_inference_time:.4f} seconds")
    print(f"Average FPS: {avg_fps:.2f}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
yolo_ckp = "runs/detect/train9/weights/best.pt"
yolo_type = "YOLOv8"
effsam_ckp = "checkpoint/eff_sam_l0.pt"
effsam_type = "l0"
multimask_output = False
input_type = "image"
conf = 0.6

model = UW_EffSSAM(
    yolo_ckp=yolo_ckp,
    yolo_type=yolo_type,
    effsam_ckp=effsam_ckp,
    effsam_type=effsam_type,
    multimask_output=multimask_output,
    input_type=input_type,
    conf=conf
).to(device)

video_processing("underwater_1920_1080_30fps.mp4", "op_video.mp4", model, device)


  from .autonotebook import tqdm as notebook_tqdm
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x00000174E2F0B850>>
Traceback (most recent call last):
  File "c:\Anaconda\envs\Underwater\lib\site-packages\ipykernel\ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


In [None]:
import os
import cv2
import matplotlib.pyplot as plt

def visualize_yolo_labels(image_folder, label_folder, output_folder=None):
    # Get all image filenames
    image_filenames = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]

    for image_filename in image_filenames:
        # Construct the image path
        image_path = os.path.join(image_folder, image_filename)
        
        # Load the image
        image = cv2.imread(image_path)
        height, width, _ = image.shape
        
        # Construct the corresponding label file path
        label_filename = os.path.splitext(image_filename)[0] + '.txt'
        label_path = os.path.join(label_folder, label_filename)
        
        # Check if the label file exists
        if not os.path.exists(label_path):
            print(f'Label file {label_filename} does not exist, skipping...')
            continue
        
        # Read the label file
        with open(label_path, 'r') as file:
            lines = file.readlines()
        
        # Parse the labels and draw the bounding boxes
        for line in lines:
            # YOLO format: class x_center y_center width height
            parts = line.strip().split()
            class_id = int(parts[0])
            x_center = float(parts[1]) * width
            y_center = float(parts[2]) * height
            box_width = float(parts[3]) * width
            box_height = float(parts[4]) * height
            
            # Calculate bounding box coordinates
            x1 = int(x_center - box_width / 2)
            y1 = int(y_center - box_height / 2)
            x2 = int(x_center + box_width / 2)
            y2 = int(y_center + box_height / 2)
            
            # Draw the rectangle
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
            # Annotate the class ID on the box
            cv2.putText(image, str(class_id), (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
        
        # Display the image
        plt.figure(figsize=(10, 10))
        plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        plt.axis('off')
        plt.show()
        
        # If an output folder is specified, save the visualized results
        if output_folder:
            if not os.path.exists(output_folder):
                os.makedirs(output_folder)
            output_path = os.path.join(output_folder, image_filename)
            cv2.imwrite(output_path, image)

# Example usage
image_folder = 'datasets/SUIM/train_val/train_val/images'
label_folder = 'label'
output_folder = None  # Optional, set to None if you don't want to save the output

visualize_yolo_labels(image_folder, label_folder, output_folder)
