# Salt Crystal Purity Classification - YOLOv8 Training

This notebook trains a YOLOv8 model to classify salt crystals as **pure** or **impure**.

## Requirements
- Dataset labeled in Label Studio (YOLO format), zipped
- Google Colab with GPU runtime

## Before Starting
1. Go to **Runtime > Change runtime type**
2. Select **T4 GPU** (or any available GPU)
3. Click **Save**

---
## Step 1: Check GPU & Install Dependencies

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Install Ultralytics (YOLOv8)
!pip install ultralytics -q

# Verify installation
import ultralytics
ultralytics.checks()

---
## Step 2: Mount Google Drive & Load Dataset

Your dataset is stored in Google Drive at `MyDrive/salt-crystal/data.zip`

In [None]:
from google.colab import drive
import zipfile
import os

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')

# Path to your dataset in Google Drive
zip_path = '/content/drive/MyDrive/salt-crystal/data.zip'

# Verify the file exists
if os.path.exists(zip_path):
    print(f"\nDataset found: {zip_path}")
else:
    print(f"\nERROR: Dataset not found at {zip_path}")
    print("Please check the path and try again.")

# Extract the dataset
print("\nExtracting dataset...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall('/content/dataset')

print("Dataset extracted successfully!")
print("\nExtracted contents:")
!ls -la /content/dataset

In [None]:
# Explore the dataset structure to find images and labels folders
import os

def list_directory(path, indent=0):
    """List directory contents recursively (2 levels deep)"""
    if indent > 2:
        return
    try:
        items = os.listdir(path)
        for item in items[:10]:  # Limit to 10 items
            full_path = os.path.join(path, item)
            if os.path.isdir(full_path):
                print("  " * indent + f"[DIR] {item}/")
                list_directory(full_path, indent + 1)
            else:
                print("  " * indent + f"      {item}")
        if len(items) > 10:
            print("  " * indent + f"      ... and {len(items) - 10} more files")
    except Exception as e:
        print(f"Error: {e}")

print("Dataset structure:")
print("="*50)
list_directory('/content/dataset')

---
## Step 3: Verify Dataset Paths

Label Studio exports data with `images/` and `labels/` folders. Let's verify the paths.

In [None]:
# Label Studio YOLO export structure:
# - images/    (contains all images)
# - labels/    (contains YOLO format .txt files)
# - classes.txt (contains class names)

SOURCE_IMAGES = '/content/dataset/images'
SOURCE_LABELS = '/content/dataset/labels'

# Verify paths exist
import os

if os.path.exists(SOURCE_IMAGES):
    num_images = len([f for f in os.listdir(SOURCE_IMAGES) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    print(f"Images folder found: {SOURCE_IMAGES}")
    print(f"  Contains {num_images} images")
else:
    print(f"WARNING: Images folder NOT found at {SOURCE_IMAGES}")

if os.path.exists(SOURCE_LABELS):
    num_labels = len([f for f in os.listdir(SOURCE_LABELS) if f.endswith('.txt')])
    print(f"Labels folder found: {SOURCE_LABELS}")
    print(f"  Contains {num_labels} label files")
else:
    print(f"WARNING: Labels folder NOT found at {SOURCE_LABELS}")

# Check classes.txt
classes_file = '/content/dataset/classes.txt'
if os.path.exists(classes_file):
    with open(classes_file, 'r') as f:
        classes = [line.strip() for line in f.readlines() if line.strip()]
    print(f"\nClasses found in classes.txt:")
    for i, cls in enumerate(classes):
        print(f"  {i}: {cls}")
else:
    print(f"\nWARNING: classes.txt not found at {classes_file}")

---
## Step 4: Organize Dataset (Train/Validation Split)

In [None]:
import os
import shutil
import random

# Create train/valid directories
os.makedirs('/content/dataset/train/images', exist_ok=True)
os.makedirs('/content/dataset/train/labels', exist_ok=True)
os.makedirs('/content/dataset/valid/images', exist_ok=True)
os.makedirs('/content/dataset/valid/labels', exist_ok=True)

# Get all image files
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
image_files = [f for f in os.listdir(SOURCE_IMAGES) if f.lower().endswith(image_extensions)]

print(f"Found {len(image_files)} images")

# Shuffle for random split
random.seed(42)  # For reproducibility
random.shuffle(image_files)

# Split 90% train, 10% validation
split_idx = int(len(image_files) * 0.9)
train_files = image_files[:split_idx]
valid_files = image_files[split_idx:]

print(f"Training set: {len(train_files)} images")
print(f"Validation set: {len(valid_files)} images")

# Copy files to train folder
print("\nCopying training files...")
for img in train_files:
    # Copy image
    shutil.copy(os.path.join(SOURCE_IMAGES, img), '/content/dataset/train/images/')
    # Copy corresponding label
    label = os.path.splitext(img)[0] + '.txt'
    label_path = os.path.join(SOURCE_LABELS, label)
    if os.path.exists(label_path):
        shutil.copy(label_path, '/content/dataset/train/labels/')

# Copy files to valid folder
print("Copying validation files...")
for img in valid_files:
    # Copy image
    shutil.copy(os.path.join(SOURCE_IMAGES, img), '/content/dataset/valid/images/')
    # Copy corresponding label
    label = os.path.splitext(img)[0] + '.txt'
    label_path = os.path.join(SOURCE_LABELS, label)
    if os.path.exists(label_path):
        shutil.copy(label_path, '/content/dataset/valid/labels/')

print("\nDataset organization complete!")
print(f"Train images: {len(os.listdir('/content/dataset/train/images'))}")
print(f"Train labels: {len(os.listdir('/content/dataset/train/labels'))}")
print(f"Valid images: {len(os.listdir('/content/dataset/valid/images'))}")
print(f"Valid labels: {len(os.listdir('/content/dataset/valid/labels'))}")

---
## Step 5: Create Dataset Configuration (YAML)

In [None]:
# Read class names from Label Studio's classes.txt
classes_file = '/content/dataset/classes.txt'

with open(classes_file, 'r') as f:
    classes = [line.strip() for line in f.readlines() if line.strip()]

print(f"Found {len(classes)} classes: {classes}")

# Build YAML configuration dynamically
yaml_lines = [
    "path: /content/dataset",
    "train: train/images",
    "val: valid/images",
    "",
    "names:"
]

for i, cls in enumerate(classes):
    yaml_lines.append(f"  {i}: {cls}")

yaml_content = "\n".join(yaml_lines)

# Write YAML file
with open('/content/dataset/salt_crystal.yaml', 'w') as f:
    f.write(yaml_content)

print("\nDataset configuration file created!")
print("="*50)
print(yaml_content)

In [None]:
# Verify class labels in your dataset
# Check a few label files to confirm class IDs match
import os

label_dir = '/content/dataset/train/labels'
label_files = os.listdir(label_dir)[:3]

print("Sample label files content:")
print("(Format: class_id x_center y_center width height)")
print("="*50)

for lf in label_files:
    print(f"\n{lf}:")
    with open(os.path.join(label_dir, lf), 'r') as f:
        content = f.read().strip()
        print(content if content else "  (empty file)")

---
## Step 6: Train YOLOv8 Model

### Model Options:
| Model | Size | Speed | Accuracy |
|-------|------|-------|----------|
| yolov8n.pt | Nano | Fastest | Good |
| yolov8s.pt | Small | Fast | Better |
| yolov8m.pt | Medium | Moderate | High |
| yolov8l.pt | Large | Slower | Highest |

In [None]:
from ultralytics import YOLO

# Load a pretrained YOLOv8 model
# Change to 'yolov8s.pt' or 'yolov8m.pt' for better accuracy
model = YOLO('yolov8n.pt')

# Train the model
results = model.train(
    data='/content/dataset/salt_crystal.yaml',
    epochs=100,           # Number of training epochs
    imgsz=640,            # Image size
    batch=16,             # Batch size (reduce to 8 if memory error)
    patience=20,          # Early stopping patience
    save=True,            # Save checkpoints
    project='/content/runs',
    name='salt_crystal_model',
    exist_ok=True,        # Overwrite if exists
    pretrained=True,      # Use pretrained weights
    optimizer='auto',     # Automatic optimizer selection
    verbose=True,         # Print training progress
    seed=42               # For reproducibility
)

---
## Step 7: Evaluate Model Performance

In [None]:
from ultralytics import YOLO

# Load the best trained model
model = YOLO('/content/runs/salt_crystal_model/weights/best.pt')

# Validate on validation set
metrics = model.val()

# Print metrics
print("\n" + "="*50)
print("MODEL PERFORMANCE METRICS")
print("="*50)
print(f"mAP50:      {metrics.box.map50:.4f}  (Mean Average Precision @ IoU 50%)")
print(f"mAP50-95:   {metrics.box.map:.4f}  (Mean AP across IoU thresholds)")
print(f"Precision:  {metrics.box.mp:.4f}  (How many detections are correct)")
print(f"Recall:     {metrics.box.mr:.4f}  (How many objects were detected)")
print("="*50)

if metrics.box.map50 > 0.7:
    print("\nModel performance is GOOD!")
elif metrics.box.map50 > 0.5:
    print("\nModel performance is ACCEPTABLE. Consider more training data for improvement.")
else:
    print("\nModel performance needs improvement. Try:\n- More labeled images\n- Larger model (yolov8s.pt)\n- More epochs")

In [None]:
# View training results plots
from IPython.display import Image, display
import os

results_dir = '/content/runs/salt_crystal_model'

# Display confusion matrix
if os.path.exists(f'{results_dir}/confusion_matrix.png'):
    print("Confusion Matrix:")
    display(Image(filename=f'{results_dir}/confusion_matrix.png', width=600))

# Display training results
if os.path.exists(f'{results_dir}/results.png'):
    print("\nTraining Results:")
    display(Image(filename=f'{results_dir}/results.png', width=800))

---
## Step 8: Test Predictions on Sample Images

In [None]:
from ultralytics import YOLO
from IPython.display import Image, display
import glob

# Load trained model
model = YOLO('/content/runs/salt_crystal_model/weights/best.pt')

# Run inference on validation images
results = model.predict(
    source='/content/dataset/valid/images',
    save=True,
    conf=0.5,  # Confidence threshold
    project='/content/runs',
    name='predictions',
    exist_ok=True
)

print("Predictions complete!")

In [None]:
# Display prediction results
from IPython.display import Image, display
import glob

# Get prediction images
pred_dir = '/content/runs/predictions'
result_images = glob.glob(f'{pred_dir}/*.jpg') + glob.glob(f'{pred_dir}/*.png')

print(f"Showing {min(6, len(result_images))} prediction results:\n")

for img_path in result_images[:6]:
    print(f"Image: {os.path.basename(img_path)}")
    display(Image(filename=img_path, width=500))
    print("-" * 50)

---
## Step 9: Download Trained Model

In [None]:
from google.colab import files

# Download the best model weights
print("Downloading best.pt (your trained model)...")
files.download('/content/runs/salt_crystal_model/weights/best.pt')

In [None]:
# Optional: Download last checkpoint as backup
from google.colab import files

print("Downloading last.pt (backup checkpoint)...")
files.download('/content/runs/salt_crystal_model/weights/last.pt')

---
## Local Deployment Guide

After downloading `best.pt`, use these code snippets on your local machine:

### Install Requirements
```bash
pip install ultralytics opencv-python
```

### Run Inference
```python
from ultralytics import YOLO

# Load model
model = YOLO('best.pt')

# Predict on image
results = model.predict('salt_sample.jpg', conf=0.5)
results[0].show()

# Predict on webcam
results = model.predict(source=0, show=True)
```

---
## Troubleshooting

### GPU Memory Error
Reduce batch size in training cell: `batch=8` or `batch=4`

### Low Accuracy
- Add more labeled images (500+)
- Use larger model: `yolov8s.pt`
- Increase epochs: `epochs=200`

### Runtime Disconnects
- Enable background execution
- Use Colab Pro for longer sessions