In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
from torchvision import transforms
from PIL import Image

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the image transformations for grayscale images
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to RGB
    transforms.Resize((224, 224)),                # Resize to 224x224, standard input size for ResNet
    transforms.ToTensor(),                        # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

# Define function to load and preprocess image
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    return image.unsqueeze(0)  # Add batch dimension

# Define and load the model
model = models.resnet50(pretrained=False)
num_classes = 14
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Load and adjust the state dictionary
state_dict = torch.load('resnet50_scratch.pth')

# If the state_dict keys have a 'resnet.' prefix, remove it
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('resnet.'):
        new_key = k[len('resnet.'):]
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict)
model.to(device)
model.eval()

# Perform inference on an example image
image_path = './test2.png'
image = load_image(image_path).to(device)

with torch.no_grad():
    output = model(image)
    predictions = torch.sigmoid(output)
    predicted_labels = (predictions > 0.5).float()


  state_dict = torch.load('resnet50_scratch.pth')


In [2]:
# Pathologies
pathologies = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Effusion",
    "Emphysema",
    "Fibrosis",
    "Hernia",
    "Infiltration",
    "Pleural_Thickening",
    "Mass",
    "Nodule",
    "Pneumonia",
    "Pneumothorax"
]

# Convert predictions to readable format
predicted_labels = predicted_labels.squeeze().cpu().numpy()
for i, label in enumerate(predicted_labels):
    print(f'{pathologies[i]}: {"Present" if label else "Absent"}')

Atelectasis: Absent
Cardiomegaly: Absent
Consolidation: Absent
Edema: Absent
Effusion: Absent
Emphysema: Absent
Fibrosis: Absent
Hernia: Absent
Infiltration: Absent
Pleural_Thickening: Absent
Mass: Absent
Nodule: Absent
Pneumonia: Absent
Pneumothorax: Absent


In [3]:
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image
import torch.nn as nn

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the image transformations for grayscale images
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to RGB
    transforms.Resize((224, 224)),                # Resize to 224x224, standard input size for ResNet
    transforms.ToTensor(),                        # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

# Define pathologies
pathologies = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Effusion",
    "Emphysema",
    "Fibrosis",
    "Hernia",
    "Infiltration",
    "Pleural_Thickening",
    "Mass",
    "Nodule",
    "Pneumonia",
    "Pneumothorax"
]

# # Load the model
# model = models.resnet50(pretrained=False)
# num_classes = 14
# model.fc = nn.Linear(model.fc.in_features, num_classes)
# model.load_state_dict(torch.load('resnet50_scratch.pth'))
# model.to(device)
# model.eval()

# Load and adjust the state dictionary
state_dict = torch.load('resnet50_scratch.pth')

# If the state_dict keys have a 'resnet.' prefix, remove it
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('resnet.'):
        new_key = k[len('resnet.'):]
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict)
model.to(device)
model.eval()

# Define the prediction function
def predict(image):
    image = Image.fromarray(image).convert('RGB')  # Ensure the image is in RGB format
    image = transform(image).unsqueeze(0).to(device)  # Apply transformations and move to GPU
    with torch.no_grad():
        output = model(image)
        predictions = torch.sigmoid(output)
        predicted_labels = (predictions > 0.5).float()
    result = {pathologies[i]: "Present" if predicted_labels[0, i] else "Absent" for i in range(num_classes)}
    return result

# Create the Gradio interface
iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs=gr.JSON())
iface.launch()


  from .autonotebook import tqdm as notebook_tqdm
  state_dict = torch.load('resnet50_scratch.pth')


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


