In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

DIM = 32
CH = 3

class Net(nn.Module):
    def __init__(self, num_classes=3):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=CH, out_channels=32, kernel_size=3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(32)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.dropout1 = nn.Dropout2d(p=0.2)

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.batchnorm4 = nn.BatchNorm2d(64)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.dropout2 = nn.Dropout2d(p=0.3)

        self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.batchnorm5 = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.batchnorm6 = nn.BatchNorm2d(128)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)
        self.dropout3 = nn.Dropout2d(p=0.4)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * (DIM // 8) * (DIM // 8), 128)
        self.batchnorm_fc = nn.BatchNorm1d(128)
        self.dropout_fc = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.batchnorm1(x)
        x = torch.relu(self.conv2(x))
        x = self.batchnorm2(x)
        x = self.maxpool1(x)
        x = self.dropout1(x)

        x = torch.relu(self.conv3(x))
        x = self.batchnorm3(x)
        x = torch.relu(self.conv4(x))
        x = self.batchnorm4(x)
        x = self.maxpool2(x)
        x = self.dropout2(x)

        x = torch.relu(self.conv5(x))
        x = self.batchnorm5(x)
        x = torch.relu(self.conv6(x))
        x = self.batchnorm6(x)
        x = self.maxpool3(x)
        x = self.dropout3(x)

        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.batchnorm_fc(x)
        x = self.dropout_fc(x)
        x = self.fc2(x)

        return x

# Define your model
model = Net(num_classes=3)

In [None]:
# get the saved model
model = Net()
model_weights = torch.load('/content/drive/MyDrive/modelwithtorch5.pth', map_location=torch.device('cpu'))
model.load_state_dict(model_weights)

<All keys matched successfully>

In [None]:
data_transforms = {
    'test': transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
}

In [None]:
!pip install flask flask-restful flask-swagger flask-swagger-ui flask-cors pyngrok

In [None]:
!ngrok authtoken  # should put your own authtoken to ngrok for flask to work from google colab

In [None]:
# REST api
from flask import Flask, request, jsonify
from flask_restful import Api, Resource, reqparse
from flask_swagger_ui import get_swaggerui_blueprint
from flask_swagger import swagger
from pyngrok import ngrok
import io
from PIL import Image
import base64
from flask_cors import CORS
from werkzeug.datastructures import FileStorage
import numpy as np
import torch.nn.functional as F

class_names = ['car', 'cat', 'pizza']

app = Flask(__name__)
api = Api(app)
CORS(app, resources={r"*": {"origins": "*"}})
model = model.to('cpu')
model.eval()

def preprocess_image(image_file):
    image = Image.open(image_file)
    transformed_image = data_transforms['test'](image)
    return transformed_image.unsqueeze(0)

# Define a class for prediction
class Prediction(Resource):
    def post(self):
      try:
        # Get the image from the request
        file = request.files.get("file")
        if not file:
            return jsonify({"error": "No image provided"}), 400

        # Preprocess the image
        processed_image = preprocess_image(file)
        processed_image = processed_image.to('cpu')

        # Make a prediction using the model
        with torch.no_grad():
            output = model(processed_image)
            output_softmax = F.softmax(output, dim=1)
            conf, predicted_idx = torch.max(output_softmax.data, 1)

        # Prepare the response
        response = {
            "class_name": class_names[predicted_idx.item()],
            "confidence": f"{conf.item() * 100:.2f}%"
        }

      except Exception as e:
        print("Error processing request:", e)
        return jsonify({"error": str(e)}), 500

      return jsonify(response)


# Add resource to API
api.add_resource(Prediction, '/predict')

# Generate Swagger/OpenAPI spec
@app.route("/spec")
def spec():
    swag = {
        'swagger': '2.0',
        'info': {
            'title': 'Image Prediction API',
            'version': '1.0'
        },
        'paths': {
            '/predict': {
                'post': {
                    'summary': 'Predict class of an image',
                    'description': 'Upload an image file to predict its class.',
                    'consumes': [
                        'multipart/form-data'
                    ],
                    'produces': [
                        'application/json'
                    ],
                    'parameters': [
                        {
                            'name': 'file',
                            'in': 'formData',
                            'type': 'file',
                            'required': True,
                            'description': 'Image file to predict'
                        }
                    ],
                    'responses': {
                        '200': {
                            'description': 'Prediction successful',
                            'schema': {
                                'type': 'object',
                                'properties': {
                                    'class_name': {
                                        'type': 'string',
                                        'description': 'Predicted class name'
                                    },
                                    'confidence': {
                                        'type': 'string',
                                        'format': 'byte',
                                        'description': 'Predicted class name'
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    return swag

# Swagger UI configuration
SWAGGER_URL = '/api/docs'  # URL for accessing Swagger UI (without trailing '/')
API_URL = '/spec'  # Our API url (can be of a local server, like '/apidocs')

# Call factory function to create our blueprint
swagger_ui_blueprint = get_swaggerui_blueprint(
    SWAGGER_URL,  # Swagger UI static files will be mapped to '{SWAGGER_URL}/dist/'
    API_URL,
    config={  # Swagger UI config overrides
        'app_name': "Image Prediction API"
    },
)

app.register_blueprint(swagger_ui_blueprint, url_prefix=SWAGGER_URL)

# Set up pyngrok
ngrok_tunnel = ngrok.connect(5000)
print("Public URL:", ngrok_tunnel.public_url)

# Run the Flask app
app.run(port=5000)