In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification, ViTFeatureExtractor
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Ensure we use a GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class TomatoLeafDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

def load_data(data_dir):
    classes = os.listdir(data_dir)
    class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
    
    image_paths = []
    labels = []

    for cls_name in classes:
        cls_dir = os.path.join(data_dir, cls_name)
        for root, _, files in os.walk(cls_dir):
            for file in files:
                if file.endswith(('.jpg', '.jpeg', '.png')):
                    image_paths.append(os.path.join(root, file))
                    labels.append(class_to_idx[cls_name])
    
    return image_paths, labels, class_to_idx

def plot_loss_accuracy(train_losses, val_losses, val_accuracies):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies, label='Validation Accuracy', color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

def plot_confusion_matrix(cm, classes):
    plt.figure(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.show()

data_dir = 'D:\\Publish Paper\\Dataset plant\\PlantVillage\\train'
image_paths, labels, class_to_idx = load_data(data_dir)

# Split the dataset into training and validation sets
train_paths, val_paths, train_labels, val_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42)

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = TomatoLeafDataset(train_paths, train_labels, transform=transform)
val_dataset = TomatoLeafDataset(val_paths, val_labels, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
num_classes = len(class_to_idx)

# Load pre-trained ViT model
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=num_classes)
model = model.to(device)
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

def train(model, train_loader, val_loader, epochs):
    train_losses = []
    val_losses = []
    val_accuracies = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_losses.append(train_loss / len(train_loader))
        val_loss, val_accuracy = evaluate(model, val_loader)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}')

    plot_loss_accuracy(train_losses, val_losses, val_accuracies)

def evaluate(model, val_loader):
    model.eval()
    val_loss = 0
    correct = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = correct / len(val_loader.dataset)
    return val_loss, val_accuracy

def get_predictions(model, data_loader):
    model.eval()
    predictions = []
    ground_truths = []

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            preds = torch.argmax(outputs, dim=1)
            predictions.extend(preds.cpu().numpy())
            ground_truths.extend(labels.cpu().numpy())

    return predictions, ground_truths

# Train the model
train(model, train_loader, val_loader, epochs=10)

# Evaluate on validation set
val_loss, val_accuracy = evaluate(model, val_loader)
print(f'Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

# Get predictions and ground truths
val_predictions, val_ground_truths = get_predictions(model, val_loader)

# Calculate and plot confusion matrix
class_names = list(class_to_idx.keys())
cm = confusion_matrix(val_ground_truths, val_predictions)
plot_confusion_matrix(cm, class_names)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/4 [00:00<?, ?it/s]


AttributeError: 'tuple' object has no attribute 'to'