In [0]:
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Load the segmentation model and classification model
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Define the architecture here
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.final_conv = nn.Conv2d(64, 2, kernel_size=1)

    def forward(self, x):
        # Implement the forward pass here
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.final_conv(x)
        return x
segmentation_model = UNet()
segmentation_model.load_state_dict(torch.load('model (1).pt', map_location=torch.device('cpu')))
segmentation_model.eval()
class VIT(nn.Module):
    def __init__(self, config=ViTConfig(), num_labels=2, model_checkpoint='google/vit-base-patch16-224-in21k'):
        super(VIT, self).__init__()
        self.vit = ViTModel.from_pretrained(model_checkpoint, add_pooling_layer=False)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
        self.pooler_activation = nn.Tanh()

    def forward(self, x):
        x = self.vit(x)['last_hidden_state']
        x = self.pooler_activation(self.pooler(x[:, 0, :]))
        output = self.classifier(x)
        return output
classification_model = VIT()
classification_model.load_state_dict(torch.load('weed_detection_model.pth', map_location=torch.device('cpu')))
classification_model.eval()

# Define the class names
class_names = ["non-weed", "weed-images"]

# Define the transformations to apply to the input images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Define the function to preprocess the input image
def preprocess_image(image):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    image = transform(image)
    image = image.unsqueeze(0)
    return image

# Define the function to perform the prediction
def predict_image(image):
    # Preprocess the image
    image_tensor = preprocess_image(image)

    # Perform classification using the Vision Transformer model
    with torch.no_grad():
        classification_output = classification_model(image_tensor)
    _, predicted_classes = torch.topk(classification_output, k=2, dim=1)
    confidences = torch.softmax(classification_output, dim=1)[0, predicted_classes]

    # Extract the top predicted class and its confidence
    top_predicted_class = predicted_classes[0, 0].item()
    top_predicted_class_name = class_names[top_predicted_class]
    top_confidence = confidences[0, 0].item()

    # Check if both weed and non-weed classes are present
    if 0 in predicted_classes and 1 in predicted_classes:
        second_predicted_class = predicted_classes[0, 1].item()
        second_predicted_class_name = class_names[second_predicted_class]
        second_confidence = confidences[0, 1].item()
    else:
        second_predicted_class = None
        second_predicted_class_name = None
        second_confidence = None

    # Perform segmentation using the U-Net model
    with torch.no_grad():
        segmentation_output = segmentation_model(image_tensor)

    # Process the segmentation output
    binary_mask = (segmentation_output > 0.5).float()
    binary_mask = binary_mask.argmax(dim=1).squeeze().cpu().numpy()
    blue_color = np.array([0, 0, 255], dtype=np.uint8)
    segmented_image = image_tensor.squeeze().permute(1, 2, 0)
    segmented_image = segmented_image.cpu().numpy()
    segmented_image[binary_mask == 1] = blue_color
    segmented_image = Image.fromarray(segmented_image.astype(np.uint8))

    # Return the predicted classes, confidences, and segmented image
    return top_predicted_class_name, top_confidence, second_predicted_class_name, second_confidence, segmented_image

# Define the inputs and outputs for the gradio interface
inputs = gr.Image()
outputs = [
    gr.Textbox(label="Top Predicted Class"),
    gr.Textbox(label="Top Confidence"),
    gr.Textbox(label="Second Predicted Class"),
    gr.Textbox(label="Second Confidence"),
    gr.Image(label="Segmented Image")
]

# Create the gradio interface
gr.Interface(fn=predict_image, inputs=inputs, outputs=outputs).launch()