# Training the Model

In this notebook, we will train the YOLOv8 models (any variant YOLOv8n, YOLOv8s, YOLOv8m, YOLOv8l etc.) on the custom dataset that contains 21k annotated images of dangerous objects.

Before starting, make sure to download this custom dataset from here: https://universe.roboflow.com/startup-zn0ol/dangerous-objects-dq94u

> Note: You can use any dataset of your choice, but make sure to update the paths in the code below.

Verify and update the `data.yaml` file with the new classes and paths to the training and validation datasets. Make sure to put the `dangerous-objects` folder in the `datasets/` directory.


```yaml
# Dataset paths
train: dangerous-objects/train/images
val: dangerous-objects/valid/images
test: dangerous-objects/test/images

# Number of classes
nc: 6

# Class names
names: ['ammo', 'firearm', 'grenade', 'knife', 'pistol', 'rocket']
```

In [21]:
%pip install --upgrade pip
%pip install -r requirements.txt

# Optional: Set up Comet.ml to log experiment data
import os
from dotenv import load_dotenv

from comet_ml import Experiment

load_dotenv()
          
experiment = Experiment(
  api_key=os.getenv("COMET_API_KEY"),
  project_name="enhancing-home-security",
  workspace="chiragagg5k"
)

# Imports
from ultralytics import YOLO
from ultralytics.utils.metrics import ConfusionMatrix
import torch
import matplotlib.pyplot as plt
import numpy as np

## Clear notebook output
from IPython.display import clear_output
clear_output(wait=False)

print("Notebook setup completed.")

Notebook setup completed.


In [1]:
# Load a pretrained model
model_name = "yolov8n"
model = YOLO(
    f"{model_name}.yaml"
) # or yolov8s.pt, yolov8m.pt, yolov8l.pt, yolov8x.pt

# If cuda is available
if(torch.cuda.is_available()):
    print("Using GPU")
    model = model.to("cuda")

NameError: name 'YOLO' is not defined

In [11]:
results = model.train(data="data.yaml", epochs=100, imgsz=640, verbose=True, plots=True)

New https://pypi.org/project/ultralytics/8.2.90 available 😃 Update with 'pip install -U ultralytics'
Ultralytics YOLOv8.2.89 🚀 Python-3.11.4 torch-2.4.0 CPU (Apple M2)
[34m[1mengine/trainer: [0mtask=detect, mode=train, model=yolov8n.pt, data=data.yaml, epochs=100, time=None, patience=100, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=None, workers=8, project=None, name=train2, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=F

[34m[1mtrain: [0mScanning /Users/chiragagg5k/Desktop/Coding_Stuff/Enhancing-Home-Security/datasets/dangerous-objects/train/labels.cache... 18855 images, 34 backgrounds, 0 corrupt: 100%|██████████| 18855/18855 [00:00<?, ?it/s]




[34m[1mval: [0mScanning /Users/chiragagg5k/Desktop/Coding_Stuff/Enhancing-Home-Security/datasets/dangerous-objects/valid/labels.cache... 1749 images, 1 backgrounds, 0 corrupt: 100%|██████████| 1749/1749 [00:00<?, ?it/s]

Plotting labels to runs/detect/train2/labels.jpg... 





[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m SGD(lr=0.01, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
[34m[1mTensorBoard: [0mmodel graph visualization added ✅
Image sizes 640 train, 640 val
Using 0 dataloader workers
Logging results to [1mruns/detect/train2[0m
Starting training for 100 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      1/100         0G      1.627      3.918      1.776         43        640:   2%|▏         | 28/1179 [03:41<2:32:05,  7.93s/it]


KeyboardInterrupt: 

[1;38;5;196mCOMET ERROR:[0m Due to connectivity issues, there's an error in processing the heartbeat. The experiment's status updates might be inaccurate until the connection issues are resolved.


In [20]:
# Validate the model

metrics = model.val(data="data.yaml", imgsz=640, save_json=True, plots=True)
print("mAP50-95:", metrics.box.map)
print("mAP50:", metrics.box.map50)
print("mAP75:", metrics.box.map75)

conf_matrix = ConfusionMatrix(nc=len(model.names))
conf_matrix.process_batch(results.boxes, results.target)

conf_matrix.plot_confusion_matrix(normalize=True, title="Confusion Matrix")
plt.savefig("confusion_matrix.png")
plt.show()

tp = conf_matrix.tp
fp = conf_matrix.fp
fn = conf_matrix.fn
tn = conf_matrix.tn

precision = tp / (tp + fp + 1e-9)
recall = tp / (tp + fn + 1e-9)
avg_precision = np.mean(precision)
avg_recall = np.mean(recall)
print(f"Average Precision: {avg_precision}")
print(f"Average Recall: {avg_recall}")

f1_score = 2 * (precision * recall) / (precision + recall + 1e-9)
avg_f1_score = np.mean(f1_score)
print(f"Average F1 Score: {avg_f1_score}")

fpr = fp / (fp + tn + 1e-9)
avg_fpr = np.mean(fpr)
print(f"Average False Positive Rate: {avg_fpr}")

80


In [22]:
# Save the model and exit logger

model.save(f"models/{model_name}_threat_detection.pt")
experiment.log_artifact(f"models/{model_name}_threat_detection.pt")
experiment.end()

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : envious_chicken_3994
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/chiragagg5k/enhancing-home-security/af02068495d34f2e93ee98454d6d856b
[1;38;5;39mCOMET INFO:[0m   Uploads:
[1;38;5;39mCOMET INFO:[0m     environment details      : 1
[1;38;5;39mCOMET INFO:[0m     filename                 : 1
[1;38;5;39mCOMET INFO:[0m     git metadata             : 1
[1;38;5;39mCOMET INFO:[0m     git-patch (uncompressed) : 1 (1.20 KB)
[1;38;5;39mCOMET INFO:[0m     installed packages       : 1
[1;38;5;39mCOMET INFO:[0m     notebook