In [4]:
import ultralytics
from ultralytics import YOLO
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch    
import cv2
import numpy as np
import matplotlib.pyplot as plt
import requests
import torchvision.transforms as transforms
from PIL import Image
import io
import os
from yolov8_heatmap import yolov8_heatmap

In [1]:
def get_params(weight):
    params = {
        'weight': weight, # 现在只需要指定权重即可,不需要指定cfg
        'device': 'cuda:0',
        'method': 'EigenGradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM, EigenCAM, HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
        'layer': [15, 18, 21],
        'backward_type': 'class', # class, box, all
        'conf_threshold': 0.2, # 0.2
        'ratio': 0.02, # 0.02-0.1
        'show_box': False,
        'renormalize': False
    }
    return params

In [8]:
model_ori = r'E:\code\yolov8_result\udacity\no_mod\weights\best.pt'
model_se = r'E:\code\yolov8_result\udacity\se\weights\best.pt'
model_cbam = r'E:\code\yolov8_result\udacity\eca\weights\best.pt'
model_eca = r'E:\code\yolov8_result\udacity\cbam\weights\best.pt'
model_dw = r'E:\code\yolov8_result\udacity\lsma\weights\best.pt'

In [9]:
heatmap_ori = yolov8_heatmap(**get_params(model_ori))
heatmap_se = yolov8_heatmap(**get_params(model_se))
heatmap_cbam = yolov8_heatmap(**get_params(model_cbam))
heatmap_eca = yolov8_heatmap(**get_params(model_eca))
heatmap_dw = yolov8_heatmap(**get_params(model_dw))

YOLOv8 summary: 225 layers, 3011433 parameters, 0 gradients, 8.2 GFLOPs
YOLOv8-se summary: 285 layers, 3066601 parameters, 0 gradients, 8.4 GFLOPs
YOLOv8-eca summary: 267 layers, 3053185 parameters, 0 gradients, 8.4 GFLOPs
YOLOv8-cbam summary: 291 layers, 3161941 parameters, 0 gradients, 8.5 GFLOPs
YOLOv8-lsma summary: 285 layers, 3060583 parameters, 0 gradients, 8.6 GFLOPs


In [10]:
image_folder = r'E:\Datasets\udacity\images'
result_folder= 'E:\\code\\ultralytics\\result'
# 确保输出目录存在
os.makedirs(result_folder, exist_ok=True)

titles = ['Origin','YOLOv8', 'SE', 'CBAM', 'ECA', 'LSMA']
image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]
print(image_files)

['1479498371963069978.jpg', '1479498372942264998.jpg', '1479498373462797835.jpg', '1479498373962951201.jpg', '1479498374962942172.jpg', '1479498375942206592.jpg', '1479498376463086347.jpg', '1479498377463264578.jpg', '1479498377963597629.jpg', '1479498378965237962.jpg', '1479498379965419997.jpg', '1479498380466064740.jpg', '1479498381465380454.jpg', '1479498382449249006.jpg', '1479498382965792478.jpg', '1479498383466326914.jpg', '1479498384442876688.jpg', '1479498384963639932.jpg', '1479498385966154564.jpg', '1479498386466851996.jpg', '1479498387466479165.jpg', '1479498388466168072.jpg', '1479498389464606786.jpg', '1479498389966519477.jpg', '1479498390964153934.jpg', '1479498391966209986.jpg', '1479498392966162658.jpg', '1479498393466501007.jpg', '1479498394463918193.jpg', '1479498395464437392.jpg', '1479498395964946961.jpg', '1479498396964620740.jpg', '1479498397956718105.jpg', '1479498398464810598.jpg', '1479498398965062294.jpg', '1479498399466018917.jpg', '1479498400465050134.jpg', 

In [None]:
for image_file in image_files:
    image_path = os.path.join(image_folder, image_file)
    img = cv2.imread(image_path)
    images=[]
    images.append(img)
    
    heatmap_ori(image_path,images)
    heatmap_se(image_path,images)
    heatmap_cbam(image_path,images)
    heatmap_eca(image_path,images)
    heatmap_dw(image_path,images)
    
    fig, axs = plt.subplots(nrows=1, ncols=6, figsize=(50, 50))
    plt.tight_layout()
    plt.subplots_adjust(hspace=0) 
    # 遍历子图并显示图像
    for ax, title, image in zip(axs, titles, images):
        ax.set_title(title,fontsize=45)
        ax.imshow(image)
        ax.axis('off')  # 关闭坐标轴
    
    output_path = os.path.join(result_folder, image_file)
    plt.savefig(output_path, dpi=300, transparent=True)

 68%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                             | 74/109 [00:00<00:00, 5693.14it/s]
 67%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                               | 73/109 [00:00<00:00, 5214.40it/s]
 70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                         | 76/109 [00:00<00:00, 5066.79it/s]
 68%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                             | 74/109 [00:00<00:00, 5867.49it/s]
 67%|███████████████████████████████████████████████████████████████████