In [22]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


def imshow_with_labels(images, labels, classes):
    num_images = len(images)
    grid_size = int(np.ceil(np.sqrt(num_images)))
    fig, axs = plt.subplots(grid_size, grid_size, figsize=(15, 15))
    axs = axs.flatten()

    for i in range(grid_size**2):
        if i < num_images:
            # img = to_pil_image(images[i])
            # extract that image (need to transpose it back to 32x32x3)
            img = images[i].numpy().transpose((1, 2, 0))
            img = img / 2 + 0.5  # undo normalization
            label = classes[labels[i]]
            axs[i].imshow(img)
            axs[i].set_title(f"Label: {label}")
            axs[i].axis("off")
        else:
            axs[i].axis("off")
    plt.tight_layout()
    plt.show()


def load_data(path, batch_size, DataType):
    Dataset = ImageFolder(path, transform=transform[DataType])  # type: ignore
    if batch_size == 0:
        batch_size = len(Dataset)
    Dataloader = DataLoader(Dataset, batch_size=batch_size,
                            shuffle=True, drop_last=True)
    return Dataset, Dataloader


def loadTrain(path, batch_size):
    train_path = os.path.join(path, "train")
    trainset, train_loader = load_data(train_path, batch_size, "training")
    dataiter = iter(train_loader)
    images, labels = next(dataiter)
    # try again
    print("Data shapes (train/test):")
    print(images.data.shape)

    # and the range of pixel intensity values
    print("\nData value range:")
    print((torch.min(images.data), torch.max(images.data)))

    # Show images
    imshow_with_labels(images, labels, trainset.classes)

    return train_loader, trainset.classes


def loadTest(path, batch_size=0):
    test_path = os.path.join(path, "validation")
    testset, test_loader = load_data(test_path, batch_size, "evaluate")
    return test_loader, testset.classes


def function2trainModel(model, device, train_loader, lossFun, optimizer):
    epochs = 10

    model.to(device)

    # initialize losses
    trainLoss = np.zeros(epochs)
    # devLoss   = torch.zeros(epochs)
    trainAcc = np.zeros(epochs)
    # devAcc    = torch.zeros(epochs)

    for epochi in range(epochs):
        # loop over training data batches
        model.train()  # switch to train mode
        batchLoss = []
        batchAcc = []
        for batch_idx, (X, y) in enumerate(train_loader):
            # push data to GPU
            X = X.to(device)
            y = y.to(device)
            # forward pass and loss
            yHat = model(X)
            loss = lossFun(yHat, y)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # loss and accuracy for this batch
            batchLoss.append(loss.item())
            batchAcc.append(torch.mean(
                (torch.argmax(yHat, dim=1) == y).float()).item())
            print(
                f"Epoch: {epochi+1}/{epochs}, Batch: {batch_idx}, {batch_idx+1}/{len(train_loader)}")

        # end of batch loop
        # get average losses and accuracies across the batches
        trainLoss[epochi] = np.mean(batchLoss)
        trainAcc[epochi] = 100 * np.mean(batchAcc)
    return trainLoss, trainAcc, model

In [23]:
%matplotlib inline
import torch.nn as nn
from torchvision import models


def makeMorpherNet(printtoggle=False):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    num_features = model.fc.in_features
    
    # fc ：Dropout → Linear
    model.fc = nn.Sequential( # type: ignore
        nn.Dropout(p=0.7),               # 这里放 dropout
        nn.Linear(num_features, 3)
    )
    
    lossfun = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)
    return model, lossfun, optimizer


In [None]:
import cv2
import torch
import torchvision.transforms as transforms
from ultralytics import YOLO
import threading
from functools import wraps

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


# 单例装饰器（线程安全）
def singleton(cls):
    instances = {}
    lock = threading.Lock()

    @wraps(cls)
    def wrapper(*args, **kwargs):
        if cls not in instances:
            with lock:  # 加锁确保多线程安全
                if cls not in instances:  # 双重检查锁定
                    instances[cls] = cls(*args, **kwargs)
        return instances[cls]

    return wrapper


# 应用装饰器
@singleton
class Detector:
    def __init__(self, ModelPath, MorpherModelPath):
        self.classes = ["Ambiguous", "Long", "Short"]

        # 加载YOLOv8n
        self.fly_detector = YOLO(ModelPath)

        # 加载表情识别模型（保持不变）
        self.model, _, _ = makeMorpherNet()
        state_dict = torch.load(MorpherModelPath)
        self.model.load_state_dict(state_dict)
        self.model.eval()

        # 设置设备（优先使用GPU）
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def process(self, img):
        # 使用YOLO进行果蝇检测
        results = self.fly_detector.predict(img, conf=0.25, iou=0.6)

        cropped_flies = []
        flies_pos = []

        # 解析YOLO检测结果
        for result in results:
            boxes = result.boxes.cpu().numpy()
            for box in boxes:
                try:
                    # 获取xyxy坐标并转换为(x,y,w,h)格式
                    x1, y1, x2, y2 = map(int, box.xyxy[0])
                except Exception:
                    continue
                # 保证坐标在图像范围内
                x1 = max(0, x1)
                y1 = max(0, y1)
                x2 = min(img.shape[1]-1, x2)
                y2 = min(img.shape[0]-1, y2)
                if x2 <= x1 or y2 <= y1:
                    continue

                w, h = x2 - x1, y2 - y1

                # 裁剪果蝇区域
                cropped_fly = img[y1:y2, x1:x2]
                if cropped_fly.size == 0:
                    continue
                cropped_fly = cv2.resize(cropped_fly, (224, 224))

                cropped_flies.append(cropped_fly)
                flies_pos.append((x1, y1, w, h))

        # 如果没有检测到目标，直接返回原图
        if len(cropped_flies) == 0:
            return img

        # 性状识别
        tensor_flies = [self.transform2tensor(fly) for fly in cropped_flies]

        # 计数容器
        counts = {c: 0 for c in self.classes}
        total = 0

        for i, tensor_fly in enumerate(tensor_flies):
            tensor_fly = tensor_fly.to(self.device)
            with torch.no_grad():
                output = self.model(tensor_fly)

            x, y, w, h = flies_pos[i]
            cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 8)
            max_idx = int(output.argmax().item())
            label = self.classes[max_idx]
            counts[label] += 1
            total += 1

            # 在框上方写类别
            cv2.putText(img, label, (x, max(50, y - 50)),
                        cv2.FONT_HERSHEY_SIMPLEX, 4, (0, 128, 0), 8)

        # 在图像左上角绘制统计信息（半透明背景）
        lines = [f"{k}: {counts[k]}" for k in self.classes]
        lines.append(f"Total: {total}")
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 4
        thickness = 8
        padding = 50

        # 计算背景大小
        text_sizes = [cv2.getTextSize(line, font, font_scale, thickness)[
            0] for line in lines]
        box_w = max(w for w, h in text_sizes) + padding * 2
        box_h = sum(h + 50 for w, h in text_sizes) + padding

        x0, y0 = 10, 10
        overlay = img.copy()
        bg_color = (0, 0, 0)  # black background
        cv2.rectangle(overlay, (x0, y0),
                      (x0 + box_w, y0 + box_h), bg_color, -1)
        alpha = 0.55
        cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)

        # 写文本
        y_text = y0 + padding + text_sizes[0][1]
        for idx, line in enumerate(lines):
            cv2.putText(img, line, (x0 + padding, y_text), font,
                        font_scale, (255, 255, 255), thickness)
            y_text += text_sizes[idx][1] + 50

        return img

    def transform2tensor(self, data):
        tensor_data = transform(data)
        return tensor_data.unsqueeze(0)  # 添加batch维度


if __name__ == "__main__":
    img = cv2.imread("demo.jpg")

    # 初始化检测器
    detector = Detector("best.pt", "fly_morpher_resnet_1210_224.pth")

    process_img = detector.process(img)

    cv2.imwrite("result.jpg", process_img)
    # cv2.imshow("result", process_img)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()


0: 640x480 5 flys, 14.7ms
Speed: 2.2ms preprocess, 14.7ms inference, 1.5ms postprocess per image at shape (1, 3, 640, 480)
