# üßÇ Salt Crystal Purity Classification - YOLOv8 Training
### With Data Augmentation for Better Results

This notebook trains a YOLOv8 model to classify salt crystals as **pure** or **impure**.

**What's New:** Added data augmentation step to multiply your dataset 3x for better model performance!

## Requirements
- Dataset labeled in Label Studio (YOLO format), zipped
- Google Colab with GPU runtime

## Before Starting
1. Go to **Runtime > Change runtime type**
2. Select **T4 GPU** (or any available GPU)
3. Click **Save**

---
## Step 1: Check GPU & Install Dependencies

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Install Ultralytics (YOLOv8) and Albumentations (for augmentation)
!pip install ultralytics albumentations -q

# Verify installation
import ultralytics
ultralytics.checks()

import albumentations as A
print(f"\nAlbumentations version: {A.__version__}")

---
## Step 2: Mount Google Drive & Load Dataset

Your dataset is stored in Google Drive at `MyDrive/salt-crystal/data.zip`

In [None]:
from google.colab import drive
import zipfile
import os

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')

# Path to your dataset in Google Drive
zip_path = '/content/drive/MyDrive/salt-crystal/data.zip'

# Verify the file exists
if os.path.exists(zip_path):
    print(f"\n‚úÖ Dataset found: {zip_path}")
else:
    print(f"\n‚ùå ERROR: Dataset not found at {zip_path}")
    print("Please check the path and try again.")

# Extract the dataset
print("\nExtracting dataset...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall('/content/dataset')

print("‚úÖ Dataset extracted successfully!")
print("\nExtracted contents:")
!ls -la /content/dataset

In [None]:
# Explore the dataset structure
import os

def list_directory(path, indent=0):
    """List directory contents recursively (2 levels deep)"""
    if indent > 2:
        return
    try:
        items = os.listdir(path)
        for item in items[:10]:
            full_path = os.path.join(path, item)
            if os.path.isdir(full_path):
                print("  " * indent + f"üìÅ {item}/")
                list_directory(full_path, indent + 1)
            else:
                print("  " * indent + f"   üìÑ {item}")
        if len(items) > 10:
            print("  " * indent + f"   ... and {len(items) - 10} more files")
    except Exception as e:
        print(f"Error: {e}")

print("Dataset structure:")
print("=" * 50)
list_directory('/content/dataset')

---
## Step 3: Verify Dataset Paths

Label Studio exports data with `images/` and `labels/` folders. Let's verify the paths.

In [None]:
# Label Studio YOLO export structure:
# - images/    (contains all images)
# - labels/    (contains YOLO format .txt files)
# - classes.txt (contains class names)

SOURCE_IMAGES = '/content/dataset/images'
SOURCE_LABELS = '/content/dataset/labels'

import os

if os.path.exists(SOURCE_IMAGES):
    num_images = len([f for f in os.listdir(SOURCE_IMAGES) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    print(f"‚úÖ Images folder found: {SOURCE_IMAGES}")
    print(f"   Contains {num_images} images")
else:
    print(f"‚ùå WARNING: Images folder NOT found at {SOURCE_IMAGES}")

if os.path.exists(SOURCE_LABELS):
    num_labels = len([f for f in os.listdir(SOURCE_LABELS) if f.endswith('.txt')])
    print(f"‚úÖ Labels folder found: {SOURCE_LABELS}")
    print(f"   Contains {num_labels} label files")
else:
    print(f"‚ùå WARNING: Labels folder NOT found at {SOURCE_LABELS}")

# Check classes.txt
classes_file = '/content/dataset/classes.txt'
if os.path.exists(classes_file):
    with open(classes_file, 'r') as f:
        classes = [line.strip() for line in f.readlines() if line.strip()]
    print(f"\n‚úÖ Classes found in classes.txt:")
    for i, cls in enumerate(classes):
        print(f"   {i}: {cls}")
else:
    print(f"\n‚ùå WARNING: classes.txt not found at {classes_file}")

---
## üÜï Step 4: Data Augmentation

This is the **NEW STEP** that multiplies your dataset!

**What it does:**
- Takes your 360 original images
- Creates 3 augmented versions of each image
- Automatically adjusts bounding box coordinates
- Results in ~1,440 total images (360 original + 1,080 augmented)

**Augmentations applied:**
- Horizontal Flip
- Random Brightness & Contrast
- Slight Rotation (¬±10¬∞)
- Gaussian Noise
- Blur

In [None]:
import albumentations as A
import cv2
import os
import glob
from tqdm import tqdm
import numpy as np

# Configuration
SOURCE_IMAGES = '/content/dataset/images'
SOURCE_LABELS = '/content/dataset/labels'
AUG_IMAGES = '/content/dataset/images_augmented'
AUG_LABELS = '/content/dataset/labels_augmented'
NUM_AUGMENTATIONS = 3  # Number of augmented versions per image

# Create output directories
os.makedirs(AUG_IMAGES, exist_ok=True)
os.makedirs(AUG_LABELS, exist_ok=True)

# Define augmentation pipeline
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.Rotate(limit=10, p=0.4, border_mode=cv2.BORDER_CONSTANT),
    A.GaussNoise(var_limit=(10, 50), p=0.3),
    A.OneOf([
        A.MotionBlur(blur_limit=3, p=0.5),
        A.GaussianBlur(blur_limit=3, p=0.5),
    ], p=0.2),
    A.CLAHE(clip_limit=2.0, p=0.2),  # Improve contrast
], bbox_params=A.BboxParams(
    format='yolo',
    label_fields=['class_labels'],
    min_visibility=0.3  # Keep boxes that are at least 30% visible
))

print("üîÑ Starting Data Augmentation...")
print(f"   Creating {NUM_AUGMENTATIONS} augmented versions per image")
print("=" * 50)

In [None]:
# Get all image files
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
image_files = [f for f in os.listdir(SOURCE_IMAGES) if f.lower().endswith(image_extensions)]

print(f"Found {len(image_files)} original images")

# First, copy all original images and labels to augmented folders
print("\nüìã Copying original images...")
for img_file in tqdm(image_files, desc="Copying originals"):
    # Copy original image
    img_path = os.path.join(SOURCE_IMAGES, img_file)
    img = cv2.imread(img_path)
    cv2.imwrite(os.path.join(AUG_IMAGES, img_file), img)
    
    # Copy original label
    name = os.path.splitext(img_file)[0]
    label_path = os.path.join(SOURCE_LABELS, f"{name}.txt")
    if os.path.exists(label_path):
        with open(label_path, 'r') as f:
            content = f.read()
        with open(os.path.join(AUG_LABELS, f"{name}.txt"), 'w') as f:
            f.write(content)

print(f"‚úÖ Copied {len(image_files)} original images")

In [None]:
# Now create augmented versions
print("\nüé® Creating augmented versions...")
augmented_count = 0
skipped_count = 0

for img_file in tqdm(image_files, desc="Augmenting"):
    img_path = os.path.join(SOURCE_IMAGES, img_file)
    name = os.path.splitext(img_file)[0]
    ext = os.path.splitext(img_file)[1]
    label_path = os.path.join(SOURCE_LABELS, f"{name}.txt")
    
    # Read image
    image = cv2.imread(img_path)
    if image is None:
        skipped_count += 1
        continue
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Read labels (YOLO format: class x_center y_center width height)
    bboxes = []
    class_labels = []
    
    if os.path.exists(label_path):
        with open(label_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 5:
                    class_labels.append(parts[0])
                    # YOLO format: x_center, y_center, width, height (normalized)
                    bbox = [float(x) for x in parts[1:5]]
                    # Clamp values to valid range
                    bbox = [max(0.001, min(0.999, x)) for x in bbox]
                    bboxes.append(bbox)
    
    # Create multiple augmented versions
    for aug_idx in range(NUM_AUGMENTATIONS):
        try:
            # Apply augmentation
            augmented = transform(
                image=image,
                bboxes=bboxes,
                class_labels=class_labels
            )
            
            aug_image = augmented['image']
            aug_bboxes = augmented['bboxes']
            aug_class_labels = augmented['class_labels']
            
            # Save augmented image
            aug_name = f"{name}_aug{aug_idx + 1}"
            aug_image_bgr = cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR)
            cv2.imwrite(os.path.join(AUG_IMAGES, f"{aug_name}{ext}"), aug_image_bgr)
            
            # Save augmented labels
            with open(os.path.join(AUG_LABELS, f"{aug_name}.txt"), 'w') as f:
                for cls, bbox in zip(aug_class_labels, aug_bboxes):
                    # Ensure values are within valid range
                    bbox = [max(0.001, min(0.999, x)) for x in bbox]
                    f.write(f"{cls} {' '.join(f'{x:.6f}' for x in bbox)}\n")
            
            augmented_count += 1
            
        except Exception as e:
            # Skip if augmentation fails (e.g., bbox goes out of bounds)
            skipped_count += 1
            continue

print(f"\n" + "=" * 50)
print(f"‚úÖ AUGMENTATION COMPLETE!")
print(f"=" * 50)
print(f"   Original images:  {len(image_files)}")
print(f"   Augmented images: {augmented_count}")
print(f"   Total images:     {len(image_files) + augmented_count}")
if skipped_count > 0:
    print(f"   Skipped:          {skipped_count} (due to errors)")

In [None]:
# Visualize some augmented samples
import matplotlib.pyplot as plt
import random

# Get a random original image
sample_img = random.choice(image_files)
sample_name = os.path.splitext(sample_img)[0]
sample_ext = os.path.splitext(sample_img)[1]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

# Original
orig_path = os.path.join(AUG_IMAGES, sample_img)
if os.path.exists(orig_path):
    img = cv2.imread(orig_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    axes[0].imshow(img)
    axes[0].set_title('Original', fontsize=12)
    axes[0].axis('off')

# Augmented versions
for i in range(3):
    aug_path = os.path.join(AUG_IMAGES, f"{sample_name}_aug{i+1}{sample_ext}")
    if os.path.exists(aug_path):
        img = cv2.imread(aug_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        axes[i+1].imshow(img)
        axes[i+1].set_title(f'Augmented {i+1}', fontsize=12)
        axes[i+1].axis('off')

plt.suptitle(f'Sample: {sample_name}', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nüëÜ Above shows one original image and its 3 augmented versions")

In [None]:
# Update source paths to use augmented data for the rest of the pipeline
SOURCE_IMAGES = '/content/dataset/images_augmented'
SOURCE_LABELS = '/content/dataset/labels_augmented'

print("‚úÖ Source paths updated to use augmented dataset")
print(f"   Images: {SOURCE_IMAGES}")
print(f"   Labels: {SOURCE_LABELS}")
print(f"\n   Total images available: {len(os.listdir(SOURCE_IMAGES))}")
print(f"   Total labels available: {len(os.listdir(SOURCE_LABELS))}")

---
## Step 5: Organize Dataset (Train/Validation Split)

Now we split the augmented dataset into training and validation sets.

In [None]:
import os
import shutil
import random

# Use augmented data paths
SOURCE_IMAGES = '/content/dataset/images_augmented'
SOURCE_LABELS = '/content/dataset/labels_augmented'

# Create train/valid directories
os.makedirs('/content/dataset/train/images', exist_ok=True)
os.makedirs('/content/dataset/train/labels', exist_ok=True)
os.makedirs('/content/dataset/valid/images', exist_ok=True)
os.makedirs('/content/dataset/valid/labels', exist_ok=True)

# Get all image files from augmented folder
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
image_files = [f for f in os.listdir(SOURCE_IMAGES) if f.lower().endswith(image_extensions)]

print(f"Found {len(image_files)} total images (original + augmented)")

# Shuffle for random split
random.seed(42)  # For reproducibility
random.shuffle(image_files)

# Split 90% train, 10% validation
split_idx = int(len(image_files) * 0.9)
train_files = image_files[:split_idx]
valid_files = image_files[split_idx:]

print(f"\nüìä Dataset Split:")
print(f"   Training set:   {len(train_files)} images (90%)")
print(f"   Validation set: {len(valid_files)} images (10%)")

In [None]:
# Copy files to train folder
print("\nüìã Copying training files...")
for img in tqdm(train_files, desc="Training set"):
    # Copy image
    shutil.copy(os.path.join(SOURCE_IMAGES, img), '/content/dataset/train/images/')
    # Copy corresponding label
    label = os.path.splitext(img)[0] + '.txt'
    label_path = os.path.join(SOURCE_LABELS, label)
    if os.path.exists(label_path):
        shutil.copy(label_path, '/content/dataset/train/labels/')

# Copy files to valid folder
print("\nüìã Copying validation files...")
for img in tqdm(valid_files, desc="Validation set"):
    # Copy image
    shutil.copy(os.path.join(SOURCE_IMAGES, img), '/content/dataset/valid/images/')
    # Copy corresponding label
    label = os.path.splitext(img)[0] + '.txt'
    label_path = os.path.join(SOURCE_LABELS, label)
    if os.path.exists(label_path):
        shutil.copy(label_path, '/content/dataset/valid/labels/')

print("\n" + "=" * 50)
print("‚úÖ Dataset organization complete!")
print("=" * 50)
print(f"   Train images: {len(os.listdir('/content/dataset/train/images'))}")
print(f"   Train labels: {len(os.listdir('/content/dataset/train/labels'))}")
print(f"   Valid images: {len(os.listdir('/content/dataset/valid/images'))}")
print(f"   Valid labels: {len(os.listdir('/content/dataset/valid/labels'))}")

---
## Step 6: Create Dataset Configuration (YAML)

In [None]:
# Read class names from Label Studio's classes.txt
classes_file = '/content/dataset/classes.txt'

with open(classes_file, 'r') as f:
    classes = [line.strip() for line in f.readlines() if line.strip()]

print(f"Found {len(classes)} classes: {classes}")

# Build YAML configuration dynamically
yaml_lines = [
    "path: /content/dataset",
    "train: train/images",
    "val: valid/images",
    "",
    "names:"
]

for i, cls in enumerate(classes):
    yaml_lines.append(f"  {i}: {cls}")

yaml_content = "\n".join(yaml_lines)

# Write YAML file
with open('/content/dataset/salt_crystal.yaml', 'w') as f:
    f.write(yaml_content)

print("\n‚úÖ Dataset configuration file created!")
print("=" * 50)
print(yaml_content)

In [None]:
# Verify class labels in your dataset
import os

label_dir = '/content/dataset/train/labels'
label_files = os.listdir(label_dir)[:3]

print("Sample label files content:")
print("(Format: class_id x_center y_center width height)")
print("=" * 50)

for lf in label_files:
    print(f"\n{lf}:")
    with open(os.path.join(label_dir, lf), 'r') as f:
        content = f.read().strip()
        print(content if content else "  (empty file)")

---
## Step 7: Train YOLOv8 Model

### Model Options:
| Model | Size | Speed | Accuracy |
|-------|------|-------|----------|
| yolov8n.pt | Nano | Fastest | Good |
| yolov8s.pt | Small | Fast | Better |
| yolov8m.pt | Medium | Moderate | High |
| yolov8l.pt | Large | Slower | Highest |

In [None]:
from ultralytics import YOLO

# Load a pretrained YOLOv8 model
# Using yolov8s.pt (Small) for better accuracy with augmented data
model = YOLO('yolov8s.pt')

# Train the model
results = model.train(
    data='/content/dataset/salt_crystal.yaml',
    epochs=100,           # Number of training epochs
    imgsz=640,            # Image size
    batch=16,             # Batch size (reduce to 8 if memory error)
    patience=20,          # Early stopping patience
    save=True,            # Save checkpoints
    project='/content/runs',
    name='salt_crystal_model',
    exist_ok=True,        # Overwrite if exists
    pretrained=True,      # Use pretrained weights
    optimizer='auto',     # Automatic optimizer selection
    verbose=True,         # Print training progress
    seed=42,              # For reproducibility
    
    # Augmentation settings (YOLOv8 built-in - complements our offline augmentation)
    augment=True,         # Enable built-in augmentation
    hsv_h=0.015,          # HSV-Hue augmentation
    hsv_s=0.7,            # HSV-Saturation augmentation
    hsv_v=0.4,            # HSV-Value augmentation
    degrees=0.0,          # Rotation (we already did this offline)
    translate=0.1,        # Translation
    scale=0.5,            # Scale
    fliplr=0.5,           # Horizontal flip
    mosaic=1.0,           # Mosaic augmentation
)

---
## Step 8: Evaluate Model Performance

In [None]:
from ultralytics import YOLO

# Load the best trained model
model = YOLO('/content/runs/salt_crystal_model/weights/best.pt')

# Validate on validation set
metrics = model.val()

# Print metrics
print("\n" + "=" * 50)
print("üìä MODEL PERFORMANCE METRICS")
print("=" * 50)
print(f"mAP50:      {metrics.box.map50:.4f}  (Mean Average Precision @ IoU 50%)")
print(f"mAP50-95:   {metrics.box.map:.4f}  (Mean AP across IoU thresholds)")
print(f"Precision:  {metrics.box.mp:.4f}  (How many detections are correct)")
print(f"Recall:     {metrics.box.mr:.4f}  (How many objects were detected)")
print("=" * 50)

if metrics.box.map50 > 0.8:
    print("\nüéâ Model performance is EXCELLENT!")
elif metrics.box.map50 > 0.7:
    print("\n‚úÖ Model performance is GOOD!")
elif metrics.box.map50 > 0.5:
    print("\n‚ö†Ô∏è Model performance is ACCEPTABLE. Consider more training data for improvement.")
else:
    print("\n‚ùå Model performance needs improvement. Try:\n- More labeled images\n- Larger model (yolov8m.pt)\n- More epochs")

In [None]:
# View training results plots
from IPython.display import Image, display
import os

results_dir = '/content/runs/salt_crystal_model'

# Display confusion matrix
if os.path.exists(f'{results_dir}/confusion_matrix.png'):
    print("üìä Confusion Matrix:")
    display(Image(filename=f'{results_dir}/confusion_matrix.png', width=600))

# Display training results
if os.path.exists(f'{results_dir}/results.png'):
    print("\nüìà Training Results:")
    display(Image(filename=f'{results_dir}/results.png', width=800))

---
## Step 9: Test Predictions on Sample Images

In [None]:
from ultralytics import YOLO
from IPython.display import Image, display
import glob

# Load trained model
model = YOLO('/content/runs/salt_crystal_model/weights/best.pt')

# Run inference on validation images
results = model.predict(
    source='/content/dataset/valid/images',
    save=True,
    conf=0.5,  # Confidence threshold
    project='/content/runs',
    name='predictions',
    exist_ok=True
)

print("‚úÖ Predictions complete!")

In [None]:
# Display prediction results
from IPython.display import Image, display
import glob
import os

# Get prediction images
pred_dir = '/content/runs/predictions'
result_images = glob.glob(f'{pred_dir}/*.jpg') + glob.glob(f'{pred_dir}/*.png')

print(f"Showing {min(6, len(result_images))} prediction results:\n")

for img_path in result_images[:6]:
    print(f"üñºÔ∏è Image: {os.path.basename(img_path)}")
    display(Image(filename=img_path, width=500))
    print("-" * 50)

---
## Step 10: Download Trained Model

In [None]:
from google.colab import files

# Download the best model weights
print("üì• Downloading best.pt (your trained model)...")
files.download('/content/runs/salt_crystal_model/weights/best.pt')

In [None]:
# Optional: Download last checkpoint as backup
from google.colab import files

print("üì• Downloading last.pt (backup checkpoint)...")
files.download('/content/runs/salt_crystal_model/weights/last.pt')

---
## üìã Summary: What Changed with Augmentation

| Metric | Before (No Augmentation) | After (With Augmentation) |
|--------|--------------------------|---------------------------|
| Original Images | 360 | 360 |
| Total Training Images | ~324 | ~1,296 |
| Data Variety | Low | High |
| Overfitting Risk | High | Lower |
| Expected Accuracy | Moderate | Higher |

### Augmentations Applied:
- ‚úÖ Horizontal Flip (50% chance)
- ‚úÖ Brightness & Contrast adjustment
- ‚úÖ Rotation (¬±10¬∞)
- ‚úÖ Gaussian Noise
- ‚úÖ Motion/Gaussian Blur
- ‚úÖ CLAHE (Contrast enhancement)

---
## üñ•Ô∏è Local Deployment Guide

After downloading `best.pt`, use these code snippets on your local machine:

### Install Requirements
```bash
pip install ultralytics opencv-python
```

### Run Inference
```python
from ultralytics import YOLO

# Load model
model = YOLO('best.pt')

# Predict on image
results = model.predict('salt_sample.jpg', conf=0.5)
results[0].show()

# Predict on webcam
results = model.predict(source=0, show=True)
```

---
## üîß Troubleshooting

### GPU Memory Error
Reduce batch size in training cell: `batch=8` or `batch=4`

### Low Accuracy
- Add more labeled images (500+)
- Use larger model: `yolov8m.pt`
- Increase epochs: `epochs=150`
- Increase `NUM_AUGMENTATIONS` to 5

### Runtime Disconnects
- Enable background execution
- Use Colab Pro for longer sessions

### Augmentation Errors
- Some images may be skipped if bounding boxes go out of bounds
- This is normal and handled automatically