In [None]:
# ============================================================
# 🧠 FACE MASK DETECTION (3 Classes) using YOLOv8 + PyTorch
# ============================================================

# ✅ 1. Install dependencies
!pip install ultralytics==8.2.0 torch torchvision torchaudio opencv-python matplotlib seaborn tqdm --quiet

# ✅ 2. Import libraries
import os
import random
import shutil
import yaml
import torch
from ultralytics import YOLO
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import cv2
import numpy as np
from IPython.display import display, Image

# ============================================================
# ⚙️ 3. Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# ============================================================
# 📁 4. Dataset structure
# You should have:
# FMD_DATASET/
# ├── images/
# │   ├── img1.jpg ...
# ├── labels/
# │   ├── img1.txt ...
# Each .txt file should have YOLO format labels (class x_center y_center width height)

DATASET_PATH = "/content/FMD_DATASET"

# ============================================================
# 🧩 5. Split data (train:70, val:10, test:20)
def split_dataset(base_path, train_ratio=0.7, val_ratio=0.1, test_ratio=0.2):
    img_dir = os.path.join(base_path, "images")
    lbl_dir = os.path.join(base_path, "labels")

    images = [f for f in os.listdir(img_dir) if f.endswith((".jpg", ".png", ".jpeg"))]
    random.shuffle(images)

    n_total = len(images)
    n_train = int(train_ratio * n_total)
    n_val = int(val_ratio * n_total)

    subsets = {
        "train": images[:n_train],
        "val": images[n_train:n_train + n_val],
        "test": images[n_train + n_val:]
    }

    for subset, imgs in subsets.items():
        os.makedirs(f"{base_path}/images/{subset}", exist_ok=True)
        os.makedirs(f"{base_path}/labels/{subset}", exist_ok=True)
        for img in imgs:
            shutil.copy(f"{img_dir}/{img}", f"{base_path}/images/{subset}/{img}")
            lbl_name = img.replace(".jpg", ".txt").replace(".png", ".txt")
            shutil.copy(f"{lbl_dir}/{lbl_name}", f"{base_path}/labels/{subset}/{lbl_name}")
    print("✅ Dataset split complete!")

split_dataset(DATASET_PATH)

# ============================================================
# 📄 6. Create YOLO dataset.yaml
data_yaml = {
    'train': f'{DATASET_PATH}/images/train',
    'val': f'{DATASET_PATH}/images/val',
    'test': f'{DATASET_PATH}/images/test',
    'nc': 3,  # number of classes
    'names': ['mask', 'no_mask', 'incorrect_mask']
}

yaml_path = f"{DATASET_PATH}/data.yaml"
with open(yaml_path, 'w') as f:
    yaml.dump(data_yaml, f, default_flow_style=False)
print(f"✅ Created YAML file at: {yaml_path}")

# ============================================================
# 🚀 7. Train YOLOv8 model
model = YOLO('yolov8n.pt')  # You can use 'yolov8s.pt' for better accuracy

results = model.train(
    data=yaml_path,
    epochs=80,
    imgsz=640,
    batch=32,
    device=device,
    patience=10,         # early stopping
    project="runs_fmd",
    name="exp_mask",
    exist_ok=True,
    augment=True,
)

# ============================================================
# 📉 8. Plot loss curves
metrics = model.trainer.metrics
plt.figure(figsize=(8,5))
plt.plot(metrics["train/loss"], label='Train Loss')
plt.plot(metrics["val/loss"], label='Val Loss')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid()
plt.show()

# ============================================================
# 🧾 9. Evaluate model on test set
metrics = model.val(split='test')
print(metrics)

# ============================================================
# 🔍 10. Confusion Matrix
preds = model.val(split='test', save_json=True, conf=0.25)
cm = confusion_matrix(preds.boxes.cls.cpu(), preds.boxes.cls.cpu())
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=data_yaml['names'], yticklabels=data_yaml['names'])
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

# ============================================================
# 🎥 11. Real-time Inference (Webcam)
def realtime_detection(model_path="runs_fmd/exp_mask/weights/best.pt"):
    model = YOLO(model_path)
    cap = cv2.VideoCapture(0)
    while True:
        ret, frame = cap.read()
        if not ret: break
        results = model(frame, stream=True)
        for r in results:
            annotated = r.plot()
            cv2.imshow('Face Mask Detection', annotated)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()

# Uncomment to test:
# realtime_detection()
