In [13]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

# Ensure we use a GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [14]:
class TomatoLeafDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

def load_data(data_dir):
    classes = os.listdir(data_dir)
    class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}

    image_paths = []
    labels = []

    for cls_name in classes:
        cls_dir = os.path.join(data_dir, cls_name)
        for root, _, files in os.walk(cls_dir):
            for file in files:
                if file.endswith(('.jpg', '.jpeg', '.png')):
                    image_paths.append(os.path.join(root, file))
                    labels.append(class_to_idx[cls_name])

    return image_paths, labels, class_to_idx

data_dir = 'D:\\Publish Paper\\Dataset plant\\PlantDiseasesDataset\\train'
val_data_dir = 'D:\\Publish Paper\\Dataset plant\\PlantDiseasesDataset\\valid'

train_image_paths, train_labels, class_to_idx = load_data(data_dir)
val_image_paths, val_labels, _ = load_data(val_data_dir)
class_names = list(class_to_idx.keys())

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = TomatoLeafDataset(train_image_paths, train_labels, transform=transform)
val_dataset = TomatoLeafDataset(val_image_paths, val_labels, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
num_classes = len(class_to_idx)

In [15]:
# Load pre-trained ViT model
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=num_classes)
model = model.to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

def train(model, train_loader, val_loader, epochs):
    train_losses = []
    val_losses = []
    val_accuracies = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_losses.append(train_loss / len(train_loader))
        
        val_loss, val_accuracy, _, _ = evaluate(model, val_loader)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}')

    return train_losses, val_losses, val_accuracies

def evaluate(model, val_loader):
    model.eval()
    val_loss = 0
    correct = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_loader)
    val_accuracy = correct / len(val_loader.dataset)
    
    return val_loss, val_accuracy, all_preds, all_labels

# Train the model
epochs = 10
train_losses, val_losses, val_accuracies = train(model, train_loader, val_loader, epochs)

  0%|          | 0/5 [00:00<?, ?it/s]

In [None]:
# Evaluate the model on the validation set and get predictions
val_loss, val_accuracy, all_preds, all_labels = evaluate(model, val_loader)
print(f'Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

In [None]:
# Plot training & validation loss and accuracy
epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, label='Training Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Validation Accuracy')

plt.show()

# Plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
fig, ax = plt.subplots(figsize=(10, 10))
disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical')
plt.title('Confusion Matrix')
plt.show()