In [1]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification, ViTFeatureExtractor
from sklearn.model_selection import train_test_split

# Ensure we use a GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
class TomatoLeafDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

def load_data(data_dir):
    classes = os.listdir(data_dir)
    class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
    
    image_paths = []
    labels = []

    for cls_name in classes:
        cls_dir = os.path.join(data_dir, cls_name)
        for root, _, files in os.walk(cls_dir):
            for file in files:
                if file.endswith(('.jpg', '.jpeg', '.png')):
                    image_paths.append(os.path.join(root, file))
                    labels.append(class_to_idx[cls_name])
    
    return image_paths, labels, class_to_idx

data_dir = 'D:\\Publish Paper\\Dataset plant\\PlantDiseasesDataset\\train'
image_paths, labels, class_to_idx = load_data(data_dir)

# Split the dataset into training and validation sets
train_paths, val_paths, train_labels, val_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42)

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = TomatoLeafDataset(train_paths, train_labels, transform=transform)
val_dataset = TomatoLeafDataset(val_paths, val_labels, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [3]:
num_classes = len(class_to_idx)

# Load pre-trained ViT model
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=num_classes)
model = model.to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

def train(model, train_loader, val_loader, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        val_loss, val_accuracy = evaluate(model, val_loader)
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}')

def evaluate(model, val_loader):
    model.eval()
    val_loss = 0
    correct = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = correct / len(val_loader.dataset)
    return val_loss, val_accuracy

# Train the model
train(model, train_loader, val_loader, epochs=10)


100%|██████████| 4/4 [01:07<00:00, 16.77s/it]


Epoch 1/10, Train Loss: 1.4754558950662613, Val Loss: 0.7114087343215942, Val Accuracy: 1.0


100%|██████████| 4/4 [00:58<00:00, 14.57s/it]


Epoch 2/10, Train Loss: 0.5940137356519699, Val Loss: 0.41889652609825134, Val Accuracy: 1.0


100%|██████████| 4/4 [00:58<00:00, 14.75s/it]


Epoch 3/10, Train Loss: 0.3943185806274414, Val Loss: 0.32043057680130005, Val Accuracy: 1.0


100%|██████████| 4/4 [01:06<00:00, 16.63s/it]


Epoch 4/10, Train Loss: 0.31611649692058563, Val Loss: 0.2702229619026184, Val Accuracy: 1.0


100%|██████████| 4/4 [01:09<00:00, 17.29s/it]


Epoch 5/10, Train Loss: 0.27212150767445564, Val Loss: 0.23731723427772522, Val Accuracy: 1.0


100%|██████████| 4/4 [01:07<00:00, 16.84s/it]


Epoch 6/10, Train Loss: 0.24202125519514084, Val Loss: 0.2136065661907196, Val Accuracy: 1.0


100%|██████████| 4/4 [01:05<00:00, 16.33s/it]


Epoch 7/10, Train Loss: 0.21914341673254967, Val Loss: 0.19432446360588074, Val Accuracy: 1.0


100%|██████████| 4/4 [01:06<00:00, 16.53s/it]


Epoch 8/10, Train Loss: 0.2021772414445877, Val Loss: 0.17806749045848846, Val Accuracy: 1.0


100%|██████████| 4/4 [01:04<00:00, 16.10s/it]


Epoch 9/10, Train Loss: 0.18414605781435966, Val Loss: 0.164119690656662, Val Accuracy: 1.0


100%|██████████| 4/4 [01:04<00:00, 16.13s/it]


Epoch 10/10, Train Loss: 0.169449083507061, Val Loss: 0.1517661213874817, Val Accuracy: 1.0


In [5]:
val_loss, val_accuracy = evaluate(model, val_loader)
print(f'Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

Validation Loss: 0.1517661213874817, Validation Accuracy: 1.0
