# LEGO Bricks ML Vision - Training Pipeline Demonstration

This notebook provides a step-by-step demonstration of training the YOLO-based computer vision models for detecting LEGO bricks and studs. The training pipeline showcases how to:

1. Prepare and process LEGO brick/stud datasets
2. Configure and train YOLOv8 models 
3. Evaluate model performance
4. Export models for inference

## Project Overview

The LEGO Bricks ML Vision project uses two distinct computer vision models:
- **Brick Detection Model**: Identifies complete LEGO bricks in images
- **Stud Detection Model**: Identifies individual studs on LEGO bricks

Together, these models enable the full classification pipeline that can identify brick dimensions based on stud patterns.

![Training Pipeline Overview](../docs/assets/images/train_pipeline_diagram.png)

## 1. Setup and Environment Configuration

In [1]:
# Check for "LEGO_BRICKS_ML_VISION" folder in the cwd folder branch

import os
import sys
from pathlib import Path
import subprocess
import logging
import rich.logging as rlog

# Set up rich logger with emoji support
logger = logging.getLogger("notebook_logger")
if not logger.handlers:
    handler = rlog.RichHandler(rich_tracebacks=True, markup=True, show_time=False)
    logger.setLevel(logging.INFO)
    logger.addHandler(handler)
    logger.info("✅ [bold green]Logger initialized for LEGO ML Vision notebook[/bold green]")

def check_repo_clone():
    """
    Check if the cwd name matches the repo name.
    If not, check if the parent folder matches the repo name.
    If not, clone the repo.

    Returns the local repo root path and adds it to the sys.path
    """
    # Set up rich logger
    logger = logging.getLogger("repo_setup")
    handler = rlog.RichHandler(rich_tracebacks=True, markup=True)
    logger.setLevel(logging.INFO)
    if not logger.handlers:
        logger.addHandler(handler)
    
    userGithub = "MiguelDiLalla"
    repoGithub = "LEGO_Bricks_ML_Vision"
    repo_url = f"https://github.com/{userGithub}/{repoGithub}.git"

    cwd = Path.cwd()
    cwd_name = cwd.name
    cwd_parent = cwd.parent

    logger.info(f"Checking for repository: [bold blue]{repoGithub}[/bold blue]")
    
    if cwd_name != repoGithub and cwd_parent.name != repoGithub:
        logger.info(f"Repository not found in current path or parent directory")
        logger.info(f"Cloning from [green]{repo_url}[/green]...")
        
        try:
            subprocess.run(["git", "clone", repo_url], check=True)
            logger.info(f"Repository successfully cloned")
            # Add the repo to the sys.path
            sys.path.append(cwd / repoGithub)
            # Change the cwd to the repo root
            os.chdir(cwd / repoGithub)
            return cwd / repoGithub
        except subprocess.CalledProcessError as e:
            logger.error(f"Failed to clone repository: {e}")
            logger.error(f"Please clone manually with: git clone {repo_url}")
            raise RuntimeError(f"Repository setup failed: {e}")
    else:
        repo_path = cwd if cwd_name == repoGithub else cwd_parent
        logger.info(f"Repository [bold blue]{repoGithub}[/bold blue] already available at [bold green]{repo_path}[/bold green]")
        
        # Add the repo to the sys.path
        sys.path.append(repo_path)
        # Change the cwd to the repo root (and log it)
        logger.info(f"Changing working directory to: [bold green]{repo_path}[/bold green]")
        os.chdir(repo_path)
        return repo_path

repo_clone_path = check_repo_clone()



In [2]:
def setup_requirements(repo_path):
    """
    Fetch the requirements.txt file
    check if the requirements are already installed
    install the requirements if not installed

    returns none and sumary the installed and missing packages
    """
    # Set up rich logger
    logger = logging.getLogger("requirements_setup")
    handler = rlog.RichHandler(rich_tracebacks=True, markup=True)
    logger.setLevel(logging.INFO)
    if not logger.handlers:
        logger.addHandler(handler)

    requirements_path = repo_path / "requirements.txt"
    logger.info(f"Checking for requirements file: [bold blue]{requirements_path}[/bold blue]")

    if requirements_path.exists():
        logger.info(f"Requirements file found")
        with open(requirements_path, "r") as f:
            requirements = f.read().splitlines()
        logger.info(f"Checking for installed packages...")
        installed_packages = subprocess.run(["pip", "freeze"], capture_output=True, text=True).stdout.splitlines()
        missing_packages = [pkg for pkg in requirements if not any(pkg.split("==")[0] in pkg for pkg in installed_packages)]
        if missing_packages:
            logger.info(f"Missing packages: [bold red]{missing_packages}[/bold red]")
            logger.info(f"Installing missing packages...")
            subprocess.run(["pip", "install", *missing_packages], check=True)
            logger.info(f"Requirements successfully installed")
        else:
            logger.info(f"All requirements are already installed")
    else:
        logger.error(f"Requirements file not found")
        logger.error(f"Please create a requirements.txt file with the required packages")
        raise FileNotFoundError(f"Requirements file not found: {requirements_path}")

setup_requirements(repo_clone_path)



In [3]:
# Import required libraries

import matplotlib.pyplot as plt
import numpy as np
import cv2
from pathlib import Path
from IPython.display import Image, display
from ultralytics import YOLO
from PIL import Image as PILImage
import random


# Import project modules
from train import setup_logging, detect_hardware, load_config, cleanup_training_sessions
from train import unzip_dataset, validate_dataset, create_dataset_structure, dataset_split , augment_data
from train import select_model, train_model, display_last_training_session



In [7]:
# Setup logging and load configuration
setup_logging()
config = load_config()

# Detect optimal hardware
device = detect_hardware()
print(f"Training will use device: {device}")

cleanup_training_sessions(repo_clone_path)







## 2. Dataset Exploration

Let's examine our dataset before training. We have separate datasets for brick detection and stud detection.

In [8]:
# Extract and prepare datasets
brick_dataset_path = unzip_dataset("bricks")
stud_dataset_path = unzip_dataset("studs")

# Display dataset information
def display_dataset_info(mode):
    images_path, labels_path = validate_dataset(mode)
    image_count = len(list(Path(images_path).glob("*.jpg")))
    label_count = len(list(Path(labels_path).glob("*.txt")))
    
    print(f"\n{mode.capitalize()} Dataset:")
    print(f"- Images path: {images_path}")
    print(f"- Labels path: {labels_path}")
    print(f"- Total images: {image_count}")
    print(f"- Total labels: {label_count}")
    
    if image_count != label_count:
        print(f"⚠️  Warning: Number of images and labels do not match")
        raise ValueError("Number of images and labels do not match")
    
    return images_path, labels_path

bricks_images_path, bricks_labels_path = display_dataset_info("bricks")
studs_images_path, studs_labels_path = display_dataset_info("studs")






### Visualize Sample Images with Annotations

In [10]:
def visualize_sample_with_annotations(images_path, labels_path, mode, num_samples=3):
    """
    Visualize sample images with their YOLO format annotations.
    
    Args:
        images_path: Path to images directory
        labels_path: Path to labels directory 
        mode: Either "bricks" or "studs" to determine visualization style
        num_samples: Number of samples to visualize
    """
    # Get image files and ensure we don't exceed available samples
    image_files = list(Path(images_path).glob("*.jpg"))
    num_samples = min(num_samples, len(image_files))
    samples = random.sample(image_files, num_samples)
    
    # Create figure for horizontal layout
    fig, axes = plt.subplots(1, num_samples, figsize=(5*num_samples, 5))
    if num_samples == 1:
        axes = [axes]  # Make iterable for single sample case
        
    # Define colors based on mode
    colors = {
        "bricks": {"box": "red", "text_bg": "darkred"},
        "studs": {"box": "blue", "text_bg": "darkblue"}
    }
    color = colors.get(mode, {"box": "green", "text_bg": "darkgreen"})
    
    for i, (img_path, ax) in enumerate(zip(samples, axes)):
        # Load image
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        height, width = img.shape[:2]
        
        # Display image
        ax.imshow(img)
        ax.set_title(f"{mode.capitalize()} Sample {i+1}\n{img_path.name}", fontsize=10)
        ax.axis('off')  # Hide axis for cleaner visualization
        
        # Try to find the corresponding label file
        label_file = Path(labels_path) / f"{img_path.stem}.txt"
        
        # If not found directly, check subdirectories
        if not label_file.exists():
            parent_dir = Path(labels_path).parent
            for subdir in ['train', 'val', 'test']:
                alt_path = parent_dir / subdir / "labels" / f"{img_path.stem}.txt"
                if alt_path.exists():
                    label_file = alt_path
                    break
        
        # Load and draw annotations if label file exists
        if label_file.exists():
            with open(label_file, 'r') as f:
                annotations = f.readlines()
            
            # Process each annotation
            for ann in annotations:
                parts = ann.strip().split()
                if len(parts) >= 5:  # Ensure valid YOLO format
                    class_id = int(float(parts[0]))
                    x_center, y_center, w, h = map(float, parts[1:5])
                    
                    # Convert normalized YOLO coordinates to pixel values
                    x1 = int((x_center - w/2) * width)
                    y1 = int((y_center - h/2) * height)
                    x2 = int((x_center + w/2) * width)
                    y2 = int((y_center + h/2) * height)
                    
                    # Draw bounding box
                    rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                                        fill=False, edgecolor=color["box"], linewidth=2)
                    ax.add_patch(rect)
                    
                    # Add label text with background
                    ax.text(x1, y1-5, f"Class {class_id}", color='white', fontsize=8,
                           bbox=dict(facecolor=color["text_bg"], alpha=0.7))
        else:
            ax.text(10, 30, "No labels found", color='white', fontsize=10,
                   bbox=dict(facecolor='red', alpha=0.7))
    
    plt.tight_layout()
    plt.show()

# Display dataset samples with annotations
print("\n===== Brick Detection Dataset Samples =====")
visualize_sample_with_annotations(bricks_images_path, bricks_labels_path, "bricks")

print("\n===== Stud Detection Dataset Samples =====")
visualize_sample_with_annotations(studs_images_path, studs_labels_path, "studs")





### Create YOLO training structure:

In [None]:
create_dataset_structure("bricks", repo_clone_path)
create_dataset_structure("studs", repo_clone_path)



### Split into train, validation, and test sets:

In [None]:
bricks_split_path = dataset_split("bricks", repo_clone_path)
studs_split_path = dataset_split("studs", repo_clone_path)



### Augment Training Data using Albumentations:

In [None]:
augment_data(bricks_split_path, 2)
augment_data(studs_split_path, 2)



## 3. Training the Brick Detection Model

Now let's train the YOLOv8 model for brick detection. We'll use the training pipeline from `train.py`.

In [None]:
# Training configuration for the brick detection model
brick_training_params = {
    "dataset_path": bricks_split_path,
    "model_path": select_model("bricks", use_pretrained=False),
    "device": device,
    "epochs": 1,
    "batch_size": 2,
    "repo_root": repo_clone_path
}

print("Starting Brick Detection Model Training...")
print(f"Parameters: {brick_training_params}")

# Run the training pipeline (set to shorter epochs for demo purposes)
try:
    brick_results_dir = train_model(**brick_training_params)
    print(f"\nTraining completed. Results saved to: {brick_results_dir}")
except Exception as e:
    print(f"Training error: {e}")













































### Visualize Brick Detection Training Results

In [None]:
# Display training session output files:

#   get the directory of the training session by last time modified

brick_results_dir = Path(repo_clone_path) / "results" / "bricks"
last_bricks_results_dir = max(brick_results_dir.glob("*"), key=lambda f: f.stat().st_mtime)

display_last_training_session(last_bricks_results_dir)



## 4. Training the Stud Detection Model

Now we'll train the YOLOv8 model for detecting studs on LEGO bricks.

In [None]:
# Training configuration for the stud detection model
studs_training_params = {
    "dataset_path": studs_split_path,
    "model_path": select_model("bricks", use_pretrained=False),
    "device": device,
    "epochs": 1,
    "batch_size": 2,
    "repo_root": repo_clone_path
}

print("Starting Studs Detection Model Training...")
print(f"Parameters: {studs_training_params}")

# Run the training pipeline (set to shorter epochs for demo purposes)
try:
    studs_results_dir = train_model(**studs_training_params)
    print(f"\nTraining completed. Results saved to: {studs_results_dir}")
except Exception as e:
    print(f"Training error: {e}")



### Visualize Stud Detection Training Results

In [None]:
# Display training session output files:

#   get the directory of the training session by last time modified

studs_results_dir = Path(repo_clone_path) / "results" / "studs"
last_studs_results_dir = max(studs_results_dir.glob("*"), key=lambda f: f.stat().st_mtime)

display_last_training_session(last_studs_results_dir)



## 5. Model Evaluation on Test Images

Let's test our trained models on some unseen images to see how they perform.

In [None]:
# Load the best trained models
try:
    brick_model_path = Path(last_bricks_results_dir) / "weights" / "best.pt"
    stud_model_path = Path(last_studs_results_dir) / "weights" / "best.pt"
    
    brick_model = YOLO(str(brick_model_path))
    stud_model = YOLO(str(stud_model_path))
    
    print(f"Models loaded successfully")
except Exception as e:
    print(f"Error loading models: {e}")



### Load Inference modules:

In [None]:
# Core detection utilities
from utils.detection_utils import detect_bricks, detect_studs

# Visualization utilities
from utils.visualization_utils import draw_detection_visualization

# Configuration (needed for model loading)
from utils.config_utils import config

In [None]:
# set images avaible inside the presentation folder 

Bricks_Presentation_folder = Path(repo_clone_path) / "presentation" / "Test_images" / "BricksPics"
Studs_Presentation_folder = Path(repo_clone_path) / "presentation" / "Test_images" / "StudsPics"

### Evaluate Brick Detection Model:

In [None]:
# Randomly select 6 images from the Bricks folder
brick_images = list(Bricks_Presentation_folder.glob("*.jpg"))
selected_images = random.sample(brick_images, min(6, len(brick_images)))

# Create a figure with 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Process each image and display results
for idx, image_path in enumerate(selected_images):
    row = idx // 3
    col = idx % 3
    
    # Run brick detection
    brick_results = detect_bricks(
        str(image_path),
        conf=0.25,
        save_annotated=False
    )
    
    # Load and display the annotated image
    img = PILImage.open(image_path)
    img = np.array(img)
    
    # Draw detections
    annotated_img = draw_detection_visualization(img, brick_results)
    
    # Display in grid
    axes[row, col].imshow(annotated_img)
    axes[row, col].set_title(f'Brick Detection: {image_path.name}')
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

### Evaluate Stud Detection Model:

In [None]:
# Randomly select 6 images from the Studs folder
stud_images = list(Studs_Presentation_folder.glob("*.jpg"))
selected_images = random.sample(stud_images, min(6, len(stud_images)))

# Create a figure with 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Process each image and display results
for idx, image_path in enumerate(selected_images):
    row = idx // 3
    col = idx % 3
    
    # Run stud detection
    stud_results = detect_studs(
        str(image_path),
        conf=0.25,
        save_annotated=False
    )
    
    # Load and display the annotated image
    img = PILImage.open(image_path)
    img = np.array(img)
    
    # Draw detections
    annotated_img = draw_detection_visualization(img, stud_results)
    
    # Display in grid
    axes[row, col].imshow(annotated_img)
    axes[row, col].set_title(f'Stud Detection: {image_path.name}')
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()