# ResNet CNN deployment with Restful API

In [None]:
from flask import Flask, request, jsonify
from torchvision import models, transforms
import torch
import torch.nn as nn
from PIL import Image
import os

app = Flask(__name__)

# Define the model architecture
num_classes = 3  # Update based on your dataset
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Load the model's state dictionary with map_location=torch.device('cpu')
model_path = 'D:\\McGill留學\\AI4Good\\resnetCNNmodel.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()  # Set the model to evaluation mode

# Define the transformation
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Function to preprocess the image and predict
def model_predict(img_path, model):
    img = Image.open(img_path).convert('RGB')
    img = test_transform(img)
    img = img.unsqueeze(0)  # Add batch dimension

    with torch.no_grad():
        preds = model(img)
        predicted_class = torch.argmax(preds, dim=1).item()

    if predicted_class == 0:
        return 'Cross'
    elif predicted_class == 1:
        return 'Offset'  # Update based on your class labels
    else:
        return 'T'  # Update based on your class labels

@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return jsonify({'error': 'No file part'})

    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': 'No selected file'})

    if file:
        basepath = os.path.dirname(__file__)
        upload_folder = os.path.join(basepath, 'uploads')
        if not os.path.exists(upload_folder):
            os.makedirs(upload_folder)
        file_path = os.path.join(upload_folder, file.filename)
        file.save(file_path)

        # Make prediction
        prediction = model_predict(file_path, model)
        return jsonify({'prediction': prediction})

@app.route('/')
def index():
    return "Welcome to the Image Classification API. Use the /predict endpoint to make predictions."

if __name__ == '__main__':
    app.run(debug=True)