In [40]:
import time
import cv2
import mss
import numpy as np
import torch
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from inference import get_model

In [41]:
ROBOFLOW_API_KEY = "FVx0cmRkyRxl6chDkDgT" 
# ID модели с сайта Roboflow
ROBOFLOW_MODEL_ID = "hollow-knight/3" 

# Модель классификации 
ACTION_MODEL_PATH = 'best_action_model.pth'

# Размер, который ждет классификатор
CLASSIFIER_IMG_SIZE = (512, 512) 
ACTION_CLASSES = ['hornet', 'hornet_drill', 'hornet_ram', 'hornet_silk', 'hornet_throw']
MONITOR_CONFIG = {"top": 0, "left": 0, "width": 1920, "height": 1080}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [42]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.batch1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, stride=1, bias=False)
        self.batch2 = nn.BatchNorm2d(out_channels)
        self.down_sample = nn.Identity()
        if in_channels != out_channels or stride != 1:
            self.down_sample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        if self.down_sample is not None:
            identity = self.down_sample(x)
        out = self.conv1(x)
        out = self.batch1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.batch2(out)
        out += identity
        out = self.relu(out)
        return out

class CustomLightNet(nn.Module):
    def __init__(self, num_clusses=5):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.LeakyReLU(inplace=True)
        self.basic1 = ResidualBlock(16, 16)
        self.basic2 = ResidualBlock(16, 16) 
        self.basic3 = ResidualBlock(16, 32, stride=2)
        self.basic4 = ResidualBlock(32, 32, stride=1) 
        self.basic5 = ResidualBlock(32, 64, stride=2)
        self.basic6 = ResidualBlock(64, 64, stride=1) 
        self.basic7 = ResidualBlock(64, 128, stride=2)
        self.basic8 = ResidualBlock(128, 128, stride=1) 
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=0.5)
        self.fc = nn.Linear(128, num_clusses)
        
    def forward(self, x):
        out = self.relu((self.bn1((self.conv1((x))))))
        out = self.basic2(self.basic1(out))
        out = self.basic4(self.basic3(out))
        out = self.basic6(self.basic5(out))
        out = self.basic8(self.basic7(out))
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc(out)
        return out

In [43]:
detector = get_model(model_id=ROBOFLOW_MODEL_ID, api_key=ROBOFLOW_API_KEY)
classifier = CustomLightNet(num_clusses=len(ACTION_CLASSES))
try:
    classifier.load_state_dict(torch.load(ACTION_MODEL_PATH, map_location=device))
    classifier.to(device)
    classifier.eval()
    print("Классификатор атак загружен.")
except FileNotFoundError:
    print(f"ОШИБКА: Не найден файл {ACTION_MODEL_PATH}. Положите его рядом со скриптом.")
    exit()

classifier_transform = A.Compose([
    A.Resize(height=CLASSIFIER_IMG_SIZE[0], width=CLASSIFIER_IMG_SIZE[1]),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])





Классификатор атак загружен.


In [44]:
def main():
    sct = mss.mss()
    print("Бот запущен. Нажмите 'q' для выхода.")
    
    while True:
        start_time = time.time()

        # 1. Захват экрана
        screenshot = sct.grab(MONITOR_CONFIG)
        frame_np = np.array(screenshot)
        frame_rgb = cv2.cvtColor(frame_np, cv2.COLOR_BGRA2RGB)
        frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_BGRA2BGR)

        # 2. ДЕТЕКЦИЯ через Roboflow
        results = detector.infer(frame_rgb)
        
        boss_found = False
        
        predictions = results[0].predictions if isinstance(results, list) else results.predictions

        for pred in predictions:
            
            if getattr(pred, 'class_name') != "Enemy":
                continue # Пропускаем, если это не вра
                
            x_center = pred.x
            y_center = pred.y
            width = pred.width
            height = pred.height
            confidence = pred.confidence

            if confidence > 0.5:
                boss_found = True
                
                x1 = int(x_center - width / 2)
                y1 = int(y_center - height / 2)
                x2 = int(x_center + width / 2)
                y2 = int(y_center + height / 2)
                h_img, w_img, _ = frame_rgb.shape
                x1, y1 = max(0, x1), max(0, y1)
                x2, y2 = min(w_img, x2), min(h_img, y2)

                cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)

                # 3. КРОП И КЛАССИФИКАЦИЯ
                boss_crop = frame_rgb[y1:y2, x1:x2]
                
                if boss_crop.size > 0:
                    transformed = classifier_transform(image=boss_crop)
                    img_tensor = transformed['image'].unsqueeze(0).to(device)
                    
                    with torch.no_grad():
                        outputs = classifier(img_tensor)
                        probs = torch.nn.functional.softmax(outputs, dim=1)
                        action_conf, action_idx = torch.max(probs, 1)
                    
                    action_name = ACTION_CLASSES[action_idx.item()]
                    action_score = action_conf.item()
                    label_text = f"ATTACK: {action_name} ({action_score:.0%})"
                    color = (0, 165, 255) if action_name != 'hornet' else (255, 255, 255)
                    cv2.putText(frame_bgr, label_text, (x1, y1 - 10), 
                                cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)

        if not boss_found:
             cv2.putText(frame_bgr, "Looking for Enemy...", (10, 30), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

        # FPS
        fps = 1.0 / (time.time() - start_time)
        cv2.putText(frame_bgr, f"FPS: {int(fps)}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 1)

        display_frame = cv2.resize(frame_bgr, (960, 540)) # Уменьшаем 1920x1080 до 960x540
        cv2.imshow('Hollow Knight Bot', display_frame)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cv2.destroyAllWindows()

if __name__ == "__main__":
    main()

Бот запущен. Нажмите 'q' для выхода.
