# 🤖 Building a Robust Computer Vision Model! 🧠

Welcome! In this workshop, you'll take the model you trained on Teachable Machine and make it even smarter and more robust.

**Our Journey:**
1.  **Upload & Setup:** We'll get our environment ready and upload the model and image samples you created.
2.  **Sanity Check:** We'll test your original model on your original images to make sure it works. 
3.  **The Break Test:** We'll try to "trick" your model with rotated and noisy images to see its weaknesses.
4.  **The Training Montage:** We'll teach your model to be tougher using **Data Augmentation** and visualize its learning progress.
5.  **The Final Showdown:** You'll test your new, super-smart model in real-time with an enhanced webcam UI!

### Step 1: Setup and Installations

First, we need to make sure our environment has all the tools we need. We'll be using TensorFlow (for the model), OpenCV (for image processing), and a few other helpful libraries.

Just run the cell below to get everything ready.

In [None]:
#
# STEP 1: SETUP AND INSTALLATIONS
#

# Import all the libraries we'll need
import tensorflow as tf
import numpy as np
import cv2  # This is the OpenCV library
from google.colab import files
from google.colab.patches import cv2_imshow # for displaying images in Colab
import os
import zipfile
import random
from PIL import Image
import matplotlib.pyplot as plt

# Imports for the live webcam feed
from IPython.display import display, Javascript, HTML
from google.colab.output import eval_js
from base64 import b64decode, b64encode
import io
import shutil

print("✅ All libraries imported successfully!")
print(f"Using TensorFlow version: {tf.__version__}")

### Step 2: Upload Your Files

Now it's time to bring in your hard work from Teachable Machine. You will need to upload several files:

1.  **Your Model ZIP:** The converted Keras model zip file. **Important:** Please make sure it is named `model.zip` before uploading.
2.  **Your Sample ZIPs:** The zip file of image samples for **each** of your classes. For example, you might have `Class1.zip`, `class2.zip`, `class3.zip`, etc.

**Action:** Run the cell below. A "Choose Files" button will appear. Select your `model.zip` AND all of your sample class zip files at the same time and upload them.

In [None]:
#
# STEP 2: UPLOAD FILES (MODIFIED FOR MULTIPLE SAMPLE ZIPS)
#

print("Please upload your `model.zip` file and all of your sample class zip files.")
print("You can select them all at once in the file dialog.")

# Clean up any previous uploads
if os.path.exists('model'):
  shutil.rmtree('model')
if os.path.exists('samples'):
  shutil.rmtree('samples')
if os.path.exists('model.zip'):
  os.remove('model.zip')

# Upload all files
uploaded = files.upload()

# Separate the model zip from the sample zips
if 'model.zip' not in uploaded:
    print("\n\n❌ ERROR: `model.zip` was not found. Please rename your model zip file and run this cell again.")
else:
    model_zip_path = 'model.zip'
    sample_zip_paths = [name for name in uploaded.keys() if name != 'model.zip']
    print(f"\n✅ Uploaded '{model_zip_path}' successfully.")
    if sample_zip_paths:
        print("✅ Uploaded the following sample files:")
        for name in sample_zip_paths:
            print(f"- {name}")
    else:
        print("⚠️ Warning: No sample zip files were uploaded.")

### Step 3: Unzip Everything

Our files are uploaded, but they're still zipped up. Let's unpack them.

This code is extra smart: it will look at the name of each sample `.zip` file (e.g., `My Face.zip`), create a folder with that name (`My Face/`), and extract the images into it. This makes it work perfectly with the zip files you download directly from Teachable Machine.

In [None]:
#
# STEP 3: UNZIP FILES (MODIFIED FOR MULTIPLE SAMPLE ZIPS & ROBUSTNESS)
#

# --- Unzip the model ---
print("Unzipping the model...")
model_extract_path = 'model'
# Clean up previous runs
if os.path.exists(model_extract_path):
  shutil.rmtree(model_extract_path)
os.makedirs(model_extract_path, exist_ok=True)
with zipfile.ZipFile(model_zip_path, 'r') as zip_ref:
    zip_ref.extractall(model_extract_path)
print("Model files:")
!ls {model_extract_path}

# --- Unzip all the samples ---
# This section is modified to be robust to Teachable Machine's zip format.
# It will create a sub-folder for each zip file based on the zip file's name.
print("\nUnzipping samples into class-specific folders...")
samples_extract_path = 'samples'
# Clean up previous runs
if os.path.exists(samples_extract_path):
  shutil.rmtree(samples_extract_path)
os.makedirs(samples_extract_path, exist_ok=True)

for sample_zip in sample_zip_paths:
    # 1. Get the class name from the zip filename (e.g., "Class 1.zip" -> "Class 1")
    # We use os.path.basename to be safe, then os.path.splitext to remove the extension.
    class_name = os.path.splitext(os.path.basename(sample_zip))[0]

    # 2. Create a specific directory for this class inside the 'samples' folder
    class_dir_path = os.path.join(samples_extract_path, class_name)
    os.makedirs(class_dir_path, exist_ok=True)
    print(f"  - Extracting '{sample_zip}' into folder '{class_dir_path}'")

    # 3. Extract the contents of this zip file into its dedicated class folder
    with zipfile.ZipFile(sample_zip, 'r') as zip_ref:
        zip_ref.extractall(class_dir_path)

print("\nSample folders created:")
# Using !ls -R on a variable path in Colab to show directories and some contents
!ls -R {samples_extract_path}

### Step 4: Load Your Original Model

Time to wake up our model! We'll load the `keras_model.h5` file and the `labels.txt` file.

You might see errors when loading models from Teachable Machine in a newer version of TensorFlow (like the one in Colab). We add `compile=False` to the `load_model` function. This is a common and important fix for model compatibility!

In [None]:
#
# STEP 4: LOAD THE ORIGINAL MODEL (WITH FIX)
#

# --- Load the Keras model ---
model_path = 'model/keras_model.h5'
labels_path = 'model/labels.txt'

# The compile=False is critical for loading models from Teachable Machine
original_model = tf.keras.models.load_model(model_path, compile=False)

# --- Load the labels ---
with open(labels_path, 'r') as f:
    # Read lines, strip out the numbering and extra spaces
    class_labels = [line.strip().split(' ', 1)[1] for line in f]

print("🤖 Model Loaded Successfully!")
print("🧠 These are the classes the model knows:")
for i, label in enumerate(class_labels):
    print(f"{i}: {label}")

### Step 5: Sanity Check - Test the Original Model

Does the model remember what you taught it? Let's find out!

We'll take a few images from your `samples` folder and see if the model can correctly guess what they are. This is our "sanity check" to make sure everything is working as expected. 

Look at the **confidence bars**. The predictions should be very accurate!

In [None]:
#
# STEP 5: SANITY CHECK
#

# --- Helper function to show nice confidence bars ---
def display_prediction_bars(prediction, labels):
    # ANSI escape codes for colors
    GREEN = '\033[92m'
    RESET = '\033[0m'
    
    top_prediction_index = np.argmax(prediction)
    
    for i, (label, confidence) in enumerate(zip(labels, prediction[0])):
        confidence_percent = confidence * 100
        bar = '█' * int(confidence_percent / 4)
        if i == top_prediction_index:
            print(f'{GREEN}{label:>20}: [{bar:<25}] {confidence_percent:6.2f}%{RESET}')
        else:
            print(f'{label:>20}: [{bar:<25}] {confidence_percent:6.2f}%')


print("--- Running Sanity Check on Original Images ---")

# A little helper function to prepare the image for the model
def preprocess_image(img_path):
    # The model expects images to be 224x224
    image = Image.open(img_path).resize((224, 224))
    image_array = np.array(image)
    # Convert to 3 channels if it's grayscale
    if len(image_array.shape) == 2:
      image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB)
    # Remove alpha channel if it exists
    if image_array.shape[2] == 4:
      image_array = image_array[:, :, :3]
    # Normalize the image
    normalized_image_array = (image_array.astype(np.float32) / 127.5) - 1
    # Add a "batch" dimension
    return np.expand_dims(normalized_image_array, axis=0)

# Go through each class folder in your samples
samples_dir = 'samples'
class_folders = [f for f in os.listdir(samples_dir) if os.path.isdir(os.path.join(samples_dir, f))]

for class_name in class_folders:
    class_dir = os.path.join(samples_dir, class_name)
    
    # Pick one random image from the folder
    image_files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    if not image_files:
        print(f"\nNo images found in folder: {class_name}. Skipping.")
        continue

    random_image_name = random.choice(image_files)
    image_path = os.path.join(class_dir, random_image_name)

    # Display the image being tested
    print(f"\n--- Testing image from folder: {class_name} ---")
    display(Image.open(image_path).resize((150, 150)))

    # Preprocess the image and make a prediction
    processed_image = preprocess_image(image_path)
    prediction = original_model.predict(processed_image, verbose=0)
    
    # Display the results with bars
    print("🤖 Model Prediction:")
    display_prediction_bars(prediction, class_labels)


### Step 6: The Break Test! Is Your Model Fragile?

Real-world video isn't perfect. Sometimes the lighting is bad (creating "noise"), or you might tilt your head (a "rotation"). Let's see how our model handles these situations.

We will take the same images from before, but this time we'll **add random noise** and **rotate them** a bit before showing them to the model.

**Prediction:** Do you think the model will still be accurate? Watch the confidence bars.

In [None]:
#
# STEP 6: THE BREAK TEST (NOISE AND ROTATION)
#

# --- Helper functions to mess up our images ---
def add_noise(img):
    """Adds random 'salt and pepper' noise to an image."""
    img_array = np.array(img)
    rows, cols, _ = img_array.shape
    # Add salt
    salt_pixels = int(0.05 * img_array.size / 3) # 5% of pixels
    for _ in range(salt_pixels):
        y, x = random.randint(0, rows - 1), random.randint(0, cols - 1)
        img_array[y, x] = (255, 255, 255) # white pixel
    # Add pepper
    pepper_pixels = int(0.05 * img_array.size / 3)
    for _ in range(pepper_pixels):
        y, x = random.randint(0, rows - 1), random.randint(0, cols - 1)
        img_array[y, x] = (0, 0, 0) # black pixel
    return Image.fromarray(img_array)

def rotate_image(img):
    """Rotates an image by a random angle between -25 and 25 degrees."""
    angle = random.uniform(-25, 25)
    return img.rotate(angle, expand=True, fillcolor='black').resize(img.size)


# --- The Test ---
print("--- Running Break Test with NOISE and ROTATION ---")
print("Watch how the model's confidence and accuracy might drop!")

# Go through each class folder again
for class_name in class_folders:
    class_dir = os.path.join(samples_dir, class_name)
    image_files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    if not image_files:
        continue
    random_image_name = random.choice(image_files)
    image_path = os.path.join(class_dir, random_image_name)

    # Open the image
    original_pil_image = Image.open(image_path).convert('RGB')

    # Mess it up!
    rotated_img = rotate_image(original_pil_image)
    noisy_and_rotated_img = add_noise(rotated_img)

    # Display the messed up image
    print(f"\n--- Testing messed up image from folder: {class_name} ---")
    display(noisy_and_rotated_img.resize((150, 150)))

    # Preprocess the MESSED UP image for the model
    image_array = np.array(noisy_and_rotated_img.resize((224, 224)))
    normalized_image_array = (image_array.astype(np.float32) / 127.5) - 1
    processed_image = np.expand_dims(normalized_image_array, axis=0)

    # Make a prediction with the original model
    prediction = original_model.predict(processed_image, verbose=0)

    # Display the results with bars
    print("🤖 Model Prediction:")
    display_prediction_bars(prediction, class_labels)


**Observation:** What happened? Most likely, the model got confused! It wasn't as confident, and it might have even guessed the wrong class. This is because it was only ever trained on "perfect" pictures.

---

### Step 7: The Solution! Data Augmentation & Retraining

To fix this, we need to teach our model what tilted and noisy images look like. We'll do this by creating a new, bigger dataset. For every *one* image you gave it, we'll create *several* new versions with random rotations and noise. This is called **Data Augmentation**.

Then, we'll retrain our model for a short time on this new, tougher dataset. This is like a "training montage" in a movie!

This part might take a minute or two to run.

In [None]:
#
# STEP 7: DATA AUGMENTATION AND RETRAINING
#

print("--- Starting Data Augmentation ---")
print("Creating new, tougher training images. This might take a moment...")

X_train = [] # To hold the image data
y_train = [] # To hold the labels

# Find the class folders and create a mapping from folder name to class index
# This is important to ensure the label index matches the one in labels.txt
class_name_to_index = {name: i for i, name in enumerate(class_labels)}

# Loop through all your original samples
for class_folder_name in class_folders:
    # Get the correct class index from our mapping
    class_index = class_name_to_index.get(class_folder_name.replace('-samples', '').replace('(1)','').replace('(2)','').replace('(3)','').replace('(4)','').strip())
    if class_index is None:
        # Fallback for simple folder names
        class_index = class_name_to_index.get(class_folder_name)
        if class_index is None:
            print(f"Warning: Folder '{class_folder_name}' does not match any label in labels.txt. Skipping.")
            continue

    print(f"Augmenting images for class: {class_folder_name} (index: {class_index})")
    class_dir = os.path.join(samples_dir, class_folder_name)
    
    image_files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    for image_name in image_files:
        image_path = os.path.join(class_dir, image_name)
        try:
            original_pil_image = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Could not open {image_path}, skipping. Error: {e}")
            continue

        # Add the original image first
        img_array = np.array(original_pil_image.resize((224, 224)))
        X_train.append(img_array)
        y_train.append(class_index)
        
        # Now create 5 augmented versions of this image
        for _ in range(5):
            augmented_img = original_pil_image
            # Apply transformations randomly
            if random.random() > 0.5:
                augmented_img = rotate_image(augmented_img)
            if random.random() > 0.5:
                augmented_img = add_noise(augmented_img)
            
            img_array = np.array(augmented_img.resize((224, 224)))
            X_train.append(img_array)
            y_train.append(class_index)

# Convert Python lists to NumPy arrays
X_train = np.array(X_train)
y_train = np.array(y_train)

# Normalize the image data (just like before)
X_train = (X_train.astype(np.float32) / 127.5) - 1

# Convert labels to a "one-hot" format (e.g., class 1 of 3 becomes [0, 1, 0])
num_classes = len(class_labels)
y_train_one_hot = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)

print(f"\nOriginal dataset had ~{len(X_train)//6} images.")
print(f"New augmented dataset has {len(X_train)} images!")
print("\n--- Now Retraining the Model ---")

# We need to "compile" the model to get it ready for training.
# We'll use a standard "Adam" optimizer and "Categorical Crossentropy" for the loss function.
original_model.compile(optimizer='adam',
                       loss='categorical_crossentropy',
                       metrics=['accuracy'])

# Let's train! An "epoch" is one full pass through the dataset.
# We'll do a few epochs.
history = original_model.fit(X_train, y_train_one_hot, epochs=5, batch_size=32, shuffle=True, validation_split=0.1)

# Let's rename our newly trained model
robust_model = original_model
print("\n✅ Training complete! Your model is now more robust!")

### Visualizing the Training Process

How do we know the model actually learned? We can plot its **accuracy** and **loss** from the training process. 
- **Accuracy** should go **UP** (we want it to be more correct).
- **Loss** (which is like "error") should go **DOWN** (we want it to make fewer mistakes).

The graphs below show that our model got better with each pass (epoch) through our new, tougher, augmented dataset.

In [None]:
# Plotting the training history
plt.figure(figsize=(12, 4))

# Plot Training & Validation Accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plot Training & Validation Loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.show()

### Step 8: The Final Showdown! Live Webcam Test

This is the moment of truth. Let's test our new **robust model** with a live webcam feed and our new, beautiful UI.

**Action:**
1.  Run the cell below.
2.  A video feed of your webcam will appear. You must **Allow** your browser to use the camera.
3.  Try moving your head, tilting it, and making different faces. The model should be much better at recognizing you and your neighbors, even when the image isn't perfect!
4.  To stop the feed, **click on the video frame itself**.

In [None]:
#
# STEP 8: LIVE WEBCAM TEST (ENHANCED UI)
#

# --- Helper functions to handle the webcam stream and UI ---

def video_stream():
  js = Javascript('''
    var video;
    var div = null;
    var stream;
    var captureCanvas;
    var predictionContainer;

    var pendingResolve = null;
    var shutdown = false;

    function removeDom() {
      stream.getTracks().forEach(function(track) {
        track.stop();
      });
      video.remove();
      div.remove();
      video = null;
      div = null;
      stream = null;
      captureCanvas = null;
      predictionContainer = null;
    }

    function onAnimationFrame() {
      if (!shutdown) {
        window.requestAnimationFrame(onAnimationFrame);
      }
      if (pendingResolve) {
        var result = "";
        if (!shutdown) {
          captureCanvas.getContext('2d').drawImage(video, 0, 0, 640, 480);
          result = captureCanvas.toDataURL('image/jpeg', 0.8);
        }
        var lp = pendingResolve;
        pendingResolve = null;
        lp(result);
      }
    }

    async function createDom() {
      if (div !== null) {
        return;
      }
      div = document.createElement('div');
      div.style.border = '2px solid black';
      div.style.padding = '3px';
      div.style.width = '100%';
      div.style.maxWidth = '600px';
      document.body.appendChild(div);

      video = document.createElement('video');
      video.style.display = 'block';
      video.width = div.clientWidth - 6;
      video.setAttribute('playsinline', '');
      video.onclick = () => { shutdown = true; };
      stream = await navigator.mediaDevices.getUserMedia(
          {video: { facingMode: "user" }});
      div.appendChild(video);

      const instruction = document.createElement('div');
      instruction.innerHTML = "<b>Click on the video frame to stop</b>";
      instruction.style.textAlign = 'center';
      instruction.style.padding = '5px';
      div.appendChild(instruction);

      // Create the container for prediction bars
      predictionContainer = document.createElement('div');
      div.appendChild(predictionContainer);
      
      // Add some CSS for the progress bars
      const style = document.createElement('style');
      style.innerHTML = `
        .prediction-bar-container {
          display: flex;
          align-items: center;
          margin-bottom: 5px;
        }
        .prediction-label {
          width: 150px; 
          text-align: right;
          margin-right: 10px;
          font-family: sans-serif;
          font-size: 14px;
        }
        .progress-bar-bg {
          flex-grow: 1;
          height: 20px;
          background-color: #e0e0e0;
          border-radius: 5px;
          overflow: hidden;
        }
        .progress-bar {
          height: 100%;
          background-color: #757575; /* Gray for non-winners */
          transition: width 0.2s ease-in-out;
          border-radius: 5px;
        }
        .progress-bar.winner {
          background-color: #4CAF50; /* Green for the winner */
        }
      `;
      document.head.appendChild(style);

      video.srcObject = stream;
      await video.play();

      captureCanvas = document.createElement('canvas');
      captureCanvas.width = 640;
      captureCanvas.height = 480;
      window.requestAnimationFrame(onAnimationFrame);
    }

    async function updatePrediction(predictionHtml) {
      if (shutdown) {
        removeDom();
        shutdown = false;
        return '';
      }
      await createDom();
      predictionContainer.innerHTML = predictionHtml;
      var result = await new Promise(function(resolve, reject) {
        pendingResolve = resolve;
      });
      shutdown = false;
      return result;
    }
    ''')

  display(js)

def get_frame(prediction_html):
  # This is the corrected function call to the new JS function
  data = eval_js('updatePrediction("{}")'.format(prediction_html.replace('"', '\"')))
  return data

def js_to_image(js_reply):
  """Decodes a base64-encoded image from JavaScript into an OpenCV image."""
  image_bytes = b64decode(js_reply.split(',')[1])
  jpg_as_np = np.frombuffer(image_bytes, dtype=np.uint8)
  img = cv2.imdecode(jpg_as_np, flags=1)
  return img

def create_html_bars(prediction, labels):
    """Creates an HTML string with progress bars for the predictions."""
    html = ""
    top_prediction_index = np.argmax(prediction[0])
    
    for i, (label, confidence) in enumerate(zip(labels, prediction[0])):
        confidence_percent = confidence * 100
        winner_class = 'winner' if i == top_prediction_index else ''
        html += f'''
        <div class="prediction-bar-container">
          <div class="prediction-label">{label}</div>
          <div class="progress-bar-bg">
            <div class="progress-bar {winner_class}" style="width: {confidence_percent}%;"></div>
          </div>
          <div style="margin-left: 10px; font-family: sans-serif; font-size: 14px;">{confidence_percent:.1f}%</div>
        </div>
        '''
    return html

# --- THE MAIN LOOP ---
print("Starting live webcam feed... Allow camera access when prompted.")
video_stream()

prediction_html = "Capturing..."
while True:
    js_reply = get_frame(prediction_html)
    if not js_reply:
        break

    # Convert the frame to an OpenCV image
    frame = js_to_image(js_reply)

    # Preprocess the frame for our model
    resized_frame = cv2.resize(frame, (224, 224))
    normalized_frame = (resized_frame.astype(np.float32) / 127.5) - 1
    input_frame = np.expand_dims(normalized_frame, axis=0)

    # Make a prediction with our NEW ROBUST model
    prediction = robust_model.predict(input_frame, verbose=0) 

    # Create the HTML for the prediction bars
    prediction_html = create_html_bars(prediction, class_labels)

print("Webcam feed stopped.")

## 🎉 Congratulations! 🎉

You have successfully:
- Loaded a pre-trained model and fixed compatibility issues.
- Handled a flexible file structure for your data.
- Identified the model's weaknesses using real-world variations (noise and rotation).
- Fixed those weaknesses by **augmenting your data**.
- **Retrained** your model to make it more robust.
- Visualized the model's training improvement with graphs.
- Tested your new and improved model in real-time with a cool UI!

This is a fundamental workflow in Computer Vision and Machine Learning. Well done!