In [1]:
import torch # import libraries
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import torchvision.models as models

# Define the model architecture
num_classes = 102  # Flower102 has 102 classes
model = models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)  # Replace last layer

# Load the saved state dict
model_path = "/content/drive/MyDrive/flowers102_resnet50_best.pt"
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)

# Set model to evaluation mode
model.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [7]:
# preprocess before feeding into model
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # Match the input size used in training
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [8]:
class_labels = {i: f"Class_{i}" for i in range(102)}

In [9]:
def predict_flower(image_input):
    """
    Predict flower class for a given image path or URL.

    Returns:
        dict: {
            'predicted_class': str,
            'confidence': float,
            'top3_predictions': list of tuples [(class_name, confidence), ...]
        }
    """
    # Load image
    if image_input.startswith("http"):
        response = requests.get(image_input)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_input).convert("RGB")

    # Preprocess
    input_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension

    # Model inference
    with torch.no_grad():
        outputs = model(input_tensor)
        probs = torch.softmax(outputs, dim=1)

    # Predicted class
    confidence, pred_idx = torch.max(probs, dim=1)
    pred_class = class_labels[int(pred_idx)]

    # Top-3 predictions
    top3_conf, top3_idx = torch.topk(probs, 3)
    top3 = [(class_labels[int(idx)], float(conf)) for idx, conf in zip(top3_idx[0], top3_conf[0])]

    return {
        "predicted_class": pred_class,
        "confidence": float(confidence),
        "top3_predictions": top3
    }


In [16]:
result = predict_flower("/content/drive/MyDrive/flowers102/flowers102/flowers-102/jpg/image_00001.jpg")  # or a URL
print(result)

{'predicted_class': 'Class_76', 'confidence': 0.9888086915016174, 'top3_predictions': [('Class_76', 0.9888086915016174), ('Class_34', 0.001065384945832193), ('Class_65', 0.0010294703533872962)]}


Discussion: production deployment

Model Serving: Wrap the inference function in a FastAPI or Flask REST API. Use TorchServe for production-grade serving if high throughput is required.

Performance & Scalability: Batch inference for multiple images.Use GPU acceleration for low latency.Convert the model to TorchScript or ONNX for faster inference.

Monitoring & Maintenance: Log predictions for quality control.Periodically retrain the model with new data.