# Flower Classification - Training and Prediction

This notebook demonstrates how to train a deep learning model and make predictions using the command-line scripts.

## Project Structure
- `train.py` - Training script
- `predict.py` - Prediction script
- `workspace_utils.py` - Helper functions
- `flowers/` - Dataset directory
- `cat_to_name.json` - Category names mapping

## 1. Setup and Imports

First, let's verify that all necessary files are in place.

In [None]:
import os
import subprocess

# Check if required files exist
required_files = ['train.py', 'predict.py', 'workspace_utils.py', 'cat_to_name.json']
for file in required_files:
    if os.path.exists(file):
        print(f"✅ {file} found")
    else:
        print(f"❌ {file} not found")

# Check if flowers directory exists
if os.path.exists('flowers'):
    print(f"✅ flowers/ directory found")
    # Check subdirectories
    for subdir in ['train', 'valid', 'test']:
        path = f'flowers/{subdir}'
        if os.path.exists(path):
            num_classes = len(os.listdir(path))
            print(f"   - {subdir}: {num_classes} classes")
else:
    print(f"❌ flowers/ directory not found")

## 2. Train the Model

Train a VGG19 model on the flower dataset. This will take several minutes depending on your hardware.

### Training Parameters:
- Architecture: VGG19
- Learning Rate: 0.001
- Hidden Units: 512
- Dropout: 0.3
- Epochs: 5
- GPU: Enabled (if available)

In [None]:
# Training command
train_command = [
    'python3', 'train.py',
    'flowers',                    # data directory
    '--save_dir', 'checkpoint.pth',  # where to save the model
    '--arch', 'vgg19',            # architecture
    '--learning_rate', '0.001',   # learning rate
    '--hidden_units', '512',      # hidden units
    '--dropout', '0.3',           # dropout rate
    '--epochs', '5',              # number of epochs
    '--gpu',                      # use GPU
    '--verbose'                   # verbose output
]

print("Starting training...")
print(f"Command: {' '.join(train_command)}")
print("\n" + "="*70 + "\n")

# Run the training script
result = subprocess.run(train_command, capture_output=False, text=True)

if result.returncode == 0:
    print("\n" + "="*70)
    print("✅ Training completed successfully!")
else:
    print("\n" + "="*70)
    print(f"❌ Training failed with exit code {result.returncode}")

## 3. Verify Checkpoint

Check if the training checkpoint was created successfully.

In [None]:
checkpoint_path = 'checkpoint.pth'

if os.path.exists(checkpoint_path):
    size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
    print(f"✅ Checkpoint saved: {checkpoint_path}")
    print(f"   Size: {size_mb:.2f} MB")
else:
    print(f"❌ Checkpoint not found: {checkpoint_path}")

## 4. Make Predictions

Use the trained model to predict flower classes for test images.

### Prediction Parameters:
- Top K: 5 (show top 5 predictions)
- GPU: Enabled (if available)
- Category names: Load from JSON file

In [None]:
# Find a test image
test_image = None
test_dir = 'flowers/test'

if os.path.exists(test_dir):
    # Get first available test image
    for class_dir in os.listdir(test_dir):
        class_path = os.path.join(test_dir, class_dir)
        if os.path.isdir(class_path):
            images = [f for f in os.listdir(class_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
            if images:
                test_image = os.path.join(class_path, images[0])
                print(f"Using test image: {test_image}")
                break

if not test_image:
    print("❌ No test images found!")
else:
    # Prediction command
    predict_command = [
        'python3', 'predict.py',
        test_image,                          # image to predict
        'checkpoint.pth',                    # model checkpoint
        '--category_names', 'cat_to_name.json',  # category names
        '--top_k', '5',                      # top 5 predictions
        '--gpu',                             # use GPU
        '--verbose'                          # verbose output
    ]

    print(f"\nCommand: {' '.join(predict_command)}")
    print("\n" + "="*70 + "\n")

    # Run the prediction script
    result = subprocess.run(predict_command, capture_output=False, text=True)

    if result.returncode == 0:
        print("\n" + "="*70)
        print("✅ Prediction completed successfully!")
    else:
        print("\n" + "="*70)
        print(f"❌ Prediction failed with exit code {result.returncode}")

## 5. Display Prediction Results

View the generated prediction plot.

In [None]:
from IPython.display import Image, display
import glob

# Find the most recent prediction image
prediction_images = glob.glob('*_prediction.png')

if prediction_images:
    # Sort by modification time, get most recent
    latest_prediction = max(prediction_images, key=os.path.getmtime)
    print(f"Displaying prediction: {latest_prediction}\n")
    display(Image(filename=latest_prediction))
else:
    print("❌ No prediction images found")

## 6. Batch Predictions (Optional)

Run predictions on multiple test images.

In [None]:
import random

# Collect multiple test images
test_images = []
test_dir = 'flowers/test'

if os.path.exists(test_dir):
    for class_dir in os.listdir(test_dir):
        class_path = os.path.join(test_dir, class_dir)
        if os.path.isdir(class_path):
            images = [os.path.join(class_path, f) for f in os.listdir(class_path) 
                     if f.endswith(('.jpg', '.jpeg', '.png'))]
            if images:
                # Add one random image from this class
                test_images.append(random.choice(images))

# Limit to 5 images
test_images = test_images[:5]

print(f"Running predictions on {len(test_images)} images...\n")

for idx, image_path in enumerate(test_images, 1):
    print(f"\n{'='*70}")
    print(f"Image {idx}/{len(test_images)}: {os.path.basename(image_path)}")
    print('='*70)
    
    predict_command = [
        'python3', 'predict.py',
        image_path,
        'checkpoint.pth',
        '--category_names', 'cat_to_name.json',
        '--top_k', '3',
        '--gpu'
    ]
    
    subprocess.run(predict_command, capture_output=False, text=True)

print(f"\n{'='*70}")
print("✅ Batch predictions completed!")

## 7. Alternative: Train with Different Architectures

Compare different model architectures.

In [None]:
# Train with ResNet18 (faster, smaller model)
train_command_resnet = [
    'python3', 'train.py',
    'flowers',
    '--save_dir', 'checkpoint_resnet18.pth',
    '--arch', 'resnet18',         # Different architecture
    '--learning_rate', '0.001',
    '--hidden_units', '256',      # Smaller hidden layer
    '--dropout', '0.3',
    '--epochs', '5',
    '--gpu'
]

print("Training ResNet18 model...")
print(f"Command: {' '.join(train_command_resnet)}")
print("\nNote: This is optional. Uncomment the line below to run.")

# Uncomment to run:
# subprocess.run(train_command_resnet, capture_output=False, text=True)

## 8. Clean Up (Optional)

Remove generated files if needed.

In [None]:
# Uncomment to clean up generated files

# import glob
# import os

# # Remove checkpoint files
# for checkpoint in glob.glob('checkpoint*.pth'):
#     os.remove(checkpoint)
#     print(f"Removed: {checkpoint}")

# # Remove prediction images
# for pred_img in glob.glob('*_prediction.png'):
#     os.remove(pred_img)
#     print(f"Removed: {pred_img}")

print("Uncomment the code above to clean up files.")