In [None]:
import os
import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torchvision.models.detection.rpn import AnchorGenerator
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn

In [None]:
# Define your dataset class and transformations
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, class_names, transform=None):
        self.dataset_path = dataset_path
        self.class_names = class_names
        self.transform = transform
        self.data = self.load_dataset()

    def load_dataset(self):
        data = []
        for idx, class_name in enumerate(self.class_names):
            class_path = os.path.join(self.dataset_path, class_name)
            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                label = idx  # Numerical label based on class index
                data.append((image_path, label))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, label = self.data[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        return image, label


In [None]:
# Set paths and class names
train_dataset_path = "../dataset/train/"
test_dataset_path = "../dataset/test/"
class_names = ["Black Rot", "ESCA", "Healthy", "Leaf Blight"]

# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor
    # Add other transformations as needed
])

In [None]:
# Create dataset instances
train_dataset = CustomDataset(
    train_dataset_path, class_names, transform=transform)
test_dataset = CustomDataset(
    test_dataset_path, class_names, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
# Load pre-trained Faster R-CNN model
backbone = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    pretrained=True)
backbone.out_channels = 256  # Adjust the number of output channels
rpn_anchor_generator = AnchorGenerator(
    sizes=((32, 64, 128, 256, 512),),
    aspect_ratios=((0.5, 1.0, 2.0),) * 5
)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
    featmap_names=['0'], output_size=7, sampling_ratio=2
)
model = FasterRCNN(
    backbone,
    num_classes=len(class_names),
    rpn_anchor_generator=rpn_anchor_generator,
    box_roi_pool=roi_pooler
)


In [None]:
# Define optimizer and loss function
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)


In [None]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for images, targets in train_loader:
        optimizer.zero_grad()
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        losses.backward()
        optimizer.step()


In [None]:
# Inference
model.eval()
image_path = "path_to_your_image.jpg"  # Replace with your image path
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

with torch.no_grad():
    predictions = model(image_tensor)


In [None]:
# Process prediction for visualization
def visualize_prediction(image, predictions, class_names):
    boxes = predictions[0]['boxes'].detach().cpu().numpy()
    labels = predictions[0]['labels'].detach().cpu().numpy()
    scores = predictions[0]['scores'].detach().cpu().numpy()

    fig, ax = plt.subplots(1)
    ax.imshow(image)
    colors = plt.cm.hsv(np.linspace(0, 1, len(class_names))).tolist()

    for box, label, score in zip(boxes, labels, scores):
        box = [int(b) for b in box]
        color = colors[label]
        label_text = f"{class_names[label]}: {score:.2f}"
        ax.add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1],
                                   fill=False, edgecolor=color, linewidth=2))
        ax.text(box[0], box[1], label_text,
                bbox=dict(facecolor=color, alpha=0.5))

    plt.show()


In [None]:
# Visualize predictions
visualize_prediction(image, predictions, class_names)
