In [1]:
import shutil
from fastai.vision.all import *
from pathlib import Path

# 1. Setup paths
# We use the 'val' folder from Imagenette to ensure these are NEW images
path_random = untar_data(URLs.IMAGENETTE_160)
source_random_test = get_image_files(path_random/'val')[:10] # Grab 10 fresh ones

dest_test_folder = Path('./test_images_set')
dest_test_folder.mkdir(exist_ok=True)

# 2. Copy them
print(f"Adding {len(source_random_test)} NEW random images to your Test Set...")
for i, file in enumerate(source_random_test):
    # We rename them so you can easily spot them in the results
    shutil.copy(file, dest_test_folder / f'test_random_{i}.jpg')

print("✅ DONE! Your test folder now includes 'Unknown' challenges.")



Adding 10 NEW random images to your Test Set...
✅ DONE! Your test folder now includes 'Unknown' challenges.


In [2]:
# --- BLOCK B: TRAIN 4-CLASS MODEL (Clay, Grass, Hard, Unknown) ---
import shutil
from fastai.vision.all import *
from pathlib import Path
import os

# 1. Setup path
path = Path('./tennis_courts')

# 2. Safety Check: Ensure unknown folder exists
unknown_folder = path/'unknown'
if not unknown_folder.exists() or len(get_image_files(unknown_folder)) == 0:
    print("⚠️ WARNING: The 'unknown' folder is empty! Run the download cell above first.")

# 3. Load Data (Now finding 4 classes)
dls = ImageDataLoaders.from_folder(
    path,
    valid_pct=0.2,
    seed=42,
    item_tfms=Resize(224, method='squish'),
    batch_tfms=aug_transforms(mult=1.5),
    bs=16,
    num_workers=0,
    device=torch.device('cpu') 
)

print(f"Classes found: {dls.vocab}") 
# VERIFY: Should be ['clay', 'grass', 'hard', 'unknown']

# 4. Train
learn = vision_learner(dls, resnet18, metrics=error_rate)
learn.model.to('cpu')
learn.dls.to('cpu')

print("Training 4-class model...")
learn.fine_tune(4)

Classes found: ['clay', 'grass', 'hard']
Training 4-class model...


epoch,train_loss,valid_loss,error_rate,time
0,2.062885,1.284347,0.428571,00:08


epoch,train_loss,valid_loss,error_rate,time
0,0.872955,1.0682,0.333333,00:08
1,0.667944,0.920734,0.285714,00:08
2,0.575649,0.755137,0.333333,00:08
3,0.448299,0.661885,0.333333,00:09


In [None]:
import coremltools as ct
import torch

# 1. Get the trained model (ensure it is on CPU)
model = learn.model.eval().cpu()

# 2. Create dummy input (1 image, 3 channels, 224x224 pixels)
# This tells the converter what size image the iPhone will send to the AI
dummy_input = torch.rand(1, 3, 224, 224)

# 3. Trace the model
print("Tracing model...")
traced_model = torch.jit.trace(model, dummy_input)

# 4. Convert to Core ML format
print("Converting to Core ML...")
mlmodel = ct.convert(
    traced_model,
    inputs=[ct.ImageType(
        name="image", 
        shape=dummy_input.shape, 
        scale=1/255.0, # Convert 0-255 RGB values to 0-1
        bias=[-0.485/0.229, -0.456/0.224, -0.406/0.225] # Normalize colors for ResNet
    )],
    classifier_config=ct.ClassifierConfig(list(dls.vocab)) # Save the class names (hard, clay, grass)
)

# 5. Add Metadata and Save
mlmodel.short_description = "Classifies Tennis Court Surfaces"
mlmodel.author = "Luca"
mlmodel.save("TennisClassifier.mlpackage")

print("\n✅ SUCCESS! 'TennisClassifier.mlpackage' is saved in your project folder.")

In [None]:
# Show 6 examples
learn.show_results(max_n=6, figsize=(7, 8))

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

In [None]:
import ipywidgets as widgets

# Create the button
uploader = widgets.FileUpload()
print("Click the button below to upload an image:")
display(uploader)

In [None]:
# Check if an image was actually uploaded
if len(uploader.data) > 0:
    # 1. Get the image data from the uploader
    # (uploader.data[0] contains the raw bytes of the first file)
    img = PILImage.create(uploader.data[0])
    
    # 2. Display the image you picked
    display(img.to_thumb(256,256))

    # 3. Predict
    pred_class, pred_idx, probs = learn.predict(img)

    print(f"\nPREDICTION: {pred_class.upper()}")
    print(f"Confidence: {probs[pred_idx]*100:.2f}%")
else:
    print("⚠️ You haven't uploaded an image yet! Go back to the cell above and pick a file.")