In [9]:
!pip install flask pillow torch torchvision



In [None]:
from flask import Flask, request, jsonify
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision.models as models

app = Flask(__name__)
model = models.resnet50(pretrained=True)
model.eval()

def preprocess_image(image):
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = preprocess(image)
    return image.unsqueeze(0)

@app.route('/classify', methods=['GET','POST'])
def classify_dish():
    if 'file' not in request.files:
        return jsonify({'error': 'No file part'})

    file = request.files['file']

    if file.filename == '':
        return jsonify({'error': 'No selected file'})

    try:
        image = Image.open(file)
        image_tensor = preprocess_image(image)

        with torch.no_grad():
            output = model(image_tensor)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
            predicted_class = torch.argmax(probabilities).item()

        return jsonify({'predicted_class': predicted_class, 'probabilities': probabilities.tolist()})
    except Exception as e:
        return jsonify({'error': str(e)})

if __name__ == '__main__':
    app.run()


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [03/Jun/2024 15:02:33] "GET /classify HTTP/1.1" 200 -
127.0.0.1 - - [03/Jun/2024 15:03:36] "GET /classify HTTP/1.1" 200 -
127.0.0.1 - - [03/Jun/2024 15:04:09] "POST /classify HTTP/1.1" 200 -
127.0.0.1 - - [03/Jun/2024 15:04:14] "GET /classify HTTP/1.1" 200 -
