# YOLOv8 Face Recognition Training Notebook

This notebook trains a YOLOv8 (classification mode) model for celebrity face recognition using the folder structure under `dataset/` (train/test per class). It:

- Installs required packages
- Detects GPU (CUDA) automatically
- Loads images via `torchvision.datasets.ImageFolder`
- Applies resizing & normalization transforms compatible with YOLOv8 classification (224x224)
- Trains for 20 epochs tracking Accuracy, Precision, Recall, F1, Loss
- Saves per-epoch model checkpoints & plots into a timestamped subfolder inside `results/`
- Selects & saves the best model (`best_model.pt`)
- Evaluates on the test set and produces a percentage confusion matrix.

Run cells in order. Adjust batch size or learning rate if you encounter memory limits.

In [1]:
# Install required packages (run once). If already installed, can skip.
%pip install ultralytics torch torchvision torchaudio tqdm seaborn scikit-learn matplotlib

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
# Imports, configuration, CUDA check, results directory setup
import os, math, time, json, random, shutil, datetime
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, accuracy_score
from tqdm.auto import tqdm

# Reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if device.type == 'cuda':
    print('GPU:', torch.cuda.get_device_name(0))

# Create timestamped run directory
RESULTS_ROOT = Path('results')
RESULTS_ROOT.mkdir(exist_ok=True)
run_timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
RUN_DIR = RESULTS_ROOT / run_timestamp
RUN_DIR.mkdir(parents=True, exist_ok=True)
print('Run directory:', RUN_DIR)

# Hyperparameters
IMG_SIZE = 224
BATCH_SIZE = 32  # adjust if out of memory
EPOCHS = 20
LR = 1e-3

Using device: cpu
Run directory: results\20250819_192427


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Dataset & DataLoaders
DATA_ROOT = Path('dataset')
train_dir = DATA_ROOT / 'train'
val_dir = DATA_ROOT / 'test'  # using provided test as validation/eval

# Transforms (YOLOv8 classification defaults roughly: resize + center crop + normalize ImageNet stats)
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
val_tfms = train_tfms

train_ds = datasets.ImageFolder(root=str(train_dir), transform=train_tfms)
val_ds = datasets.ImageFolder(root=str(val_dir), transform=val_tfms)

class_names = train_ds.classes
num_classes = len(class_names)
print(f'Classes ({num_classes}):', class_names)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=device.type=='cuda')
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=device.type=='cuda')

with open(RUN_DIR / 'classes.json', 'w') as f:
    json.dump(class_names, f, indent=2)
print('Saved class names to', RUN_DIR / 'classes.json')

Classes (31): ['Akshay Kumar', 'Alexandra Daddario', 'Alia Bhatt', 'Amitabh Bachchan', 'Andy Samberg', 'Anushka Sharma', 'Billie Eilish', 'Brad Pitt', 'Camila Cabello', 'Charlize Theron', 'Claire Holt', 'Courtney Cox', 'Dwayne Johnson', 'Elizabeth Olsen', 'Ellen Degeneres', 'Henry Cavill', 'Hrithik Roshan', 'Hugh Jackman', 'Jessica Alba', 'Kashyap', 'Lisa Kudrow', 'Margot Robbie', 'Marmik', 'Natalie Portman', 'Priyanka Chopra', 'Robert Downey Jr', 'Roger Federer', 'Tom Cruise', 'Vijay Deverakonda', 'Virat Kohli', 'Zac Efron']
Saved class names to results\20250819_192427\classes.json


In [4]:
# Model preparation (YOLOv8 classification)
from ultralytics import YOLO

# Load pretrained YOLOv8n classification weights
base_model = YOLO('yolov8n-cls.pt')

# Access the underlying torch ClassificationModel
model = base_model.model  # ClassificationModel

# --- Flexible classifier adaptation strategies ---
# Strategy 1: Use reset_classifier if available (some versions implement this)
reset_done = False
if hasattr(model, 'reset_classifier'):
    try:
        model.reset_classifier(num_classes=num_classes)
        print('Used model.reset_classifier to set num_classes =', num_classes)
        reset_done = True
    except Exception as e:
        print('reset_classifier failed:', e)

# Strategy 2: Replace last nn.Linear whose out_features looks like original class count (e.g., 1000) inside model.model (ModuleList)
if not reset_done and hasattr(model, 'model') and isinstance(model.model, (nn.Sequential, nn.ModuleList)):
    candidate_indices = [i for i, m in reversed(list(enumerate(model.model))) if isinstance(m, nn.Linear)]
    replaced = False
    for idx in candidate_indices:
        lin = model.model[idx]
        in_features = lin.in_features
        if lin.out_features != num_classes:
            try:
                model.model[idx] = nn.Linear(in_features, num_classes)
                print(f'Replaced Linear at index {idx}: {in_features} -> {num_classes}')
                replaced = True
                break
            except Exception as e:
                print(f'Failed replacing Linear at index {idx}:', e)
    if replaced:
        reset_done = True

# Strategy 3: Search all submodules for a terminal Linear layer with large out_features (e.g., >= num_classes)
if not reset_done:
    term_linear = None
    for name, module_sub in reversed(list(model.named_modules())):
        if isinstance(module_sub, nn.Linear):
            term_linear = (name, module_sub)
            break
    if term_linear is not None:
        name, lin = term_linear
        if lin.out_features != num_classes:
            in_features = lin.in_features
            # Replace via attribute traversal
            parent = model
            name_parts = name.split('.')
            for p in name_parts[:-1]:
                parent = getattr(parent, p)
            try:
                setattr(parent, name_parts[-1], nn.Linear(in_features, num_classes))
                print(f'Replaced terminal Linear {name}: {in_features} -> {num_classes}')
                reset_done = True
            except Exception as e:
                print(f'Failed replacing terminal Linear {name}:', e)
        else:
            print('Existing terminal Linear already matches num_classes.')
            reset_done = True

# Strategy 4: Wrap model with a new head if no internal Linear layer was found/replaced.
if not reset_done:
    print('No suitable internal Linear layer found; wrapping with new head.')
    class YOLOClassifierWrapper(nn.Module):
        def __init__(self, backbone, num_classes, img_size):
            super().__init__()
            self.backbone = backbone
            self.num_classes = num_classes
            self.img_size = img_size
            with torch.no_grad():
                dummy = torch.randn(1, 3, img_size, img_size)
                feat = self.backbone(dummy)
                if isinstance(feat, (list, tuple)):
                    feat = feat[0]
                # If 4D, global average pool to (N, C)
                if feat.ndim == 4:
                    feat = feat.mean(dim=(2,3))
                self.feat_dim = feat.shape[1]
            self.head = nn.Linear(self.feat_dim, num_classes)
        def forward(self, x):
            out = self.backbone(x)
            if isinstance(out, (list, tuple)):
                out = out[0]
            if out.ndim == 4:
                out = out.mean(dim=(2,3))
            # If out already equals desired num_classes, assume backbone handled it
            if out.shape[1] != self.num_classes:
                out = self.head(out)
            return out
    model = YOLOClassifierWrapper(model, num_classes, IMG_SIZE)

# Move to device
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

print('Model ready with', num_classes, 'classes. Head replacement strategy success =', reset_done)

Replaced terminal Linear model.9.linear: 1280 -> 31
Model ready with 31 classes. Head replacement strategy success = True


In [5]:
# Training loop
metrics_history = {
    'epoch': [],
    'train_loss': [],
    'train_acc': [],
    'train_precision': [],
    'train_recall': [],
    'train_f1': []
}

best_f1 = -1.0
best_path = RUN_DIR / 'best_model.pt'

for epoch in range(1, EPOCHS+1):
    model.train()
    epoch_losses = []
    all_preds = []
    all_targets = []

    pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{EPOCHS}', leave=False)
    for images, targets in pbar:
        images = images.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        if isinstance(outputs, (list, tuple)):
            outputs = outputs[0]
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())
        preds = outputs.argmax(dim=1)
        all_preds.append(preds.detach().cpu())
        all_targets.append(targets.detach().cpu())

        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()

    train_loss = float(np.mean(epoch_losses))
    train_acc = accuracy_score(all_targets, all_preds)
    train_precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
    train_recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
    train_f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

    metrics_history['epoch'].append(epoch)
    metrics_history['train_loss'].append(train_loss)
    metrics_history['train_acc'].append(train_acc)
    metrics_history['train_precision'].append(train_precision)
    metrics_history['train_recall'].append(train_recall)
    metrics_history['train_f1'].append(train_f1)

    # Save checkpoint
    epoch_ckpt = RUN_DIR / f'trained_epoch_{epoch}.pt'
    torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'metrics': metrics_history}, epoch_ckpt)

    # Plot metrics (acc/prec/recall/f1)
    fig1, ax1 = plt.subplots(figsize=(8,5))
    ax1.plot(metrics_history['epoch'], metrics_history['train_acc'], label='Accuracy')
    ax1.plot(metrics_history['epoch'], metrics_history['train_precision'], label='Precision')
    ax1.plot(metrics_history['epoch'], metrics_history['train_recall'], label='Recall')
    ax1.plot(metrics_history['epoch'], metrics_history['train_f1'], label='F1')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Score')
    ax1.set_title('Training Metrics')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    fig1.tight_layout()
    fig1.savefig(RUN_DIR / f'trained_epoch_{epoch}_metrics.png')
    plt.close(fig1)

    # Plot loss
    fig2, ax2 = plt.subplots(figsize=(8,5))
    ax2.plot(metrics_history['epoch'], metrics_history['train_loss'], label='Loss', color='red')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.set_title('Training Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    fig2.tight_layout()
    fig2.savefig(RUN_DIR / f'trained_epoch_{epoch}_loss.png')
    plt.close(fig2)

    # Update best model
    if train_f1 > best_f1:
        best_f1 = train_f1
        torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'metrics': metrics_history}, best_path)

    print(f'Epoch {epoch}: loss={train_loss:.4f} acc={train_acc:.4f} prec={train_precision:.4f} rec={train_recall:.4f} f1={train_f1:.4f} | best_f1={best_f1:.4f}')

print('Training complete. Best model saved to', best_path)

                                                                        

Epoch 1: loss=2.5482 acc=0.3556 prec=0.3791 rec=0.3239 f1=0.3206 | best_f1=0.3206


                                                                        

Epoch 2: loss=1.5230 acc=0.6526 prec=0.6621 rec=0.6268 f1=0.6298 | best_f1=0.6298


                                                                        

Epoch 3: loss=1.1249 acc=0.7557 prec=0.7595 rec=0.7448 f1=0.7479 | best_f1=0.7479


                                                                        

Epoch 4: loss=0.9130 acc=0.8177 prec=0.8285 rec=0.8117 f1=0.8168 | best_f1=0.8168


                                                                        

Epoch 5: loss=0.7547 acc=0.8493 prec=0.8554 rec=0.8478 f1=0.8502 | best_f1=0.8502


                                                                        

Epoch 6: loss=0.6354 acc=0.8790 prec=0.8836 rec=0.8807 f1=0.8814 | best_f1=0.8814


                                                                        

Epoch 7: loss=0.5597 acc=0.9020 prec=0.9063 rec=0.9014 f1=0.9033 | best_f1=0.9033


                                                                        

Epoch 8: loss=0.5100 acc=0.9094 prec=0.9141 rec=0.9113 f1=0.9118 | best_f1=0.9118


                                                                        

Epoch 9: loss=0.4588 acc=0.9227 prec=0.9330 rec=0.9229 f1=0.9265 | best_f1=0.9265


                                                                         

Epoch 10: loss=0.3783 acc=0.9477 prec=0.9516 rec=0.9514 f1=0.9511 | best_f1=0.9511


                                                                         

Epoch 11: loss=0.3414 acc=0.9586 prec=0.9614 rec=0.9597 f1=0.9602 | best_f1=0.9602


                                                                         

Epoch 12: loss=0.3147 acc=0.9586 prec=0.9606 rec=0.9610 f1=0.9604 | best_f1=0.9604


                                                                         

Epoch 13: loss=0.2733 acc=0.9696 prec=0.9713 rec=0.9702 f1=0.9705 | best_f1=0.9705


                                                                         

Epoch 14: loss=0.2486 acc=0.9738 prec=0.9741 rec=0.9745 f1=0.9741 | best_f1=0.9741


                                                                         

Epoch 15: loss=0.2331 acc=0.9770 prec=0.9776 rec=0.9780 f1=0.9777 | best_f1=0.9777


                                                                         

Epoch 16: loss=0.2212 acc=0.9805 prec=0.9820 rec=0.9819 f1=0.9818 | best_f1=0.9818


                                                                         

Epoch 17: loss=0.1859 acc=0.9856 prec=0.9862 rec=0.9865 f1=0.9863 | best_f1=0.9863


                                                                         

Epoch 18: loss=0.1736 acc=0.9902 prec=0.9909 rec=0.9908 f1=0.9907 | best_f1=0.9907


                                                                         

Epoch 19: loss=0.1641 acc=0.9926 prec=0.9932 rec=0.9933 f1=0.9932 | best_f1=0.9932


                                                                         

Epoch 20: loss=0.1682 acc=0.9902 prec=0.9907 rec=0.9910 f1=0.9908 | best_f1=0.9932
Training complete. Best model saved to results\20250819_192427\best_model.pt


In [7]:
# Test evaluation & confusion matrix
# Load best model
ckpt = torch.load(best_path, map_location=device)
model.load_state_dict(ckpt['model_state'])
model.eval()

all_preds = []
all_targets = []

with torch.no_grad():
    for images, targets in tqdm(val_loader, desc='Testing', leave=False):
        images = images.to(device)
        targets = targets.to(device)
        outputs = model(images)
        if isinstance(outputs, (list, tuple)):
            outputs = outputs[0]
        preds = outputs.argmax(dim=1)
        all_preds.append(preds.cpu())
        all_targets.append(targets.cpu())

all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()

overall_acc = accuracy_score(all_targets, all_preds)
overall_precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
overall_recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
overall_f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

print(f'Test: acc={overall_acc:.4f} precision={overall_precision:.4f} recall={overall_recall:.4f} f1={overall_f1:.4f}')

cm = confusion_matrix(all_targets, all_preds, labels=list(range(num_classes)))
cm_percent = cm.astype(float) / cm.sum(axis=1, keepdims=True)
cm_percent = np.nan_to_num(cm_percent)

fig, ax = plt.subplots(figsize=(max(8, num_classes*0.4), max(6, num_classes*0.4)))
sns.heatmap(cm_percent*100, annot=False, cmap='Blues', cbar=True, ax=ax, fmt='.1f')
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title('Confusion Matrix (%)')
ax.set_xticks(np.arange(num_classes)+0.5)
ax.set_yticks(np.arange(num_classes)+0.5)
ax.set_xticklabels(class_names, rotation=90)
ax.set_yticklabels(class_names, rotation=0)
fig.tight_layout()
cm_path = RUN_DIR / 'confusion_matrix.png'
fig.savefig(cm_path, dpi=150)
plt.close(fig)
print('Saved confusion matrix to', cm_path)

# Save final metrics JSON
final_metrics = {
    'overall_acc': overall_acc,
    'overall_precision': overall_precision,
    'overall_recall': overall_recall,
    'overall_f1': overall_f1,
    'best_epoch': ckpt['epoch']
}
with open(RUN_DIR / 'final_metrics.json', 'w') as f:
    json.dump(final_metrics, f, indent=2)
print('Saved final metrics JSON.')

                                                        

Test: acc=0.9984 precision=0.9986 recall=0.9985 f1=0.9985
Saved confusion matrix to results\20250819_192427\confusion_matrix.png
Saved final metrics JSON.
Saved confusion matrix to results\20250819_192427\confusion_matrix.png
Saved final metrics JSON.


## Training Run Summary

Artifacts saved under the timestamped folder in `results/` (value stored in `RUN_DIR`).

Contents include:
- `trained_epoch_{n}.pt` per-epoch checkpoints
- `trained_epoch_{n}_metrics.png` line plots for accuracy/precision/recall/F1
- `trained_epoch_{n}_loss.png` loss curve
- `best_model.pt` best checkpoint by F1
- `classes.json` class label order
- `confusion_matrix.png` final confusion matrix (percentage per true class)
- `final_metrics.json` overall test metrics and best epoch

You can start another training run by re-running from the imports cell (a new timestamped folder will be created). Adjust hyperparameters as needed (e.g., `BATCH_SIZE`, `LR`, `EPOCHS`).