Модель классификатора должна соответствовать ACTION_CLASSES

In [3]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import mss
import time
import albumentations as A
from albumentations.pytorch import ToTensorV2
from ultralytics import YOLO

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
import cv2
import numpy as np
import time
from ultralytics import YOLO 

YOLO_MODEL_PATH = "runs/detect/hollow_knight_detector2/weights/best.pt" 
CLASSIFIER_MODEL_PATH = 'best_action_model_multi_class_last.pth'


ATTACK_TO_ID = {
    'Crystal_Guard': 0, 'jump_move': 1, 'hand_laser': 2, 'scream_beams': 3,
    'False_Knight': 4, 'trowing_wave': 5, 'hit_from_heaven': 6, 'left-right-smashing': 7,
    'just_jumpmove': 8,
    'hornet': 9, 'hornet_ram': 10, 'hornet_drill': 11, 'hornet_throw': 12, 'hornet_silk': 13
}
ID_TO_ATTACK = {v: k for k, v in ATTACK_TO_ID.items()}


ENEMY_ATTACKS_INDICES = {
    'hornet': [9, 10, 11, 12, 13],       
    'false_knight': [4, 5, 6, 7, 8],    
    'crystal': [0, 1, 2, 3]            
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

yolo_model = YOLO(YOLO_MODEL_PATH)

def get_classifier(num_classes):
    model = models.resnet18(pretrained=False) 
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

classifier = get_classifier(len(ATTACK_TO_ID))
classifier.load_state_dict(torch.load(CLASSIFIER_MODEL_PATH, map_location=device))
classifier.to(device)
classifier.eval()

clf_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# INFERENCE

def predict_action_masked(crop, enemy_name):

    img_tensor = clf_transform(crop).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = classifier(img_tensor) 
        

    if enemy_name in ENEMY_ATTACKS_INDICES:
        valid_indices = ENEMY_ATTACKS_INDICES[enemy_name]

        mask = torch.full_like(outputs, float('-inf'))
        mask[:, valid_indices] = outputs[:, valid_indices]

        probs = F.softmax(mask, dim=1)
    else:
      
        probs = F.softmax(outputs, dim=1)
        
    conf, predicted = torch.max(probs, 1)
    action_id = predicted.item()
    return ID_TO_ATTACK[action_id], conf.item()



import mss

sct = mss.mss()
monitor = {"top": 40, "left": 0, "width": 1280, "height": 720} # Настройте под себя

while True:
    start_time = time.time()
    

    frame = np.array(sct.grab(monitor))
    frame = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)

    results = yolo_model.predict(frame, conf=0.5, verbose=False)[0]
    

    vis_frame = frame.copy()
    
    if results.boxes:
        for box in results.boxes:

            x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
            

            cls_id = int(box.cls[0])
            enemy_name = yolo_model.names[cls_id] 
            

            cv2.rectangle(vis_frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
            
         
            h, w, _ = frame.shape
            crop = frame[max(0, y1):min(h, y2), max(0, x1):min(w, x2)]
            
            if crop.size > 0:
                action_name, action_conf = predict_action_masked(crop, enemy_name)
                
                # Враг + Действие
                label = f"{enemy_name}: {action_name} ({action_conf:.2f})"
                cv2.putText(vis_frame, label, (x1, max(20, y1 - 10)), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
    
    # FPS
    fps = 1.0 / (time.time() - start_time)
    cv2.putText(vis_frame, f"FPS: {int(fps)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
    
    cv2.imshow("Hollow Knight Bot", vis_frame)
    
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cv2.destroyAllWindows()

  classifier.load_state_dict(torch.load(CLASSIFIER_MODEL_PATH, map_location=device))
