In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from lime.lime_image import LimeImageExplainer
from skimage.segmentation import mark_boundaries
from PIL import Image

# ResNet9 Model Definition
class ImageClassificationBase(torch.nn.Module):
    def training_step(self, batch):
        images, labels = batch
        out = self(images)                  # Generate predictions
        loss = torch.nn.functional.cross_entropy(out, labels) # Calculate loss
        return loss

    def validation_step(self, batch):
        images, labels = batch
        out = self(images)                   # Generate prediction
        loss = torch.nn.functional.cross_entropy(out, labels)  # Calculate loss
        acc = self.accuracy(out, labels)          # Calculate accuracy
        return {"val_loss": loss.detach(), "val_accuracy": acc}

    def validation_epoch_end(self, outputs):
        batch_losses = [x["val_loss"] for x in outputs]
        batch_accuracy = [x["val_accuracy"] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()       # Combine loss
        epoch_accuracy = torch.stack(batch_accuracy).mean()
        return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy} # Combine accuracies

    def epoch_end(self, epoch, result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['val_loss'], result['val_accuracy']))

    @staticmethod
    def accuracy(outputs, labels):
        _, preds = torch.max(outputs, dim=1)
        return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ResNet9(ImageClassificationBase):
    def _init_(self, in_channels, num_classes):
        super()._init_()
        self.conv1 = self.conv_block(in_channels, 64)
        self.conv2 = self.conv_block(64, 128, pool=True)
        self.res1 = torch.nn.Sequential(self.conv_block(128, 128), self.conv_block(128, 128))

        self.conv3 = self.conv_block(128, 256, pool=True)
        self.conv4 = self.conv_block(256, 512, pool=True)
        self.res2 = torch.nn.Sequential(self.conv_block(512, 512), self.conv_block(512, 512))

        self.classifier = torch.nn.Sequential(torch.nn.MaxPool2d(4),
                                              torch.nn.Flatten(),
                                              torch.nn.Linear(512, num_classes))

    @staticmethod
    def conv_block(in_channels, out_channels, pool=False):
        layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                  torch.nn.BatchNorm2d(out_channels),
                  torch.nn.ReLU(inplace=True)]
        if pool:
            layers.append(torch.nn.MaxPool2d(2))
        return torch.nn.Sequential(*layers)

    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out

# Load the trained model
model = ResNet9(3, 38)  # Ensure the class count matches your dataset
model.load_state_dict(torch.load('./plant-disease-model.pth'))
model.eval()

# Define the device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Load a sample dataset for visualization
data_dir = "../input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)"
test_dir = os.path.join(data_dir, "test")
test_dataset = ImageFolder(test_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# Define a helper function to predict using the model
def predict(input_images):
    input_images = torch.stack([transform(img) for img in input_images]).to(device)
    outputs = model(input_images)
    return outputs.detach().cpu().numpy()

# LIME Integration
explainer = LimeImageExplainer()

# Pick a random image from the test dataset
image, label = next(iter(test_loader))
image = image[0].permute(1, 2, 0).numpy()
original_image = image

# Define a prediction function compatible with LIME
def predict_lime(images):
    batch = torch.stack([transform(Image.fromarray((img * 255).astype(np.uint8))) for img in images]).to(device)
    outputs = model(batch)
    probabilities = torch.nn.functional.softmax(outputs, dim=1).detach().cpu().numpy()
    return probabilities

# Explain the prediction
explanation = explainer.explain_instance(
    original_image, 
    predict_lime, 
    top_labels=3, 
    hide_color=0, 
    num_samples=1000
)

# Visualize the explanation
plt.figure(figsize=(10, 10))
for i, label in enumerate(explanation.top_labels):
    temp, mask = explanation.get_image_and_mask(
        label,
        positive_only=True,
        num_features=10,
        hide_rest=False
    )
    plt.subplot(1, 3, i + 1)
    plt.title(f"Label {label}")
    plt.imshow(mark_boundaries(temp / 255.0, mask))
plt.show()

# saving the entire model to working directory
PATH = './plime.pth'
torch.save(model, PATH)