In [4]:
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import numpy as np

# Load the trained model
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)
model.load_state_dict(torch.load('D:/PROJECT/Part Time/Tuberculosis detention in chest radiograph using convolutional neural network architecture - deep learning/PRG/chest_xray_classification.pth'))
model.eval()

# Define the transformations for input images
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to make predictions
def predict_xray(img):
    # Check if the input image is None or missing data
    if img is None:
        return "Invalid image"

    # Convert Gradio Image data to a NumPy array
    img_data = np.array(img)

    # Preprocess the input image
    img_pil = Image.fromarray((img_data * 255).astype('uint8'))
    img_tensor = preprocess(img_pil)
    img_tensor = img_tensor.unsqueeze(0)  # Add batch dimension

    # Make prediction
    with torch.no_grad():
        output = model(img_tensor)

    # Get the predicted class
    _, predicted_class = torch.max(output, 1)

    class_labels = ["Tuberculosis", "Tuberculosis", "Normal", "Tuberculosis"]
    prediction = class_labels[predicted_class.item()]

    return prediction

# Gradio UI
iface = gr.Interface(
    fn=predict_xray,
    inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"),
    outputs=gr.Textbox(label="Prediction:"),
    live=True,
    theme="compact",
)

iface.launch(share=True)




Running on local URL:  http://127.0.0.1:7863

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


