In [1]:
import io
import base64
import json
from io import BytesIO
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from flask import escape
from flask import Flask, request, jsonify

In [2]:
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 2)

model.load_state_dict(torch.load('pneumonia_classifier.pth' , map_location = torch.device('cpu'), weights_only = False))
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
class_names = ['NORMAL', 'PNEUMONIA']

In [5]:
def validate_input(job_input):
    if job_input is None:
        return None, 'Please provide an image'

    if isinstance(job_input, str):
        try:
            job_input = json.loads(job_input)
        except json.JSONDecodeError:
            return None, 'Invalid JSON format'

    image_data = job_input.get('image')
    if image_data is None:
        return None, 'Please provide an image'
    if not isinstance(image_data, str):
        return None, 'Image must be Base64 encoded string'

    return {'image': image_data}, None

In [6]:
from PIL import Image
import torch

def predict_image(image_path):
    try:
        image = Image.open(image_path).convert('RGB')
        image = transform(image)
        image = image.unsqueeze(0)
        with torch.no_grad():
            outputs = model(image)
            _, preds = torch.max(outputs, 1)
        predicted_class = class_names[preds.item()]
        print(f'Prediction: {predicted_class}')
    except IOError as e:
        print(f'Invalid image data: {e}')
    except Exception as e:
        print(f'Unexpected error: {str(e)}')

predict_image('test1.jpeg')


Prediction: PNEUMONIA
