In [None]:
import os
import shutil
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from PIL import Image

# Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Data transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Function to extract leaf from image
def extract_leaf(image_path):
    try:
        img = Image.open(image_path).convert("RGB")
    except Exception as e:
        print("Error opening image:", image_path)
        return None
    img_np = np.array(img)
    
    # Convert to HSV and create mask for green regions
    hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
    lower_green = np.array([10, 20, 20])
    upper_green = np.array([85, 255, 255])
    mask = cv2.inRange(hsv, lower_green, upper_green)
    
    # Smooth and refine the mask
    mask = cv2.GaussianBlur(mask, (5, 5), 0)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    
    # Extract largest contour (leaf)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        leaf_mask = np.zeros_like(mask)
        cv2.drawContours(leaf_mask, [largest_contour], -1, 255, thickness=cv2.FILLED)
        leaf = cv2.bitwise_and(img_np, img_np, mask=leaf_mask)
        return Image.fromarray(leaf)
    else:
        return img  # If no leaf is found, return original image

# Custom Dataset to Load Both Original and Extracted Images
class HybridPlantDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.class_to_idx = {}

        # Read class names
        classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls: i for i, cls in enumerate(classes)}

        # Store paths for both original and extracted images
        for class_name in classes:
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                for img_name in os.listdir(class_path):
                    img_path = os.path.join(class_path, img_name)
                    self.data.append((img_path, self.class_to_idx[class_name]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        
        # Load original image
        original_image = Image.open(img_path).convert("RGB")

        # Load extracted leaf version
        leaf_image = extract_leaf(img_path)
        if leaf_image is None:
            leaf_image = original_image  # Fallback in case of error
        
        # Apply transformations
        if self.transform:
            original_image = self.transform(original_image)
            leaf_image = self.transform(leaf_image)
        
        # Concatenate both images along the channel dimension (6 channels: 3 original + 3 extracted)
        hybrid_image = torch.cat((original_image, leaf_image), dim=0)
        
        return hybrid_image, label

base_dir = "plant_disease_data"
train_dir = f"{base_dir}/train"
valid_dir = f"{base_dir}/valid"

# Initialize datasets
train_dataset = HybridPlantDiseaseDataset(train_dir, transform=transform)
valid_dataset = HybridPlantDiseaseDataset(valid_dir, transform=transform)

# Initialize data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)



# Load pre-trained ResNet18 model (modify first layer for 6-channel input)
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)  # Adjust for 6-channel input
model.fc = nn.Linear(model.fc.in_features, len(train_dataset.class_to_idx))
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # Training loop with progress tracking
epochs = 15
train_losses, valid_losses, train_accs, valid_accs = [], [], [], []

for epoch in range(epochs):
    model.train()
    correct, total, running_loss = 0, 0, 0
    total_batches = len(train_loader)  # Total number of batches
    
    print(f"\nEpoch {epoch+1}/{epochs}")

    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        # Progress tracking
        progress = (batch_idx + 1) / total_batches
        progress_bar = "=" * int(20 * progress) + " " * (20 - int(20 * progress))
        print(f"\rProcessing batch {batch_idx+1}/{total_batches} [{progress_bar}] {int(progress * 100)}% done", end="")

    train_loss = running_loss / total_batches
    train_acc = correct / total
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    print(f"\nTrain Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

    # Validation
    model.eval()
    correct, total, running_loss = 0, 0, 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    valid_loss = running_loss / len(valid_loader)
    valid_acc = correct / total
    valid_losses.append(valid_loss)
    valid_accs.append(valid_acc)

    print(f"Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc:.4f}")

# Plot results
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label="Train Loss")
plt.plot(valid_losses, label="Valid Loss")
plt.legend()
plt.title("Loss")

plt.subplot(1, 2, 2)
plt.plot(train_accs, label="Train Acc")
plt.plot(valid_accs, label="Valid Acc")
plt.legend()
plt.title("Accuracy")

plt.show()

model_path = "plant_disease_resnet18.pth"
torch.save(model.state_dict(), model_path)


In [None]:
# Optional: Move to Google Drive if mounted
from google.colab import drive
drive.mount('/content/drive')
torch.save(model.state_dict(), "/content/drive/My Drive/" + model_path)

from google.colab import files
files.download(model_path)

In [None]:
# Load model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, len(train_dataset.class_to_idx))
model.load_state_dict(torch.load("plant_disease_resnet18.pth", map_location=device))
model.to(device)
model.eval()
print("Model loaded successfully!")