<a href="https://colab.research.google.com/github/OmdenaAI/KenyaChapter_EarlyDetectionofCropDiseases/blob/main/VITmodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [32]:
# Import required libraries
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
from PIL import Image

In [33]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [34]:
# Parameters
batch_size = 16  # Adjust based on your GPU memory
num_epochs = 10   # Reduce for quick testing
learning_rate = 5e-5  # Typical learning rate for ViT fine-tuning
image_size = 224  # Resize all images to 224x224
num_classes = 13  # Number of classes (diseases + healthy)

In [35]:
# Define class labels (ensure this matches your dataset)
class_labels = [
    'Bacterial spot', 'Black mold', 'Gray spot', 'Late blight',
    'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold',
    'Tomato___Septoria_leaf_spot', 'Tomato___Target_Spot',
    'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus',
    'Tomato___healthy', 'healthy'
]

In [36]:
# Define transformations for the dataset
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),  # Resize all images
    transforms.ToTensor(),  # Convert to Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
])

In [37]:
# Load the dataset using ImageFolder
dataset_path = "/content/drive/MyDrive/tomato-images"  # Update this path
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

In [38]:
# Split dataset into training and validation sets (80% train, 20% val)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [39]:
# Create DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


In [40]:
# Load the pretrained ViT model and modify for classification
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=num_classes
)
model.to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [41]:
# Define the optimizer, loss function, and learning rate scheduler
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)  # Optional scheduler

In [17]:
# Enable mixed precision training
from torch.cuda.amp import GradScaler, autocast
scaler = torch.amp.GradScaler('cuda')


In [42]:
# Training Loop
scaler = torch.amp.GradScaler()  # Proper initialization

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()  # Reset gradients

        # Enable mixed precision
        with torch.amp.autocast(device_type='cuda'):
            outputs = model(pixel_values=inputs)
            loss = criterion(outputs.logits, labels)

        # Backpropagation with scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {running_loss / len(train_loader)}")

    # Step the learning rate scheduler
    scheduler.step()


Training Epoch 1/10: 100%|██████████| 687/687 [02:16<00:00,  5.03it/s]


Epoch 1 Loss: 0.40938722554816603


Training Epoch 2/10: 100%|██████████| 687/687 [02:01<00:00,  5.67it/s]


Epoch 2 Loss: 0.07122949768383777


Training Epoch 3/10: 100%|██████████| 687/687 [02:00<00:00,  5.71it/s]


Epoch 3 Loss: 0.03842757437161266


Training Epoch 4/10: 100%|██████████| 687/687 [02:00<00:00,  5.72it/s]


Epoch 4 Loss: 0.0157784520278331


Training Epoch 5/10: 100%|██████████| 687/687 [01:59<00:00,  5.73it/s]


Epoch 5 Loss: 0.00897352616851702


Training Epoch 6/10: 100%|██████████| 687/687 [01:59<00:00,  5.75it/s]


Epoch 6 Loss: 0.005936562214141611


Training Epoch 7/10: 100%|██████████| 687/687 [01:59<00:00,  5.75it/s]


Epoch 7 Loss: 0.004405399217111705


Training Epoch 8/10: 100%|██████████| 687/687 [02:00<00:00,  5.70it/s]


Epoch 8 Loss: 0.0034379644198763407


Training Epoch 9/10: 100%|██████████| 687/687 [01:59<00:00,  5.75it/s]


Epoch 9 Loss: 0.002766118219340548


Training Epoch 10/10: 100%|██████████| 687/687 [01:59<00:00,  5.74it/s]

Epoch 10 Loss: 0.002277648832013199





In [43]:
# Step the learning rate scheduler
scheduler.step()

In [44]:
# Save the model and feature extractor
model.save_pretrained("./vit-tomato-disease")
print("Model saved!")

Model saved!


In [47]:
# Load the saved model and feature extractor
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image

# Load model and set to evaluation mode
model = ViTForImageClassification.from_pretrained("./vit-tomato-disease").to(device)
model.eval()

# Define the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# Define test image path
test_image_path = "/content/drive/MyDrive/OIP (3).jpeg"  # Update this path
image = Image.open(test_image_path)

# Preprocess the image
inputs = feature_extractor(images=image, return_tensors="pt").to(device)

# Make prediction
with torch.no_grad():
    outputs = model(**inputs)
    predicted_label_index = outputs.logits.argmax(-1).item()

# Map the prediction to the disease name
predicted_disease = class_labels[predicted_label_index]
print(f"Predicted Disease: {predicted_disease}")




Predicted Disease: Black mold


In [46]:
from sklearn.metrics import accuracy_score

# Evaluate model on validation dataset
def evaluate_model(model, val_loader):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)

            # Make predictions
            outputs = model(pixel_values=inputs)
            preds = outputs.logits.argmax(dim=1)  # Get predicted class indices

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    return accuracy

# Call the evaluation function and print accuracy
val_accuracy = evaluate_model(model, val_loader)
print(f"Validation Accuracy: {val_accuracy * 100:.2f}%")


Evaluating: 100%|██████████| 172/172 [00:31<00:00,  5.49it/s]

Validation Accuracy: 99.45%



