In [9]:
import os
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report, f1_score
import matplotlib.pyplot as plt
from ultralytics import YOLO  # YOLOv8
import zipfile

# # === Распаковка датасета ===
dataset_path = "./dataset"
os.makedirs(dataset_path, exist_ok=True)

for zip_file in ["Images.zip", "Masks.zip", "Test.zip", "Training.zip"]:
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(dataset_path)


In [10]:
# === Гиперпараметры ===
BATCH_SIZE = 32
LEARNING_RATE = 0.001
EPOCHS = 10

# === Предобработка ===
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])



In [11]:
# === Датасет ===
class FireDataset(Dataset):
    def __init__(self, image_folder, labels):
        self.image_folder = image_folder
        self.labels = labels
        self.image_files = os.listdir(image_folder)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.image_files[idx])
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = self.labels.get(self.image_files[idx], 0)
        image = transform(image)
        return image, label



In [12]:
# === Классификация CNN ===
class FireClassifier(torch.nn.Module):
    def __init__(self):
        super(FireClassifier, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = torch.nn.Linear(512, 2)  # 2 класса: Огонь / Нет огня

    def forward(self, x):
        return self.model(x)



In [13]:
from tqdm import tqdm

# === Обучение CNN с прогрессом ===
def train_classifier(model, train_loader, val_loader):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    model.train()
    train_losses, val_losses = [], []

    for epoch in range(EPOCHS):
        epoch_loss = 0
        # Добавим прогресс-бар для обучения
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch") as pbar:
            for images, labels in pbar:
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                
                # Обновляем прогресс-бар с потерями
                pbar.set_postfix(loss=epoch_loss / (pbar.n + 1))

        train_losses.append(epoch_loss / len(train_loader))

        # Валидация с прогресс-баром
        model.eval()
        val_loss = 0
        with torch.no_grad():
            with tqdm(val_loader, desc="Validation", unit="batch") as pbar_val:
                for images, labels in pbar_val:
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    pbar_val.set_postfix(val_loss=val_loss / (pbar_val.n + 1))

        val_losses.append(val_loss / len(val_loader))
        print(f'Epoch {epoch+1}: Train Loss={train_losses[-1]:.4f}, Val Loss={val_losses[-1]:.4f}')
    
    return train_losses, val_losses


In [14]:
# === YOLOv8 для обнаружения огня ===
def detect_fire_yolo(image_path, model_path='yolov8s.pt'):
    model = YOLO(model_path)
    results = model(image_path)
    results.show()

# === Графики ===
def plot_metrics(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()



In [15]:
# === F-Score ===
def evaluate_model(model, test_loader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            y_true.extend(labels.numpy())
            y_pred.extend(preds.numpy())
    print(classification_report(y_true, y_pred))
    print("F1 Score:", f1_score(y_true, y_pred, average='weighted'))

# === Анализ видео ===
def process_video(video_path, model):
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        detect_fire_yolo(frame)
        cv2.imshow('Fire Detection', frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()



In [16]:
# === Основной код ===
if __name__ == "__main__":
    # Подготовка данных
    fire_folder = os.path.join(dataset_path, 'Training/fire')
    no_fire_folder = os.path.join(dataset_path, 'Training/no_fire')
    
    train_labels = {}
    for file in os.listdir(fire_folder):
        train_labels[file] = 1  # Метка 1 — Огонь
    for file in os.listdir(no_fire_folder):
        train_labels[file] = 0  # Метка 0 — Нет огня
    
    train_dataset = FireDataset(fire_folder, train_labels)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    


In [20]:
# Обучение CNN
model = FireClassifier()
train_losses, val_losses = train_classifier(model, train_loader, train_loader)
plot_metrics(train_losses, val_losses)
evaluate_model(model, train_loader)
    


Epoch 1/10:   0%|          | 1/782 [00:03<46:12,  3.55s/batch, loss=0.929]


KeyboardInterrupt: 

In [21]:
# Тестирование YOLO
# detect_fire_yolo(os.path.join(dataset_path, '/test_image.jpg'))
    
    # Анализ видео
process_video('fire_video.mp4', model)


0: 384x640 (no detections), 120.1ms
Speed: 3.0ms preprocess, 120.1ms inference, 1.4ms postprocess per image at shape (1, 3, 384, 640)


AttributeError: 'list' object has no attribute 'show'

In [19]:
# === Сохранение обученной модели ===
model_save_path = "./fire_classifier.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Модель сохранена по пути: {model_save_path}")

Модель сохранена по пути: ./fire_classifier.pth
