In [2]:
%%writefile app/model.py
import torch
import torchvision

def create_vit(seed: int=42):
    weights = torchvision.models.ViT_B_16_Weights.DEFAULT
    
    transforms = weights.transforms()
    
    model = torchvision.models.vit_b_16(weights=weights)
    
    for param in model.parameters():
        param.requires_grad = False
        
    torch.manual_seed(seed)
    model.heads = torch.nn.Sequential(torch.nn.LayerNorm(normalized_shape=768),
                                      torch.nn.Linear(in_features=768, out_features=1))
    return model, transforms

Writing app/model.py


In [20]:
import shutil
from pathlib import Path

# 1. Create an examples directory
app = Path("app/")
example_path = app / "examples"
example_path.mkdir(parents=True, exist_ok=True)

# 2. Collect three random test dataset image paths
chest_xray_examples = [Path("chest_xray/train/NORMAL/IM-0117-0001.jpeg"),
                       Path("chest_xray/train/NORMAL/IM-0154-0001.jpeg"),
                       Path("chest_xray/train/PNEUMONIA/person3_bacteria_10.jpeg"),
                       Path("chest_xray/train/PNEUMONIA/person16_bacteria_54.jpeg"),
]

# 3. Copy the three random images to the examples directory
for example in chest_xray_examples:
    destination = example_path / example.name
    print(f"[INFO] Copying {example} to {destination}")
    shutil.copy2(src=example, dst=destination)

[INFO] Copying chest_xray\train\NORMAL\IM-0117-0001.jpeg to app\examples\IM-0117-0001.jpeg
[INFO] Copying chest_xray\train\NORMAL\IM-0154-0001.jpeg to app\examples\IM-0154-0001.jpeg
[INFO] Copying chest_xray\train\PNEUMONIA\person3_bacteria_10.jpeg to app\examples\person3_bacteria_10.jpeg
[INFO] Copying chest_xray\train\PNEUMONIA\person16_bacteria_54.jpeg to app\examples\person16_bacteria_54.jpeg


In [18]:
%%writefile app/app.py
import gradio as gr
import os
import torch

from model import create_vit
from timeit import default_timer as timer
from typing import Tuple, Dict

class_names = ["NORMAL", "PNEUMONIA"]

vit_model, vit_transforms = create_vit(seed=42)

vit_model.load_state_dict(
    torch.load(
        f="finetuned_vit_b_16_pneumonia_feature_extractor.pth", 
        map_location=torch.device("cpu")
    )
)

def predict(img):
    start_timer = timer()
    
    img = vit_transforms(img).unsqueeze(0)
    
    vit_model.eval()
    with torch.inference_mode():
        pred_prob_int = torch.sigmoid(vit_model(img)).round().int().squeeze()
        
    if pred_prob_int.item() == 1:
        class_name = class_names[1]
    else:
        class_name = class_names[0]
            
    pred_time = round(timer() - start_timer, 5)
    
    return class_name, pred_time

title = "Detect Pneumonia from chest X-Ray"
description = "A ViT feature extractor Computer Vision model to detect Pneumonia from X-Ray Images."
article = "Access project repository at [GitHub](https://github.com/Ammar2k)"

example_list = [["examples/" + example] for example in os.listdir("examples")]

demo = gr.Interface(fn=predict, 
                    inputs=gr.Image(type="pil"),
                    outputs=[gr.Label(num_top_classes=6, label="Predictions"), 
                    gr.Number(label="Prediction time(s)")],
                    examples=example_list,
                    title=title,
                    description=description,
                    article=article
                   )

demo.launch()

Overwriting app/app.py
