# Blackjack Card Detection — YOLO Training Pipeline

This notebook:
1. Trains a YOLOv5-nano **object detection** model on the synthetic card dataset (10 epochs)
2. Evaluates on the test split (precision, recall, mAP, confusion matrix)
3. Extracts ground-truth crops (Option B) for per-class analysis
4. Exports `best.pt` ready to drop into the blackjack bot

**Prerequisites:** `data.yaml` and `train/`, `valid/`, `test/` folders in the working directory.

## 0 — Imports & GPU Setup

In [None]:
import os, glob, random, shutil, pathlib, warnings
warnings.filterwarnings('ignore')

import torch
import numpy as np
import pandas as pd
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from ultralytics import YOLO

sns.set(style='darkgrid')

# ── GPU selection ────────────────────────────────────────────
# Change "1" to whichever GPU index is free on your server.
# Run `nvidia-smi` in a terminal to check availability.
os.environ['CUDA_DEVICE_ORDER']   = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')
else:
    print('No GPU detected — training will use CPU (slow but works).')

## 1 — Verify Dataset Structure

Quick sanity checks before training: does the data exist, how many images per split, how many classes?

In [None]:
# ── Paths — adjust if your layout differs ───────────────────
DATA_YAML  = 'data.yaml'
TRAIN_IMGS = './train/images'
TRAIN_LBLS = './train/labels'
VAL_IMGS   = './valid/images'
TEST_IMGS  = './test/images'

for p in [TRAIN_IMGS, TRAIN_LBLS, VAL_IMGS, TEST_IMGS, DATA_YAML]:
    exists = os.path.exists(p)
    print(f"{'✅' if exists else '❌'}  {p}")
    if not exists:
        print(f"   ⚠ Missing! Fix this before training.")

print()
for name, path in [('Train', TRAIN_IMGS), ('Valid', VAL_IMGS), ('Test', TEST_IMGS)]:
    n = len(glob.glob(os.path.join(path, '*'))) if os.path.isdir(path) else 0
    print(f'{name}: {n:,} images')

In [None]:
# ── Parse data.yaml to get class names ──────────────────────
import yaml

with open(DATA_YAML) as f:
    data_cfg = yaml.safe_load(f)

CLASS_NAMES = data_cfg['names']           # list or dict of class names
if isinstance(CLASS_NAMES, dict):         # ultralytics sometimes uses {0: 'name', ...}
    CLASS_NAMES = [CLASS_NAMES[k] for k in sorted(CLASS_NAMES.keys())]

NUM_CLASSES = len(CLASS_NAMES)
print(f'{NUM_CLASSES} classes: {CLASS_NAMES[:10]} ... {CLASS_NAMES[-5:]}')

## 2 — Preview Training Images

In [None]:
sample_files = random.sample(os.listdir(TRAIN_IMGS),
                             min(9, len(os.listdir(TRAIN_IMGS))))

fig, axes = plt.subplots(3, 3, figsize=(12, 12))
for ax, fname in zip(axes.flat, sample_files):
    img = plt.imread(os.path.join(TRAIN_IMGS, fname))
    ax.imshow(img)
    ax.set_title(fname, fontsize=8)
    ax.axis('off')
plt.suptitle('Random Training Samples', fontsize=14)
plt.tight_layout()
plt.show()

## 3 — Train YOLO Detection Model

Fine-tune YOLOv5-nano (pre-trained on COCO) on the card dataset.  
10 epochs for speed — bump to 50–100 for a serious run.

In [None]:
# ── Hyperparameters ─────────────────────────────────────────
BASE_MODEL = 'yolov5nu.pt'   # nano — fast. Try 'yolov8s.pt' for better accuracy.
EPOCHS     = 10
BATCH      = 16
IMGSZ      = 416
DEVICE     = 0 if torch.cuda.is_available() else 'cpu'

model = YOLO(BASE_MODEL)

results = model.train(
    data=DATA_YAML,
    epochs=EPOCHS,
    batch=BATCH,
    imgsz=IMGSZ,
    optimizer='auto',
    device=DEVICE,
    cache=False,
    patience=5,          # early stopping if val loss plateaus
    save=True,
    plots=True,          # auto-generate confusion matrix, PR curves, etc.
    verbose=True,
)

# The best weights are saved automatically
TRAIN_DIR = pathlib.Path(results.save_dir)
BEST_PT   = TRAIN_DIR / 'weights' / 'best.pt'
print(f'\n✅ Training complete.  Best weights: {BEST_PT}')

## 4 — Training Curves

In [None]:
csv_path = TRAIN_DIR / 'results.csv'
df = pd.read_csv(csv_path)
df.columns = df.columns.str.strip()   # ultralytics pads column names with spaces

fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Box loss
axes[0].plot(df['epoch'], df['train/box_loss'], label='train')
axes[0].plot(df['epoch'], df['val/box_loss'],   label='val')
axes[0].set_title('Box Loss');  axes[0].legend()

# Classification loss
axes[1].plot(df['epoch'], df['train/cls_loss'], label='train')
axes[1].plot(df['epoch'], df['val/cls_loss'],   label='val')
axes[1].set_title('Classification Loss');  axes[1].legend()

# mAP
axes[2].plot(df['epoch'], df['metrics/mAP50(B)'],    label='mAP@50')
axes[2].plot(df['epoch'], df['metrics/mAP50-95(B)'], label='mAP@50-95')
axes[2].set_title('Validation mAP');  axes[2].legend()

for ax in axes:
    ax.set_xlabel('Epoch')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5 — Confusion Matrix

In [None]:
cm_path = TRAIN_DIR / 'confusion_matrix.png'
if cm_path.exists():
    plt.figure(figsize=(14, 14))
    plt.imshow(Image.open(cm_path))
    plt.axis('off')
    plt.title('Confusion Matrix (validation set)')
    plt.show()
else:
    print(f'Confusion matrix not found at {cm_path}')

## 6 — Evaluate on Test Set

In [None]:
best_model = YOLO(str(BEST_PT))

test_stats = best_model.val(split='test')

print('─' * 40)
print(f"Precision : {test_stats.results_dict['metrics/precision(B)']:.3f}")
print(f"Recall    : {test_stats.results_dict['metrics/recall(B)']:.3f}")
print(f"mAP@50    : {test_stats.results_dict['metrics/mAP50(B)']:.3f}")
print(f"mAP@50-95 : {test_stats.results_dict['metrics/mAP50-95(B)']:.3f}")
print('─' * 40)

## 7 — Visual Predictions on Test Images

In [None]:
# Run predictions and save annotated images
pred_results = best_model.predict(
    source=TEST_IMGS,
    save=True,
    imgsz=IMGSZ,
    conf=0.25,
    verbose=False,
)
pred_dir = pathlib.Path(pred_results[0].save_dir)

# Display a sample
pred_images = [f for f in os.listdir(pred_dir) if f.lower().endswith(('.jpg','.jpeg','.png'))]
samples = random.sample(pred_images, min(10, len(pred_images)))

fig, axes = plt.subplots(2, 5, figsize=(22, 9))
for ax, fname in zip(axes.flat, samples):
    img = Image.open(pred_dir / fname)
    ax.imshow(img)
    ax.set_title(fname, fontsize=7)
    ax.axis('off')
plt.suptitle('Test Predictions (annotated)', fontsize=14)
plt.tight_layout()
plt.show()

## 8 — Option B: Extract Ground-Truth Crops for Per-Class Analysis

This section reads the YOLO-format `.txt` label files, crops each card from the
source images using ground-truth bounding boxes, and saves them into an
`ImageFolder` structure (`crops/<split>/<class_name>/`).  

Useful for:
- Inspecting which classes your detector confuses
- Training a standalone classification model later if needed
- Counting per-class sample distribution

In [None]:
CROP_ROOT = pathlib.Path('./crops')

def extract_crops(images_dir, labels_dir, split_name, class_names):
    """
    Read YOLO-format label files, crop cards from images using
    ground-truth bounding boxes, and save into crops/<split>/<class>/.
    
    Returns a Counter of {class_name: count}.
    """
    out_root = CROP_ROOT / split_name
    counts = Counter()
    
    img_files = sorted(glob.glob(os.path.join(images_dir, '*')))
    
    for img_path in img_files:
        # Derive label path: .../images/foo.jpg → .../labels/foo.txt
        stem = pathlib.Path(img_path).stem
        label_path = os.path.join(labels_dir, stem + '.txt')
        if not os.path.exists(label_path):
            continue
        
        img = cv2.imread(img_path)
        if img is None:
            continue
        h, w = img.shape[:2]
        
        with open(label_path) as f:
            for i, line in enumerate(f):
                parts = line.strip().split()
                if len(parts) < 5:
                    continue
                
                cls_id = int(parts[0])
                cx, cy, bw, bh = map(float, parts[1:5])
                
                # Convert YOLO normalized (cx, cy, w, h) → pixel (x1, y1, x2, y2)
                x1 = int((cx - bw / 2) * w)
                y1 = int((cy - bh / 2) * h)
                x2 = int((cx + bw / 2) * w)
                y2 = int((cy + bh / 2) * h)
                
                # Clamp to image bounds
                x1, y1 = max(0, x1), max(0, y1)
                x2, y2 = min(w, x2), min(h, y2)
                
                if x2 <= x1 or y2 <= y1:
                    continue
                
                crop = img[y1:y2, x1:x2]
                
                cls_name = class_names[cls_id] if cls_id < len(class_names) else f'unknown_{cls_id}'
                out_dir = out_root / cls_name
                out_dir.mkdir(parents=True, exist_ok=True)
                
                out_file = out_dir / f'{stem}_{i}.jpg'
                cv2.imwrite(str(out_file), crop)
                counts[cls_name] += 1
    
    return counts


# ── Run extraction for each split ────────────────────────────
splits = [
    ('train', './train/images', './train/labels'),
    ('valid', './valid/images', './valid/labels'),
    ('test',  './test/images',  './test/labels'),
]

all_counts = {}
for split_name, img_dir, lbl_dir in splits:
    if os.path.isdir(img_dir) and os.path.isdir(lbl_dir):
        print(f'Extracting crops for {split_name}...')
        counts = extract_crops(img_dir, lbl_dir, split_name, CLASS_NAMES)
        all_counts[split_name] = counts
        print(f'  → {sum(counts.values()):,} crops across {len(counts)} classes')
    else:
        print(f'⚠ Skipping {split_name} — directories not found')

print(f'\nCrops saved to: {CROP_ROOT.resolve()}')

### 8b — Class Distribution (from crops)

In [None]:
if 'train' in all_counts:
    train_counts = all_counts['train']
    sorted_cls = sorted(train_counts.keys())
    vals = [train_counts.get(c, 0) for c in sorted_cls]
    
    plt.figure(figsize=(20, 5))
    plt.bar(range(len(sorted_cls)), vals, color='steelblue')
    plt.xticks(range(len(sorted_cls)), sorted_cls, rotation=90, fontsize=7)
    plt.ylabel('Number of Crops')
    plt.title('Training Set — Crops per Class (ground truth)')
    plt.tight_layout()
    plt.show()
    
    # Flag any imbalances
    min_cls = min(train_counts, key=train_counts.get)
    max_cls = max(train_counts, key=train_counts.get)
    print(f'Fewest samples: {min_cls} ({train_counts[min_cls]})')
    print(f'Most samples:   {max_cls} ({train_counts[max_cls]})')
    print(f'Ratio: {train_counts[max_cls] / max(1, train_counts[min_cls]):.1f}x')

### 8c — Preview Random Crops

In [None]:
crop_files = list(CROP_ROOT.rglob('train/**/*.jpg'))
if crop_files:
    samples = random.sample(crop_files, min(16, len(crop_files)))
    fig, axes = plt.subplots(2, 8, figsize=(20, 5))
    for ax, fp in zip(axes.flat, samples):
        img = plt.imread(str(fp))
        ax.imshow(img)
        ax.set_title(fp.parent.name, fontsize=9)  # class name is the folder
        ax.axis('off')
    plt.suptitle('Random Ground-Truth Crops', fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print('No crop files found.')

## 9 — Per-Class Accuracy (via model predictions on crops)

Run the trained **detection** model on individual crops to see which card classes
it recognises well and which it struggles with. This is a quick proxy for
per-class recall on isolated cards.

In [None]:
test_crop_dir = CROP_ROOT / 'test'

if test_crop_dir.exists():
    correct, total = 0, 0
    per_class_correct = Counter()
    per_class_total   = Counter()
    
    for cls_dir in sorted(test_crop_dir.iterdir()):
        if not cls_dir.is_dir():
            continue
        true_cls = cls_dir.name
        
        for img_path in cls_dir.glob('*.jpg'):
            preds = best_model.predict(str(img_path), conf=0.2, verbose=False)
            
            # Take the highest-confidence detection
            if preds and len(preds[0].boxes) > 0:
                top_box = preds[0].boxes[0]  # highest conf
                pred_cls = best_model.names[int(top_box.cls[0])]
            else:
                pred_cls = 'NO_DETECTION'
            
            per_class_total[true_cls] += 1
            total += 1
            if pred_cls == true_cls:
                correct += 1
                per_class_correct[true_cls] += 1
    
    print(f'Overall crop accuracy: {correct}/{total} = {correct/max(1,total):.1%}')
    print()
    
    # Show worst classes
    per_class_acc = {}
    for cls in sorted(per_class_total.keys()):
        acc = per_class_correct[cls] / max(1, per_class_total[cls])
        per_class_acc[cls] = acc
    
    worst = sorted(per_class_acc.items(), key=lambda x: x[1])[:10]
    print('Worst 10 classes by crop accuracy:')
    for cls, acc in worst:
        print(f'  {cls:>5s}: {acc:.0%}  ({per_class_correct[cls]}/{per_class_total[cls]})')
else:
    print('No test crops found — run section 8 first.')

## 10 — Export `best.pt` for the Blackjack Bot

In [None]:
# Copy best.pt to a convenient location
EXPORT_PATH = pathlib.Path('./best.pt')
shutil.copy2(BEST_PT, EXPORT_PATH)

size_mb = EXPORT_PATH.stat().st_size / 1e6
print(f'✅ Exported: {EXPORT_PATH.resolve()}  ({size_mb:.1f} MB)')
print()
print('To use in the blackjack bot:')
print('  1. Copy best.pt into the blackjack-bot/ folder')
print('  2. Verify config/settings.py has  MODEL_PATH = "best.pt"')
print('  3. Run:  python -m bot.main')

## 11 — Quick Smoke Test

Confirm the exported model loads and can predict on a single test image.

In [None]:
# Load the exported weights fresh (simulates what the bot does)
smoke_model = YOLO(str(EXPORT_PATH))

test_files = glob.glob(os.path.join(TEST_IMGS, '*'))
if test_files:
    sample = random.choice(test_files)
    res = smoke_model.predict(sample, conf=0.25, verbose=False)
    
    # Draw results
    annotated = res[0].plot()  # returns BGR numpy array with boxes drawn
    
    plt.figure(figsize=(10, 8))
    plt.imshow(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
    plt.title(f'Smoke test — {len(res[0].boxes)} detections')
    plt.axis('off')
    plt.show()
    
    # Print detections
    for box in res[0].boxes:
        cls_name = smoke_model.names[int(box.cls[0])]
        conf = float(box.conf[0])
        print(f'  {cls_name:>5s}  conf={conf:.2f}  bbox={box.xyxy[0].tolist()}')
    
    print(f'\n✅ Model works. Ready for the bot.')
else:
    print('No test images found.')