In [1]:
import gradio as gr
import os
import torch
from torch import nn
from torchvision import models, transforms
from collections import OrderedDict

from PIL import Image

In [2]:
def load_checkpoint(filepath, class_mapping):
    """
    Loads a checkpoint and rebuilds the model.

    Input:
    filepath(str): Relative path to model checkpoint
    """
    if os.path.exists(filepath):
        checkpoint = torch.load(filepath)

        if "resnet50" in checkpoint["arch"]:
            model = models.resnet50(pretrained=True)
            num_ftrs = model.fc.in_features
        
        elif "resnet18" in checkpoint["arch"]:
            model = models.resnet18(pretrained=True)
            num_ftrs = model.classifier[-1].out_features

        elif "vgg16" in checkpoint["arch"]:
            model = models.vgg16(pretrained=True)
            num_ftrs = model.classifier[-1].out_features
        
        else:
            return print("Architecture not recognized.")

        for param in model.parameters():
            param.requires_grad = False

        num_classes = len(class_mapping)
        classifier = nn.Sequential(
            OrderedDict(
                [
                    ("fc", nn.Linear(num_ftrs, num_classes)),
                    ("output", nn.LogSoftmax(dim=1)),
                ]
            )
        )
        model.fc = classifier

        model.class_to_idx = checkpoint["class_to_idx"]        
        model.load_state_dict(checkpoint["model_state_dict"])

        return model
    
    else:
        print("No such checkpoint found.")

In [3]:
# Data structure and model path
class_mapping = {'Basalt': 0,
 'Coal': 1,
 'Granite': 2,
 'Limestone': 3,
 'Marble': 4,
 'Quartzite': 5,
 'Sandstone': 6}

In [4]:
model_path = "../checkpoint/20220909_resnet50.pth"
res50 = load_checkpoint(model_path, class_mapping)



In [5]:
data_transforms = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

def load_image(img, data_transforms=data_transforms, size=(256,256)):
    img = data_transforms(img)
    img = img.unsqueeze(0)
    return img

In [6]:
image = "../img/Dataset/Coal/2.jpg"

def model_predict(image_pil, model, class_mapping):
    img = load_image(image_pil)
    output = model(img)
    probabilities = torch.exp(output)

    prediction = probabilities.max(dim=1)[1].tolist()
    class_prob = probabilities.max(dim=1)[0].tolist()
    
    class_name = [key for key in class_mapping if class_mapping[key] == prediction[0]]
    return class_name[0], class_prob[0]

In [7]:

def model_predict_path(image_path, model, class_mapping):
    image_pil = Image.open(image_path)
    image_pil = image_pil.convert("RGB")
    img = load_image(image_pil)
    output = model.forward(img)
    probabilities = torch.exp(output)
    prediction = probabilities.max(dim=1)[1].tolist()
    class_prob = probabilities.max(dim=1)[0].tolist()
    
    class_name = [key for key in class_mapping if class_mapping[key] == prediction[0]]
    return class_name[0], class_prob[0]

In [8]:
model_path = "../checkpoint/20220909_resnet50.pth"
image = "../img/Dataset/Coal/2.jpg"
res50 = load_checkpoint(model_path, class_mapping)

model_predict_path(image, res50, class_mapping)

('Marble', 0.2975787818431854)

In [9]:
def fn(model_choice, image):
    if model_choice=="resnet_50":
        class_name, class_prob = model_predict(image, res50, class_mapping)
        return image, class_name, class_prob
    elif model_choice=="resnet_18":
        return None
        # return gptj6B(input)
    


In [10]:
title = "Rock Classification"
description = "Rock classification using ResNet50"
article = "Stuff"

gr.Interface(fn, [gr.inputs.Dropdown(["resnet_50", "resnet_18"]), gr.inputs.Image(type='pil',image_mode="RGB")], ["image", "text", "text"], title=title, description=description, article=article).launch(share=True)




Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://10349.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces: https://huggingface.co/spaces


(<gradio.routes.App at 0x1266f4400>,
 'http://127.0.0.1:7860/',
 'https://10349.gradio.app')