![DataDunkers.ca Banner](https://github.com/Data-Dunkers/lessons/blob/main/images/top-banner.jpg?raw=true)

# Training an AI

We are going to train an AI system to recognize if an image is about basketball, baseball, or hockey.

For training data, we'll use images that are [public domain](https://en.wikipedia.org/wiki/Public_domain) or [Creative Commons](https://creativecommons.org/) because we are allowed to use them without purchasing a license.

The more example images we provide for training, the better the AI system will be able to discriminate between basketball and baseball.

## Getting Training Data

1. Create three folders on your computer, one called `basketball`, one called `baseball`, and one called `hockey`.
1. Find and download at least 10 images related to basketball from [Pexels](https://www.pexels.com/search/basketball/) or [Pixabay](https://pixabay.com/images/search/basketball/). Put them in your `basketball` folder.
1. Find and download at least 10 images related to baseball from [Pexels](https://www.pexels.com/search/baseball/) or [Pixabay](https://pixabay.com/images/search/baseball/). Put them in your `baseball` folder.
1. Find and download at least 10 images related to hocket from [Pexels](https://www.pexels.com/search/hockey/) or [Pixabay](https://pixabay.com/images/search/hockey/). Put them in your `hockey` folder.

Open [Teachable Machine image training](https://teachablemachine.withgoogle.com/train/image)

1. Rename **Class 1** as `basketball`, and **Class 2** as `baseball` by clicking on the pencil icons.
1. Click the **Add a class** button and rename that new class as `hockey`.
1. Upload your images to the correct class by clicking each of the **Upload** buttons.
1. Click the **Train Model** button.
1. After the training has finished, click the **Export Model** button, click the **Tensorflow Lite** tab on the right, then click the **Download my model** button. The button will change to **Converting model...** and it will take a few minutes, don't click away from that browser tab.
1. Your model should then download automatically as **converted_tflite.zip**.
1. Upload your **converted_tflite.zip** file to the folder that this notebook is in:

    * If you are using the [Callysto Hub](https://hub.callysto.ca), it should work to click [here](.) then click the **↑Upload** button at the right.
    * If you are using [Colab](https://colab.research.google.com/), click the button on the left that looks like a folder (🗂️) then click the button that contains an **↑**.
    * If you are running in Jupyter Lab, the file browser is already on the left.

After you have completed all of those steps, run the following cell to set up the image classifier.

In [None]:
from zipfile import ZipFile
from PIL import Image, ImageOps
import numpy as np
import requests, urllib.request, os
import pandas as pd
try:
    import tflite_runtime.interpreter as tflite
except: # if tflite is not available, import and alias from tensorflow
    import tensorflow as tf
    tflite = tf.lite

try:
    with ZipFile('converted_tflite.zip', 'r') as zip_object:
        zip_object.extractall()
except:
    print('Unable to find your converted_tflite.zip file, using Data Dunkers online version')
    r = requests.get('https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/converted_tflite.zip')
    with open('converted_tflite.zip', 'wb') as f:
        f.write(r.content)
    with ZipFile('converted_tflite.zip', 'r') as zip_object:
        zip_object.extractall()

interpreter = tflite.Interpreter('model_unquant.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
class_names = open('labels.txt', 'r').readlines()
os.remove('model_unquant.tflite')
os.remove('labels.txt')

def classify_image(image_url, show_image=False):
    filename = image_url.split('/')[-1]
    r = requests.get(image_url, stream=True)
    with open(filename, 'wb') as f:
        f.write(r.content)
    image = Image.open(filename).convert('RGB')
    image = image.resize((input_shape[1], input_shape[2]))
    if show_image:
        display(image)
    os.remove(filename)
    input_data = (np.expand_dims(np.array(image), axis=0) / 255.0).astype(np.float32)
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    predicted_class = np.argmax(output_data)
    predicted_class_name = class_names[predicted_class].strip()[2:]
    confidence_level = output_data[0][predicted_class]
    return predicted_class_name, confidence_level, image

print('Model imported and classify_image(image_url) function defined')

Now that we have set up the `classify_image()` function, we can load an image from a link and get its classification according to our trained AI. The function will return a classification category and confidence level, and a resized version of the image.

Change the string in the `image_url` variable to be a direct link to an online image.

Make sure you have copied the **image** address and that it is not a link to a webpage. The url should end with something like **.jpg**, **.gif**, or **.png**

In [None]:
image_url = 'https://img.redbull.com/images/redbullcom/2023/11/9/auzqbcftt6nbqxhgsnjw/pascal-siakam-portrait'

results = classify_image(image_url)
results

The first value returned is the classification, in our case `basketball` or `baseball` or `hockey`.

The second is "confidence score" which is how sure the AI is of that classification, `1` means 100% confident.

The third value is the downloaded and resized image. Run the next cell to display the downloaded image.

In [None]:
results[2]

We can even use this to categorize a list of online images. We'll try it with some art rather than photos and see how it performs.

In [None]:
urls = [
    'https://free-images.com/sm/2c90/basketball_ball_orange_grass.jpg',
    'https://free-images.com/sm/ba10/baseball_field_baseball_gravel.jpg',
    'https://free-images.com/sm/862b/hockey_puck_hockey_pucks.jpg',
    'https://free-images.com/md/c139/backyard_baseball_baseball_cards_1.jpg',
    'https://free-images.com/md/6de0/backyard_baseball_baseball_cards_8.jpg',
    'https://tinyurl.com/PS43skythlee',
    'https://collectionapi.metmuseum.org/api/collection/v1/iiif/421259/778912/main-image',
    'https://collectionapi.metmuseum.org/api/collection/v1/iiif/704860/1556049/main-image',
    'https://collectionapi.metmuseum.org/api/collection/v1/iiif/437192/795865/main-image',
    'https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/public-domain-images/Childe_Hassam_Ice_Hockey.jpg',
    'https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/public-domain-images/Egyptian_Basketball_Painting.jpg',
    'https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/public-domain-images/Fletcher_Ransom_Out_at_Home.jpg',
    'https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/public-domain-images/Gemini_Baseball_Slide.png',
    'https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/public-domain-images/Gemini_Basketball_Player.png',
    'https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/public-domain-images/Gemini_Hockey_Goalie.png',
    'https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/public-domain-images/ChatGPT_Renaissance_Basketball_Player.png',
    'https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/public-domain-images/Jean_Jacoby_Hockey.jpg',
    'https://raw.githubusercontent.com/Data-Dunkers/data/refs/heads/main/public-domain-images/Russ_Meyer_Baseball_Card.jpg',
]

data = pd.DataFrame(urls, columns=['url'])
labels = []
confidences = []
images = []

for url in urls:
    results = classify_image(url, True)
    print(f"I'm {int(results[1]*100)}% certain that is {results[0]}.")
    labels.append(results[0])
    confidences.append(results[1])
    images.append(results[2])
    print('---')

data['image'] = images
data['label'] = labels
data['confidence'] = confidences
data

We now have a dataframe of images, labels, and confidence values. To access a particular row we can use `.iloc[]`.

In [None]:
data.iloc[5]

And we can even display an image from the dataframe.

In [None]:
data.iloc[5]['image']

If the model is not accurately identifying the sports, go back to the start of this notebook and train it with more images.

Of course we can also use this same process to train an AI model to categorize of other things, for example identifying if an image is soup, salad, or sandwich.

[![Data Dunkers License](https://github.com/Data-Dunkers/lessons/blob/main/images/bottom-banner.jpg?raw=true)](https://github.com/Data-Dunkers/lessons/blob/main/LICENSE.md)