In [71]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms, models
from PIL import Image
import os
import numpy as np
from sklearn.model_selection import train_test_split


In [72]:
class PlantDiseaseDataset(data.Dataset):
    def __init__(self, image_paths, mask_paths, labels, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx] if self.mask_paths[idx] is not None else None
        label = self.labels[idx]
        
        # Load and combine the image and mask
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L") if mask_path else Image.new("L", image.size, color=0)  # Create a blank mask if not available

        # Resize the image and mask
        target_size = (256, 256)
        image = image.resize(target_size)
        mask = mask.resize(target_size)

        image = transforms.ToTensor()(image)
        mask = transforms.ToTensor()(mask)

        # Combine image and mask into a 4-channel tensor
        combined_image = torch.cat((image, mask), dim=0)  # 4 channels: 3 for RGB, 1 for mask

        if self.transform:
            combined_image = self.transform(combined_image)

        return combined_image, label


In [74]:
# Example image and mask paths (you should modify these based on your data)
image_dir = './Plant_Disease_Dataset_Unified/train/images/'
mask_dir = './dataset/masks/'

image_paths = []  # List of image file paths
mask_paths = []  # List of corresponding mask file paths
labels = []  # Labels (disease types or categories)

# Populate the lists with paths (you can modify this part based on your data structure)
for class_name in os.listdir(image_dir):
    class_folder = os.path.join(image_dir, class_name)
    for image_file in os.listdir(class_folder):
        image_paths.append(os.path.join(class_folder, image_file))
        mask_paths.append(os.path.join(mask_dir, class_name, image_file.replace(".jpg", ".png")))  # Adjust according to your mask naming
        labels.append(class_name)  # Assuming the class_name is the label

# Split data into train and validation sets
train_image_paths, val_image_paths, train_mask_paths, val_mask_paths, train_labels, val_labels = train_test_split(
    image_paths, mask_paths, labels, test_size=0.2, random_state=42)

# Define transformations (you can add more augmentations if needed)
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

# Create datasets
train_dataset = PlantDiseaseDataset(train_image_paths, train_mask_paths, train_labels, transform=transform)
val_dataset = PlantDiseaseDataset(val_image_paths, val_mask_paths, val_labels, transform=transform)

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False)


In [75]:
# Load pre-trained ResNet model
model = models.resnet18(pretrained=True)

# Modify the first convolution layer to accept 4 channels instead of 3
model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

# Replace the final fully connected layer for classification (assuming 3 classes here)
num_classes = len(set(labels))  # Update this based on your dataset
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)




RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), torch.tensor([int(label) for label in labels]).to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct/total:.2f}%")

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), torch.tensor([int(label) for label in labels]).to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct/total:.2f}%")

# Save the trained model
torch.save(model.state_dict(), "plant_disease_model.pth")


In [None]:
# Load the saved model
model = models.resnet18(pretrained=True)
model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load("plant_disease_model.pth"))
model.to(device)
model.eval()

# Predict on a new image
def predict_image(image_path, model, device):
    image = Image.open(image_path).convert("RGB")
    mask = Image.new("L", image.size, color=0)  # Assume no mask
    target_size = (256, 256)
    image = image.resize(target_size)
    mask = mask.resize(target_size)

    image = transforms.ToTensor()(image)
    mask = transforms.ToTensor()(mask)

    combined_image = torch.cat((image, mask), dim=0).unsqueeze(0).to(device)  # Add batch dimension

    with torch.no_grad():
        outputs = model(combined_image)
        _, predicted = torch.max(outputs.data, 1)

    return predicted.item()

# Example prediction
image_path = "path_to_new_leaf_image.jpg"
predicted_class = predict_image(image_path, model, device)
print(f"Predicted class: {predicted_class}")
