In [3]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
from collections import OrderedDict

# Define CNN Feature Extractor
class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        base_model = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-2])

    def forward(self, x):
        return self.feature_extractor(x)  # (batch, 2048, H, W)

# Define Transformer Block
class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim=2048, heads=8, ff_dim=2048, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, dim)
        )
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        attn_out, _ = self.self_attn(x, x, x)
        x = self.norm1(x + attn_out)
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

# Define Full CNN + Transformer Model
class CNN_Transformer_Model(nn.Module):
    def __init__(self, num_classes=38):
        super(CNN_Transformer_Model, self).__init__()
        self.cnn = CNNFeatureExtractor()
        self.transformer = TransformerEncoderBlock(dim=2048)
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.cnn(x)  # Shape: (batch, 2048, H, W)
        x = x.flatten(2).permute(2, 0, 1)  # Reshape to (seq_len, batch, 2048)
        x = self.transformer(x)  # Transformer processes this shape
        x = x.mean(dim=0)  # Global Average Pooling
        x = self.fc(x)  # Final classification layer
        return x

# Load Model
model_path = "C:/Users/BIBHAV KUMAR/Desktop/COMPLETE_HYBRID_MODEL_PLANTDISESASE/model/cnn_transformer_model.pth"
num_classes = 38  # Set this based on your dataset
model = CNN_Transformer_Model(num_classes)

# Load state dict and fix DataParallel keys if needed
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    new_state_dict[k.replace("module.", "")] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()

print("✅ Model loaded successfully!")

# Image Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Prediction Function
def predict_image(image_path, model, class_names):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension

    model.eval()
    with torch.no_grad():
        output = model(image)
    
    probabilities = torch.nn.functional.softmax(output, dim=1)
    predicted_class_idx = torch.argmax(probabilities, dim=1).item()
    return class_names[predicted_class_idx]

# Define Class Names
class_names = [ "Apple___Apple_scab",
    "Apple___Black_rot",
    "Apple___Cedar_apple_rust",
    "Apple___healthy",
    "Blueberry___healthy",
    "Cherry___healthy",
    "Cherry___Powdery_mildew",
    "Corn___Cercospora_leaf_spot Gray_leaf_spot",
    "Corn___Common_rust",
    "Corn___healthy",
    "Corn___Northern_Leaf_Blight",
    "Grape___Black_rot",
    "Grape___Esca_(Black_Measles)",
    "Grape___healthy",
    "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
    "Orange___Haunglongbing_(Citrus_greening)",
    "Peach___Bacterial_spot",
    "Peach___healthy",
    "Pepper,_bell___Bacterial_spot",
    "Pepper,_bell___healthy",
    "Potato___Early_blight",
    "Potato___healthy",
    "Potato___Late_blight",
    "Raspberry___healthy",
    "Soybean___healthy",
    "Squash___Powdery_mildew",
    "Strawberry___healthy",
    "Strawberry___Leaf_scorch",
    "Tomato___Bacterial_spot",
    "Tomato___Early_blight",
    "Tomato___healthy",
    "Tomato___Late_blight",
    "Tomato___Leaf_Mold",
    "Tomato___Septoria_leaf_spot",
    "Tomato___Spider_mites Two-spotted_spider_mite",
    "Tomato___Target_Spot",
    "Tomato___Tomato_mosaic_virus",
    "Tomato___Tomato_Yellow_Leaf_Curl_Virus"]

# Example Prediction
image_path = "img2.JPG"
predicted_class = predict_image(image_path, model, class_names)
print(f"🔹 Final Prediction: {predicted_class}")


✅ Model loaded successfully!
🔹 Final Prediction: Grape___Black_rot


In [1]:
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.models as models
from collections import OrderedDict

# Define CNN Feature Extractor
class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        base_model = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-2])

    def forward(self, x):
        return self.feature_extractor(x)

# Define Transformer Block
class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim=2048, heads=8, ff_dim=2048, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, dim)
        )
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        attn_out, _ = self.self_attn(x, x, x)
        x = self.norm1(x + attn_out)
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

# Define CNN + Transformer Model
class CNN_Transformer_Model(nn.Module):
    def __init__(self, num_classes=38):
        super(CNN_Transformer_Model, self).__init__()
        self.cnn = CNNFeatureExtractor()
        self.transformer = TransformerEncoderBlock(dim=2048)
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.cnn(x)
        x = x.flatten(2).permute(2, 0, 1)
        x = self.transformer(x)
        x = x.mean(dim=0)
        x = self.fc(x)
        return x

# Load Model
model_path = "C:/Users/BIBHAV KUMAR/Desktop/COMPLETE_HYBRID_MODEL_PLANTDISESASE/model/cnn_transformer_model.pth"
num_classes = 38
model = CNN_Transformer_Model(num_classes)

state_dict = torch.load(model_path, map_location=torch.device('cpu'))
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    new_state_dict[k.replace("module.", "")] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()

# Define Class Names
class_names = [
    "Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy",
    "Blueberry___healthy", "Cherry___healthy", "Cherry___Powdery_mildew",
    "Corn___Cercospora_leaf_spot Gray_leaf_spot", "Corn___Common_rust", "Corn___healthy", "Corn___Northern_Leaf_Blight",
    "Grape___Black_rot", "Grape___Esca_(Black_Measles)", "Grape___healthy", "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
    "Orange___Haunglongbing_(Citrus_greening)", "Peach___Bacterial_spot", "Peach___healthy",
    "Pepper,_bell___Bacterial_spot", "Pepper,_bell___healthy", "Potato___Early_blight", "Potato___healthy",
    "Potato___Late_blight", "Raspberry___healthy", "Soybean___healthy", "Squash___Powdery_mildew",
    "Strawberry___healthy", "Strawberry___Leaf_scorch", "Tomato___Bacterial_spot", "Tomato___Early_blight",
    "Tomato___healthy", "Tomato___Late_blight", "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot",
    "Tomato___Spider_mites Two-spotted_spider_mite", "Tomato___Target_Spot", "Tomato___Tomato_mosaic_virus",
    "Tomato___Tomato_Yellow_Leaf_Curl_Virus"
]

# Image Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Prediction Function
def predict_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)

    model.eval()
    with torch.no_grad():
        output = model(image)
    
    probabilities = torch.nn.functional.softmax(output, dim=1)
    predicted_class_idx = torch.argmax(probabilities, dim=1).item()
    return class_names[predicted_class_idx]

# Tkinter GUI
class DiseaseDetectionApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Plant Disease Detection")
        self.root.geometry("600x500")

        self.label = tk.Label(root, text="🌱 Plant Disease Detection", font=("Arial", 16, "bold"))
        self.label.pack(pady=10)

        self.upload_button = tk.Button(root, text="📂 Upload Image", command=self.upload_image, font=("Arial", 12))
        self.upload_button.pack()

        self.canvas = tk.Canvas(root, width=300, height=300)
        self.canvas.pack()

        self.result_label = tk.Label(root, text="", font=("Arial", 14, "bold"), fg="green")
        self.result_label.pack()

        self.predict_button = tk.Button(root, text="🔍 Predict", command=self.predict_disease, state=tk.DISABLED, font=("Arial", 12))
        self.predict_button.pack(pady=10)

        self.image_path = None

    def upload_image(self):
        file_path = filedialog.askopenfilename(filetypes=[("Image Files", "*.jpg;*.jpeg;*.png")])
        if file_path:
            self.image_path = file_path
            image = Image.open(file_path)
            image = image.resize((300, 300))
            self.photo = ImageTk.PhotoImage(image)
            self.canvas.create_image(150, 150, image=self.photo)
            self.predict_button.config(state=tk.NORMAL)

    def predict_disease(self):
        if self.image_path:
            predicted_class = predict_image(self.image_path)
            self.result_label.config(text=f"🌿 Prediction: {predicted_class}")

# Run Tkinter App
root = tk.Tk()
app = DiseaseDetectionApp(root)
root.mainloop()


