In [1]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from transformers import ViTModel, ViTConfig

# Load the trained model
class CustomModel(nn.Module):
    def __init__(self, vit_model):
        super(CustomModel, self).__init__()
        self.vit_model = vit_model
        self.fc = nn.Linear(config.hidden_size, 10)  # Assuming 10 classes

    def forward(self, x):
        outputs = self.vit_model(x).pooler_output
        x = self.fc(outputs)
        return x

# Configuration for ViT
config = ViTConfig.from_pretrained('google/vit-base-patch16-224')
vit_model = ViTModel(config)
model = CustomModel(vit_model)

# Load the best model
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# Define transforms for the test data
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the test dataset
test_path = 'D:/Publish Paper/Dataset plant/test/test'  # Update with your test data path
test_data = datasets.ImageFolder(test_path, transform=test_transforms)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

# List to store true labels and predictions
all_labels = []
all_preds = []

# Make predictions
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Compute confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=test_data.classes, yticklabels=test_data.classes)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()

# Print classification report
class_names = test_data.classes
print(classification_report(all_labels, all_preds, target_names=class_names))

# Compute overall accuracy
accuracy = accuracy_score(all_labels, all_preds)
print(f'Overall Accuracy: {accuracy * 100}%')

# Function to predict and visualize a single image
def predict_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image_tensor = test_transforms(image).unsqueeze(0)
    with torch.no_grad():
        output = model(image_tensor)
        _, predicted = torch.max(output.data, 1)
    class_name = class_names[predicted.item()]
    
    plt.imshow(image)
    plt.title(f'Predicted: {class_name}')
    plt.show()

# Test the function with a sample image
sample_image_path = '"D:/Publish Paper/Dataset plant/test/test/TomatoYellowCurlVirus2.JPG"'  # Update with the path to a sample image
predict_image(sample_image_path)


NameError: name 'nn' is not defined