<b> Loading the Model </b>

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchvision
import torch.nn as nn

def load_classification_model(weights_path):
    model = torchvision.models.resnet34()
    model.fc = nn.Sequential(nn.Linear(model.fc.in_features, 128),
                            nn.ReLU(),
                            nn.Dropout(0.4),
                            nn.Linear(128, 2))
    model.load_state_dict(torch.load(weights_path, map_location="cpu"))
    model.eval()
    return model

In [3]:
path = r'weights\resnet34.pth'
model = load_classification_model(path)

<b> Model Inference </b>

In [4]:
#Enter your Google API Key
import google.generativeai as genai

GOOGLE_API_KEY = 
genai.configure(api_key=GOOGLE_API_KEY)

In [5]:
def prompt_gemini(img) -> str:
    model = genai.GenerativeModel('gemini-1.5-pro-latest')
    response = model.generate_content(
        [
            "You are a smart and precise AI engine that can understand the defects in materials.",
            img,
            "Describe the material defect in 50 this image. Briefly cover the important details, causes and effects. Stricty limit your answer to 50 words",
        ],
        stream=True,
    )

    complete_respose = str()
    for chunk in response:
        complete_respose += chunk.text

    return complete_respose

In [6]:
import math
def inference(PIL_image, img_tensor, model, threshold):
    logits = model(img_tensor)

    #Non defective condition
    if (logits[0][1] - logits[0][0]).item() >= math.log(threshold/(1-threshold)):
        return "The item in this image does not contain any visible defects"
    else:
        return prompt_gemini(PIL_image)
    

In [7]:
from PIL import Image
import torchvision.transforms as transforms

def predict_image(image, threshold = 0.7):
    PIL_image = Image.fromarray(image)
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize input images to a size larger than the required crop size
        transforms.CenterCrop(227),     # Crop the center region of the resized image
        transforms.ToTensor(),          # Convert the image to a PyTorch tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
    ])

    input_tensor = transform(PIL_image)
    input_batch = input_tensor.unsqueeze(0)

    # Make a prediction
    with torch.no_grad():
        return inference(PIL_image, input_batch, model, threshold)

In [9]:
import gradio as gr

app = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(),
    outputs=gr.Textbox(label="Image Diagnosis"),
    title="MatLLM - Defect Classification & Description",
    description="Upload an image and the model will classify/describe it.",
    theme="default",
    css="""
        /* Style the output box */
        .gr-output-text {
        background-color: #f5f5f5;
        padding: 10px;
        border: 1px solid #ddd;
        border-radius: 4px;
        font-size: 16px;
        }
    """
)

# Launch the Gradio application
app.launch(share=True)

Running on local URL:  http://127.0.0.1:7861

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


2024/04/16 12:58:22 [W] [service.go:132] login to server failed: dial tcp 44.237.78.176:7000: i/o timeout


