In [None]:
# Google Colab Setup
try:
    import google.colab
    import os
    import sys
    
    print("Running in Google Colab")
    
    # Define base path
    base_path = '/content'
    
    # Create project structure
    src_path = os.path.join(base_path, 'src')
    data_root = os.path.join(src_path, 'data')
    notebooks_path = os.path.join(base_path, 'notebooks')
    
    os.makedirs(data_root, exist_ok=True)
    os.makedirs(notebooks_path, exist_ok=True)
    
    # Install requirements
    if os.path.exists(os.path.join(base_path, 'requirements.txt')):
        print("Installing requirements...")
        !pip install -q -r requirements.txt
        
    # Download dataset
    print("Downloading Flowers102 dataset...")
    from torchvision.datasets import Flowers102
    try:
        Flowers102(root=data_root, split='train', download=True)
        Flowers102(root=data_root, split='val', download=True)
        Flowers102(root=data_root, split='test', download=True)
        print("Dataset downloaded.")
    except Exception as e:
        print(f"Dataset download failed: {e}")
    
    # Change directory to notebooks
    if os.path.exists(notebooks_path):
        os.chdir(notebooks_path)
        print(f"Changed directory to {notebooks_path}")
        
    # Add project root to sys.path so imports work
    project_root = base_path
    if project_root not in sys.path:
        sys.path.append(project_root)
        print(f"Added {project_root} to sys.path")

except ImportError:
    # Local Machine Setup
    import os
    import sys
    
    print("Running on Local Machine")
    
    # Get the path to the project root (assuming running from notebooks/)
    current_dir = os.getcwd()
    if current_dir.endswith('notebooks'):
        project_root = os.path.abspath('..')
    else:
        project_root = os.path.abspath('.')
    
    if os.path.exists(os.path.join(project_root, 'src')):
         if project_root not in sys.path:
            sys.path.append(project_root)
            print(f"Added project root to sys.path: {project_root}")


# Part 6: Gradio Demo

In this notebook, we will create a simple interactive web interface using Gradio to demonstrate our flower classification model.

In [None]:
import sys
import os
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms

# Add src to path
sys.path.append(os.path.abspath('../'))

from src.models.base_model import get_model
from src.utils.seeds import set_seeds

set_seeds(42)

## 1. Load Model

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(num_classes=102, fine_tune=False)
model_path = 'best_model.pt' # Or best_model_finetuned.pt if available

if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    print(f"Loaded {model_path}")
else:
    print("Warning: Model checkpoint not found.")

model = model.to(DEVICE)
model.eval()

## 2. Define Transforms

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

## 3. Predict Function

In [None]:
def predict_flower(image):
    if image is None:
        return None
    
    # Preprocess
    img_t = data_transforms(image).unsqueeze(0).to(DEVICE)
    
    # Predict
    with torch.no_grad():
        outputs = model(img_t)
        probs = torch.nn.functional.softmax(outputs, dim=1)[0]
        
    # Get Top 3
    top3_prob, top3_idx = torch.topk(probs, 3)
    
    # Format for Gradio Label
    # Since we don't have class names mapping loaded here easily, we'll return IDs
    # In a real app, you'd load a JSON mapping ID -> Name
    results = {}
    for i in range(3):
        class_id = str(top3_idx[i].item())
        score = float(top3_prob[i])
        results[f"Class {class_id}"] = score
        
    return results

## 4. Launch Interface

In [None]:
iface = gr.Interface(
    fn=predict_flower,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=3),
    title="Flower Classifier ResNet50",
    description="Upload an image of a flower to classify it into one of 102 categories."
)

iface.launch(share=False)